graphql_starter/axum/
extract.rs

1//! Wrappers over axum's [extract](https://docs.rs/axum/latest/axum/extract/index.html), providing custom error responses.
2//!
3//! It avoids having to use [WithRejection](https://docs.rs/axum-extra/latest/axum_extra/extract/struct.WithRejection.html)
4//! every time
5
6use std::{convert::Infallible, sync::Arc};
7
8use axum::{
9    extract::{FromRequest, FromRequestParts, OptionalFromRequest, OptionalFromRequestParts, Request},
10    response::{IntoResponse, Response},
11};
12use bytes::{BufMut, BytesMut};
13use error_info::ErrorInfo;
14use http::{header, request::Parts, HeaderValue, StatusCode};
15use serde::{de::DeserializeOwned, Serialize};
16
17use crate::error::{ApiError, GenericErrorCode, MapToErr};
18
19/// Wrapper over [axum::Json] to customize error responses
20#[derive(Debug, Clone, Copy, Default)]
21#[must_use]
22pub struct Json<T>(pub T);
23
24impl<S, T> FromRequest<S> for Json<T>
25where
26    T: DeserializeOwned,
27    S: Send + Sync,
28{
29    type Rejection = Box<ApiError>;
30
31    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
32        <::axum::Json<T> as FromRequest<S>>::from_request(req, state)
33            .await
34            .map(|::axum::Json(value)| Json(value))
35            .map_err(|err| {
36                tracing::info!("Couldn't parse json request: {err}");
37                ApiError::new(err.status(), err.body_text())
38            })
39    }
40}
41
42impl<T, S> OptionalFromRequest<S> for Json<T>
43where
44    T: DeserializeOwned,
45    S: Send + Sync,
46{
47    type Rejection = Box<ApiError>;
48
49    async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
50        <::axum::Json<T> as OptionalFromRequest<S>>::from_request(req, state)
51            .await
52            .map(|v| v.map(|::axum::Json(value)| Json(value)))
53            .map_err(|err| {
54                tracing::info!("Couldn't parse json request: {err}");
55                ApiError::new(err.status(), err.body_text())
56            })
57    }
58}
59
60impl<T> IntoResponse for Json<T>
61where
62    T: Serialize,
63{
64    fn into_response(self) -> Response {
65        // Mimic ::axum::Json::into_response with custom error
66        let mut buf = BytesMut::with_capacity(128).writer();
67        match serde_json::to_writer(&mut buf, &self.0).map_to_internal_err("Error serializing response") {
68            Ok(()) => (
69                [(
70                    header::CONTENT_TYPE,
71                    HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()),
72                )],
73                buf.into_inner().freeze(),
74            )
75                .into_response(),
76            Err(err) => ApiError::from_err(err).into_response(),
77        }
78    }
79}
80
81/// Wrapper over [axum::extract::Query] to customize error responses
82#[derive(Debug, Clone, Copy, Default)]
83pub struct Query<T>(pub T);
84
85impl<T, S> FromRequestParts<S> for Query<T>
86where
87    T: DeserializeOwned,
88    S: Send + Sync,
89{
90    type Rejection = Box<ApiError>;
91
92    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
93        ::axum::extract::Query::<T>::from_request_parts(parts, state)
94            .await
95            .map(|::axum::extract::Query(value)| Query(value))
96            .map_err(|err| {
97                tracing::warn!(
98                    "[{} {}] Couldn't parse request query: {err}",
99                    err.status().as_str(),
100                    err.status().canonical_reason().unwrap_or("Unknown")
101                );
102                ApiError::new(err.status(), err.body_text())
103            })
104    }
105}
106
107/// Wrapper over [axum::extract::Path] to customize error responses
108#[derive(Debug, Clone, Copy, Default)]
109pub struct Path<T>(pub T);
110
111impl<T, S> FromRequestParts<S> for Path<T>
112where
113    T: DeserializeOwned + Send,
114    S: Send + Sync,
115{
116    type Rejection = Box<ApiError>;
117
118    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
119        <::axum::extract::Path<T> as FromRequestParts<S>>::from_request_parts(parts, state)
120            .await
121            .map(|::axum::extract::Path(value)| Path(value))
122            .map_err(|err| {
123                tracing::error!(
124                    "[{} {}] Couldn't extract request path: {err}",
125                    err.status().as_str(),
126                    err.status().canonical_reason().unwrap_or("Unknown")
127                );
128                ApiError::new(err.status(), err.body_text())
129            })
130    }
131}
132
133impl<T, S> OptionalFromRequestParts<S> for Path<T>
134where
135    T: DeserializeOwned + Send + 'static,
136    S: Send + Sync,
137{
138    type Rejection = Box<ApiError>;
139
140    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Option<Self>, Self::Rejection> {
141        <::axum::extract::Path<T> as OptionalFromRequestParts<S>>::from_request_parts(parts, state)
142            .await
143            .map(|v| v.map(|::axum::extract::Path(value)| Path(value)))
144            .map_err(|err| {
145                tracing::error!(
146                    "[{} {}] Couldn't extract request path: {err}",
147                    err.status().as_str(),
148                    err.status().canonical_reason().unwrap_or("Unknown")
149                );
150                ApiError::new(err.status(), err.body_text())
151            })
152    }
153}
154
155/// Wrapper over [axum::Extension] to customize error responses
156#[derive(Debug, Clone, Copy, Default)]
157pub struct Extension<T>(pub T);
158
159impl<T, S> FromRequestParts<S> for Extension<T>
160where
161    T: Clone + Send + Sync + 'static,
162    S: Send + Sync,
163{
164    type Rejection = Box<ApiError>;
165
166    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
167        <::axum::Extension<T> as FromRequestParts<S>>::from_request_parts(parts, state)
168            .await
169            .map(|::axum::Extension(value)| Extension(value))
170            .map_err(|err| {
171                tracing::error!("[500 Internal Server Error] Couldn't extract extension: {err}");
172                ApiError::new(err.status(), GenericErrorCode::InternalServerError.raw_message())
173            })
174    }
175}
176
177impl<T, S> OptionalFromRequestParts<S> for Extension<T>
178where
179    T: Clone + Send + Sync + 'static,
180    S: Send + Sync,
181{
182    type Rejection = Box<ApiError>;
183
184    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Option<Self>, Self::Rejection> {
185        <::axum::Extension<T> as OptionalFromRequestParts<S>>::from_request_parts(parts, state)
186            .await
187            .map(|v| v.map(|::axum::Extension(value)| Extension(value)))
188            .map_err(|_: Infallible| {
189                ApiError::new(
190                    StatusCode::INTERNAL_SERVER_ERROR,
191                    GenericErrorCode::InternalServerError.raw_message(),
192                )
193            })
194    }
195}
196
197/// Extractor for an optional `Accept-Languages` header
198#[derive(Debug, Clone, Default)]
199pub struct AcceptLanguage(pub Option<Arc<Vec<String>>>);
200impl AcceptLanguage {
201    /// Returns the list of accepted languages, ordered by quality descending
202    pub fn accepted_languages(&self) -> Option<&[String]> {
203        self.0.as_deref().map(|s| s.as_slice())
204    }
205
206    /// Returns the first accepted language, the one with higher quality
207    pub fn preferred_language(&self) -> Option<&str> {
208        self.accepted_languages().and_then(|l| l.first().map(|s| s.as_str()))
209    }
210}
211
212impl<S> FromRequestParts<S> for AcceptLanguage
213where
214    S: Send + Sync,
215{
216    type Rejection = Box<ApiError>;
217
218    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
219        // Extract the header and parse it (if any)
220        let accept_language = parts
221            .headers
222            .get(http::header::ACCEPT_LANGUAGE)
223            .and_then(|v| v.to_str().ok().map(accept_language::parse).map(Arc::new))
224            .filter(|v| !v.is_empty());
225
226        Ok(AcceptLanguage(accept_language))
227    }
228}