use std::collections::HashMap;
use std::sync::Arc;
use axum::extract::rejection::JsonRejection;
use axum::extract::{FromRequest, Query, Request};
use axum::http::StatusCode;
use axum::response::sse::{Event, KeepAlive, KeepAliveStream, Sse};
use axum::response::{IntoResponse, Response};
use axum::Router as AxumRouter;
use futures::stream::StreamExt;
use crate::procedure::{ProcedureBody, ProcedureDescriptor, ProcedureResult, StreamFrame};
use crate::wire::RpcRequest;
#[cfg(feature = "ws")]
#[path = "ws.rs"]
mod ws;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProcKindRuntime {
Query,
Mutation,
Subscription,
}
type LayerApply = Box<dyn FnMut(AxumRouter) -> AxumRouter + Send + Sync>;
#[derive(Default)]
pub struct Router {
procedures: Vec<ProcedureDescriptor>,
layers: Vec<LayerApply>,
}
impl Router {
#[must_use]
pub fn new() -> Self {
Self {
procedures: Vec::new(),
layers: Vec::new(),
}
}
#[must_use]
pub fn procedure(mut self, desc: ProcedureDescriptor) -> Self {
assert!(
!self.procedures.iter().any(|p| p.name == desc.name),
"taut-rpc: procedure `{}` is already registered on this Router",
desc.name
);
self.procedures.push(desc);
self
}
#[must_use]
pub fn layer<L>(mut self, layer: L) -> Self
where
L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
L::Service: tower::Service<axum::http::Request<axum::body::Body>, Error = std::convert::Infallible>
+ Clone
+ Send
+ Sync
+ 'static,
<L::Service as tower::Service<axum::http::Request<axum::body::Body>>>::Response:
axum::response::IntoResponse + 'static,
<L::Service as tower::Service<axum::http::Request<axum::body::Body>>>::Future:
Send + 'static,
{
let mut slot = Some(layer);
self.layers.push(Box::new(move |r: AxumRouter| {
let layer = slot
.take()
.expect("taut-rpc: layer adapter invoked more than once");
r.layer(layer)
}));
self
}
#[must_use]
pub fn ir(&self) -> crate::ir::Ir {
let mut procedures = Vec::with_capacity(self.procedures.len());
let mut types: Vec<crate::ir::TypeDef> = Vec::new();
let mut seen_type_names: std::collections::HashSet<String> =
std::collections::HashSet::new();
for desc in &self.procedures {
procedures.push(desc.ir.clone());
for td in &desc.type_defs {
if seen_type_names.insert(td.name.clone()) {
types.push(td.clone());
}
}
}
crate::ir::Ir {
ir_version: crate::ir::Ir::CURRENT_VERSION,
procedures,
types,
}
}
pub fn into_axum(mut self) -> AxumRouter {
#[cfg_attr(not(feature = "ir-export"), allow(unused_variables))]
let _ir: Arc<crate::ir::Ir> = Arc::new(self.ir());
let names: Arc<Vec<String>> =
Arc::new(self.procedures.iter().map(|p| p.name.to_string()).collect());
let mut app = AxumRouter::new()
.route("/rpc/_health", axum::routing::get(|| async { "ok" }))
.route(
"/rpc/_version",
axum::routing::get(|| async {
axum::Json(serde_json::json!({
"taut_rpc": env!("CARGO_PKG_VERSION"),
"ir_version": crate::IR_VERSION,
}))
}),
)
.route(
"/rpc/_procedures",
axum::routing::get(move || {
let names = names.clone();
async move { axum::Json((*names).clone()) }
}),
);
#[cfg(feature = "ir-export")]
{
let ir_for_route = _ir.clone();
app = app.route(
"/rpc/_ir",
axum::routing::get(move || {
let ir = ir_for_route.clone();
async move { axum::Json((*ir).clone()) }
}),
);
}
#[cfg(feature = "ws")]
{
let descriptors_arc: Arc<Vec<ProcedureDescriptor>> = Arc::new(self.procedures.clone());
app = app.route(
"/rpc/_ws",
axum::routing::get(crate::router::ws::ws_route::ws_handler(descriptors_arc)),
);
}
for desc in std::mem::take(&mut self.procedures) {
let path = format!("/rpc/{}", desc.name);
match desc.kind {
ProcKindRuntime::Query | ProcKindRuntime::Mutation => {
let handler = match desc.body {
ProcedureBody::Unary(h) => h,
ProcedureBody::Stream(_) => {
unreachable!(
"taut-rpc: query/mutation `{}` was registered with a streaming body",
desc.name
)
}
};
app = app.route(
&path,
axum::routing::post(move |input: RpcInput| {
let handler = handler.clone();
async move {
let RpcInput(value) = input;
let result = handler(value).await;
procedure_result_into_response(result)
}
}),
);
}
ProcKindRuntime::Subscription => {
let handler = match desc.body {
ProcedureBody::Stream(h) => h,
ProcedureBody::Unary(_) => {
unreachable!(
"taut-rpc: subscription `{}` was registered with a unary body",
desc.name
)
}
};
app = app.route(
&path,
axum::routing::get(move |Query(params): Query<HashMap<String, String>>| {
let handler = handler.clone();
async move {
sse_response_for(handler, params.get("input").map(String::as_str))
}
}),
);
}
}
}
let mut router = app.fallback(not_found_fallback);
for mut apply in self.layers.drain(..) {
router = apply(router);
}
router
}
}
struct RpcInput(serde_json::Value);
impl<S> FromRequest<S> for RpcInput
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
match axum::Json::<RpcRequest<serde_json::Value>>::from_request(req, state).await {
Ok(axum::Json(RpcRequest { input })) => Ok(RpcInput(input)),
Err(rej) => Err(decode_error_response(&rej)),
}
}
}
fn decode_error_response(rej: &JsonRejection) -> Response {
let body = serde_json::json!({
"err": {
"code": "decode_error",
"payload": { "message": rej.body_text() },
}
});
(StatusCode::BAD_REQUEST, axum::Json(body)).into_response()
}
fn procedure_result_into_response(result: ProcedureResult) -> Response {
match result {
ProcedureResult::Ok(value) => {
let body = serde_json::json!({ "ok": value });
(StatusCode::OK, axum::Json(body)).into_response()
}
ProcedureResult::Err {
http_status,
code,
payload,
} => {
let status =
StatusCode::from_u16(http_status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = serde_json::json!({
"err": { "code": code, "payload": payload }
});
(status, axum::Json(body)).into_response()
}
}
}
fn parse_input_param(raw: Option<&str>) -> Result<serde_json::Value, serde_json::Error> {
match raw {
None | Some("") => Ok(serde_json::Value::Null),
Some(s) => serde_json::from_str(s),
}
}
#[allow(clippy::needless_pass_by_value)] fn sse_response_for(
handler: crate::procedure::StreamHandler,
raw_input: Option<&str>,
) -> Sse<
KeepAliveStream<futures::stream::BoxStream<'static, Result<Event, std::convert::Infallible>>>,
> {
use futures::stream;
let event_stream: futures::stream::BoxStream<'static, Result<Event, std::convert::Infallible>> =
match parse_input_param(raw_input) {
Err(e) => {
let event = Event::default()
.event("error")
.json_data(serde_json::json!({
"code": "decode_error",
"payload": { "message": e.to_string() },
}))
.expect("valid json");
stream::once(async move { Ok(event) }).boxed()
}
Ok(input_json) => {
let frames = handler(input_json);
frames
.map(|frame| {
let event = match frame {
StreamFrame::Data(v) => Event::default()
.event("data")
.json_data(v)
.expect("valid json"),
StreamFrame::Error { code, payload } => Event::default()
.event("error")
.json_data(serde_json::json!({
"code": code,
"payload": payload,
}))
.expect("valid json"),
};
Ok::<Event, std::convert::Infallible>(event)
})
.boxed()
}
};
let end = stream::once(async {
Ok::<Event, std::convert::Infallible>(Event::default().event("end").data(""))
});
Sse::new(event_stream.chain(end).boxed()).keep_alive(KeepAlive::default())
}
async fn not_found_fallback(req: Request) -> Response {
let path = req.uri().path();
let procedure = path.strip_prefix("/rpc/").unwrap_or(path).to_string();
let body = serde_json::json!({
"err": {
"code": "not_found",
"payload": { "procedure": procedure },
}
});
(StatusCode::NOT_FOUND, axum::Json(body)).into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{HttpMethod, Primitive, ProcKind, Procedure, TypeRef};
use crate::procedure::{StreamHandler, UnaryHandler};
use axum::body::Body;
use futures::future::BoxFuture;
use futures::stream::{self, BoxStream};
use http::Request as HttpRequest;
use tower::ServiceExt;
fn make_descriptor(
name: &'static str,
kind: ProcKindRuntime,
handler: UnaryHandler,
) -> ProcedureDescriptor {
let ir_kind = match kind {
ProcKindRuntime::Query => ProcKind::Query,
ProcKindRuntime::Mutation => ProcKind::Mutation,
ProcKindRuntime::Subscription => ProcKind::Subscription,
};
ProcedureDescriptor {
name,
kind,
ir: Procedure {
name: name.to_string(),
kind: ir_kind,
input: TypeRef::Primitive(Primitive::Unit),
output: TypeRef::Primitive(Primitive::Unit),
errors: vec![],
http_method: HttpMethod::Post,
doc: None,
},
type_defs: vec![],
body: ProcedureBody::Unary(handler),
}
}
fn echo_handler() -> UnaryHandler {
Arc::new(
|input: serde_json::Value| -> BoxFuture<'static, ProcedureResult> {
Box::pin(async move { ProcedureResult::Ok(input) })
},
)
}
fn not_found_handler() -> UnaryHandler {
Arc::new(
|_input: serde_json::Value| -> BoxFuture<'static, ProcedureResult> {
Box::pin(async move {
ProcedureResult::Err {
http_status: 404,
code: "not_found".to_string(),
payload: serde_json::Value::Null,
}
})
},
)
}
fn make_stream_descriptor(name: &'static str, handler: StreamHandler) -> ProcedureDescriptor {
ProcedureDescriptor {
name,
kind: ProcKindRuntime::Subscription,
ir: Procedure {
name: name.to_string(),
kind: ProcKind::Subscription,
input: TypeRef::Primitive(Primitive::Unit),
output: TypeRef::Primitive(Primitive::Unit),
errors: vec![],
http_method: HttpMethod::Get,
doc: None,
},
type_defs: vec![],
body: ProcedureBody::Stream(handler),
}
}
fn counting_stream_handler(n: usize) -> StreamHandler {
Arc::new(
move |_input: serde_json::Value| -> BoxStream<'static, StreamFrame> {
stream::iter((0..n).map(|i| StreamFrame::Data(serde_json::json!(i)))).boxed()
},
)
}
#[test]
fn empty_router_builds() {
let _: AxumRouter = Router::new().into_axum();
}
#[tokio::test]
async fn health_endpoint_returns_ok() {
let app = Router::new().into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.uri("/rpc/_health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&bytes[..], b"ok");
}
#[tokio::test]
async fn health_endpoint_unchanged() {
let app = Router::new().into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.uri("/rpc/_health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let ct = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
ct.starts_with("text/plain"),
"expected text/plain content-type, got {ct:?}"
);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&bytes[..], b"ok");
}
#[tokio::test]
async fn version_endpoint_returns_json_with_version_and_ir_version() {
let app = Router::new().into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.uri("/rpc/_version")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).expect("body must parse as JSON");
assert!(
v["taut_rpc"].is_string(),
"taut_rpc field must be a string; got {v}"
);
assert_eq!(
v["taut_rpc"].as_str().unwrap(),
env!("CARGO_PKG_VERSION"),
"taut_rpc must match CARGO_PKG_VERSION",
);
assert_eq!(
v["ir_version"].as_u64(),
Some(u64::from(crate::IR_VERSION)),
"ir_version must match crate::IR_VERSION",
);
}
#[tokio::test]
async fn procedures_endpoint_lists_registered_names() {
let app = Router::new()
.procedure(make_descriptor(
"alpha",
ProcKindRuntime::Query,
echo_handler(),
))
.procedure(make_descriptor(
"beta",
ProcKindRuntime::Mutation,
echo_handler(),
))
.into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.uri("/rpc/_procedures")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let names: Vec<String> = serde_json::from_slice(&bytes).unwrap();
assert_eq!(names, vec!["alpha".to_string(), "beta".to_string()]);
}
#[tokio::test]
async fn registered_query_dispatches_through_handler() {
let app = Router::new()
.procedure(make_descriptor(
"echo",
ProcKindRuntime::Query,
echo_handler(),
))
.into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.method("POST")
.uri("/rpc/echo")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":42}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(v, serde_json::json!({"ok": 42}));
}
#[tokio::test]
async fn handler_error_surfaces_with_envelope_and_status() {
let app = Router::new()
.procedure(make_descriptor(
"echo",
ProcKindRuntime::Query,
not_found_handler(),
))
.into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.method("POST")
.uri("/rpc/echo")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":null}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(
v,
serde_json::json!({
"err": { "code": "not_found", "payload": null }
})
);
}
#[tokio::test]
async fn malformed_json_returns_decode_error_envelope() {
let app = Router::new()
.procedure(make_descriptor(
"echo",
ProcKindRuntime::Query,
echo_handler(),
))
.into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.method("POST")
.uri("/rpc/echo")
.header("content-type", "application/json")
.body(Body::from("not json"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(v["err"]["code"], serde_json::json!("decode_error"));
assert!(v["err"]["payload"]["message"].is_string());
assert!(!v["err"]["payload"]["message"].as_str().unwrap().is_empty());
}
#[tokio::test]
async fn unknown_procedure_returns_not_found_envelope() {
let app = Router::new().into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.method("POST")
.uri("/rpc/nonexistent")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":null}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(
v,
serde_json::json!({
"err": {
"code": "not_found",
"payload": { "procedure": "nonexistent" }
}
})
);
}
#[test]
#[should_panic(expected = "already registered")]
fn duplicate_procedure_name_panics() {
let _ = Router::new()
.procedure(make_descriptor(
"dup",
ProcKindRuntime::Query,
echo_handler(),
))
.procedure(make_descriptor(
"dup",
ProcKindRuntime::Query,
echo_handler(),
));
}
#[test]
fn ir_snapshot_contains_registered_procedures() {
let router = Router::new()
.procedure(make_descriptor("a", ProcKindRuntime::Query, echo_handler()))
.procedure(make_descriptor(
"b",
ProcKindRuntime::Mutation,
echo_handler(),
));
let ir = router.ir();
assert_eq!(ir.ir_version, crate::ir::Ir::CURRENT_VERSION);
let names: Vec<&str> = ir.procedures.iter().map(|p| p.name.as_str()).collect();
assert_eq!(names, vec!["a", "b"]);
}
fn with_marker_layer(router: Router, marker: &'static str) -> Router {
router.layer(axum::middleware::from_fn(
move |req: axum::extract::Request, next: axum::middleware::Next| async move {
let mut resp = next.run(req).await;
resp.headers_mut()
.insert("x-taut-test", marker.parse().unwrap());
resp
},
))
}
#[tokio::test]
async fn router_with_no_layers_builds_unchanged() {
let app = Router::new().into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.uri("/rpc/_health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().get("x-taut-test").is_none());
}
#[tokio::test]
async fn router_with_a_simple_layer_applies_it() {
let router = Router::new().procedure(make_descriptor(
"ping",
ProcKindRuntime::Query,
echo_handler(),
));
let app = with_marker_layer(router, "hit").into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.method("POST")
.uri("/rpc/ping")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":null}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("x-taut-test")
.map(|v| v.to_str().unwrap()),
Some("hit"),
);
}
#[tokio::test]
async fn multiple_layers_compose_in_outer_first_order() {
let router = Router::new().procedure(make_descriptor(
"ping",
ProcKindRuntime::Query,
echo_handler(),
));
let router = with_marker_layer(router, "inner");
let router = with_marker_layer(router, "outer");
let app = router.into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.method("POST")
.uri("/rpc/ping")
.header("content-type", "application/json")
.body(Body::from(r#"{"input":null}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("x-taut-test")
.map(|v| v.to_str().unwrap()),
Some("outer"),
);
}
#[tokio::test]
async fn subscription_route_emits_three_data_frames_then_end() {
let app = Router::new()
.procedure(make_stream_descriptor("ticks", counting_stream_handler(3)))
.into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.method("GET")
.uri("/rpc/ticks?input=null")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let ct = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
ct.starts_with("text/event-stream"),
"expected text/event-stream content-type, got {ct:?}"
);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body = std::str::from_utf8(&bytes).expect("utf8 body");
for i in 0..3 {
let needle = format!("event: data\ndata: {i}\n\n");
assert!(
body.contains(&needle),
"missing data frame {i}; body was:\n{body}"
);
}
assert!(
body.contains("event: end"),
"missing terminating event:end frame; body was:\n{body}"
);
let last_data_idx = body.rfind("event: data").expect("at least one data frame");
let end_idx = body.find("event: end").expect("end frame present");
assert!(
end_idx > last_data_idx,
"event:end must follow the last data frame; body was:\n{body}"
);
}
#[tokio::test]
async fn subscription_decode_error_emits_error_frame_then_end() {
let app = Router::new()
.procedure(make_stream_descriptor(
"ticks",
counting_stream_handler(0),
))
.into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.method("GET")
.uri("/rpc/ticks?input=not-json")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body = std::str::from_utf8(&bytes).expect("utf8 body");
assert!(
body.contains("event: error"),
"missing event:error frame; body was:\n{body}"
);
let error_section = body
.split("event: error\n")
.nth(1)
.expect("error event present");
let data_line = error_section
.lines()
.find(|l| l.starts_with("data: "))
.expect("data line under error event");
let json_str = data_line.strip_prefix("data: ").unwrap();
let v: serde_json::Value =
serde_json::from_str(json_str).expect("valid json in error frame");
assert_eq!(v["code"], serde_json::json!("decode_error"));
assert!(
v["payload"]["message"].is_string(),
"payload.message should be a string; got {v}"
);
let error_idx = body.find("event: error").unwrap();
let end_idx = body
.find("event: end")
.expect("missing event:end after event:error");
assert!(
end_idx > error_idx,
"event:end must follow event:error; body was:\n{body}"
);
}
#[tokio::test]
async fn subscription_with_no_input_param_decodes_as_null() {
let app = Router::new()
.procedure(make_stream_descriptor("ticks", counting_stream_handler(1)))
.into_axum();
let response = app
.oneshot(
HttpRequest::builder()
.method("GET")
.uri("/rpc/ticks")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body = std::str::from_utf8(&bytes).expect("utf8 body");
assert!(
body.contains("event: data\ndata: 0\n\n"),
"expected single data frame for null input; body was:\n{body}"
);
assert!(
body.contains("event: end"),
"missing event:end; body was:\n{body}"
);
}
}