taut-rpc 0.1.0

End-to-end type-safe RPC between Rust (axum) and TypeScript clients.
Documentation
//! Phase 2 integration tests for taut-rpc.
//!
//! These tests cover the new Phase 2 surface area:
//!
//! - `#[derive(taut_rpc::TautError)]` — default `snake_case` `code` from variant
//!   names, default `http_status` of 400, and per-variant
//!   `#[taut(code = "...", status = N)]` overrides.
//! - The updated `#[rpc]` error path that consults
//!   `<E as TautError>::http_status` / `<E as TautError>::code` to build the
//!   wire envelope.
//! - `Router::layer(layer)` for composing tower middleware.
//!
//! HTTP-level assertions mirror `tests/integration.rs` (use `oneshot` and
//! `axum::body::to_bytes`).

use axum::body::Body;
use http::Request as HttpRequest;
use tower::ServiceExt;

// ---------------------------------------------------------------------------
// 1. Default snake_case `code` and default 400 `http_status`.
// ---------------------------------------------------------------------------

#[test]
fn derive_taut_error_emits_default_snake_case_codes() {
    use taut_rpc::TautError;

    #[derive(serde::Serialize, taut_rpc::TautError, Debug)]
    #[serde(tag = "code", content = "payload", rename_all = "snake_case")]
    #[allow(dead_code)]
    enum E {
        NotFound,
        BadRequest,
        ServerError,
    }

    assert_eq!(<E as TautError>::code(&E::NotFound), "not_found");
    assert_eq!(<E as TautError>::code(&E::BadRequest), "bad_request");
    assert_eq!(<E as TautError>::code(&E::ServerError), "server_error");

    // Default `http_status` is 400 for every variant when no override is set.
    assert_eq!(<E as TautError>::http_status(&E::NotFound), 400);
    assert_eq!(<E as TautError>::http_status(&E::BadRequest), 400);
    assert_eq!(<E as TautError>::http_status(&E::ServerError), 400);
}

// ---------------------------------------------------------------------------
// 2. Per-variant `#[taut(status = N)]` override.
// ---------------------------------------------------------------------------

#[test]
fn derive_taut_error_honors_per_variant_status_override() {
    use taut_rpc::TautError;

    #[derive(serde::Serialize, taut_rpc::TautError, Debug)]
    #[serde(tag = "code", content = "payload", rename_all = "snake_case")]
    #[allow(dead_code)]
    enum E {
        #[taut(status = 401)]
        Unauth,
        #[taut(status = 429)]
        RateLimited,
        Generic,
    }

    assert_eq!(<E as TautError>::http_status(&E::Unauth), 401);
    assert_eq!(<E as TautError>::http_status(&E::RateLimited), 429);
    // No override → default 400.
    assert_eq!(<E as TautError>::http_status(&E::Generic), 400);
}

// ---------------------------------------------------------------------------
// 3. Per-variant `#[taut(code = "...")]` override.
// ---------------------------------------------------------------------------

#[test]
fn derive_taut_error_honors_code_override() {
    use taut_rpc::TautError;

    #[derive(serde::Serialize, taut_rpc::TautError, Debug)]
    #[serde(tag = "code", content = "payload", rename_all = "snake_case")]
    #[allow(dead_code)]
    enum E {
        #[taut(code = "auth_required")]
        Unauth,
        Plain,
    }

    assert_eq!(<E as TautError>::code(&E::Unauth), "auth_required");
    // No override → default snake_case'd variant name.
    assert_eq!(<E as TautError>::code(&E::Plain), "plain");
}

// ---------------------------------------------------------------------------
// 4. Derive works on variants that carry payloads (tuple + named).
// ---------------------------------------------------------------------------

#[test]
fn derive_taut_error_works_on_variants_with_payloads() {
    use taut_rpc::TautError;

    #[derive(serde::Serialize, taut_rpc::TautError, Debug)]
    #[serde(tag = "code", content = "payload", rename_all = "snake_case")]
    #[allow(dead_code)]
    enum E {
        Tuple(String),
        Named { reason: String },
    }

    // The match arms generated by the derive must ignore payloads — these
    // calls compile and return the snake_case'd variant name.
    let t = E::Tuple("hi".into());
    let n = E::Named { reason: "x".into() };
    assert_eq!(<E as TautError>::code(&t), "tuple");
    assert_eq!(<E as TautError>::code(&n), "named");

    // Default status applies regardless of payload.
    assert_eq!(<E as TautError>::http_status(&t), 400);
    assert_eq!(<E as TautError>::http_status(&n), 400);
}

// ---------------------------------------------------------------------------
// 5. `#[rpc]` consults `TautError` for both HTTP status and the body `code`.
// ---------------------------------------------------------------------------

#[derive(serde::Serialize, taut_rpc::Type, taut_rpc::TautError, Debug)]
#[serde(tag = "code", content = "payload", rename_all = "snake_case")]
#[allow(dead_code)]
enum RpcE {
    #[taut(status = 401)]
    Unauth,
}

#[taut_rpc::rpc]
#[allow(clippy::unnecessary_wraps, clippy::unused_async)] // `#[rpc]` requires `async fn` signatures
async fn protected_proc() -> Result<i32, RpcE> {
    Err(RpcE::Unauth)
}

