api_tools/server/axum/layers/
http_errors.rs1use crate::server::axum::response::ApiError;
4use axum::body::Body;
5use axum::http::{Request, StatusCode};
6use axum::response::{IntoResponse, Response};
7use futures::future::BoxFuture;
8use std::task::{Context, Poll};
9use tower::{Layer, Service};
10
11#[derive(Clone, Debug)]
13pub struct HttpErrorsConfig {
14 pub body_max_size: usize,
16}
17
18#[derive(Clone)]
19pub struct HttpErrorsLayer {
20 pub config: HttpErrorsConfig,
21}
22
23impl HttpErrorsLayer {
24 pub fn new(config: &HttpErrorsConfig) -> Self {
26 Self { config: config.clone() }
27 }
28}
29
30impl<S> Layer<S> for HttpErrorsLayer {
31 type Service = HttpErrorsMiddleware<S>;
32
33 fn layer(&self, inner: S) -> Self::Service {
34 HttpErrorsMiddleware {
35 inner,
36 config: self.config.clone(),
37 }
38 }
39}
40
41#[derive(Clone)]
42pub struct HttpErrorsMiddleware<S> {
43 inner: S,
44 config: HttpErrorsConfig,
45}
46
47impl<S> Service<Request<Body>> for HttpErrorsMiddleware<S>
48where
49 S: Service<Request<Body>, Response = Response> + Send + Clone + 'static,
50 S::Future: Send + 'static,
51{
52 type Response = S::Response;
53 type Error = S::Error;
54 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
56
57 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58 self.inner.poll_ready(cx)
59 }
60
61 fn call(&mut self, request: Request<Body>) -> Self::Future {
62 let mut inner = self.inner.clone();
63 let config = self.config.clone();
64
65 Box::pin(async move {
66 let response: Response = inner.call(request).await?;
67
68 let headers = response.headers();
70 if let Some(content_type) = headers.get("content-type") {
71 let content_type = content_type.to_str().unwrap_or_default();
72 if content_type.starts_with("image/")
73 || content_type.starts_with("audio/")
74 || content_type.starts_with("video/")
75 {
76 return Ok(response);
77 }
78 }
79
80 let (parts, body) = response.into_parts();
81 match axum::body::to_bytes(body, config.body_max_size).await {
82 Ok(body) => match String::from_utf8(body.to_vec()) {
83 Ok(body) => match parts.status {
84 StatusCode::METHOD_NOT_ALLOWED => Ok(ApiError::MethodNotAllowed.into_response()),
85 StatusCode::UNPROCESSABLE_ENTITY => Ok(ApiError::UnprocessableEntity(body).into_response()),
86 _ => Ok(Response::from_parts(parts, Body::from(body))),
87 },
88 Err(err) => Ok(ApiError::InternalServerError(err.to_string()).into_response()),
89 },
90 Err(_) => Ok(ApiError::PayloadTooLarge.into_response()),
91 }
92 })
93 }
94}