1use std::{convert::Infallible, sync::Arc};
7
8use axum::{
9 extract::{FromRequest, FromRequestParts, OptionalFromRequest, OptionalFromRequestParts, Request},
10 response::{IntoResponse, Response},
11};
12use bytes::{BufMut, BytesMut};
13use error_info::ErrorInfo;
14use http::{header, request::Parts, HeaderValue, StatusCode};
15use serde::{de::DeserializeOwned, Serialize};
16
17use crate::error::{ApiError, GenericErrorCode, MapToErr};
18
19#[derive(Debug, Clone, Copy, Default)]
21#[must_use]
22pub struct Json<T>(pub T);
23
24impl<S, T> FromRequest<S> for Json<T>
25where
26 T: DeserializeOwned,
27 S: Send + Sync,
28{
29 type Rejection = Box<ApiError>;
30
31 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
32 <::axum::Json<T> as FromRequest<S>>::from_request(req, state)
33 .await
34 .map(|::axum::Json(value)| Json(value))
35 .map_err(|err| {
36 tracing::info!("Couldn't parse json request: {err}");
37 ApiError::new(err.status(), err.body_text())
38 })
39 }
40}
41
42impl<T, S> OptionalFromRequest<S> for Json<T>
43where
44 T: DeserializeOwned,
45 S: Send + Sync,
46{
47 type Rejection = Box<ApiError>;
48
49 async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
50 <::axum::Json<T> as OptionalFromRequest<S>>::from_request(req, state)
51 .await
52 .map(|v| v.map(|::axum::Json(value)| Json(value)))
53 .map_err(|err| {
54 tracing::info!("Couldn't parse json request: {err}");
55 ApiError::new(err.status(), err.body_text())
56 })
57 }
58}
59
60impl<T> IntoResponse for Json<T>
61where
62 T: Serialize,
63{
64 fn into_response(self) -> Response {
65 let mut buf = BytesMut::with_capacity(128).writer();
67 match serde_json::to_writer(&mut buf, &self.0).map_to_internal_err("Error serializing response") {
68 Ok(()) => (
69 [(
70 header::CONTENT_TYPE,
71 HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()),
72 )],
73 buf.into_inner().freeze(),
74 )
75 .into_response(),
76 Err(err) => ApiError::from_err(err).into_response(),
77 }
78 }
79}
80
81#[derive(Debug, Clone, Copy, Default)]
83pub struct Query<T>(pub T);
84
85impl<T, S> FromRequestParts<S> for Query<T>
86where
87 T: DeserializeOwned,
88 S: Send + Sync,
89{
90 type Rejection = Box<ApiError>;
91
92 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
93 ::axum::extract::Query::<T>::from_request_parts(parts, state)
94 .await
95 .map(|::axum::extract::Query(value)| Query(value))
96 .map_err(|err| {
97 tracing::warn!(
98 "[{} {}] Couldn't parse request query: {err}",
99 err.status().as_str(),
100 err.status().canonical_reason().unwrap_or("Unknown")
101 );
102 ApiError::new(err.status(), err.body_text())
103 })
104 }
105}
106
107#[derive(Debug, Clone, Copy, Default)]
109pub struct Path<T>(pub T);
110
111impl<T, S> FromRequestParts<S> for Path<T>
112where
113 T: DeserializeOwned + Send,
114 S: Send + Sync,
115{
116 type Rejection = Box<ApiError>;
117
118 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
119 <::axum::extract::Path<T> as FromRequestParts<S>>::from_request_parts(parts, state)
120 .await
121 .map(|::axum::extract::Path(value)| Path(value))
122 .map_err(|err| {
123 tracing::error!(
124 "[{} {}] Couldn't extract request path: {err}",
125 err.status().as_str(),
126 err.status().canonical_reason().unwrap_or("Unknown")
127 );
128 ApiError::new(err.status(), err.body_text())
129 })
130 }
131}
132
133impl<T, S> OptionalFromRequestParts<S> for Path<T>
134where
135 T: DeserializeOwned + Send + 'static,
136 S: Send + Sync,
137{
138 type Rejection = Box<ApiError>;
139
140 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Option<Self>, Self::Rejection> {
141 <::axum::extract::Path<T> as OptionalFromRequestParts<S>>::from_request_parts(parts, state)
142 .await
143 .map(|v| v.map(|::axum::extract::Path(value)| Path(value)))
144 .map_err(|err| {
145 tracing::error!(
146 "[{} {}] Couldn't extract request path: {err}",
147 err.status().as_str(),
148 err.status().canonical_reason().unwrap_or("Unknown")
149 );
150 ApiError::new(err.status(), err.body_text())
151 })
152 }
153}
154
155#[derive(Debug, Clone, Copy, Default)]
157pub struct Extension<T>(pub T);
158
159impl<T, S> FromRequestParts<S> for Extension<T>
160where
161 T: Clone + Send + Sync + 'static,
162 S: Send + Sync,
163{
164 type Rejection = Box<ApiError>;
165
166 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
167 <::axum::Extension<T> as FromRequestParts<S>>::from_request_parts(parts, state)
168 .await
169 .map(|::axum::Extension(value)| Extension(value))
170 .map_err(|err| {
171 tracing::error!("[500 Internal Server Error] Couldn't extract extension: {err}");
172 ApiError::new(err.status(), GenericErrorCode::InternalServerError.raw_message())
173 })
174 }
175}
176
177impl<T, S> OptionalFromRequestParts<S> for Extension<T>
178where
179 T: Clone + Send + Sync + 'static,
180 S: Send + Sync,
181{
182 type Rejection = Box<ApiError>;
183
184 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Option<Self>, Self::Rejection> {
185 <::axum::Extension<T> as OptionalFromRequestParts<S>>::from_request_parts(parts, state)
186 .await
187 .map(|v| v.map(|::axum::Extension(value)| Extension(value)))
188 .map_err(|_: Infallible| {
189 ApiError::new(
190 StatusCode::INTERNAL_SERVER_ERROR,
191 GenericErrorCode::InternalServerError.raw_message(),
192 )
193 })
194 }
195}
196
197#[derive(Debug, Clone, Default)]
199pub struct AcceptLanguage(pub Option<Arc<Vec<String>>>);
200impl AcceptLanguage {
201 pub fn accepted_languages(&self) -> Option<&[String]> {
203 self.0.as_deref().map(|s| s.as_slice())
204 }
205
206 pub fn preferred_language(&self) -> Option<&str> {
208 self.accepted_languages().and_then(|l| l.first().map(|s| s.as_str()))
209 }
210}
211
212impl<S> FromRequestParts<S> for AcceptLanguage
213where
214 S: Send + Sync,
215{
216 type Rejection = Box<ApiError>;
217
218 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
219 let accept_language = parts
221 .headers
222 .get(http::header::ACCEPT_LANGUAGE)
223 .and_then(|v| v.to_str().ok().map(accept_language::parse).map(Arc::new))
224 .filter(|v| !v.is_empty());
225
226 Ok(AcceptLanguage(accept_language))
227 }
228}