rama_http/layer/cors/
allow_origin.rs1use crate::dep::http::{
2 header::{self, HeaderName, HeaderValue},
3 request::Parts as RequestParts,
4};
5use pin_project_lite::pin_project;
6use std::{
7 array, fmt,
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11};
12
13use super::{Any, WILDCARD};
14
15#[derive(Clone, Default)]
22#[must_use]
23pub struct AllowOrigin(OriginInner);
24
25impl AllowOrigin {
26 pub fn any() -> Self {
32 Self(OriginInner::Const(WILDCARD))
33 }
34
35 pub fn exact(origin: HeaderValue) -> Self {
41 Self(OriginInner::Const(origin))
42 }
43
44 #[allow(clippy::borrow_interior_mutable_const)]
54 pub fn list<I>(origins: I) -> Self
55 where
56 I: IntoIterator<Item = HeaderValue>,
57 {
58 let origins = origins.into_iter().collect::<Vec<_>>();
59 if origins.contains(&WILDCARD) {
60 panic!(
61 "Wildcard origin (`*`) cannot be passed to `AllowOrigin::list`. \
62 Use `AllowOrigin::any()` instead"
63 );
64 }
65
66 Self(OriginInner::List(origins))
67 }
68
69 pub fn predicate<F>(f: F) -> Self
75 where
76 F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
77 {
78 Self(OriginInner::Predicate(Arc::new(f)))
79 }
80
81 pub fn async_predicate<F, Fut>(f: F) -> Self
87 where
88 F: FnOnce(HeaderValue, &RequestParts) -> Fut + Send + Sync + 'static + Clone,
89 Fut: Future<Output = bool> + Send + 'static,
90 {
91 Self(OriginInner::AsyncPredicate(Arc::new(move |v, p| {
92 Box::pin((f.clone())(v, p))
93 })))
94 }
95
96 pub fn mirror_request() -> Self {
105 Self::predicate(|_, _| true)
106 }
107
108 #[allow(clippy::borrow_interior_mutable_const)]
109 pub(super) fn is_wildcard(&self) -> bool {
110 matches!(&self.0, OriginInner::Const(v) if v == WILDCARD)
111 }
112
113 pub(super) fn to_future(
114 &self,
115 origin: Option<&HeaderValue>,
116 parts: &RequestParts,
117 ) -> AllowOriginFuture {
118 let name = header::ACCESS_CONTROL_ALLOW_ORIGIN;
119
120 match &self.0 {
121 OriginInner::Const(v) => AllowOriginFuture::ok(Some((name, v.clone()))),
122 OriginInner::List(l) => {
123 AllowOriginFuture::ok(origin.filter(|o| l.contains(o)).map(|o| (name, o.clone())))
124 }
125 OriginInner::Predicate(c) => AllowOriginFuture::ok(
126 origin
127 .filter(|origin| c(origin, parts))
128 .map(|o| (name, o.clone())),
129 ),
130 OriginInner::AsyncPredicate(f) => {
131 if let Some(origin) = origin.cloned() {
132 let fut = f(origin.clone(), parts);
133 AllowOriginFuture::fut(async move { fut.await.then_some((name, origin)) })
134 } else {
135 AllowOriginFuture::ok(None)
136 }
137 }
138 }
139 }
140}
141
142pin_project! {
143 #[project = AllowOriginFutureProj]
144 pub(super) enum AllowOriginFuture {
145 Ok{
146 res: Option<(HeaderName, HeaderValue)>
147 },
148 Future{
149 #[pin]
150 future: Pin<Box<dyn Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>>
151 },
152 }
153}
154
155impl AllowOriginFuture {
156 fn ok(res: Option<(HeaderName, HeaderValue)>) -> Self {
157 Self::Ok { res }
158 }
159
160 fn fut<F: Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>(
161 future: F,
162 ) -> Self {
163 Self::Future {
164 future: Box::pin(future),
165 }
166 }
167}
168
169impl Future for AllowOriginFuture {
170 type Output = Option<(HeaderName, HeaderValue)>;
171
172 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
173 match self.project() {
174 AllowOriginFutureProj::Ok { res } => Poll::Ready(res.take()),
175 AllowOriginFutureProj::Future { future } => future.poll(cx),
176 }
177 }
178}
179
180impl fmt::Debug for AllowOrigin {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 match &self.0 {
183 OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(),
184 OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(),
185 OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
186 OriginInner::AsyncPredicate(_) => f.debug_tuple("AsyncPredicate").finish(),
187 }
188 }
189}
190
191impl From<Any> for AllowOrigin {
192 fn from(_: Any) -> Self {
193 Self::any()
194 }
195}
196
197impl From<HeaderValue> for AllowOrigin {
198 fn from(val: HeaderValue) -> Self {
199 Self::exact(val)
200 }
201}
202
203impl<const N: usize> From<[HeaderValue; N]> for AllowOrigin {
204 fn from(arr: [HeaderValue; N]) -> Self {
205 #[allow(deprecated)] Self::list(array::IntoIter::new(arr))
207 }
208}
209
210impl From<Vec<HeaderValue>> for AllowOrigin {
211 fn from(vec: Vec<HeaderValue>) -> Self {
212 Self::list(vec)
213 }
214}
215
216#[derive(Clone)]
217enum OriginInner {
218 Const(HeaderValue),
219 List(Vec<HeaderValue>),
220 Predicate(
221 Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
222 ),
223 AsyncPredicate(
224 Arc<
225 dyn for<'a> Fn(
226 HeaderValue,
227 &'a RequestParts,
228 ) -> Pin<Box<dyn Future<Output = bool> + Send + 'static>>
229 + Send
230 + Sync
231 + 'static,
232 >,
233 ),
234}
235
236impl Default for OriginInner {
237 fn default() -> Self {
238 Self::List(Vec::new())
239 }
240}