1use async_trait::async_trait;
45use reinhardt_core::exception::Result;
46use std::any::{Any, TypeId};
47use std::sync::Arc;
48
49use crate::{Request, Response};
50
51pub type MiddlewareDiRegistration = (TypeId, Arc<dyn Any + Send + Sync>);
59
60#[async_trait]
65pub trait Handler: Send + Sync {
66 async fn handle(&self, request: Request) -> Result<Response>;
72}
73
74#[async_trait]
79impl<T: Handler + ?Sized> Handler for Arc<T> {
80 async fn handle(&self, request: Request) -> Result<Response> {
81 (**self).handle(request).await
82 }
83}
84
85#[async_trait]
91pub trait Middleware: Send + Sync {
92 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response>;
103
104 fn should_continue(&self, _request: &Request) -> bool {
126 true
127 }
128
129 fn di_registrations(&self) -> Vec<MiddlewareDiRegistration> {
163 Vec::new()
164 }
165}
166
167pub struct MiddlewareChain {
172 middlewares: Vec<Arc<dyn Middleware>>,
173 handler: Arc<dyn Handler>,
174}
175
176impl MiddlewareChain {
177 pub fn new(handler: Arc<dyn Handler>) -> Self {
198 Self {
199 middlewares: Vec::new(),
200 handler,
201 }
202 }
203
204 pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
232 self.middlewares.push(middleware);
233 self
234 }
235
236 pub fn add_middleware(&mut self, middleware: Arc<dyn Middleware>) {
264 self.middlewares.push(middleware);
265 }
266}
267
268#[async_trait]
269impl Handler for MiddlewareChain {
270 async fn handle(&self, request: Request) -> Result<Response> {
271 if self.middlewares.is_empty() {
272 return self.handler.handle(request).await;
273 }
274
275 let mut current_handler: Arc<dyn Handler> = Arc::new(ErrorToResponseHandler {
286 inner: self.handler.clone(),
287 });
288
289 let active_middlewares: Vec<_> = self
292 .middlewares
293 .iter()
294 .rev()
295 .filter(|mw| mw.should_continue(&request))
296 .collect();
297
298 for middleware in active_middlewares {
299 let mw = middleware.clone();
300 let handler = current_handler.clone();
301
302 current_handler = Arc::new(ConditionalComposedHandler {
303 middleware: mw,
304 next: handler,
305 });
306 }
307
308 current_handler.handle(request).await
309 }
310}
311
312pub struct ExcludeMiddleware {
352 inner: Arc<dyn Middleware>,
353 exclusions: Vec<String>,
354}
355
356impl ExcludeMiddleware {
357 pub fn new(inner: Arc<dyn Middleware>) -> Self {
359 Self {
360 inner,
361 exclusions: Vec::new(),
362 }
363 }
364
365 pub fn add_exclusion(mut self, pattern: &str) -> Self {
369 self.exclusions.push(pattern.to_string());
370 self
371 }
372
373 pub fn add_exclusion_mut(&mut self, pattern: &str) {
377 self.exclusions.push(pattern.to_string());
378 }
379
380 fn is_excluded(&self, path: &str) -> bool {
382 self.exclusions.iter().any(|pattern| {
383 if pattern.ends_with('/') {
384 path.starts_with(pattern.as_str())
386 } else {
387 path == pattern
389 }
390 })
391 }
392}
393
394#[async_trait]
395impl Middleware for ExcludeMiddleware {
396 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
397 self.inner.process(request, next).await
398 }
399
400 fn should_continue(&self, request: &Request) -> bool {
401 if self.is_excluded(request.uri.path()) {
402 return false;
403 }
404 self.inner.should_continue(request)
405 }
406}
407
408struct ErrorToResponseHandler {
415 inner: Arc<dyn Handler>,
416}
417
418#[async_trait]
419impl Handler for ErrorToResponseHandler {
420 async fn handle(&self, request: Request) -> Result<Response> {
421 match self.inner.handle(request).await {
422 Ok(response) => Ok(response),
423 Err(e) => Ok(Response::from(e)),
424 }
425 }
426}
427
428struct ConditionalComposedHandler {
433 middleware: Arc<dyn Middleware>,
434 next: Arc<dyn Handler>,
435}
436
437#[async_trait]
438impl Handler for ConditionalComposedHandler {
439 async fn handle(&self, request: Request) -> Result<Response> {
440 let response = match self.middleware.process(request, self.next.clone()).await {
445 Ok(response) => response,
446 Err(e) => Response::from(e),
447 };
448
449 Ok(response)
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use bytes::Bytes;
457 use hyper::{HeaderMap, Method, Version};
458
459 struct MockHandler {
461 response_body: String,
462 }
463
464 #[async_trait]
465 impl Handler for MockHandler {
466 async fn handle(&self, _request: Request) -> Result<Response> {
467 Ok(Response::ok().with_body(self.response_body.clone()))
468 }
469 }
470
471 struct MockMiddleware {
473 prefix: String,
474 }
475
476 #[async_trait]
477 impl Middleware for MockMiddleware {
478 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
479 let response = next.handle(request).await?;
481
482 let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
484 let new_body = format!("{}{}", self.prefix, current_body);
485
486 Ok(Response::ok().with_body(new_body))
487 }
488 }
489
490 fn create_test_request() -> Request {
491 Request::builder()
492 .method(Method::GET)
493 .uri("/")
494 .version(Version::HTTP_11)
495 .headers(HeaderMap::new())
496 .body(Bytes::new())
497 .build()
498 .unwrap()
499 }
500
501 #[tokio::test]
502 async fn test_handler_basic() {
503 let handler = MockHandler {
504 response_body: "Hello".to_string(),
505 };
506
507 let request = create_test_request();
508 let response = handler.handle(request).await.unwrap();
509
510 let body = String::from_utf8(response.body.to_vec()).unwrap();
511 assert_eq!(body, "Hello");
512 }
513
514 #[tokio::test]
515 async fn test_middleware_basic() {
516 let handler = Arc::new(MockHandler {
517 response_body: "World".to_string(),
518 });
519
520 let middleware = MockMiddleware {
521 prefix: "Hello, ".to_string(),
522 };
523
524 let request = create_test_request();
525 let response = middleware.process(request, handler).await.unwrap();
526
527 let body = String::from_utf8(response.body.to_vec()).unwrap();
528 assert_eq!(body, "Hello, World");
529 }
530
531 #[tokio::test]
532 async fn test_middleware_chain_empty() {
533 let handler = Arc::new(MockHandler {
534 response_body: "Test".to_string(),
535 });
536
537 let chain = MiddlewareChain::new(handler);
538
539 let request = create_test_request();
540 let response = chain.handle(request).await.unwrap();
541
542 let body = String::from_utf8(response.body.to_vec()).unwrap();
543 assert_eq!(body, "Test");
544 }
545
546 #[tokio::test]
547 async fn test_middleware_chain_single() {
548 let handler = Arc::new(MockHandler {
549 response_body: "Handler".to_string(),
550 });
551
552 let middleware1 = Arc::new(MockMiddleware {
553 prefix: "MW1:".to_string(),
554 });
555
556 let chain = MiddlewareChain::new(handler).with_middleware(middleware1);
557
558 let request = create_test_request();
559 let response = chain.handle(request).await.unwrap();
560
561 let body = String::from_utf8(response.body.to_vec()).unwrap();
562 assert_eq!(body, "MW1:Handler");
563 }
564
565 #[tokio::test]
566 async fn test_middleware_chain_multiple() {
567 let handler = Arc::new(MockHandler {
568 response_body: "Data".to_string(),
569 });
570
571 let middleware1 = Arc::new(MockMiddleware {
572 prefix: "M1:".to_string(),
573 });
574
575 let middleware2 = Arc::new(MockMiddleware {
576 prefix: "M2:".to_string(),
577 });
578
579 let chain = MiddlewareChain::new(handler)
580 .with_middleware(middleware1)
581 .with_middleware(middleware2);
582
583 let request = create_test_request();
584 let response = chain.handle(request).await.unwrap();
585
586 let body = String::from_utf8(response.body.to_vec()).unwrap();
587 assert_eq!(body, "M1:M2:Data");
589 }
590
591 #[tokio::test]
592 async fn test_middleware_chain_add_middleware() {
593 let handler = Arc::new(MockHandler {
594 response_body: "Result".to_string(),
595 });
596
597 let middleware = Arc::new(MockMiddleware {
598 prefix: "Prefix:".to_string(),
599 });
600
601 let mut chain = MiddlewareChain::new(handler);
602 chain.add_middleware(middleware);
603
604 let request = create_test_request();
605 let response = chain.handle(request).await.unwrap();
606
607 let body = String::from_utf8(response.body.to_vec()).unwrap();
608 assert_eq!(body, "Prefix:Result");
609 }
610
611 struct ConditionalMiddleware {
613 prefix: String,
614 }
615
616 #[async_trait]
617 impl Middleware for ConditionalMiddleware {
618 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
619 let response = next.handle(request).await?;
620 let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
621 let new_body = format!("{}{}", self.prefix, current_body);
622 Ok(Response::ok().with_body(new_body))
623 }
624
625 fn should_continue(&self, request: &Request) -> bool {
626 request.uri.path().starts_with("/api/")
627 }
628 }
629
630 #[tokio::test]
631 async fn test_middleware_conditional_skip() {
632 let handler = Arc::new(MockHandler {
633 response_body: "Response".to_string(),
634 });
635
636 let conditional_mw = Arc::new(ConditionalMiddleware {
637 prefix: "API:".to_string(),
638 });
639
640 let chain = MiddlewareChain::new(handler).with_middleware(conditional_mw);
641
642 let api_request = Request::builder()
644 .method(Method::GET)
645 .uri("/api/users")
646 .version(Version::HTTP_11)
647 .headers(HeaderMap::new())
648 .body(Bytes::new())
649 .build()
650 .unwrap();
651 let response = chain.handle(api_request).await.unwrap();
652 let body = String::from_utf8(response.body.to_vec()).unwrap();
653 assert_eq!(body, "API:Response");
654
655 let non_api_request = Request::builder()
657 .method(Method::GET)
658 .uri("/public")
659 .version(Version::HTTP_11)
660 .headers(HeaderMap::new())
661 .body(Bytes::new())
662 .build()
663 .unwrap();
664 let response = chain.handle(non_api_request).await.unwrap();
665 let body = String::from_utf8(response.body.to_vec()).unwrap();
666 assert_eq!(body, "Response"); }
668
669 struct ShortCircuitMiddleware {
671 should_stop: bool,
672 }
673
674 #[async_trait]
675 impl Middleware for ShortCircuitMiddleware {
676 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
677 if self.should_stop {
678 return Ok(Response::unauthorized()
680 .with_body("Auth required")
681 .with_stop_chain(true));
682 }
683 next.handle(request).await
684 }
685 }
686
687 #[tokio::test]
688 async fn test_middleware_short_circuit() {
689 let handler = Arc::new(MockHandler {
690 response_body: "Handler Response".to_string(),
691 });
692
693 let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: true });
694 let normal_mw = Arc::new(MockMiddleware {
695 prefix: "Normal:".to_string(),
696 });
697
698 let chain = MiddlewareChain::new(handler)
699 .with_middleware(short_circuit_mw)
700 .with_middleware(normal_mw);
701
702 let request = create_test_request();
703 let response = chain.handle(request).await.unwrap();
704
705 assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
707 let body = String::from_utf8(response.body.to_vec()).unwrap();
708 assert_eq!(body, "Auth required");
709 }
710
711 #[tokio::test]
712 async fn test_middleware_no_short_circuit() {
713 let handler = Arc::new(MockHandler {
714 response_body: "Handler Response".to_string(),
715 });
716
717 let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: false });
718 let normal_mw = Arc::new(MockMiddleware {
719 prefix: "Normal:".to_string(),
720 });
721
722 let chain = MiddlewareChain::new(handler)
723 .with_middleware(short_circuit_mw)
724 .with_middleware(normal_mw);
725
726 let request = create_test_request();
727 let response = chain.handle(request).await.unwrap();
728
729 assert_eq!(response.status, hyper::StatusCode::OK);
731 let body = String::from_utf8(response.body.to_vec()).unwrap();
732 assert_eq!(body, "Normal:Handler Response");
733 }
734
735 #[tokio::test]
736 async fn test_middleware_multiple_conditions() {
737 let handler = Arc::new(MockHandler {
738 response_body: "Base".to_string(),
739 });
740
741 let api_mw = Arc::new(ConditionalMiddleware {
743 prefix: "API:".to_string(),
744 });
745
746 let always_mw = Arc::new(MockMiddleware {
748 prefix: "Always:".to_string(),
749 });
750
751 let chain = MiddlewareChain::new(handler)
752 .with_middleware(api_mw)
753 .with_middleware(always_mw);
754
755 let api_request = Request::builder()
757 .method(Method::GET)
758 .uri("/api/test")
759 .version(Version::HTTP_11)
760 .headers(HeaderMap::new())
761 .body(Bytes::new())
762 .build()
763 .unwrap();
764 let response = chain.handle(api_request).await.unwrap();
765 let body = String::from_utf8(response.body.to_vec()).unwrap();
766 assert_eq!(body, "API:Always:Base");
767
768 let non_api_request = Request::builder()
770 .method(Method::GET)
771 .uri("/public")
772 .version(Version::HTTP_11)
773 .headers(HeaderMap::new())
774 .body(Bytes::new())
775 .build()
776 .unwrap();
777 let response = chain.handle(non_api_request).await.unwrap();
778 let body = String::from_utf8(response.body.to_vec()).unwrap();
779 assert_eq!(body, "Always:Base"); }
781
782 #[tokio::test]
783 async fn test_response_should_stop_chain() {
784 let response = Response::ok();
785 assert!(!response.should_stop_chain());
786
787 let stopping_response = Response::unauthorized().with_stop_chain(true);
788 assert!(stopping_response.should_stop_chain());
789 }
790
791 fn create_request_with_path(path: &str) -> Request {
794 Request::builder()
795 .method(Method::GET)
796 .uri(path)
797 .version(Version::HTTP_11)
798 .headers(HeaderMap::new())
799 .body(Bytes::new())
800 .build()
801 .unwrap()
802 }
803
804 #[rstest::rstest]
805 #[case("/api/auth/login", true)]
806 #[case("/api/auth/register", true)]
807 #[case("/api/auth/", true)]
808 #[case("/api/users", false)]
809 #[case("/public", false)]
810 fn test_exclude_middleware_prefix_match(#[case] path: &str, #[case] should_exclude: bool) {
811 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
813 prefix: "MW:".to_string(),
814 });
815 let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/api/auth/");
816
817 let request = create_request_with_path(path);
819 let result = exclude_mw.should_continue(&request);
820
821 assert_eq!(result, !should_exclude);
823 }
824
825 #[rstest::rstest]
826 #[case("/health", true)]
827 #[case("/health/check", false)]
828 #[case("/healthz", false)]
829 #[case("/api/health", false)]
830 fn test_exclude_middleware_exact_match(#[case] path: &str, #[case] should_exclude: bool) {
831 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
833 prefix: "MW:".to_string(),
834 });
835 let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/health");
836
837 let request = create_request_with_path(path);
839 let result = exclude_mw.should_continue(&request);
840
841 assert_eq!(result, !should_exclude);
843 }
844
845 #[rstest::rstest]
846 fn test_exclude_middleware_no_match_passes_through() {
847 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
849 prefix: "MW:".to_string(),
850 });
851 let exclude_mw = ExcludeMiddleware::new(inner)
852 .add_exclusion("/api/auth/")
853 .add_exclusion("/health");
854
855 let request = create_request_with_path("/api/users");
857 let result = exclude_mw.should_continue(&request);
858
859 assert!(result);
861 }
862
863 #[rstest::rstest]
864 #[tokio::test]
865 async fn test_exclude_middleware_delegates_process() {
866 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
868 prefix: "INNER:".to_string(),
869 });
870 let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/excluded/");
871
872 let handler = Arc::new(MockHandler {
873 response_body: "Response".to_string(),
874 });
875
876 let request = create_request_with_path("/api/test");
878 let response = exclude_mw.process(request, handler).await.unwrap();
879
880 let body = String::from_utf8(response.body.to_vec()).unwrap();
882 assert_eq!(body, "INNER:Response");
883 }
884
885 #[rstest::rstest]
886 fn test_exclude_middleware_multiple_exclusions() {
887 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
889 prefix: "MW:".to_string(),
890 });
891 let mut exclude_mw = ExcludeMiddleware::new(inner);
892 exclude_mw.add_exclusion_mut("/api/auth/");
893 exclude_mw.add_exclusion_mut("/admin/");
894 exclude_mw.add_exclusion_mut("/health");
895
896 assert!(!exclude_mw.should_continue(&create_request_with_path("/api/auth/login")));
898 assert!(!exclude_mw.should_continue(&create_request_with_path("/admin/dashboard")));
899 assert!(!exclude_mw.should_continue(&create_request_with_path("/health")));
900 assert!(exclude_mw.should_continue(&create_request_with_path("/api/users")));
901 }
902
903 #[rstest::rstest]
904 fn test_exclude_middleware_respects_inner_should_continue() {
905 let inner: Arc<dyn Middleware> = Arc::new(ConditionalMiddleware {
907 prefix: "API:".to_string(),
908 });
909 let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/api/auth/");
910
911 assert!(!exclude_mw.should_continue(&create_request_with_path("/api/auth/login")));
914 assert!(!exclude_mw.should_continue(&create_request_with_path("/public")));
916 assert!(exclude_mw.should_continue(&create_request_with_path("/api/users")));
918 }
919
920 struct NotFoundHandler;
926
927 #[async_trait]
928 impl Handler for NotFoundHandler {
929 async fn handle(&self, _request: Request) -> Result<Response> {
930 Err(reinhardt_core::exception::Error::NotFound(
931 "not found".into(),
932 ))
933 }
934 }
935
936 struct UnauthorizedHandler;
937
938 #[async_trait]
939 impl Handler for UnauthorizedHandler {
940 async fn handle(&self, _request: Request) -> Result<Response> {
941 Err(reinhardt_core::exception::Error::Authentication(
942 "unauthorized".into(),
943 ))
944 }
945 }
946
947 struct HeaderAddingMiddleware {
949 header_name: &'static str,
950 header_value: &'static str,
951 }
952
953 #[async_trait]
954 impl Middleware for HeaderAddingMiddleware {
955 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
956 let response = next.handle(request).await?;
957 Ok(response.with_header(self.header_name, self.header_value))
958 }
959 }
960
961 struct RejectingMiddleware;
963
964 #[async_trait]
965 impl Middleware for RejectingMiddleware {
966 async fn process(&self, _request: Request, _next: Arc<dyn Handler>) -> Result<Response> {
967 Err(reinhardt_core::exception::Error::Authorization(
968 "CSRF check failed".into(),
969 ))
970 }
971 }
972
973 #[rstest::rstest]
974 #[tokio::test]
975 async fn test_chain_post_processing_runs_on_handler_error() {
976 let handler: Arc<dyn Handler> = Arc::new(NotFoundHandler);
978 let mut chain = MiddlewareChain::new(handler);
979 chain.add_middleware(Arc::new(HeaderAddingMiddleware {
980 header_name: "X-Custom-Security",
981 header_value: "applied",
982 }));
983
984 let request = create_test_request();
986 let response = chain.handle(request).await.unwrap();
987
988 assert_eq!(response.status, hyper::StatusCode::NOT_FOUND);
990 assert_eq!(
991 response
992 .headers
993 .get("X-Custom-Security")
994 .map(|v| v.to_str().unwrap()),
995 Some("applied")
996 );
997 }
998
999 #[rstest::rstest]
1000 #[tokio::test]
1001 async fn test_chain_post_processing_runs_on_middleware_error() {
1002 let handler = Arc::new(MockHandler {
1005 response_body: "OK".into(),
1006 });
1007 let mut chain = MiddlewareChain::new(handler);
1008 chain.add_middleware(Arc::new(HeaderAddingMiddleware {
1010 header_name: "X-Frame-Options",
1011 header_value: "DENY",
1012 }));
1013 chain.add_middleware(Arc::new(RejectingMiddleware));
1015
1016 let request = create_test_request();
1018 let response = chain.handle(request).await.unwrap();
1019
1020 assert_eq!(response.status, hyper::StatusCode::FORBIDDEN);
1022 assert_eq!(
1023 response
1024 .headers
1025 .get("X-Frame-Options")
1026 .map(|v| v.to_str().unwrap()),
1027 Some("DENY")
1028 );
1029 }
1030
1031 struct PassthroughMiddleware;
1033
1034 #[async_trait]
1035 impl Middleware for PassthroughMiddleware {
1036 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
1037 next.handle(request).await
1038 }
1039 }
1040
1041 #[rstest::rstest]
1042 #[tokio::test]
1043 async fn test_chain_error_preserves_correct_status_code() {
1044 let handler: Arc<dyn Handler> = Arc::new(UnauthorizedHandler);
1047 let mut chain = MiddlewareChain::new(handler);
1048 chain.add_middleware(Arc::new(PassthroughMiddleware));
1049
1050 let request = create_test_request();
1052 let response = chain.handle(request).await.unwrap();
1053
1054 assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
1056 }
1057}