jacquard_axum/lib.rs
1//! # Axum helpers for jacquard XRPC server implementations
2//!
3//! ## Usage
4//!
5//! ```no_run
6//! use axum::{Router, routing::get, http::StatusCode, response::IntoResponse, Json};
7//! use jacquard_axum::{ ExtractXrpc, IntoRouter };
8//! use std::collections::BTreeMap;
9//! use miette::{IntoDiagnostic, Result};
10//! use jacquard::api::com_atproto::identity::resolve_handle::{ResolveHandle, ResolveHandleRequest, ResolveHandleOutput};
11//! use jacquard_common::types::string::Did;
12//!
13//! async fn handle_resolve(
14//! ExtractXrpc(req): ExtractXrpc<ResolveHandleRequest>
15//! ) -> Result<Json<ResolveHandleOutput<'static>>, StatusCode> {
16//! // req is ResolveHandle<'static>, ready to use
17//! let handle = req.handle;
18//! // ... resolve logic
19//! # let output = ResolveHandleOutput { did: Did::new_static("did:plc:test").unwrap(), extra_data: None };
20//! Ok(Json(output))
21//! }
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<()> {
25//! let app = Router::new()
26//! .route("/", axum::routing::get(|| async { "hello world!" }))
27//! .merge(ResolveHandleRequest::into_router(handle_resolve));
28//!
29//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
30//! .await
31//! .into_diagnostic()?;
32//! axum::serve(listener, app).await.unwrap();
33//! Ok(())
34//! }
35//! ```
36//!
37//!
38//! The extractor uses the [`XrpcEndpoint`] trait to determine request type:
39//! - **Query**: Deserializes from query string parameters
40//! - **Procedure**: Deserializes from request body (supports custom encodings via `decode_body`)
41//!
42//! Deserialization errors return a 400 Bad Request with a JSON error body matching
43//! the XRPC error format.
44//!
45//! The extractor deserializes to borrowed types first, then converts to `'static` via
46//! [`IntoStatic`], avoiding the DeserializeOwned requirement of the Json axum extractor and similar.
47
48pub mod did_web;
49#[cfg(feature = "service-auth")]
50pub mod service_auth;
51
52use axum::{
53 Json, Router,
54 body::Bytes,
55 extract::{FromRequest, Request},
56 http::StatusCode,
57 response::{IntoResponse, Response},
58};
59use jacquard::{
60 IntoStatic,
61 xrpc::{XrpcEndpoint, XrpcError, XrpcMethod, XrpcRequest},
62};
63use serde_json::json;
64
65/// Axum extractor for XRPC requests
66///
67/// Deserializes incoming requests based on the endpoint's method type (Query or Procedure)
68/// and returns the owned (`'static`) request type ready for handler logic.
69pub struct ExtractXrpc<E: XrpcEndpoint>(pub E::Request<'static>);
70
71impl<S, R> FromRequest<S> for ExtractXrpc<R>
72where
73 S: Send + Sync,
74 R: XrpcEndpoint,
75 for<'a> R::Request<'a>: IntoStatic<Output = R::Request<'static>>,
76{
77 type Rejection = Response;
78
79 fn from_request(
80 req: Request,
81 state: &S,
82 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
83 async {
84 match R::METHOD {
85 XrpcMethod::Procedure(_) => {
86 let body = Bytes::from_request(req, state)
87 .await
88 .map_err(IntoResponse::into_response)?;
89 let decoded = R::Request::decode_body(&body);
90 match decoded {
91 Ok(value) => Ok(ExtractXrpc(*value.into_static())),
92 Err(err) => Err((
93 StatusCode::BAD_REQUEST,
94 Json(json!({
95 "error": "InvalidRequest",
96 "message": format!("failed to decode request: {}", err)
97 })),
98 )
99 .into_response()),
100 }
101 }
102 XrpcMethod::Query => {
103 if let Some(path_query) = req.uri().path_and_query() {
104 let query = path_query.query().unwrap_or("");
105 let value: R::Request<'_> =
106 serde_html_form::from_str::<R::Request<'_>>(query).map_err(|e| {
107 (
108 StatusCode::BAD_REQUEST,
109 Json(json!({
110 "error": "InvalidRequest",
111 "message": format!("failed to decode request: {}", e)
112 })),
113 )
114 .into_response()
115 })?;
116 Ok(ExtractXrpc(value.into_static()))
117 } else {
118 Err((
119 StatusCode::BAD_REQUEST,
120 Json(json!({
121 "error": "InvalidRequest",
122 "message": "wrong path"
123 })),
124 )
125 .into_response())
126 }
127 }
128 }
129 }
130 }
131}
132
133/// Conversion trait to turn an XrpcEndpoint and a handler into an axum Router
134pub trait IntoRouter {
135 fn into_router<T, S, U>(handler: U) -> Router<S>
136 where
137 T: 'static,
138 S: Clone + Send + Sync + 'static,
139 U: axum::handler::Handler<T, S>;
140}
141
142impl<X> IntoRouter for X
143where
144 X: XrpcEndpoint,
145{
146 /// Creates an axum router that will invoke `handler` in response to xrpc
147 /// request `X`.
148 fn into_router<T, S, U>(handler: U) -> Router<S>
149 where
150 T: 'static,
151 S: Clone + Send + Sync + 'static,
152 U: axum::handler::Handler<T, S>,
153 {
154 Router::new().route(
155 X::PATH,
156 (match X::METHOD {
157 XrpcMethod::Query => axum::routing::get,
158 XrpcMethod::Procedure(_) => axum::routing::post,
159 })(handler),
160 )
161 }
162}
163
164/// Axum-compatible Xrpc error wrapper
165///
166/// Implements IntoResponse, and does some mildly opinionated mapping.
167///
168/// Currently assumes that the internal xrpc errors are well-formed and
169/// compatible with [the spec](https://atproto.com/specs/xrpc#error-responses).
170#[derive(Debug, thiserror::Error, miette::Diagnostic)]
171#[error("Xrpc error: {error}")]
172pub struct XrpcErrorResponse<E>
173where
174 E: std::error::Error + IntoStatic,
175{
176 pub status: StatusCode,
177 #[diagnostic_source]
178 pub error: XrpcError<E>,
179}
180
181impl<E> XrpcErrorResponse<E>
182where
183 E: std::error::Error + IntoStatic + serde::Serialize,
184{
185 /// Creates a new XrpcErrorResponse from the given status code and error.
186 pub fn new(status: StatusCode, error: XrpcError<E>) -> Self {
187 Self { status, error }
188 }
189
190 /// Changes the status code of the error response.
191 pub fn with_status(self, status: StatusCode) -> Self {
192 Self {
193 status,
194 error: self.error,
195 }
196 }
197}
198
199impl<E> IntoResponse for XrpcErrorResponse<E>
200where
201 E: std::error::Error + IntoStatic + serde::Serialize,
202{
203 fn into_response(self) -> Response {
204 let (status, json) = match self.error {
205 XrpcError::Xrpc(error) => (
206 self.status,
207 serde_json::to_value(&error).unwrap_or(json!({
208 "error": "InternalError",
209 "message": format!("{error}")
210 })),
211 ),
212 XrpcError::Auth(auth_error) => (
213 self.status,
214 json!({
215 "error": "Authentication",
216 "message": format!("{auth_error}")
217 }),
218 ),
219 XrpcError::Generic(generic) => (
220 self.status,
221 serde_json::to_value(&generic).unwrap_or(json!({
222 "error": "InternalError",
223 "message": format!("{generic}", )
224 })),
225 ),
226 XrpcError::Decode(error) => (
227 self.status,
228 json!({
229 "error": "InvalidRequest",
230 "message": format!("failed to decode request: {error}", )
231 }),
232 ),
233 };
234 (status, Json(json)).into_response()
235 }
236}
237
238impl<E> From<XrpcError<E>> for XrpcErrorResponse<E>
239where
240 E: std::error::Error + IntoStatic,
241{
242 fn from(value: XrpcError<E>) -> Self {
243 Self {
244 status: StatusCode::INTERNAL_SERVER_ERROR,
245 error: value,
246 }
247 }
248}
249
250impl<E> From<XrpcErrorResponse<E>> for XrpcError<E>
251where
252 E: std::error::Error + IntoStatic,
253{
254 fn from(value: XrpcErrorResponse<E>) -> Self {
255 value.error
256 }
257}