api_tools/server/axum/
extractors.rs

1//! Extractor modules for Axum
2
3use crate::server::axum::layers::request_id::REQUEST_ID_HEADER;
4use crate::server::axum::response::ApiError;
5use axum::extract::FromRequestParts;
6use axum::extract::path::ErrorKind;
7use axum::extract::rejection::PathRejection;
8use axum::http::request::Parts;
9use axum::http::{HeaderValue, StatusCode};
10use serde::de::DeserializeOwned;
11
12/// Request ID extractor from HTTP headers
13pub struct RequestId(pub HeaderValue);
14
15impl<S> FromRequestParts<S> for RequestId
16where
17    S: Send + Sync,
18{
19    type Rejection = ();
20
21    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
22        match parts.headers.get(REQUEST_ID_HEADER.clone()) {
23            Some(id) => Ok(RequestId(id.clone())),
24            _ => Ok(RequestId(HeaderValue::from_static(""))),
25        }
26    }
27}
28
29/// `Path` extractor customizes the error from `axum::extract::Path`
30pub struct Path<T>(pub T);
31
32impl<S, T> FromRequestParts<S> for Path<T>
33where
34    T: DeserializeOwned + Send,
35    S: Send + Sync,
36{
37    type Rejection = (StatusCode, ApiError);
38
39    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
40        match axum::extract::Path::<T>::from_request_parts(parts, state).await {
41            Ok(value) => Ok(Self(value.0)),
42            Err(rejection) => {
43                let (status, body) = match rejection {
44                    PathRejection::FailedToDeserializePathParams(inner) => {
45                        let mut status = StatusCode::BAD_REQUEST;
46
47                        let kind = inner.into_kind();
48                        let body = match &kind {
49                            ErrorKind::WrongNumberOfParameters { .. } => ApiError::BadRequest(kind.to_string()),
50                            ErrorKind::ParseErrorAtKey { .. } => ApiError::BadRequest(kind.to_string()),
51                            ErrorKind::ParseErrorAtIndex { .. } => ApiError::BadRequest(kind.to_string()),
52                            ErrorKind::ParseError { .. } => ApiError::BadRequest(kind.to_string()),
53                            ErrorKind::InvalidUtf8InPathParam { .. } => ApiError::BadRequest(kind.to_string()),
54                            ErrorKind::UnsupportedType { .. } => {
55                                // this error is caused by the programmer using an unsupported type
56                                // (such as nested maps) so respond with `500` instead
57                                status = StatusCode::INTERNAL_SERVER_ERROR;
58                                ApiError::InternalServerError(kind.to_string())
59                            }
60                            ErrorKind::Message(msg) => ApiError::BadRequest(msg.clone()),
61                            _ => ApiError::BadRequest(format!("Unhandled deserialization error: {kind}")),
62                        };
63
64                        (status, body)
65                    }
66                    PathRejection::MissingPathParams(error) => (
67                        StatusCode::INTERNAL_SERVER_ERROR,
68                        ApiError::InternalServerError(error.to_string()),
69                    ),
70                    _ => (
71                        StatusCode::INTERNAL_SERVER_ERROR,
72                        ApiError::InternalServerError(format!("Unhandled path rejection: {rejection}")),
73                    ),
74                };
75
76                Err((status, body))
77            }
78        }
79    }
80}
81
82/// `Query` extractor customizes the error from `axum::extract::Query`
83pub struct Query<T>(pub T);
84
85impl<T, S> FromRequestParts<S> for Query<T>
86where
87    T: DeserializeOwned,
88    S: Send + Sync,
89{
90    type Rejection = (StatusCode, ApiError);
91
92    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
93        let query = parts.uri.query().unwrap_or_default();
94        let value = serde_urlencoded::from_str(query)
95            .map_err(|err| (StatusCode::BAD_REQUEST, ApiError::BadRequest(err.to_string())))?;
96
97        Ok(Query(value))
98    }
99}