modo/middleware/
error_handler.rs1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::response::{IntoResponse, Response};
6use http::request::Parts;
7use tower::{Layer, Service};
8
9use crate::error::render_error_body;
10
11pub fn error_handler<F, Fut>(handler: F) -> ErrorHandlerLayer<F>
38where
39 F: Fn(crate::error::Error, Parts) -> Fut + Clone + Send + Sync + 'static,
40 Fut: Future<Output = Response> + Send + 'static,
41{
42 ErrorHandlerLayer { handler }
43}
44
45pub async fn default_error_handler(err: crate::error::Error, parts: Parts) -> Response {
84 let status = err.status();
85 let details = err.details().cloned();
86
87 let message = match (
88 err.locale_key(),
89 parts.extensions.get::<crate::i18n::Translator>(),
90 ) {
91 (Some(key), Some(tr)) => tr.t(key, &[]),
92 _ => err.message().to_string(),
93 };
94
95 let body = render_error_body(status, &message, details.as_ref());
96 (status, axum::Json(body)).into_response()
97}
98
99#[derive(Clone)]
101pub struct ErrorHandlerLayer<F> {
102 handler: F,
103}
104
105impl<S, F> Layer<S> for ErrorHandlerLayer<F>
106where
107 F: Clone,
108{
109 type Service = ErrorHandlerService<S, F>;
110
111 fn layer(&self, inner: S) -> Self::Service {
112 ErrorHandlerService {
113 inner,
114 handler: self.handler.clone(),
115 }
116 }
117}
118
119#[derive(Clone)]
122pub struct ErrorHandlerService<S, F> {
123 inner: S,
124 handler: F,
125}
126
127impl<S, F, Fut> Service<http::Request<axum::body::Body>> for ErrorHandlerService<S, F>
128where
129 S: Service<http::Request<axum::body::Body>, Response = Response> + Clone + Send + 'static,
130 S::Future: Send,
131 S::Error: Into<std::convert::Infallible>,
132 F: Fn(crate::error::Error, Parts) -> Fut + Clone + Send + Sync + 'static,
133 Fut: Future<Output = Response> + Send + 'static,
134{
135 type Response = Response;
136 type Error = S::Error;
137 type Future = Pin<Box<dyn Future<Output = Result<Response, S::Error>> + Send>>;
138
139 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140 self.inner.poll_ready(cx)
141 }
142
143 fn call(&mut self, req: http::Request<axum::body::Body>) -> Self::Future {
144 let (parts, body) = req.into_parts();
147 let saved_parts = parts.clone();
153 let req = http::Request::from_parts(parts, body);
154
155 let handler = self.handler.clone();
156 let future = self.inner.call(req);
157
158 Box::pin(async move {
159 let response = future.await?;
160
161 if let Some(error) = response.extensions().get::<crate::error::Error>() {
162 let error = error.clone();
163 let new_response = handler(error, saved_parts).await;
164 Ok(new_response)
165 } else {
166 Ok(response)
167 }
168 })
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::error::Error as ModoError;
176 use crate::i18n::{I18n, I18nConfig};
177 use axum::body::Body;
178 use axum::{Router, routing::get};
179 use http::{Request, StatusCode};
180 use tower::ServiceExt;
181
182 fn test_i18n(dir: &std::path::Path) -> I18n {
183 let en_dir = dir.join("en");
184 let uk_dir = dir.join("uk");
185 std::fs::create_dir_all(&en_dir).unwrap();
186 std::fs::create_dir_all(&uk_dir).unwrap();
187 std::fs::write(
188 en_dir.join("errors.yaml"),
189 "user:\n not_found: User not found\n",
190 )
191 .unwrap();
192 std::fs::write(
193 uk_dir.join("errors.yaml"),
194 "user:\n not_found: Користувача не знайдено\n",
195 )
196 .unwrap();
197
198 let config = I18nConfig {
199 locales_path: dir.to_str().unwrap().to_string(),
200 default_locale: "en".into(),
201 ..I18nConfig::default()
202 };
203 I18n::new(&config).unwrap()
204 }
205
206 async fn decode_json(resp: axum::response::Response) -> serde_json::Value {
207 let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
208 .await
209 .unwrap();
210 serde_json::from_slice(&bytes).unwrap()
211 }
212
213 async fn localized_handler() -> Result<&'static str, ModoError> {
214 Err(ModoError::localized(
215 StatusCode::NOT_FOUND,
216 "errors.user.not_found",
217 ))
218 }
219
220 async fn plain_handler() -> Result<&'static str, ModoError> {
221 Err(ModoError::bad_request("boom"))
222 }
223
224 #[tokio::test]
225 async fn default_handler_uses_translator_when_present() {
226 let dir = tempfile::tempdir().unwrap();
227 let i18n = test_i18n(dir.path());
228
229 let app = Router::new()
230 .route("/", get(localized_handler))
231 .layer(error_handler(default_error_handler))
232 .layer(i18n.layer());
233
234 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
235 let resp = app.oneshot(req).await.unwrap();
236 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
237
238 let body = decode_json(resp).await;
239 assert_eq!(body["error"]["status"], 404);
240 assert_eq!(body["error"]["message"], "User not found");
241 }
242
243 #[tokio::test]
244 async fn default_handler_translates_using_resolved_locale() {
245 let dir = tempfile::tempdir().unwrap();
246 let i18n = test_i18n(dir.path());
247
248 let app = Router::new()
249 .route("/", get(localized_handler))
250 .layer(error_handler(default_error_handler))
251 .layer(i18n.layer());
252
253 let req = Request::builder()
254 .uri("/?lang=uk")
255 .body(Body::empty())
256 .unwrap();
257 let resp = app.oneshot(req).await.unwrap();
258 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
259
260 let body = decode_json(resp).await;
261 assert_eq!(body["error"]["message"], "Користувача не знайдено");
262 }
263
264 #[tokio::test]
265 async fn default_handler_falls_back_to_key_without_translator() {
266 let app = Router::new()
268 .route("/", get(localized_handler))
269 .layer(error_handler(default_error_handler));
270
271 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
272 let resp = app.oneshot(req).await.unwrap();
273 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
274
275 let body = decode_json(resp).await;
276 assert_eq!(body["error"]["message"], "errors.user.not_found");
278 }
279
280 #[tokio::test]
281 async fn default_handler_passes_through_plain_errors() {
282 let dir = tempfile::tempdir().unwrap();
284 let i18n = test_i18n(dir.path());
285
286 let app = Router::new()
287 .route("/", get(plain_handler))
288 .layer(error_handler(default_error_handler))
289 .layer(i18n.layer());
290
291 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
292 let resp = app.oneshot(req).await.unwrap();
293 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
294 let body = decode_json(resp).await;
295 assert_eq!(body["error"]["message"], "boom");
296
297 let app = Router::new()
299 .route("/", get(plain_handler))
300 .layer(error_handler(default_error_handler));
301 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
302 let resp = app.oneshot(req).await.unwrap();
303 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
304 let body = decode_json(resp).await;
305 assert_eq!(body["error"]["message"], "boom");
306 }
307}