1use std::{
2 fmt,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use axum::body::BoxBody;
8use futures::Future;
9use http::{HeaderValue, Method, Request, Response, StatusCode};
10use http_body::{Body, Empty};
11use tower_layer::Layer;
12use tower_service::Service;
13
14#[derive(Clone)]
15pub struct CorsLayer;
16
17impl<S> Layer<S> for CorsLayer {
18 type Service = Cors<S>;
19
20 fn layer(&self, service: S) -> Self::Service {
21 Cors::new(service)
22 }
23}
24
25#[derive(Clone)]
26pub struct Cors<S> {
27 inner: S,
28}
29
30impl<S> Cors<S> {
31 pub fn new(inner: S) -> Self {
32 Self { inner }
33 }
34}
35
36#[pin_project::pin_project]
37pub struct CorsFuture<S, ReqBody, ResBody>
38where
39 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
40 S::Error: fmt::Display + 'static,
41{
42 #[pin]
43 inner: S::Future,
44}
45
46impl<S, ReqBody, ResBody> Future for CorsFuture<S, ReqBody, ResBody>
47where
48 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
49 S::Error: fmt::Display + 'static,
50{
51 type Output = <S::Future as Future>::Output;
52
53 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
54 let this = self.project();
55 match this.inner.poll(cx) {
56 Poll::Pending => Poll::Pending,
57 Poll::Ready(Ok(mut response)) => {
58 response
59 .headers_mut()
60 .insert("access-control-allow-origin", HeaderValue::from_static("*"));
61 response.headers_mut().insert(
62 "access-control-allow-methods",
63 HeaderValue::from_static("POST, GET, OPTIONS, PATCH, DELETE"),
64 );
65 response.headers_mut().insert(
66 "access-control-allow-headers",
67 HeaderValue::from_static("content-type"),
68 );
69 Poll::Ready(Ok(response))
70 }
71 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
72 }
73 }
74}
75
76impl<S, ReqBody> Service<Request<ReqBody>> for Cors<S>
77where
78 S: Service<Request<ReqBody>, Response = Response<BoxBody>>,
79 ReqBody: Body + 'static,
80 S: 'static,
81 S::Error: fmt::Display + 'static,
82 S::Future: Send,
83{
84 type Response = Response<BoxBody>;
85 type Error = S::Error;
86 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89 self.inner.poll_ready(cx)
90 }
91
92 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
93 if req.method() == Method::OPTIONS && req.uri().path().starts_with("/api/v1/") {
94 return Box::pin(async move {
95 let mut response: Response<BoxBody> =
96 Response::new(axum::body::boxed(Empty::new()));
97 *response.status_mut() = StatusCode::OK;
98 response
99 .headers_mut()
100 .insert("access-control-allow-origin", HeaderValue::from_static("*"));
101 response.headers_mut().insert(
102 "access-control-allow-methods",
103 HeaderValue::from_static("POST, GET, OPTIONS, PATCH, DELETE"),
104 );
105 response.headers_mut().insert(
106 "access-control-allow-headers",
107 HeaderValue::from_static("*"),
108 );
109 response
110 .headers_mut()
111 .insert("access-control-max-age", HeaderValue::from_static("86400"));
112 Ok(response)
113 });
114 }
115 let future = self.inner.call(req);
116
117 Box::pin(CorsFuture::<S, ReqBody, BoxBody> { inner: future })
118 }
119}