rama_http/layer/cors/
allow_origin.rs

1use crate::dep::http::{
2    header::{self, HeaderName, HeaderValue},
3    request::Parts as RequestParts,
4};
5use pin_project_lite::pin_project;
6use std::{
7    array, fmt,
8    pin::Pin,
9    sync::Arc,
10    task::{Context, Poll},
11};
12
13use super::{Any, WILDCARD};
14
15/// Holds configuration for how to set the [`Access-Control-Allow-Origin`][mdn] header.
16///
17/// See [`CorsLayer::allow_origin`] for more details.
18///
19/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
20/// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
21#[derive(Clone, Default)]
22#[must_use]
23pub struct AllowOrigin(OriginInner);
24
25impl AllowOrigin {
26    /// Allow any origin by sending a wildcard (`*`)
27    ///
28    /// See [`CorsLayer::allow_origin`] for more details.
29    ///
30    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
31    pub fn any() -> Self {
32        Self(OriginInner::Const(WILDCARD))
33    }
34
35    /// Set a single allowed origin
36    ///
37    /// See [`CorsLayer::allow_origin`] for more details.
38    ///
39    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
40    pub fn exact(origin: HeaderValue) -> Self {
41        Self(OriginInner::Const(origin))
42    }
43
44    /// Set multiple allowed origins
45    ///
46    /// See [`CorsLayer::allow_origin`] for more details.
47    ///
48    /// # Panics
49    ///
50    /// If the iterator contains a wildcard (`*`).
51    ///
52    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
53    #[allow(clippy::borrow_interior_mutable_const)]
54    pub fn list<I>(origins: I) -> Self
55    where
56        I: IntoIterator<Item = HeaderValue>,
57    {
58        let origins = origins.into_iter().collect::<Vec<_>>();
59        if origins.contains(&WILDCARD) {
60            panic!(
61                "Wildcard origin (`*`) cannot be passed to `AllowOrigin::list`. \
62                 Use `AllowOrigin::any()` instead"
63            );
64        }
65
66        Self(OriginInner::List(origins))
67    }
68
69    /// Set the allowed origins from a predicate
70    ///
71    /// See [`CorsLayer::allow_origin`] for more details.
72    ///
73    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
74    pub fn predicate<F>(f: F) -> Self
75    where
76        F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
77    {
78        Self(OriginInner::Predicate(Arc::new(f)))
79    }
80
81    /// Set the allowed origins from an async predicate
82    ///
83    /// See [`CorsLayer::allow_origin`] for more details.
84    ///
85    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
86    pub fn async_predicate<F, Fut>(f: F) -> Self
87    where
88        F: FnOnce(HeaderValue, &RequestParts) -> Fut + Send + Sync + 'static + Clone,
89        Fut: Future<Output = bool> + Send + 'static,
90    {
91        Self(OriginInner::AsyncPredicate(Arc::new(move |v, p| {
92            Box::pin((f.clone())(v, p))
93        })))
94    }
95
96    /// Allow any origin, by mirroring the request origin
97    ///
98    /// This is equivalent to
99    /// [`AllowOrigin::predicate(|_, _| true)`][Self::predicate].
100    ///
101    /// See [`CorsLayer::allow_origin`] for more details.
102    ///
103    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
104    pub fn mirror_request() -> Self {
105        Self::predicate(|_, _| true)
106    }
107
108    #[allow(clippy::borrow_interior_mutable_const)]
109    pub(super) fn is_wildcard(&self) -> bool {
110        matches!(&self.0, OriginInner::Const(v) if v == WILDCARD)
111    }
112
113    pub(super) fn to_future(
114        &self,
115        origin: Option<&HeaderValue>,
116        parts: &RequestParts,
117    ) -> AllowOriginFuture {
118        let name = header::ACCESS_CONTROL_ALLOW_ORIGIN;
119
120        match &self.0 {
121            OriginInner::Const(v) => AllowOriginFuture::ok(Some((name, v.clone()))),
122            OriginInner::List(l) => {
123                AllowOriginFuture::ok(origin.filter(|o| l.contains(o)).map(|o| (name, o.clone())))
124            }
125            OriginInner::Predicate(c) => AllowOriginFuture::ok(
126                origin
127                    .filter(|origin| c(origin, parts))
128                    .map(|o| (name, o.clone())),
129            ),
130            OriginInner::AsyncPredicate(f) => {
131                if let Some(origin) = origin.cloned() {
132                    let fut = f(origin.clone(), parts);
133                    AllowOriginFuture::fut(async move { fut.await.then_some((name, origin)) })
134                } else {
135                    AllowOriginFuture::ok(None)
136                }
137            }
138        }
139    }
140}
141
142pin_project! {
143    #[project = AllowOriginFutureProj]
144    pub(super) enum AllowOriginFuture {
145        Ok{
146            res: Option<(HeaderName, HeaderValue)>
147        },
148        Future{
149            #[pin]
150            future: Pin<Box<dyn Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>>
151        },
152    }
153}
154
155impl AllowOriginFuture {
156    fn ok(res: Option<(HeaderName, HeaderValue)>) -> Self {
157        Self::Ok { res }
158    }
159
160    fn fut<F: Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>(
161        future: F,
162    ) -> Self {
163        Self::Future {
164            future: Box::pin(future),
165        }
166    }
167}
168
169impl Future for AllowOriginFuture {
170    type Output = Option<(HeaderName, HeaderValue)>;
171
172    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
173        match self.project() {
174            AllowOriginFutureProj::Ok { res } => Poll::Ready(res.take()),
175            AllowOriginFutureProj::Future { future } => future.poll(cx),
176        }
177    }
178}
179
180impl fmt::Debug for AllowOrigin {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        match &self.0 {
183            OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(),
184            OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(),
185            OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
186            OriginInner::AsyncPredicate(_) => f.debug_tuple("AsyncPredicate").finish(),
187        }
188    }
189}
190
191impl From<Any> for AllowOrigin {
192    fn from(_: Any) -> Self {
193        Self::any()
194    }
195}
196
197impl From<HeaderValue> for AllowOrigin {
198    fn from(val: HeaderValue) -> Self {
199        Self::exact(val)
200    }
201}
202
203impl<const N: usize> From<[HeaderValue; N]> for AllowOrigin {
204    fn from(arr: [HeaderValue; N]) -> Self {
205        #[allow(deprecated)] // Can be changed when MSRV >= 1.53
206        Self::list(array::IntoIter::new(arr))
207    }
208}
209
210impl From<Vec<HeaderValue>> for AllowOrigin {
211    fn from(vec: Vec<HeaderValue>) -> Self {
212        Self::list(vec)
213    }
214}
215
216#[derive(Clone)]
217enum OriginInner {
218    Const(HeaderValue),
219    List(Vec<HeaderValue>),
220    Predicate(
221        Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
222    ),
223    AsyncPredicate(
224        Arc<
225            dyn for<'a> Fn(
226                    HeaderValue,
227                    &'a RequestParts,
228                ) -> Pin<Box<dyn Future<Output = bool> + Send + 'static>>
229                + Send
230                + Sync
231                + 'static,
232        >,
233    ),
234}
235
236impl Default for OriginInner {
237    fn default() -> Self {
238        Self::List(Vec::new())
239    }
240}