jacquard_axum/
lib.rs

1//! # Axum helpers for jacquard XRPC server implementations
2//!
3//! ## Usage
4//!
5//! ```no_run
6//! use axum::{Router, routing::get, http::StatusCode, response::IntoResponse,  Json};
7//! use jacquard_axum::{ ExtractXrpc, IntoRouter };
8//! use std::collections::BTreeMap;
9//! use miette::{IntoDiagnostic, Result};
10//! use jacquard::api::com_atproto::identity::resolve_handle::{ResolveHandle, ResolveHandleRequest, ResolveHandleOutput};
11//! use jacquard_common::types::string::Did;
12//!
13//! async fn handle_resolve(
14//!     ExtractXrpc(req): ExtractXrpc<ResolveHandleRequest>
15//! ) -> Result<Json<ResolveHandleOutput<'static>>, StatusCode> {
16//!     // req is ResolveHandle<'static>, ready to use
17//!     let handle = req.handle;
18//!     // ... resolve logic
19//! #   let output = ResolveHandleOutput { did: Did::new_static("did:plc:test").unwrap(), extra_data: None  };
20//!     Ok(Json(output))
21//! }
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<()> {
25//!     let app = Router::new()
26//!          .route("/", axum::routing::get(|| async { "hello world!" }))
27//!          .merge(ResolveHandleRequest::into_router(handle_resolve));
28//!
29//!     let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
30//!         .await
31//!         .into_diagnostic()?;
32//!         axum::serve(listener, app).await.unwrap();
33//!     Ok(())
34//! }
35//! ```
36//!
37//!
38//! The extractor uses the [`XrpcEndpoint`] trait to determine request type:
39//! - **Query**: Deserializes from query string parameters
40//! - **Procedure**: Deserializes from request body (supports custom encodings via `decode_body`)
41//!
42//! Deserialization errors return a 400 Bad Request with a JSON error body matching
43//! the XRPC error format.
44//!
45//! The extractor deserializes to borrowed types first, then converts to `'static` via
46//! [`IntoStatic`], avoiding the DeserializeOwned requirement of the Json axum extractor and similar.
47
48pub mod did_web;
49#[cfg(feature = "service-auth")]
50pub mod service_auth;
51
52use axum::{
53    Json, Router,
54    body::Bytes,
55    extract::{FromRequest, Request},
56    http::StatusCode,
57    response::{IntoResponse, Response},
58};
59use jacquard::{
60    IntoStatic,
61    xrpc::{XrpcEndpoint, XrpcError, XrpcMethod, XrpcRequest},
62};
63use serde_json::json;
64
65/// Axum extractor for XRPC requests
66///
67/// Deserializes incoming requests based on the endpoint's method type (Query or Procedure)
68/// and returns the owned (`'static`) request type ready for handler logic.
69pub struct ExtractXrpc<E: XrpcEndpoint>(pub E::Request<'static>);
70
71impl<S, R> FromRequest<S> for ExtractXrpc<R>
72where
73    S: Send + Sync,
74    R: XrpcEndpoint,
75    for<'a> R::Request<'a>: IntoStatic<Output = R::Request<'static>>,
76{
77    type Rejection = Response;
78
79    fn from_request(
80        req: Request,
81        state: &S,
82    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
83        async {
84            match R::METHOD {
85                XrpcMethod::Procedure(_) => {
86                    let body = Bytes::from_request(req, state)
87                        .await
88                        .map_err(IntoResponse::into_response)?;
89                    let decoded = R::Request::decode_body(&body);
90                    match decoded {
91                        Ok(value) => Ok(ExtractXrpc(*value.into_static())),
92                        Err(err) => Err((
93                            StatusCode::BAD_REQUEST,
94                            Json(json!({
95                                "error": "InvalidRequest",
96                                "message": format!("failed to decode request: {}", err)
97                            })),
98                        )
99                            .into_response()),
100                    }
101                }
102                XrpcMethod::Query => {
103                    if let Some(path_query) = req.uri().path_and_query() {
104                        let query = path_query.query().unwrap_or("");
105                        let value: R::Request<'_> =
106                            serde_html_form::from_str::<R::Request<'_>>(query).map_err(|e| {
107                                (
108                                    StatusCode::BAD_REQUEST,
109                                    Json(json!({
110                                        "error": "InvalidRequest",
111                                        "message": format!("failed to decode request: {}", e)
112                                    })),
113                                )
114                                    .into_response()
115                            })?;
116                        Ok(ExtractXrpc(value.into_static()))
117                    } else {
118                        Err((
119                            StatusCode::BAD_REQUEST,
120                            Json(json!({
121                                "error": "InvalidRequest",
122                                "message": "wrong path"
123                            })),
124                        )
125                            .into_response())
126                    }
127                }
128            }
129        }
130    }
131}
132
133/// Conversion trait to turn an XrpcEndpoint and a handler into an axum Router
134pub trait IntoRouter {
135    fn into_router<T, S, U>(handler: U) -> Router<S>
136    where
137        T: 'static,
138        S: Clone + Send + Sync + 'static,
139        U: axum::handler::Handler<T, S>;
140}
141
142impl<X> IntoRouter for X
143where
144    X: XrpcEndpoint,
145{
146    /// Creates an axum router that will invoke `handler` in response to xrpc
147    /// request `X`.
148    fn into_router<T, S, U>(handler: U) -> Router<S>
149    where
150        T: 'static,
151        S: Clone + Send + Sync + 'static,
152        U: axum::handler::Handler<T, S>,
153    {
154        Router::new().route(
155            X::PATH,
156            (match X::METHOD {
157                XrpcMethod::Query => axum::routing::get,
158                XrpcMethod::Procedure(_) => axum::routing::post,
159            })(handler),
160        )
161    }
162}
163
164/// Axum-compatible Xrpc error wrapper
165///
166/// Implements IntoResponse, and does some mildly opinionated mapping.
167///
168/// Currently assumes that the internal xrpc errors are well-formed and
169/// compatible with [the spec](https://atproto.com/specs/xrpc#error-responses).
170#[derive(Debug, thiserror::Error, miette::Diagnostic)]
171#[error("Xrpc error: {error}")]
172pub struct XrpcErrorResponse<E>
173where
174    E: std::error::Error + IntoStatic,
175{
176    pub status: StatusCode,
177    #[diagnostic_source]
178    pub error: XrpcError<E>,
179}
180
181impl<E> XrpcErrorResponse<E>
182where
183    E: std::error::Error + IntoStatic + serde::Serialize,
184{
185    /// Creates a new XrpcErrorResponse from the given status code and error.
186    pub fn new(status: StatusCode, error: XrpcError<E>) -> Self {
187        Self { status, error }
188    }
189
190    /// Changes the status code of the error response.
191    pub fn with_status(self, status: StatusCode) -> Self {
192        Self {
193            status,
194            error: self.error,
195        }
196    }
197}
198
199impl<E> IntoResponse for XrpcErrorResponse<E>
200where
201    E: std::error::Error + IntoStatic + serde::Serialize,
202{
203    fn into_response(self) -> Response {
204        let (status, json) = match self.error {
205            XrpcError::Xrpc(error) => (
206                self.status,
207                serde_json::to_value(&error).unwrap_or(json!({
208                    "error": "InternalError",
209                    "message": format!("{error}")
210                })),
211            ),
212            XrpcError::Auth(auth_error) => (
213                self.status,
214                json!({
215                    "error": "Authentication",
216                    "message": format!("{auth_error}")
217                }),
218            ),
219            XrpcError::Generic(generic) => (
220                self.status,
221                serde_json::to_value(&generic).unwrap_or(json!({
222                    "error": "InternalError",
223                    "message": format!("{generic}", )
224                })),
225            ),
226            XrpcError::Decode(error) => (
227                self.status,
228                json!({
229                    "error": "InvalidRequest",
230                    "message": format!("failed to decode request: {error}", )
231                }),
232            ),
233        };
234        (status, Json(json)).into_response()
235    }
236}
237
238impl<E> From<XrpcError<E>> for XrpcErrorResponse<E>
239where
240    E: std::error::Error + IntoStatic,
241{
242    fn from(value: XrpcError<E>) -> Self {
243        Self {
244            status: StatusCode::INTERNAL_SERVER_ERROR,
245            error: value,
246        }
247    }
248}
249
250impl<E> From<XrpcErrorResponse<E>> for XrpcError<E>
251where
252    E: std::error::Error + IntoStatic,
253{
254    fn from(value: XrpcErrorResponse<E>) -> Self {
255        value.error
256    }
257}