1use std::collections::HashSet;
91use std::fmt::{self, Debug, Formatter};
92use std::sync::Arc;
93
94use http::{Method, Uri};
95
96mod future;
97mod layer;
98mod response;
99mod service;
100mod url;
101
102pub use self::future::ResponseFuture;
103pub use self::layer::CsrfLayer;
104pub use self::response::{DefaultResponseForProtectionError, ResponseForProtectionError};
105pub use self::service::Csrf;
106
107#[derive(Clone, Debug, PartialEq)]
109#[non_exhaustive]
110pub enum ConfigError {
111 InvalidOriginUrl {
113 origin: String,
115 message: String,
117 },
118
119 InvalidOriginUrlComponents {
122 origin: String,
124 },
125
126 OpaqueOrigin {
131 origin: String,
133 },
134
135 NonAsciiHostname {
139 origin: String,
141 },
142}
143
144impl fmt::Display for ConfigError {
145 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
146 match self {
147 ConfigError::InvalidOriginUrl { origin, message } => {
148 write!(f, "invalid origin {origin:?}: {message}")
149 }
150 ConfigError::InvalidOriginUrlComponents { origin } => write!(
151 f,
152 "invalid origin {origin:?}: path, query, and fragment are not allowed"
153 ),
154 ConfigError::OpaqueOrigin { origin } => write!(
155 f,
156 "invalid origin {origin:?}: scheme must be http or https"
157 ),
158 ConfigError::NonAsciiHostname { origin } => write!(
159 f,
160 "invalid origin {origin:?}: non-ASCII hostnames must be supplied in punycode (xn--…)"
161 ),
162 }
163 }
164}
165
166impl std::error::Error for ConfigError {}
167
168#[derive(Clone, Debug)]
180pub struct ProtectionError {
181 kind: ProtectionErrorKind,
182}
183
184impl ProtectionError {
185 pub(crate) fn new(kind: ProtectionErrorKind) -> Self {
186 Self { kind }
187 }
188
189 pub fn kind(&self) -> ProtectionErrorKind {
191 self.kind
192 }
193}
194
195impl fmt::Display for ProtectionError {
196 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
197 match self.kind {
198 ProtectionErrorKind::CrossOriginRequest => f.write_str("Cross-Origin request detected"),
199 ProtectionErrorKind::CrossOriginRequestFromOldBrowser => {
200 f.write_str("Cross-Origin request from old browser detected")
201 }
202 }
203 }
204}
205
206impl std::error::Error for ProtectionError {}
207
208#[derive(Clone, Copy, Debug, PartialEq, Eq)]
210#[non_exhaustive]
211pub enum ProtectionErrorKind {
212 CrossOriginRequest,
214
215 CrossOriginRequestFromOldBrowser,
219}
220
221type BypassFn = dyn Fn(&Method, &Uri) -> bool + Send + Sync + 'static;
222
223struct DebugFn;
224
225impl Debug for DebugFn {
226 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
227 f.write_str("<fn>")
228 }
229}
230
231#[derive(Clone, Default)]
232struct Origins(Arc<HashSet<Vec<u8>>>);
233
234impl Origins {
235 fn contains(&self, origin: &[u8]) -> bool {
236 self.0.contains(origin)
237 }
238
239 fn insert(&mut self, origin: impl Into<Vec<u8>>) {
240 Arc::make_mut(&mut self.0).insert(origin.into());
241 }
242}
243
244impl Debug for Origins {
245 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
246 write!(f, "Origins(")?;
248 f.debug_set()
249 .entries(self.0.iter().map(|o| String::from_utf8_lossy(o)))
250 .finish()?;
251 write!(f, ")")
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use std::convert::Infallible;
258
259 use http::{Request, Response, StatusCode};
260 use tower::{service_fn, ServiceExt};
261 use tower_layer::Layer;
262
263 use super::*;
264 use crate::test_helpers::{to_bytes, Body};
265
266 impl PartialEq for super::ProtectionError {
267 fn eq(&self, other: &Self) -> bool {
268 self.kind == other.kind
269 }
270 }
271
272 fn echo_service() -> impl tower::Service<
273 Request<Body>,
274 Response = Response<Body>,
275 Error = Infallible,
276 Future = impl std::future::Future<Output = Result<Response<Body>, Infallible>>,
277 > + Clone {
278 service_fn(|req: Request<Body>| async move {
279 let body: Body = match req.uri().path() {
280 "/foo" => "foo".into(),
281 "/bar" => "bar".into(),
282 _ => Body::empty(),
283 };
284 Ok::<_, Infallible>(Response::new(body))
285 })
286 }
287
288 #[tokio::test]
289 async fn test_service_allows_safe_method() {
290 let svc = CsrfLayer::new()
291 .add_trusted_origin("https://example.com")
292 .unwrap()
293 .layer(echo_service());
294
295 let req = Request::builder()
296 .method("GET")
297 .uri("/foo")
298 .body(Body::empty())
299 .unwrap();
300
301 let res = svc.oneshot(req).await.unwrap();
302
303 assert_eq!(res.status(), StatusCode::OK);
304
305 let body = to_bytes(res.into_body()).await.unwrap();
306 assert_eq!(&body[..], b"foo");
307 }
308
309 #[tokio::test]
310 async fn test_service_allows_post_from_trusted_origin() {
311 let svc = CsrfLayer::new()
312 .add_trusted_origin("https://example.com")
313 .unwrap()
314 .layer(echo_service());
315
316 let req = Request::builder()
317 .method("POST")
318 .uri("/bar")
319 .header("origin", "https://example.com")
320 .body(Body::empty())
321 .unwrap();
322
323 let res = svc.oneshot(req).await.unwrap();
324
325 assert_eq!(res.status(), StatusCode::OK);
326
327 let body = to_bytes(res.into_body()).await.unwrap();
328 assert_eq!(&body[..], b"bar");
329 }
330
331 #[tokio::test]
332 async fn test_service_rejects_post_from_untrusted_origin() {
333 let svc = CsrfLayer::new()
334 .add_trusted_origin("https://example.com")
335 .unwrap()
336 .layer(echo_service());
337
338 let req = Request::builder()
339 .method("POST")
340 .uri("/bar")
341 .header("origin", "https://malicious.example")
342 .body(Body::empty())
343 .unwrap();
344
345 let res = svc.oneshot(req).await.unwrap();
346
347 assert_eq!(res.status(), StatusCode::FORBIDDEN);
348 assert_eq!(
349 res.extensions().get::<ProtectionError>(),
350 Some(&ProtectionError::new(
351 ProtectionErrorKind::CrossOriginRequestFromOldBrowser
352 )),
353 );
354 }
355
356 #[tokio::test]
357 async fn test_service_uses_custom_rejection_response() {
358 let svc = CsrfLayer::new()
359 .with_rejection_response(|_err: ProtectionError| {
360 let mut res = Response::new(Body::from("denied"));
361 *res.status_mut() = StatusCode::IM_A_TEAPOT;
362 res
363 })
364 .layer(echo_service());
365
366 let req = Request::builder()
367 .method("POST")
368 .uri("/bar")
369 .header("origin", "https://malicious.example")
370 .body(Body::empty())
371 .unwrap();
372
373 let res = svc.oneshot(req).await.unwrap();
374
375 assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
376 assert_ne!(res.status(), StatusCode::OK);
377 assert_eq!(
380 res.extensions().get::<ProtectionError>(),
381 Some(&ProtectionError::new(
382 ProtectionErrorKind::CrossOriginRequestFromOldBrowser
383 )),
384 );
385
386 let body = to_bytes(res.into_body()).await.unwrap();
387 assert_eq!(&body[..], b"denied");
388 }
389
390 #[tokio::test]
391 async fn test_service_custom_rejection_response_not_invoked_when_allowed() {
392 let svc = CsrfLayer::new()
393 .add_trusted_origin("https://example.com")
394 .unwrap()
395 .with_rejection_response(|_err: ProtectionError| {
396 let mut res = Response::new(Body::from("denied"));
397 *res.status_mut() = StatusCode::IM_A_TEAPOT;
398 res
399 })
400 .layer(echo_service());
401
402 let req = Request::builder()
403 .method("POST")
404 .uri("/bar")
405 .header("origin", "https://example.com")
406 .body(Body::empty())
407 .unwrap();
408
409 let res = svc.oneshot(req).await.unwrap();
410
411 assert_eq!(res.status(), StatusCode::OK);
412 assert_ne!(res.status(), StatusCode::IM_A_TEAPOT);
413 assert!(res.extensions().get::<ProtectionError>().is_none());
414
415 let body = to_bytes(res.into_body()).await.unwrap();
416 assert_eq!(&body[..], b"bar");
417 }
418
419 #[test]
420 fn test_layer_add_trusted_origin() {
421 assert!(CsrfLayer::new()
424 .add_trusted_origin("https://example.com")
425 .is_ok());
426 assert!(matches!(
427 CsrfLayer::new().add_trusted_origin("not a valid url"),
428 Err(ConfigError::InvalidOriginUrl { .. })
429 ));
430 }
431
432 #[test]
433 fn test_middleware_bypass() {
434 let layer = CsrfLayer::new()
435 .with_insecure_bypass(|_method, uri| -> bool { uri.path() == "/bypass" });
436
437 let middleware = layer.layer(());
438
439 struct Test {
440 name: &'static str,
441 path: &'static str,
442 sec_fetch_site: Option<&'static str>,
443 result: Result<(), ProtectionError>,
444 }
445
446 let tests = [
447 Test {
448 name: "bypass path without sec-fetch-site",
449 path: "/bypass",
450 sec_fetch_site: None,
451 result: Ok(()),
452 },
453 Test {
454 name: "bypass path with cross-site",
455 path: "/bypass",
456 sec_fetch_site: Some("cross-site"),
457 result: Ok(()),
458 },
459 Test {
460 name: "non-bypass path without sec-fetch-site",
461 path: "/api",
462 sec_fetch_site: None,
463 result: Err(ProtectionError::new(
464 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
465 )),
466 },
467 Test {
468 name: "non-bypass path with cross-site",
469 path: "/api",
470 sec_fetch_site: Some("cross-site"),
471 result: Err(ProtectionError::new(
472 ProtectionErrorKind::CrossOriginRequest,
473 )),
474 },
475 ];
476
477 for test in tests {
478 let mut req = Request::builder()
479 .method("POST")
480 .header("host", "example.com")
481 .header("origin", "https://attacker.example")
482 .uri(format!("https://example.com{}", test.path));
483
484 if let Some(sec_fetch_site) = test.sec_fetch_site {
485 req = req.header("sec-fetch-site", sec_fetch_site);
486 }
487
488 let req = req.body(()).unwrap();
489
490 assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
491 }
492 }
493
494 #[test]
495 fn test_middleware_bypass_applies_when_origin_unparseable() {
496 let middleware = CsrfLayer::new()
497 .with_insecure_bypass(|_method, uri| uri.path() == "/bypass")
498 .layer(());
499
500 let req = Request::builder()
501 .method("POST")
502 .uri("https://example.com/bypass")
503 .header("host", "example.com")
504 .header(
505 "origin",
506 http::HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap(),
507 )
508 .body(())
509 .unwrap();
510
511 assert_eq!(middleware.verify(&req), Ok(()));
512 }
513
514 #[test]
515 fn test_middleware_debug_trait() {
516 let layer = CsrfLayer::new();
517
518 let middleware = layer
519 .clone()
520 .with_insecure_bypass(|method, uri| method == Method::POST && uri.path() == "/bypass")
521 .layer(());
522
523 assert_eq!(
524 format!("{:?}", middleware),
525 "Csrf { inner: (), insecure_bypass: Some(<fn>), trusted_origins: Origins({}), rejection_response: <fn> }"
526 );
527
528 let middleware = layer.layer(());
529
530 assert_eq!(
531 format!("{:?}", middleware),
532 "Csrf { inner: (), insecure_bypass: None, trusted_origins: Origins({}), rejection_response: <fn> }"
533 );
534 }
535
536 #[test]
537 fn test_middleware_origin_host_port_match() {
538 let middleware: Csrf<()> = Default::default();
539
540 struct Test {
541 name: &'static str,
542 uri: &'static str,
543 host: Option<&'static str>,
544 origin: &'static str,
545 result: Result<(), ProtectionError>,
546 }
547
548 let tests = [
549 Test {
550 name: "default port both sides",
551 uri: "/",
552 host: Some("example.com"),
553 origin: "https://example.com",
554 result: Ok(()),
555 },
556 Test {
557 name: "same non-default port both sides",
558 uri: "/",
559 host: Some("example.com:8443"),
560 origin: "https://example.com:8443",
561 result: Ok(()),
562 },
563 Test {
564 name: "explicit default port both sides",
565 uri: "/",
566 host: Some("example.com:443"),
567 origin: "https://example.com:443",
568 result: Ok(()),
569 },
570 Test {
571 name: "mismatched non-default ports",
572 uri: "/",
573 host: Some("example.com:8443"),
574 origin: "https://example.com:8444",
575 result: Err(ProtectionError::new(
576 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
577 )),
578 },
579 Test {
580 name: "origin has explicit default, host implicit",
583 uri: "/",
584 host: Some("example.com"),
585 origin: "https://example.com:443",
586 result: Err(ProtectionError::new(
587 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
588 )),
589 },
590 Test {
591 name: "host has explicit default, origin implicit",
592 uri: "/",
593 host: Some("example.com:443"),
594 origin: "https://example.com",
595 result: Err(ProtectionError::new(
596 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
597 )),
598 },
599 Test {
600 name: "host implicit, origin explicit non-default",
601 uri: "/",
602 host: Some("example.com"),
603 origin: "https://example.com:8443",
604 result: Err(ProtectionError::new(
605 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
606 )),
607 },
608 Test {
609 name: "missing host, uri authority implicit, origin explicit non-default",
610 uri: "https://example.com/path",
611 host: None,
612 origin: "https://example.com:8443",
613 result: Err(ProtectionError::new(
614 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
615 )),
616 },
617 Test {
618 name: "malformed host header compared verbatim",
621 uri: "/path",
622 host: Some("not a valid authority"),
623 origin: "https://example.com",
624 result: Err(ProtectionError::new(
625 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
626 )),
627 },
628 Test {
629 name: "request-target authority wins over host header (match)",
632 uri: "https://example.com/path",
633 host: Some("other.example"),
634 origin: "https://example.com",
635 result: Ok(()),
636 },
637 Test {
638 name: "origin matching host header but not authority is rejected",
641 uri: "https://example.com/path",
642 host: Some("other.example"),
643 origin: "https://other.example",
644 result: Err(ProtectionError::new(
645 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
646 )),
647 },
648 Test {
649 name: "missing host, uri carries authority (match)",
650 uri: "https://example.com/path",
651 host: None,
652 origin: "https://example.com",
653 result: Ok(()),
654 },
655 Test {
656 name: "missing host, uri authority mismatch",
657 uri: "https://other.example/path",
658 host: None,
659 origin: "https://example.com",
660 result: Err(ProtectionError::new(
661 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
662 )),
663 },
664 Test {
665 name: "missing host and no uri authority",
666 uri: "/path",
667 host: None,
668 origin: "https://example.com",
669 result: Err(ProtectionError::new(
670 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
671 )),
672 },
673 Test {
674 name: "scheme-less origin does not match host even if bytes agree",
675 uri: "/",
676 host: Some("example.com:8443"),
677 origin: "example.com:8443",
678 result: Err(ProtectionError::new(
679 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
680 )),
681 },
682 Test {
683 name: "non-http origin scheme does not enter host fallback",
684 uri: "/",
685 host: Some("example.com:8443"),
686 origin: "ftp://example.com:8443",
687 result: Err(ProtectionError::new(
688 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
689 )),
690 },
691 ];
692
693 for test in tests {
694 let mut req = Request::builder().method(Method::POST).uri(test.uri);
695
696 if let Some(host) = test.host {
697 req = req.header("host", host);
698 }
699
700 let req = req.header("origin", test.origin).body(()).unwrap();
701
702 assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
703 }
704 }
705
706 #[test]
707 fn test_middleware_sec_fetch_site() {
708 let middleware: Csrf<()> = Default::default();
709
710 const NON_DECODABLE: &[u8] = &[0xFF, 0xFE];
711 assert!(
712 http::HeaderValue::from_bytes(NON_DECODABLE)
713 .expect("NON_DECODABLE must be a valid HeaderValue")
714 .to_str()
715 .is_err(),
716 "NON_DECODABLE must fail HeaderValue::to_str()"
717 );
718
719 struct Test {
720 name: &'static str,
721 method: http::Method,
722 sec_fetch_site: Option<&'static [u8]>,
723 origin: Option<&'static [u8]>,
724 result: Result<(), ProtectionError>,
725 }
726
727 let tests = [
728 Test {
729 name: "same-origin allowed",
730 method: Method::GET,
731 sec_fetch_site: Some(b"same-origin"),
732 origin: None,
733 result: Ok(()),
734 },
735 Test {
736 name: "none allowed",
737 method: Method::POST,
738 sec_fetch_site: Some(b"none"),
739 origin: None,
740 result: Ok(()),
741 },
742 Test {
743 name: "cross-site blocked",
744 method: Method::POST,
745 sec_fetch_site: Some(b"cross-site"),
746 origin: None,
747 result: Err(ProtectionError::new(
748 ProtectionErrorKind::CrossOriginRequest,
749 )),
750 },
751 Test {
752 name: "same-site blocked",
753 method: Method::POST,
754 sec_fetch_site: Some(b"same-site"),
755 origin: None,
756 result: Err(ProtectionError::new(
757 ProtectionErrorKind::CrossOriginRequest,
758 )),
759 },
760 Test {
761 name: "no header with no origin",
762 method: Method::POST,
763 sec_fetch_site: None,
764 origin: None,
765 result: Ok(()),
766 },
767 Test {
768 name: "no header with matching origin",
769 method: Method::POST,
770 sec_fetch_site: None,
771 origin: Some(b"https://example.com"),
772 result: Ok(()),
773 },
774 Test {
775 name: "no header with mismatched origin",
776 method: Method::POST,
777 sec_fetch_site: None,
778 origin: Some(b"https://attacker.example"),
779 result: Err(ProtectionError::new(
780 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
781 )),
782 },
783 Test {
784 name: "no header with null origin",
785 method: Method::POST,
786 sec_fetch_site: None,
787 origin: Some(b"null"),
788 result: Err(ProtectionError::new(
789 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
790 )),
791 },
792 Test {
793 name: "GET allowed",
794 method: Method::GET,
795 sec_fetch_site: Some(b"cross-site"),
796 origin: None,
797 result: Ok(()),
798 },
799 Test {
800 name: "HEAD allowed",
801 method: Method::HEAD,
802 sec_fetch_site: Some(b"cross-site"),
803 origin: None,
804 result: Ok(()),
805 },
806 Test {
807 name: "OPTIONS allowed",
808 method: Method::OPTIONS,
809 sec_fetch_site: Some(b"cross-site"),
810 origin: None,
811 result: Ok(()),
812 },
813 Test {
814 name: "PUT blocked",
815 method: Method::PUT,
816 sec_fetch_site: Some(b"cross-site"),
817 origin: None,
818 result: Err(ProtectionError::new(
819 ProtectionErrorKind::CrossOriginRequest,
820 )),
821 },
822 Test {
823 name: "non-decodable origin without sec-fetch-site rejected",
824 method: Method::POST,
825 sec_fetch_site: None,
826 origin: Some(NON_DECODABLE),
827 result: Err(ProtectionError::new(
828 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
829 )),
830 },
831 Test {
832 name: "non-decodable sec-fetch-site without origin rejected",
833 method: Method::POST,
834 sec_fetch_site: Some(NON_DECODABLE),
835 origin: None,
836 result: Err(ProtectionError::new(
837 ProtectionErrorKind::CrossOriginRequest,
838 )),
839 },
840 Test {
841 name: "empty sec-fetch-site without origin allowed",
842 method: Method::POST,
843 sec_fetch_site: Some(b""),
844 origin: None,
845 result: Ok(()),
846 },
847 Test {
848 name: "empty origin without sec-fetch-site allowed",
849 method: Method::POST,
850 sec_fetch_site: None,
851 origin: Some(b""),
852 result: Ok(()),
853 },
854 ];
855
856 for test in tests {
857 let mut req = Request::builder()
858 .method(test.method)
859 .header("host", "example.com");
860
861 if let Some(sec_fetch_site) = test.sec_fetch_site {
862 req = req.header("sec-fetch-site", sec_fetch_site);
863 }
864
865 if let Some(origin) = test.origin {
866 req = req.header("origin", origin);
867 }
868
869 let req = req.body(()).unwrap();
870
871 assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
872 }
873 }
874
875 #[test]
876 fn test_middleware_trusted_origin_bypass() {
877 let layer = CsrfLayer::new()
878 .add_trusted_origin("https://trusted.example")
879 .unwrap();
880
881 let middleware = layer.layer(());
882
883 struct Test {
884 name: &'static str,
885 sec_fetch_site: Option<&'static str>,
886 origin: Option<&'static str>,
887 result: Result<(), ProtectionError>,
888 }
889
890 let tests = [
891 Test {
892 name: "trusted origin without sec-fetch-site",
893 origin: Some("https://trusted.example"),
894 sec_fetch_site: None,
895 result: Ok(()),
896 },
897 Test {
898 name: "trusted origin with cross-site",
899 origin: Some("https://trusted.example"),
900 sec_fetch_site: Some("cross-site"),
901 result: Ok(()),
902 },
903 Test {
904 name: "untrusted origin without sec-fetch-site",
905 origin: Some("https://attacker.example"),
906 sec_fetch_site: None,
907 result: Err(ProtectionError::new(
908 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
909 )),
910 },
911 Test {
912 name: "untrusted origin with cross-site",
913 origin: Some("https://attacker.example"),
914 sec_fetch_site: Some("cross-site"),
915 result: Err(ProtectionError::new(
916 ProtectionErrorKind::CrossOriginRequest,
917 )),
918 },
919 ];
920
921 for test in tests {
922 let mut req = Request::builder()
923 .method("POST")
924 .header("host", "example.com");
925
926 if let Some(sec_fetch_site) = test.sec_fetch_site {
927 req = req.header("sec-fetch-site", sec_fetch_site);
928 }
929
930 if let Some(origin) = test.origin {
931 req = req.header("origin", origin);
932 }
933
934 let req = req.body(()).unwrap();
935
936 assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
937 }
938 }
939
940 #[test]
941 fn test_middleware_trusted_origin_strict_byte_match() {
942 struct Test {
946 name: &'static str,
947 trusted: &'static str,
948 origin: &'static str,
949 result: Result<(), ProtectionError>,
950 }
951
952 let tests = [
953 Test {
954 name: "exact match trusted",
955 trusted: "https://example.com",
956 origin: "https://example.com",
957 result: Ok(()),
958 },
959 Test {
960 name: "exact match with non-default port",
961 trusted: "https://example.com:8443",
962 origin: "https://example.com:8443",
963 result: Ok(()),
964 },
965 Test {
966 name: "host case mismatch not trusted",
967 trusted: "https://Example.COM",
968 origin: "https://example.com",
969 result: Err(ProtectionError::new(
970 ProtectionErrorKind::CrossOriginRequest,
971 )),
972 },
973 Test {
974 name: "explicit default port not trusted against bare origin",
975 trusted: "https://example.com:443",
976 origin: "https://example.com",
977 result: Err(ProtectionError::new(
978 ProtectionErrorKind::CrossOriginRequest,
979 )),
980 },
981 Test {
982 name: "bare trusted not matched by explicit-default-port origin",
983 trusted: "https://example.com",
984 origin: "https://example.com:443",
985 result: Err(ProtectionError::new(
986 ProtectionErrorKind::CrossOriginRequest,
987 )),
988 },
989 ];
990
991 for test in tests {
992 let middleware = CsrfLayer::new()
993 .add_trusted_origin(test.trusted)
994 .unwrap_or_else(|e| panic!("{}: add_trusted_origin failed: {e}", test.name))
995 .layer(());
996
997 let req = Request::builder()
998 .method("POST")
999 .header("host", "other.example")
1000 .header("origin", test.origin)
1001 .header("sec-fetch-site", "cross-site")
1002 .body(())
1003 .unwrap();
1004
1005 assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
1006 }
1007 }
1008}