api_tools/server/axum/layers/
http_errors.rs

1//! Override some HTTP errors
2
3use crate::server::axum::response::ApiError;
4use axum::body::Body;
5use axum::http::{Request, StatusCode};
6use axum::response::{IntoResponse, Response};
7use futures::future::BoxFuture;
8use std::task::{Context, Poll};
9use tower::{Layer, Service};
10
11/// Configuration for the `HttpErrorsLayer`
12#[derive(Clone, Debug)]
13pub struct HttpErrorsConfig {
14    /// Maximum size of the body in bytes
15    pub body_max_size: usize,
16}
17
18#[derive(Clone)]
19pub struct HttpErrorsLayer {
20    pub config: HttpErrorsConfig,
21}
22
23impl HttpErrorsLayer {
24    /// Create a new `HttpErrorsLayer`
25    pub fn new(config: &HttpErrorsConfig) -> Self {
26        Self { config: config.clone() }
27    }
28}
29
30impl<S> Layer<S> for HttpErrorsLayer {
31    type Service = HttpErrorsMiddleware<S>;
32
33    fn layer(&self, inner: S) -> Self::Service {
34        HttpErrorsMiddleware {
35            inner,
36            config: self.config.clone(),
37        }
38    }
39}
40
41#[derive(Clone)]
42pub struct HttpErrorsMiddleware<S> {
43    inner: S,
44    config: HttpErrorsConfig,
45}
46
47impl<S> Service<Request<Body>> for HttpErrorsMiddleware<S>
48where
49    S: Service<Request<Body>, Response = Response> + Send + Clone + 'static,
50    S::Future: Send + 'static,
51{
52    type Response = S::Response;
53    type Error = S::Error;
54    // `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
55    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
56
57    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58        self.inner.poll_ready(cx)
59    }
60
61    fn call(&mut self, request: Request<Body>) -> Self::Future {
62        let mut inner = self.inner.clone();
63        let config = self.config.clone();
64
65        Box::pin(async move {
66            let response: Response = inner.call(request).await?;
67
68            // Vérifie le content-type
69            let headers = response.headers();
70            if let Some(content_type) = headers.get("content-type") {
71                let content_type = content_type.to_str().unwrap_or_default();
72                if content_type.starts_with("image/")
73                    || content_type.starts_with("audio/")
74                    || content_type.starts_with("video/")
75                {
76                    return Ok(response);
77                }
78            }
79
80            let (parts, body) = response.into_parts();
81            match axum::body::to_bytes(body, config.body_max_size).await {
82                Ok(body) => match String::from_utf8(body.to_vec()) {
83                    Ok(body) => match parts.status {
84                        StatusCode::METHOD_NOT_ALLOWED => Ok(ApiError::MethodNotAllowed.into_response()),
85                        StatusCode::UNPROCESSABLE_ENTITY => Ok(ApiError::UnprocessableEntity(body).into_response()),
86                        _ => Ok(Response::from_parts(parts, Body::from(body))),
87                    },
88                    Err(err) => Ok(ApiError::InternalServerError(err.to_string()).into_response()),
89                },
90                Err(_) => Ok(ApiError::PayloadTooLarge.into_response()),
91            }
92        })
93    }
94}