axum_util/
cors.rs

1use std::{
2    fmt,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use axum::body::BoxBody;
8use futures::Future;
9use http::{HeaderValue, Method, Request, Response, StatusCode};
10use http_body::{Body, Empty};
11use tower_layer::Layer;
12use tower_service::Service;
13
14#[derive(Clone)]
15pub struct CorsLayer;
16
17impl<S> Layer<S> for CorsLayer {
18    type Service = Cors<S>;
19
20    fn layer(&self, service: S) -> Self::Service {
21        Cors::new(service)
22    }
23}
24
25#[derive(Clone)]
26pub struct Cors<S> {
27    inner: S,
28}
29
30impl<S> Cors<S> {
31    pub fn new(inner: S) -> Self {
32        Self { inner }
33    }
34}
35
36#[pin_project::pin_project]
37pub struct CorsFuture<S, ReqBody, ResBody>
38where
39    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
40    S::Error: fmt::Display + 'static,
41{
42    #[pin]
43    inner: S::Future,
44}
45
46impl<S, ReqBody, ResBody> Future for CorsFuture<S, ReqBody, ResBody>
47where
48    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
49    S::Error: fmt::Display + 'static,
50{
51    type Output = <S::Future as Future>::Output;
52
53    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
54        let this = self.project();
55        match this.inner.poll(cx) {
56            Poll::Pending => Poll::Pending,
57            Poll::Ready(Ok(mut response)) => {
58                response
59                    .headers_mut()
60                    .insert("access-control-allow-origin", HeaderValue::from_static("*"));
61                response.headers_mut().insert(
62                    "access-control-allow-methods",
63                    HeaderValue::from_static("POST, GET, OPTIONS, PATCH, DELETE"),
64                );
65                response.headers_mut().insert(
66                    "access-control-allow-headers",
67                    HeaderValue::from_static("content-type"),
68                );
69                Poll::Ready(Ok(response))
70            }
71            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
72        }
73    }
74}
75
76impl<S, ReqBody> Service<Request<ReqBody>> for Cors<S>
77where
78    S: Service<Request<ReqBody>, Response = Response<BoxBody>>,
79    ReqBody: Body + 'static,
80    S: 'static,
81    S::Error: fmt::Display + 'static,
82    S::Future: Send,
83{
84    type Response = Response<BoxBody>;
85    type Error = S::Error;
86    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; //CorsFuture<S, ReqBody, ResBody>;
87
88    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89        self.inner.poll_ready(cx)
90    }
91
92    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
93        if req.method() == Method::OPTIONS && req.uri().path().starts_with("/api/v1/") {
94            return Box::pin(async move {
95                let mut response: Response<BoxBody> =
96                    Response::new(axum::body::boxed(Empty::new()));
97                *response.status_mut() = StatusCode::OK;
98                response
99                    .headers_mut()
100                    .insert("access-control-allow-origin", HeaderValue::from_static("*"));
101                response.headers_mut().insert(
102                    "access-control-allow-methods",
103                    HeaderValue::from_static("POST, GET, OPTIONS, PATCH, DELETE"),
104                );
105                response.headers_mut().insert(
106                    "access-control-allow-headers",
107                    HeaderValue::from_static("*"),
108                );
109                response
110                    .headers_mut()
111                    .insert("access-control-max-age", HeaderValue::from_static("86400"));
112                Ok(response)
113            });
114        }
115        let future = self.inner.call(req);
116
117        Box::pin(CorsFuture::<S, ReqBody, BoxBody> { inner: future })
118    }
119}