api_tools/server/axum/layers/
http_errors.rs

1//! Override some HTTP errors
2
3use crate::server::axum::response::ApiError;
4use axum::body::Body;
5use axum::http::{Request, StatusCode};
6use axum::response::{IntoResponse, Response};
7use futures::future::BoxFuture;
8use std::task::{Context, Poll};
9use tower::{Layer, Service};
10
11#[derive(Clone, Debug)]
12pub struct HttpErrorsConfig {
13    pub body_max_size: usize,
14}
15
16#[derive(Clone)]
17pub struct HttpErrorsLayer {
18    pub config: HttpErrorsConfig,
19}
20
21impl HttpErrorsLayer {
22    /// Create a new `HttpErrorsLayer`
23    pub fn new(config: &HttpErrorsConfig) -> Self {
24        Self { config: config.clone() }
25    }
26}
27
28impl<S> Layer<S> for HttpErrorsLayer {
29    type Service = HttpErrorsMiddleware<S>;
30
31    fn layer(&self, inner: S) -> Self::Service {
32        HttpErrorsMiddleware {
33            inner,
34            config: self.config.clone(),
35        }
36    }
37}
38
39#[derive(Clone)]
40pub struct HttpErrorsMiddleware<S> {
41    inner: S,
42    config: HttpErrorsConfig,
43}
44
45impl<S> Service<Request<Body>> for HttpErrorsMiddleware<S>
46where
47    S: Service<Request<Body>, Response = Response> + Send + Clone + 'static,
48    S::Future: Send + 'static,
49{
50    type Response = S::Response;
51    type Error = S::Error;
52    // `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
53    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
54
55    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
56        self.inner.poll_ready(cx)
57    }
58
59    fn call(&mut self, request: Request<Body>) -> Self::Future {
60        let mut inner = self.inner.clone();
61        let config = self.config.clone();
62
63        Box::pin(async move {
64            let response: Response = inner.call(request).await?;
65
66            // Vérifie le content-type
67            let headers = response.headers();
68            if let Some(content_type) = headers.get("content-type") {
69                let content_type = content_type.to_str().unwrap_or_default();
70                if content_type.starts_with("image/")
71                    || content_type.starts_with("audio/")
72                    || content_type.starts_with("video/")
73                {
74                    return Ok(response);
75                }
76            }
77
78            let (parts, body) = response.into_parts();
79            match axum::body::to_bytes(body, config.body_max_size).await {
80                Ok(body) => match String::from_utf8(body.to_vec()) {
81                    Ok(body) => match parts.status {
82                        StatusCode::METHOD_NOT_ALLOWED => Ok(ApiError::MethodNotAllowed.into_response()),
83                        StatusCode::UNPROCESSABLE_ENTITY => Ok(ApiError::UnprocessableEntity(body).into_response()),
84                        _ => Ok(Response::from_parts(parts, Body::from(body))),
85                    },
86                    Err(err) => Ok(ApiError::InternalServerError(err.to_string()).into_response()),
87                },
88                Err(_) => Ok(ApiError::PayloadTooLarge.into_response()),
89            }
90        })
91    }
92}