1use std::collections::HashSet;
16use std::future::Future;
17use std::pin::Pin;
18use std::result::Result;
19use std::sync::Arc;
20use std::task::{Context, Poll};
21
22use http::{Method, Request, Response, Uri};
23use tower::{BoxError, Layer, Service};
24use tracing::{debug, instrument, trace};
25use url::Url;
26
27#[derive(thiserror::Error, Debug)]
29pub enum ConfigError {
30 #[error(transparent)]
32 InvalidOriginUrl(#[from] url::ParseError),
33
34 #[error("invalid origin {origin:?}: path, query, and fragment are not allowed")]
36 InvalidOriginUrlComponents { origin: String },
37}
38
39#[derive(thiserror::Error, Debug, PartialEq)]
44pub enum ProtectionError {
45 #[error("Cross-Origin request detected")]
47 CrossOriginRequest,
48
49 #[error("Cross-Origin request from old browser detected")]
51 CrossOriginRequestFromOldBrowser,
52
53 #[error("Host header cannot be parsed")]
55 MalformedHost(#[source] url::ParseError),
56
57 #[error("Origin header cannot be parsed")]
59 MalformedOrigin(#[source] url::ParseError),
60}
61
62struct Bypass<T: Fn(&Method, &Uri) -> bool>(T);
63
64impl<T: Fn(&Method, &Uri) -> bool> std::fmt::Debug for Bypass<T> {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("<fn>").finish()
67 }
68}
69
70trait Filter: std::fmt::Debug + Send + Sync {
71 fn is_bypassed(&self, method: &Method, uri: &Uri) -> bool;
72}
73
74impl<T: Fn(&Method, &Uri) -> bool> Filter for Option<Bypass<T>>
75where
76 T: Send + Sync,
77{
78 fn is_bypassed(&self, method: &Method, uri: &Uri) -> bool {
79 match self {
80 Some(ref p) => p.0(method, uri),
81 None => false,
82 }
83 }
84}
85
86#[derive(Clone, Debug, Default)]
87struct Origins(Arc<HashSet<String>>);
88
89impl Origins {
90 fn contains(&self, origin: &str) -> bool {
91 self.0.contains(origin)
92 }
93
94 fn insert(&mut self, origin: impl Into<String>) {
95 Arc::make_mut(&mut self.0).insert(origin.into());
96 }
97}
98
99#[derive(Clone, Debug)]
101pub struct CrossOriginProtectionLayer {
102 insecure_bypass: Arc<dyn Filter>,
103 trusted_origins: Origins,
104}
105
106impl Default for CrossOriginProtectionLayer {
107 fn default() -> Self {
108 CrossOriginProtectionLayer {
109 insecure_bypass: Arc::new(Option::<Bypass<fn(&Method, &Uri) -> bool>>::default()),
110 trusted_origins: Origins::default(),
111 }
112 }
113}
114
115impl CrossOriginProtectionLayer {
116 pub fn add_trusted_origin<S: Into<String>>(mut self, origin: S) -> Result<Self, ConfigError> {
121 let origin = origin.into();
122
123 let url = Url::parse(&origin)?;
125
126 if url.path() != "/" || url.query().is_some() || url.fragment().is_some() {
128 return Err(ConfigError::InvalidOriginUrlComponents { origin });
129 }
130
131 debug!(origin = %origin, "added trusted origin");
132
133 self.trusted_origins.insert(origin);
134
135 Ok(self)
136 }
137
138 pub fn with_insecure_bypass<F>(self, predicate: F) -> CrossOriginProtectionLayer
141 where
142 F: Fn(&Method, &Uri) -> bool + Send + Sync + 'static,
143 {
144 debug!("added insecure bypass");
145
146 CrossOriginProtectionLayer {
147 insecure_bypass: Arc::new(Some(Bypass(predicate))),
148 trusted_origins: self.trusted_origins,
149 }
150 }
151}
152
153impl<S> Layer<S> for CrossOriginProtectionLayer {
154 type Service = CrossOriginProtectionMiddleware<S>;
155
156 fn layer(&self, inner: S) -> Self::Service {
157 CrossOriginProtectionMiddleware {
158 inner,
159 insecure_bypass: self.insecure_bypass.clone(),
160 trusted_origins: self.trusted_origins.clone(),
161 }
162 }
163}
164
165#[derive(Clone, Debug)]
167pub struct CrossOriginProtectionMiddleware<S> {
168 inner: S,
169 insecure_bypass: Arc<dyn Filter>,
170 trusted_origins: Origins,
171}
172
173impl<S: Default> Default for CrossOriginProtectionMiddleware<S> {
174 fn default() -> Self {
175 Self {
176 inner: S::default(),
177 insecure_bypass: Arc::new(Option::<Bypass<fn(&Method, &Uri) -> bool>>::default()),
178 trusted_origins: Origins::default(),
179 }
180 }
181}
182
183impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for CrossOriginProtectionMiddleware<S>
184where
185 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
186 S::Error: Into<BoxError> + Send,
187 S::Future: Future<Output = Result<Response<ResBody>, S::Error>> + Send,
188 ReqBody: Send + 'static,
189 ResBody: Send + 'static,
190{
191 type Error = BoxError;
192 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
193 type Response = Response<ResBody>;
194
195 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
196 let clone = self.inner.clone();
197 let mut inner = std::mem::replace(&mut self.inner, clone);
198
199 match self.verify(&req) {
200 Ok(_) => Box::pin(async move { inner.call(req).await.map_err(Into::into) }),
201 Err(err) => Box::pin(async move { Err(err.into()) }),
202 }
203 }
204
205 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
206 self.inner.poll_ready(cx).map_err(Into::into)
207 }
208}
209
210impl<S> CrossOriginProtectionMiddleware<S> {
211 #[instrument(skip(self, req), fields(uri = %req.uri()))]
212 fn is_exempt<Body>(&self, req: &Request<Body>) -> bool {
213 if self.insecure_bypass.is_bypassed(req.method(), req.uri()) {
214 trace!("request passed: bypassed");
215 return true;
216 }
217
218 if let Some(origin) = req.headers().get("origin") {
219 if self
220 .trusted_origins
221 .contains(origin.to_str().unwrap_or_default())
222 {
223 trace!("request passed: trusted origin");
224 return true;
225 }
226 }
227
228 false
229 }
230
231 #[instrument(skip(self, req), fields(uri = %req.uri()))]
232 fn verify<Body>(&self, req: &Request<Body>) -> Result<(), ProtectionError> {
233 if matches!(*req.method(), Method::GET | Method::HEAD | Method::OPTIONS) {
234 trace!("request passed: safe method");
235 return Ok(());
236 }
237
238 if let Some(sec_fetch_site) = req
239 .headers()
240 .get("sec-fetch-site")
241 .and_then(|h| h.to_str().ok())
242 {
243 if matches!(sec_fetch_site, "same-origin" | "none") {
244 trace!("request passed: sec-fetch-site is same-origin or none");
245 return Ok(());
246 } else if self.is_exempt(req) {
247 return Ok(());
248 } else {
249 return Err(ProtectionError::CrossOriginRequest);
250 }
251 }
252
253 match req.headers().get("origin").and_then(|h| h.to_str().ok()) {
254 Some("null") => {}
255 Some(origin) => {
256 let origin = Url::parse(origin).map_err(ProtectionError::MalformedOrigin)?;
257
258 let origin_host = origin.host_str();
259 let host = req.headers().get("host").and_then(|h| h.to_str().ok());
260
261 match (origin_host, host) {
269 (Some(origin_host), Some(host)) if origin_host == host => {
270 trace!("request passed: origin is same as host - ");
271 return Ok(());
272 }
273 _ => {}
274 }
275 }
276 None => {
277 trace!("request passed: neither sec-fetch-site nor origin header (same-origin or not a browser request)");
278 return Ok(());
279 }
280 }
281
282 if self.is_exempt(req) {
283 return Ok(());
284 }
285
286 Err(ProtectionError::CrossOriginRequestFromOldBrowser)
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use tracing::Level;
293
294 use super::*;
295 use std::sync::Once;
296
297 static INIT: Once = Once::new();
298
299 fn init() {
300 INIT.call_once(|| {
301 tracing_subscriber::fmt()
302 .with_max_level(Level::TRACE)
303 .init();
304 });
305 }
306
307 #[test]
308 fn test_url_path_normalization() {
309 for url in ["https://example.com/", "https://example.com"] {
310 let url = Url::parse(url).unwrap();
311 assert_eq!(url.path(), "/");
312 }
313 }
314
315 #[test]
316 fn test_layer_add_trusted_origin() {
317 init();
318
319 assert!(matches!(
320 CrossOriginProtectionLayer::default().add_trusted_origin("https://example.com"),
321 Ok(_)
322 ));
323
324 for origin in ["not a valid url", "example.com", "https://"] {
325 assert!(matches!(
326 CrossOriginProtectionLayer::default().add_trusted_origin(origin),
327 Err(ConfigError::InvalidOriginUrl(_))
328 ));
329 }
330
331 for origin in [
332 "https://example.com/path",
333 "https://example.com/path?query=value",
334 "https://example.com/path#fragment",
335 ] {
336 assert!(matches!(
337 CrossOriginProtectionLayer::default().add_trusted_origin(origin),
338 Err(ConfigError::InvalidOriginUrlComponents { origin }) if origin == origin
339 ));
340 }
341 }
342
343 #[test]
344 fn test_middleware_debug_trait() {
345 init();
346
347 let layer = CrossOriginProtectionLayer::default();
348
349 let middleware = layer
350 .clone()
351 .with_insecure_bypass(|method, uri| method == Method::POST && uri.path() == "/bypass")
352 .layer(());
353
354 assert_eq!(
355 format!("{:?}", middleware),
356 "CrossOriginProtectionMiddleware { inner: (), insecure_bypass: Some(<fn>), trusted_origins: Origins({}) }"
357 );
358
359 let middleware = layer.layer(());
360
361 assert_eq!(
362 format!("{:?}", middleware),
363 "CrossOriginProtectionMiddleware { inner: (), insecure_bypass: None, trusted_origins: Origins({}) }"
364 );
365 }
366
367 #[test]
368 fn test_middleware_sec_fetch_site() {
369 init();
370
371 let middleware: CrossOriginProtectionMiddleware<()> = Default::default();
372
373 struct Test {
374 name: &'static str,
375 method: http::Method,
376 sec_fetch_site: Option<&'static str>,
377 origin: Option<&'static str>,
378 result: Result<(), ProtectionError>,
379 }
380
381 let tests = [
382 Test {
383 name: "same-origin allowed",
384 method: Method::GET,
385 sec_fetch_site: Some("same-origin"),
386 origin: None,
387 result: Ok(()),
388 },
389 Test {
390 name: "none allowed",
391 method: Method::POST,
392 sec_fetch_site: Some("none"),
393 origin: None,
394 result: Ok(()),
395 },
396 Test {
397 name: "cross-site blocked",
398 method: Method::POST,
399 sec_fetch_site: Some("cross-site"),
400 origin: None,
401 result: Err(ProtectionError::CrossOriginRequest),
402 },
403 Test {
404 name: "same-site blocked",
405 method: Method::POST,
406 sec_fetch_site: Some("same-site"),
407 origin: None,
408 result: Err(ProtectionError::CrossOriginRequest),
409 },
410 Test {
411 name: "no header with no origin",
412 method: Method::POST,
413 sec_fetch_site: None,
414 origin: None,
415 result: Ok(()),
416 },
417 Test {
418 name: "no header with matching origin",
419 method: Method::POST,
420 sec_fetch_site: None,
421 origin: Some("https://example.com"),
422 result: Ok(()),
423 },
424 Test {
425 name: "no header with mismatched origin",
426 method: Method::POST,
427 sec_fetch_site: None,
428 origin: Some("https://attacker.example"),
429 result: Err(ProtectionError::CrossOriginRequestFromOldBrowser),
430 },
431 Test {
432 name: "no header with null origin",
433 method: Method::POST,
434 sec_fetch_site: None,
435 origin: Some("null"),
436 result: Err(ProtectionError::CrossOriginRequestFromOldBrowser),
437 },
438 Test {
439 name: "GET allowed",
440 method: Method::GET,
441 sec_fetch_site: Some("cross-site"),
442 origin: None,
443 result: Ok(()),
444 },
445 Test {
446 name: "HEAD allowed",
447 method: Method::HEAD,
448 sec_fetch_site: Some("cross-site"),
449 origin: None,
450 result: Ok(()),
451 },
452 Test {
453 name: "OPTIONS allowed",
454 method: Method::OPTIONS,
455 sec_fetch_site: Some("cross-site"),
456 origin: None,
457 result: Ok(()),
458 },
459 Test {
460 name: "PUT allowed",
461 method: Method::PUT,
462 sec_fetch_site: Some("cross-site"),
463 origin: None,
464 result: Err(ProtectionError::CrossOriginRequest),
465 },
466 ];
467
468 for test in tests {
469 let mut req = Request::builder()
470 .method(test.method)
471 .header("host", "example.com");
472
473 if let Some(sec_fetch_site) = test.sec_fetch_site {
474 req = req.header("sec-fetch-site", sec_fetch_site);
475 }
476
477 if let Some(origin) = test.origin {
478 req = req.header("origin", origin);
479 }
480
481 let req = req.body(()).unwrap();
482
483 assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
484 }
485 }
486
487 #[test]
488 fn test_middleware_trusted_origin_bypass() {
489 init();
490
491 let layer = CrossOriginProtectionLayer::default()
492 .add_trusted_origin("https://trusted.example")
493 .unwrap();
494
495 let middleware = layer.layer(());
496
497 struct Test {
498 name: &'static str,
499 sec_fetch_site: Option<&'static str>,
500 origin: Option<&'static str>,
501 result: Result<(), ProtectionError>,
502 }
503
504 let tests = [
505 Test {
506 name: "trusted origin without sec-fetch-site",
507 origin: Some("https://trusted.example"),
508 sec_fetch_site: None,
509 result: Ok(()),
510 },
511 Test {
512 name: "trusted origin with cross-site",
513 origin: Some("https://trusted.example"),
514 sec_fetch_site: Some("cross-site"),
515 result: Ok(()),
516 },
517 Test {
518 name: "untrusted origin without sec-fetch-site",
519 origin: Some("https://attacker.example"),
520 sec_fetch_site: None,
521 result: Err(ProtectionError::CrossOriginRequestFromOldBrowser),
522 },
523 Test {
524 name: "untrusted origin with cross-site",
525 origin: Some("https://attacker.example"),
526 sec_fetch_site: Some("cross-site"),
527 result: Err(ProtectionError::CrossOriginRequest),
528 },
529 ];
530
531 for test in tests {
532 let mut req = Request::builder()
533 .method("POST")
534 .header("host", "example.com");
535
536 if let Some(sec_fetch_site) = test.sec_fetch_site {
537 req = req.header("sec-fetch-site", sec_fetch_site);
538 }
539
540 if let Some(origin) = test.origin {
541 req = req.header("origin", origin);
542 }
543
544 let req = req.body(()).unwrap();
545
546 assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
547 }
548 }
549
550 #[test]
551 fn test_middleware_bypass() {
552 init();
553
554 let layer = CrossOriginProtectionLayer::default()
555 .with_insecure_bypass(|_method, uri| -> bool { uri.path() == "/bypass" });
556
557 let middleware = layer.layer(());
558
559 struct Test {
560 name: &'static str,
561 path: &'static str,
562 sec_fetch_site: Option<&'static str>,
563 result: Result<(), ProtectionError>,
564 }
565
566 let tests = [
567 Test {
568 name: "bypass path without sec-fetch-site",
569 path: "/bypass",
570 sec_fetch_site: None,
571 result: Ok(()),
572 },
573 Test {
574 name: "bypass path with cross-site",
575 path: "/bypass",
576 sec_fetch_site: Some("cross-site"),
577 result: Ok(()),
578 },
579 Test {
580 name: "non-bypass path without sec-fetch-site",
581 path: "/api",
582 sec_fetch_site: None,
583 result: Err(ProtectionError::CrossOriginRequestFromOldBrowser),
584 },
585 Test {
586 name: "non-bypass path with cross-site",
587 path: "/api",
588 sec_fetch_site: Some("cross-site"),
589 result: Err(ProtectionError::CrossOriginRequest),
590 },
591 ];
592
593 for test in tests {
594 let mut req = Request::builder()
595 .method("POST")
596 .header("host", "example.com")
597 .header("origin", "https://attacker.example")
598 .uri(format!("https://example.com{}", test.path));
599
600 if let Some(sec_fetch_site) = test.sec_fetch_site {
601 req = req.header("sec-fetch-site", sec_fetch_site);
602 }
603
604 let req = req.body(()).unwrap();
605
606 assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
607 }
608 }
609}