use axum::body::Body;
use http::Request as HttpRequest;
use tower::ServiceExt;
#[test]
fn derive_taut_error_emits_default_snake_case_codes() {
use taut_rpc::TautError;
#[derive(serde::Serialize, taut_rpc::TautError, Debug)]
#[serde(tag = "code", content = "payload", rename_all = "snake_case")]
#[allow(dead_code)]
enum E {
NotFound,
BadRequest,
ServerError,
}
assert_eq!(<E as TautError>::code(&E::NotFound), "not_found");
assert_eq!(<E as TautError>::code(&E::BadRequest), "bad_request");
assert_eq!(<E as TautError>::code(&E::ServerError), "server_error");
assert_eq!(<E as TautError>::http_status(&E::NotFound), 400);
assert_eq!(<E as TautError>::http_status(&E::BadRequest), 400);
assert_eq!(<E as TautError>::http_status(&E::ServerError), 400);
}
#[test]
fn derive_taut_error_honors_per_variant_status_override() {
use taut_rpc::TautError;
#[derive(serde::Serialize, taut_rpc::TautError, Debug)]
#[serde(tag = "code", content = "payload", rename_all = "snake_case")]
#[allow(dead_code)]
enum E {
#[taut(status = 401)]
Unauth,
#[taut(status = 429)]
RateLimited,
Generic,
}
assert_eq!(<E as TautError>::http_status(&E::Unauth), 401);
assert_eq!(<E as TautError>::http_status(&E::RateLimited), 429);
assert_eq!(<E as TautError>::http_status(&E::Generic), 400);
}
#[test]
fn derive_taut_error_honors_code_override() {
use taut_rpc::TautError;
#[derive(serde::Serialize, taut_rpc::TautError, Debug)]
#[serde(tag = "code", content = "payload", rename_all = "snake_case")]
#[allow(dead_code)]
enum E {
#[taut(code = "auth_required")]
Unauth,
Plain,
}
assert_eq!(<E as TautError>::code(&E::Unauth), "auth_required");
assert_eq!(<E as TautError>::code(&E::Plain), "plain");
}
#[test]
fn derive_taut_error_works_on_variants_with_payloads() {
use taut_rpc::TautError;
#[derive(serde::Serialize, taut_rpc::TautError, Debug)]
#[serde(tag = "code", content = "payload", rename_all = "snake_case")]
#[allow(dead_code)]
enum E {
Tuple(String),
Named { reason: String },
}
let t = E::Tuple("hi".into());
let n = E::Named { reason: "x".into() };
assert_eq!(<E as TautError>::code(&t), "tuple");
assert_eq!(<E as TautError>::code(&n), "named");
assert_eq!(<E as TautError>::http_status(&t), 400);
assert_eq!(<E as TautError>::http_status(&n), 400);
}
#[derive(serde::Serialize, taut_rpc::Type, taut_rpc::TautError, Debug)]
#[serde(tag = "code", content = "payload", rename_all = "snake_case")]
#[allow(dead_code)]
enum RpcE {
#[taut(status = 401)]
Unauth,
}
#[taut_rpc::rpc]
#[allow(clippy::unnecessary_wraps, clippy::unused_async)] async fn protected_proc() -> Result<i32, RpcE> {
Err(RpcE::Unauth)
}
#[tokio::test]
async fn rpc_macro_uses_taut_error_for_status_and_code() {
let app = taut_rpc::Router::new()
.procedure(__taut_proc_protected_proc())
.into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.method("POST")
.uri("/rpc/protected_proc")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":null}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), http::StatusCode::UNAUTHORIZED);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(v["err"]["code"], serde_json::json!("unauth"));
assert!(
v["err"].get("payload").is_some(),
"expected `payload` key in err envelope, got {v}"
);
}
#[taut_rpc::rpc]
#[allow(clippy::unused_async)] async fn echo_layer() -> String {
"hi".to_string()
}
#[tokio::test]
async fn router_layer_applies_middleware() {
use axum::middleware::{from_fn, Next};
use axum::response::Response;
use http::HeaderValue;
async fn add_header(req: axum::extract::Request, next: Next) -> Response {
let mut res = next.run(req).await;
res.headers_mut()
.insert("x-taut-test", HeaderValue::from_static("hit"));
res
}
let app = taut_rpc::Router::new()
.procedure(__taut_proc_echo_layer())
.layer(from_fn(add_header))
.into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.method("POST")
.uri("/rpc/echo_layer")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":null}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), http::StatusCode::OK);
assert_eq!(
response
.headers()
.get("x-taut-test")
.map(|h| h.to_str().unwrap()),
Some("hit"),
"expected `x-taut-test: hit` header injected by Router::layer"
);
}
#[tokio::test]
async fn router_multiple_layers_compose() {
use axum::middleware::{from_fn, Next};
use axum::response::Response;
use http::HeaderValue;
async fn add_a(req: axum::extract::Request, next: Next) -> Response {
let mut res = next.run(req).await;
res.headers_mut()
.insert("x-a", HeaderValue::from_static("1"));
res
}
async fn add_b(req: axum::extract::Request, next: Next) -> Response {
let mut res = next.run(req).await;
res.headers_mut()
.insert("x-b", HeaderValue::from_static("2"));
res
}
let app = taut_rpc::Router::new()
.procedure(__taut_proc_echo_layer())
.layer(from_fn(add_a))
.layer(from_fn(add_b))
.into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.method("POST")
.uri("/rpc/echo_layer")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":null}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), http::StatusCode::OK);
assert_eq!(
response.headers().get("x-a").map(|h| h.to_str().unwrap()),
Some("1"),
"expected `x-a` header from first layer"
);
assert_eq!(
response.headers().get("x-b").map(|h| h.to_str().unwrap()),
Some("2"),
"expected `x-b` header from second layer"
);
}
#[test]
fn taut_error_re_export_resolves() {
use taut_rpc::TautError;
#[derive(serde::Serialize, taut_rpc::TautError, Debug)]
#[serde(tag = "code", content = "payload", rename_all = "snake_case")]
#[allow(dead_code)]
enum X {
A,
}
struct Y;
impl serde::Serialize for Y {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str("y")
}
}
impl TautError for Y {
fn code(&self) -> &'static str {
"y"
}
}
let _: &'static str = <X as TautError>::code(&X::A);
let _: u16 = <Y as TautError>::http_status(&Y);
let _: &'static str = <Y as TautError>::code(&Y);
}
#[test]
fn standard_error_unauthenticated_has_401_status() {
use taut_rpc::{StandardError, TautError};
let e = StandardError::Unauthenticated;
assert_eq!(e.http_status(), 401);
assert_eq!(e.code(), "unauthenticated");
}