api_tools/server/axum/layers/
http_errors.rs
1use crate::server::axum::response::ApiError;
4use axum::body::Body;
5use axum::http::{Request, StatusCode};
6use axum::response::{IntoResponse, Response};
7use futures::future::BoxFuture;
8use std::task::{Context, Poll};
9use tower::{Layer, Service};
10
11#[derive(Clone, Debug)]
12pub struct HttpErrorsConfig {
13 pub body_max_size: usize,
14}
15
16#[derive(Clone)]
17pub struct HttpErrorsLayer {
18 pub config: HttpErrorsConfig,
19}
20
21impl HttpErrorsLayer {
22 pub fn new(config: &HttpErrorsConfig) -> Self {
24 Self { config: config.clone() }
25 }
26}
27
28impl<S> Layer<S> for HttpErrorsLayer {
29 type Service = HttpErrorsMiddleware<S>;
30
31 fn layer(&self, inner: S) -> Self::Service {
32 HttpErrorsMiddleware {
33 inner,
34 config: self.config.clone(),
35 }
36 }
37}
38
39#[derive(Clone)]
40pub struct HttpErrorsMiddleware<S> {
41 inner: S,
42 config: HttpErrorsConfig,
43}
44
45impl<S> Service<Request<Body>> for HttpErrorsMiddleware<S>
46where
47 S: Service<Request<Body>, Response = Response> + Send + Clone + 'static,
48 S::Future: Send + 'static,
49{
50 type Response = S::Response;
51 type Error = S::Error;
52 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
54
55 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
56 self.inner.poll_ready(cx)
57 }
58
59 fn call(&mut self, request: Request<Body>) -> Self::Future {
60 let mut inner = self.inner.clone();
61 let config = self.config.clone();
62
63 Box::pin(async move {
64 let response: Response = inner.call(request).await?;
65
66 let headers = response.headers();
68 if let Some(content_type) = headers.get("content-type") {
69 let content_type = content_type.to_str().unwrap_or_default();
70 if content_type.starts_with("image/")
71 || content_type.starts_with("audio/")
72 || content_type.starts_with("video/")
73 {
74 return Ok(response);
75 }
76 }
77
78 let (parts, body) = response.into_parts();
79 match axum::body::to_bytes(body, config.body_max_size).await {
80 Ok(body) => match String::from_utf8(body.to_vec()) {
81 Ok(body) => match parts.status {
82 StatusCode::METHOD_NOT_ALLOWED => Ok(ApiError::MethodNotAllowed.into_response()),
83 StatusCode::UNPROCESSABLE_ENTITY => Ok(ApiError::UnprocessableEntity(body).into_response()),
84 _ => Ok(Response::from_parts(parts, Body::from(body))),
85 },
86 Err(err) => Ok(ApiError::InternalServerError(err.to_string()).into_response()),
87 },
88 Err(_) => Ok(ApiError::PayloadTooLarge.into_response()),
89 }
90 })
91 }
92}