#[tokio::test]
async fn rpc_macro_uses_taut_error_for_status_and_code() {
    let app = taut_rpc::Router::new()
        .procedure(__taut_proc_protected_proc())
        .into_axum();

    let response = app
        .oneshot(
            HttpRequest::builder()
                .method("POST")
                .uri("/rpc/protected_proc")
                .header("content-type", "application/json")
                .body(Body::from(r#"{"input":null}"#))
                .unwrap(),
        )
        .await
        .unwrap();

    // The `#[taut(status = 401)]` annotation must drive the HTTP status.
    assert_eq!(response.status(), http::StatusCode::UNAUTHORIZED);

    let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
        .await
        .unwrap();
    let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();

    // Body shape: `{"err": {"code":"unauth","payload":...}}`.
    // - "unauth" comes from snake_case'ing `Unauth` (which is already lower).
    // - The derive's `<RpcE as TautError>::code` and serde's tag agree.
    assert_eq!(v["err"]["code"], serde_json::json!("unauth"));
    // The payload field exists because of `#[serde(content = "payload")]`,
    // even when the variant is unit (serde emits `null`).
    assert!(
        v["err"].get("payload").is_some(),
        "expected `payload` key in err envelope, got {v}"
    );
}

// ---------------------------------------------------------------------------
// 6. `Router::layer` applies a tower middleware to all responses.
// ---------------------------------------------------------------------------

#[taut_rpc::rpc]
#[allow(clippy::unused_async)] // `#[rpc]` requires `async fn` signatures
async fn echo_layer() -> String {
    "hi".to_string()
}

#[tokio::test]
async fn router_layer_applies_middleware() {
    use axum::middleware::{from_fn, Next};
    use axum::response::Response;
    use http::HeaderValue;

    async fn add_header(req: axum::extract::Request, next: Next) -> Response {
        let mut res = next.run(req).await;
        res.headers_mut()
            .insert("x-taut-test", HeaderValue::from_static("hit"));
        res
    }

    let app = taut_rpc::Router::new()
        .procedure(__taut_proc_echo_layer())
        .layer(from_fn(add_header))
        .into_axum();

    let response = app
        .oneshot(
            HttpRequest::builder()
                .method("POST")
                .uri("/rpc/echo_layer")
                .header("content-type", "application/json")
                .body(Body::from(r#"{"input":null}"#))
                .unwrap(),
        )
        .await
        .unwrap();

    assert_eq!(response.status(), http::StatusCode::OK);
    assert_eq!(
        response
            .headers()
            .get("x-taut-test")
            .map(|h| h.to_str().unwrap()),
        Some("hit"),
        "expected `x-taut-test: hit` header injected by Router::layer"
    );
}

// ---------------------------------------------------------------------------
// 7. Multiple `Router::layer` calls compose — both layers' effects are visible.
// ---------------------------------------------------------------------------

#[tokio::test]
async fn router_multiple_layers_compose() {
    use axum::middleware::{from_fn, Next};
    use axum::response::Response;
    use http::HeaderValue;

    async fn add_a(req: axum::extract::Request, next: Next) -> Response {
        let mut res = next.run(req).await;
        res.headers_mut()
            .insert("x-a", HeaderValue::from_static("1"));
        res
    }

    async fn add_b(req: axum::extract::Request, next: Next) -> Response {
        let mut res = next.run(req).await;
        res.headers_mut()
            .insert("x-b", HeaderValue::from_static("2"));
        res
    }

    let app = taut_rpc::Router::new()
        .procedure(__taut_proc_echo_layer())
        .layer(from_fn(add_a))
        .layer(from_fn(add_b))
        .into_axum();

    let response = app
        .oneshot(
            HttpRequest::builder()
                .method("POST")
                .uri("/rpc/echo_layer")
                .header("content-type", "application/json")
                .body(Body::from(r#"{"input":null}"#))
                .unwrap(),
        )
        .await
        .unwrap();

    assert_eq!(response.status(), http::StatusCode::OK);
    assert_eq!(
        response.headers().get("x-a").map(|h| h.to_str().unwrap()),
        Some("1"),
        "expected `x-a` header from first layer"
    );
    assert_eq!(
        response.headers().get("x-b").map(|h| h.to_str().unwrap()),
        Some("2"),
        "expected `x-b` header from second layer"
    );
}

// ---------------------------------------------------------------------------
// 8. `taut_rpc::TautError` resolves both as a trait import AND as a derive
//    macro (compile-only check — if rustc accepts this file, the names work).
// ---------------------------------------------------------------------------

#[test]
fn taut_error_re_export_resolves() {
    // Trait position.
    use taut_rpc::TautError;

    // Derive position — fully qualified path through the same name.
    #[derive(serde::Serialize, taut_rpc::TautError, Debug)]
    #[serde(tag = "code", content = "payload", rename_all = "snake_case")]
    #[allow(dead_code)]
    enum X {
        A,
    }

    // Manual `impl TautError for ...` using the trait import — exercises
    // both the trait alias and the trait method default.
    struct Y;
    impl serde::Serialize for Y {
        fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
            s.serialize_str("y")
        }
    }
    impl TautError for Y {
        fn code(&self) -> &'static str {
            "y"
        }
    }

    let _: &'static str = <X as TautError>::code(&X::A);
    let _: u16 = <Y as TautError>::http_status(&Y);
    let _: &'static str = <Y as TautError>::code(&Y);
}

// ---------------------------------------------------------------------------
// 9. The shipped `StandardError::Unauthenticated` maps to HTTP 401 and code
//    "unauthenticated".
// ---------------------------------------------------------------------------

#[test]
fn standard_error_unauthenticated_has_401_status() {
    use taut_rpc::{StandardError, TautError};

    let e = StandardError::Unauthenticated;
    assert_eq!(e.http_status(), 401);
    assert_eq!(e.code(), "unauthenticated");
}