api_tools/server/axum/
extractors.rs1use crate::server::axum::layers::request_id::REQUEST_ID_HEADER;
4use crate::server::axum::response::ApiError;
5use axum::extract::FromRequestParts;
6use axum::extract::path::ErrorKind;
7use axum::extract::rejection::PathRejection;
8use axum::http::request::Parts;
9use axum::http::{HeaderValue, StatusCode};
10use serde::de::DeserializeOwned;
11
12pub struct RequestId(pub HeaderValue);
14
15impl<S> FromRequestParts<S> for RequestId
16where
17 S: Send + Sync,
18{
19 type Rejection = ();
20
21 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
22 match parts.headers.get(REQUEST_ID_HEADER.clone()) {
23 Some(id) => Ok(RequestId(id.clone())),
24 _ => Ok(RequestId(HeaderValue::from_static(""))),
25 }
26 }
27}
28
29pub struct Path<T>(pub T);
31
32impl<S, T> FromRequestParts<S> for Path<T>
33where
34 T: DeserializeOwned + Send,
35 S: Send + Sync,
36{
37 type Rejection = (StatusCode, ApiError);
38
39 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
40 match axum::extract::Path::<T>::from_request_parts(parts, state).await {
41 Ok(value) => Ok(Self(value.0)),
42 Err(rejection) => {
43 let (status, body) = match rejection {
44 PathRejection::FailedToDeserializePathParams(inner) => {
45 let mut status = StatusCode::BAD_REQUEST;
46
47 let kind = inner.into_kind();
48 let body = match &kind {
49 ErrorKind::WrongNumberOfParameters { .. } => ApiError::BadRequest(kind.to_string()),
50 ErrorKind::ParseErrorAtKey { .. } => ApiError::BadRequest(kind.to_string()),
51 ErrorKind::ParseErrorAtIndex { .. } => ApiError::BadRequest(kind.to_string()),
52 ErrorKind::ParseError { .. } => ApiError::BadRequest(kind.to_string()),
53 ErrorKind::InvalidUtf8InPathParam { .. } => ApiError::BadRequest(kind.to_string()),
54 ErrorKind::UnsupportedType { .. } => {
55 status = StatusCode::INTERNAL_SERVER_ERROR;
58 ApiError::InternalServerError(kind.to_string())
59 }
60 ErrorKind::Message(msg) => ApiError::BadRequest(msg.clone()),
61 _ => ApiError::BadRequest(format!("Unhandled deserialization error: {kind}")),
62 };
63
64 (status, body)
65 }
66 PathRejection::MissingPathParams(error) => (
67 StatusCode::INTERNAL_SERVER_ERROR,
68 ApiError::InternalServerError(error.to_string()),
69 ),
70 _ => (
71 StatusCode::INTERNAL_SERVER_ERROR,
72 ApiError::InternalServerError(format!("Unhandled path rejection: {rejection}")),
73 ),
74 };
75
76 Err((status, body))
77 }
78 }
79 }
80}
81
82pub struct Query<T>(pub T);
84
85impl<T, S> FromRequestParts<S> for Query<T>
86where
87 T: DeserializeOwned,
88 S: Send + Sync,
89{
90 type Rejection = (StatusCode, ApiError);
91
92 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
93 let query = parts.uri.query().unwrap_or_default();
94 let value = serde_urlencoded::from_str(query)
95 .map_err(|err| (StatusCode::BAD_REQUEST, ApiError::BadRequest(err.to_string())))?;
96
97 Ok(Query(value))
98 }
99}