1use async_trait::async_trait;
45use reinhardt_core::exception::Result;
46use std::sync::Arc;
47
48use crate::{Request, Response};
49
50#[async_trait]
55pub trait Handler: Send + Sync {
56 async fn handle(&self, request: Request) -> Result<Response>;
62}
63
64#[async_trait]
69impl<T: Handler + ?Sized> Handler for Arc<T> {
70 async fn handle(&self, request: Request) -> Result<Response> {
71 (**self).handle(request).await
72 }
73}
74
75#[async_trait]
81pub trait Middleware: Send + Sync {
82 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response>;
93
94 fn should_continue(&self, _request: &Request) -> bool {
116 true
117 }
118}
119
120pub struct MiddlewareChain {
125 middlewares: Vec<Arc<dyn Middleware>>,
126 handler: Arc<dyn Handler>,
127}
128
129impl MiddlewareChain {
130 pub fn new(handler: Arc<dyn Handler>) -> Self {
151 Self {
152 middlewares: Vec::new(),
153 handler,
154 }
155 }
156
157 pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
185 self.middlewares.push(middleware);
186 self
187 }
188
189 pub fn add_middleware(&mut self, middleware: Arc<dyn Middleware>) {
217 self.middlewares.push(middleware);
218 }
219}
220
221#[async_trait]
222impl Handler for MiddlewareChain {
223 async fn handle(&self, request: Request) -> Result<Response> {
224 if self.middlewares.is_empty() {
225 return self.handler.handle(request).await;
226 }
227
228 let mut current_handler: Arc<dyn Handler> = Arc::new(ErrorToResponseHandler {
239 inner: self.handler.clone(),
240 });
241
242 let active_middlewares: Vec<_> = self
245 .middlewares
246 .iter()
247 .rev()
248 .filter(|mw| mw.should_continue(&request))
249 .collect();
250
251 for middleware in active_middlewares {
252 let mw = middleware.clone();
253 let handler = current_handler.clone();
254
255 current_handler = Arc::new(ConditionalComposedHandler {
256 middleware: mw,
257 next: handler,
258 });
259 }
260
261 current_handler.handle(request).await
262 }
263}
264
265pub struct ExcludeMiddleware {
305 inner: Arc<dyn Middleware>,
306 exclusions: Vec<String>,
307}
308
309impl ExcludeMiddleware {
310 pub fn new(inner: Arc<dyn Middleware>) -> Self {
312 Self {
313 inner,
314 exclusions: Vec::new(),
315 }
316 }
317
318 pub fn add_exclusion(mut self, pattern: &str) -> Self {
322 self.exclusions.push(pattern.to_string());
323 self
324 }
325
326 pub fn add_exclusion_mut(&mut self, pattern: &str) {
330 self.exclusions.push(pattern.to_string());
331 }
332
333 fn is_excluded(&self, path: &str) -> bool {
335 self.exclusions.iter().any(|pattern| {
336 if pattern.ends_with('/') {
337 path.starts_with(pattern.as_str())
339 } else {
340 path == pattern
342 }
343 })
344 }
345}
346
347#[async_trait]
348impl Middleware for ExcludeMiddleware {
349 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
350 self.inner.process(request, next).await
351 }
352
353 fn should_continue(&self, request: &Request) -> bool {
354 if self.is_excluded(request.uri.path()) {
355 return false;
356 }
357 self.inner.should_continue(request)
358 }
359}
360
361struct ErrorToResponseHandler {
368 inner: Arc<dyn Handler>,
369}
370
371#[async_trait]
372impl Handler for ErrorToResponseHandler {
373 async fn handle(&self, request: Request) -> Result<Response> {
374 match self.inner.handle(request).await {
375 Ok(response) => Ok(response),
376 Err(e) => Ok(Response::from(e)),
377 }
378 }
379}
380
381struct ConditionalComposedHandler {
386 middleware: Arc<dyn Middleware>,
387 next: Arc<dyn Handler>,
388}
389
390#[async_trait]
391impl Handler for ConditionalComposedHandler {
392 async fn handle(&self, request: Request) -> Result<Response> {
393 let response = match self.middleware.process(request, self.next.clone()).await {
398 Ok(response) => response,
399 Err(e) => Response::from(e),
400 };
401
402 Ok(response)
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use bytes::Bytes;
410 use hyper::{HeaderMap, Method, Version};
411
412 struct MockHandler {
414 response_body: String,
415 }
416
417 #[async_trait]
418 impl Handler for MockHandler {
419 async fn handle(&self, _request: Request) -> Result<Response> {
420 Ok(Response::ok().with_body(self.response_body.clone()))
421 }
422 }
423
424 struct MockMiddleware {
426 prefix: String,
427 }
428
429 #[async_trait]
430 impl Middleware for MockMiddleware {
431 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
432 let response = next.handle(request).await?;
434
435 let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
437 let new_body = format!("{}{}", self.prefix, current_body);
438
439 Ok(Response::ok().with_body(new_body))
440 }
441 }
442
443 fn create_test_request() -> Request {
444 Request::builder()
445 .method(Method::GET)
446 .uri("/")
447 .version(Version::HTTP_11)
448 .headers(HeaderMap::new())
449 .body(Bytes::new())
450 .build()
451 .unwrap()
452 }
453
454 #[tokio::test]
455 async fn test_handler_basic() {
456 let handler = MockHandler {
457 response_body: "Hello".to_string(),
458 };
459
460 let request = create_test_request();
461 let response = handler.handle(request).await.unwrap();
462
463 let body = String::from_utf8(response.body.to_vec()).unwrap();
464 assert_eq!(body, "Hello");
465 }
466
467 #[tokio::test]
468 async fn test_middleware_basic() {
469 let handler = Arc::new(MockHandler {
470 response_body: "World".to_string(),
471 });
472
473 let middleware = MockMiddleware {
474 prefix: "Hello, ".to_string(),
475 };
476
477 let request = create_test_request();
478 let response = middleware.process(request, handler).await.unwrap();
479
480 let body = String::from_utf8(response.body.to_vec()).unwrap();
481 assert_eq!(body, "Hello, World");
482 }
483
484 #[tokio::test]
485 async fn test_middleware_chain_empty() {
486 let handler = Arc::new(MockHandler {
487 response_body: "Test".to_string(),
488 });
489
490 let chain = MiddlewareChain::new(handler);
491
492 let request = create_test_request();
493 let response = chain.handle(request).await.unwrap();
494
495 let body = String::from_utf8(response.body.to_vec()).unwrap();
496 assert_eq!(body, "Test");
497 }
498
499 #[tokio::test]
500 async fn test_middleware_chain_single() {
501 let handler = Arc::new(MockHandler {
502 response_body: "Handler".to_string(),
503 });
504
505 let middleware1 = Arc::new(MockMiddleware {
506 prefix: "MW1:".to_string(),
507 });
508
509 let chain = MiddlewareChain::new(handler).with_middleware(middleware1);
510
511 let request = create_test_request();
512 let response = chain.handle(request).await.unwrap();
513
514 let body = String::from_utf8(response.body.to_vec()).unwrap();
515 assert_eq!(body, "MW1:Handler");
516 }
517
518 #[tokio::test]
519 async fn test_middleware_chain_multiple() {
520 let handler = Arc::new(MockHandler {
521 response_body: "Data".to_string(),
522 });
523
524 let middleware1 = Arc::new(MockMiddleware {
525 prefix: "M1:".to_string(),
526 });
527
528 let middleware2 = Arc::new(MockMiddleware {
529 prefix: "M2:".to_string(),
530 });
531
532 let chain = MiddlewareChain::new(handler)
533 .with_middleware(middleware1)
534 .with_middleware(middleware2);
535
536 let request = create_test_request();
537 let response = chain.handle(request).await.unwrap();
538
539 let body = String::from_utf8(response.body.to_vec()).unwrap();
540 assert_eq!(body, "M1:M2:Data");
542 }
543
544 #[tokio::test]
545 async fn test_middleware_chain_add_middleware() {
546 let handler = Arc::new(MockHandler {
547 response_body: "Result".to_string(),
548 });
549
550 let middleware = Arc::new(MockMiddleware {
551 prefix: "Prefix:".to_string(),
552 });
553
554 let mut chain = MiddlewareChain::new(handler);
555 chain.add_middleware(middleware);
556
557 let request = create_test_request();
558 let response = chain.handle(request).await.unwrap();
559
560 let body = String::from_utf8(response.body.to_vec()).unwrap();
561 assert_eq!(body, "Prefix:Result");
562 }
563
564 struct ConditionalMiddleware {
566 prefix: String,
567 }
568
569 #[async_trait]
570 impl Middleware for ConditionalMiddleware {
571 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
572 let response = next.handle(request).await?;
573 let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
574 let new_body = format!("{}{}", self.prefix, current_body);
575 Ok(Response::ok().with_body(new_body))
576 }
577
578 fn should_continue(&self, request: &Request) -> bool {
579 request.uri.path().starts_with("/api/")
580 }
581 }
582
583 #[tokio::test]
584 async fn test_middleware_conditional_skip() {
585 let handler = Arc::new(MockHandler {
586 response_body: "Response".to_string(),
587 });
588
589 let conditional_mw = Arc::new(ConditionalMiddleware {
590 prefix: "API:".to_string(),
591 });
592
593 let chain = MiddlewareChain::new(handler).with_middleware(conditional_mw);
594
595 let api_request = Request::builder()
597 .method(Method::GET)
598 .uri("/api/users")
599 .version(Version::HTTP_11)
600 .headers(HeaderMap::new())
601 .body(Bytes::new())
602 .build()
603 .unwrap();
604 let response = chain.handle(api_request).await.unwrap();
605 let body = String::from_utf8(response.body.to_vec()).unwrap();
606 assert_eq!(body, "API:Response");
607
608 let non_api_request = Request::builder()
610 .method(Method::GET)
611 .uri("/public")
612 .version(Version::HTTP_11)
613 .headers(HeaderMap::new())
614 .body(Bytes::new())
615 .build()
616 .unwrap();
617 let response = chain.handle(non_api_request).await.unwrap();
618 let body = String::from_utf8(response.body.to_vec()).unwrap();
619 assert_eq!(body, "Response"); }
621
622 struct ShortCircuitMiddleware {
624 should_stop: bool,
625 }
626
627 #[async_trait]
628 impl Middleware for ShortCircuitMiddleware {
629 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
630 if self.should_stop {
631 return Ok(Response::unauthorized()
633 .with_body("Auth required")
634 .with_stop_chain(true));
635 }
636 next.handle(request).await
637 }
638 }
639
640 #[tokio::test]
641 async fn test_middleware_short_circuit() {
642 let handler = Arc::new(MockHandler {
643 response_body: "Handler Response".to_string(),
644 });
645
646 let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: true });
647 let normal_mw = Arc::new(MockMiddleware {
648 prefix: "Normal:".to_string(),
649 });
650
651 let chain = MiddlewareChain::new(handler)
652 .with_middleware(short_circuit_mw)
653 .with_middleware(normal_mw);
654
655 let request = create_test_request();
656 let response = chain.handle(request).await.unwrap();
657
658 assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
660 let body = String::from_utf8(response.body.to_vec()).unwrap();
661 assert_eq!(body, "Auth required");
662 }
663
664 #[tokio::test]
665 async fn test_middleware_no_short_circuit() {
666 let handler = Arc::new(MockHandler {
667 response_body: "Handler Response".to_string(),
668 });
669
670 let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: false });
671 let normal_mw = Arc::new(MockMiddleware {
672 prefix: "Normal:".to_string(),
673 });
674
675 let chain = MiddlewareChain::new(handler)
676 .with_middleware(short_circuit_mw)
677 .with_middleware(normal_mw);
678
679 let request = create_test_request();
680 let response = chain.handle(request).await.unwrap();
681
682 assert_eq!(response.status, hyper::StatusCode::OK);
684 let body = String::from_utf8(response.body.to_vec()).unwrap();
685 assert_eq!(body, "Normal:Handler Response");
686 }
687
688 #[tokio::test]
689 async fn test_middleware_multiple_conditions() {
690 let handler = Arc::new(MockHandler {
691 response_body: "Base".to_string(),
692 });
693
694 let api_mw = Arc::new(ConditionalMiddleware {
696 prefix: "API:".to_string(),
697 });
698
699 let always_mw = Arc::new(MockMiddleware {
701 prefix: "Always:".to_string(),
702 });
703
704 let chain = MiddlewareChain::new(handler)
705 .with_middleware(api_mw)
706 .with_middleware(always_mw);
707
708 let api_request = Request::builder()
710 .method(Method::GET)
711 .uri("/api/test")
712 .version(Version::HTTP_11)
713 .headers(HeaderMap::new())
714 .body(Bytes::new())
715 .build()
716 .unwrap();
717 let response = chain.handle(api_request).await.unwrap();
718 let body = String::from_utf8(response.body.to_vec()).unwrap();
719 assert_eq!(body, "API:Always:Base");
720
721 let non_api_request = Request::builder()
723 .method(Method::GET)
724 .uri("/public")
725 .version(Version::HTTP_11)
726 .headers(HeaderMap::new())
727 .body(Bytes::new())
728 .build()
729 .unwrap();
730 let response = chain.handle(non_api_request).await.unwrap();
731 let body = String::from_utf8(response.body.to_vec()).unwrap();
732 assert_eq!(body, "Always:Base"); }
734
735 #[tokio::test]
736 async fn test_response_should_stop_chain() {
737 let response = Response::ok();
738 assert!(!response.should_stop_chain());
739
740 let stopping_response = Response::unauthorized().with_stop_chain(true);
741 assert!(stopping_response.should_stop_chain());
742 }
743
744 fn create_request_with_path(path: &str) -> Request {
747 Request::builder()
748 .method(Method::GET)
749 .uri(path)
750 .version(Version::HTTP_11)
751 .headers(HeaderMap::new())
752 .body(Bytes::new())
753 .build()
754 .unwrap()
755 }
756
757 #[rstest::rstest]
758 #[case("/api/auth/login", true)]
759 #[case("/api/auth/register", true)]
760 #[case("/api/auth/", true)]
761 #[case("/api/users", false)]
762 #[case("/public", false)]
763 fn test_exclude_middleware_prefix_match(#[case] path: &str, #[case] should_exclude: bool) {
764 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
766 prefix: "MW:".to_string(),
767 });
768 let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/api/auth/");
769
770 let request = create_request_with_path(path);
772 let result = exclude_mw.should_continue(&request);
773
774 assert_eq!(result, !should_exclude);
776 }
777
778 #[rstest::rstest]
779 #[case("/health", true)]
780 #[case("/health/check", false)]
781 #[case("/healthz", false)]
782 #[case("/api/health", false)]
783 fn test_exclude_middleware_exact_match(#[case] path: &str, #[case] should_exclude: bool) {
784 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
786 prefix: "MW:".to_string(),
787 });
788 let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/health");
789
790 let request = create_request_with_path(path);
792 let result = exclude_mw.should_continue(&request);
793
794 assert_eq!(result, !should_exclude);
796 }
797
798 #[rstest::rstest]
799 fn test_exclude_middleware_no_match_passes_through() {
800 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
802 prefix: "MW:".to_string(),
803 });
804 let exclude_mw = ExcludeMiddleware::new(inner)
805 .add_exclusion("/api/auth/")
806 .add_exclusion("/health");
807
808 let request = create_request_with_path("/api/users");
810 let result = exclude_mw.should_continue(&request);
811
812 assert!(result);
814 }
815
816 #[rstest::rstest]
817 #[tokio::test]
818 async fn test_exclude_middleware_delegates_process() {
819 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
821 prefix: "INNER:".to_string(),
822 });
823 let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/excluded/");
824
825 let handler = Arc::new(MockHandler {
826 response_body: "Response".to_string(),
827 });
828
829 let request = create_request_with_path("/api/test");
831 let response = exclude_mw.process(request, handler).await.unwrap();
832
833 let body = String::from_utf8(response.body.to_vec()).unwrap();
835 assert_eq!(body, "INNER:Response");
836 }
837
838 #[rstest::rstest]
839 fn test_exclude_middleware_multiple_exclusions() {
840 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
842 prefix: "MW:".to_string(),
843 });
844 let mut exclude_mw = ExcludeMiddleware::new(inner);
845 exclude_mw.add_exclusion_mut("/api/auth/");
846 exclude_mw.add_exclusion_mut("/admin/");
847 exclude_mw.add_exclusion_mut("/health");
848
849 assert!(!exclude_mw.should_continue(&create_request_with_path("/api/auth/login")));
851 assert!(!exclude_mw.should_continue(&create_request_with_path("/admin/dashboard")));
852 assert!(!exclude_mw.should_continue(&create_request_with_path("/health")));
853 assert!(exclude_mw.should_continue(&create_request_with_path("/api/users")));
854 }
855
856 #[rstest::rstest]
857 fn test_exclude_middleware_respects_inner_should_continue() {
858 let inner: Arc<dyn Middleware> = Arc::new(ConditionalMiddleware {
860 prefix: "API:".to_string(),
861 });
862 let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/api/auth/");
863
864 assert!(!exclude_mw.should_continue(&create_request_with_path("/api/auth/login")));
867 assert!(!exclude_mw.should_continue(&create_request_with_path("/public")));
869 assert!(exclude_mw.should_continue(&create_request_with_path("/api/users")));
871 }
872
873 struct NotFoundHandler;
879
880 #[async_trait]
881 impl Handler for NotFoundHandler {
882 async fn handle(&self, _request: Request) -> Result<Response> {
883 Err(reinhardt_core::exception::Error::NotFound(
884 "not found".into(),
885 ))
886 }
887 }
888
889 struct UnauthorizedHandler;
890
891 #[async_trait]
892 impl Handler for UnauthorizedHandler {
893 async fn handle(&self, _request: Request) -> Result<Response> {
894 Err(reinhardt_core::exception::Error::Authentication(
895 "unauthorized".into(),
896 ))
897 }
898 }
899
900 struct HeaderAddingMiddleware {
902 header_name: &'static str,
903 header_value: &'static str,
904 }
905
906 #[async_trait]
907 impl Middleware for HeaderAddingMiddleware {
908 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
909 let response = next.handle(request).await?;
910 Ok(response.with_header(self.header_name, self.header_value))
911 }
912 }
913
914 struct RejectingMiddleware;
916
917 #[async_trait]
918 impl Middleware for RejectingMiddleware {
919 async fn process(&self, _request: Request, _next: Arc<dyn Handler>) -> Result<Response> {
920 Err(reinhardt_core::exception::Error::Authorization(
921 "CSRF check failed".into(),
922 ))
923 }
924 }
925
926 #[rstest::rstest]
927 #[tokio::test]
928 async fn test_chain_post_processing_runs_on_handler_error() {
929 let handler: Arc<dyn Handler> = Arc::new(NotFoundHandler);
931 let mut chain = MiddlewareChain::new(handler);
932 chain.add_middleware(Arc::new(HeaderAddingMiddleware {
933 header_name: "X-Custom-Security",
934 header_value: "applied",
935 }));
936
937 let request = create_test_request();
939 let response = chain.handle(request).await.unwrap();
940
941 assert_eq!(response.status, hyper::StatusCode::NOT_FOUND);
943 assert_eq!(
944 response
945 .headers
946 .get("X-Custom-Security")
947 .map(|v| v.to_str().unwrap()),
948 Some("applied")
949 );
950 }
951
952 #[rstest::rstest]
953 #[tokio::test]
954 async fn test_chain_post_processing_runs_on_middleware_error() {
955 let handler = Arc::new(MockHandler {
958 response_body: "OK".into(),
959 });
960 let mut chain = MiddlewareChain::new(handler);
961 chain.add_middleware(Arc::new(HeaderAddingMiddleware {
963 header_name: "X-Frame-Options",
964 header_value: "DENY",
965 }));
966 chain.add_middleware(Arc::new(RejectingMiddleware));
968
969 let request = create_test_request();
971 let response = chain.handle(request).await.unwrap();
972
973 assert_eq!(response.status, hyper::StatusCode::FORBIDDEN);
975 assert_eq!(
976 response
977 .headers
978 .get("X-Frame-Options")
979 .map(|v| v.to_str().unwrap()),
980 Some("DENY")
981 );
982 }
983
984 struct PassthroughMiddleware;
986
987 #[async_trait]
988 impl Middleware for PassthroughMiddleware {
989 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
990 next.handle(request).await
991 }
992 }
993
994 #[rstest::rstest]
995 #[tokio::test]
996 async fn test_chain_error_preserves_correct_status_code() {
997 let handler: Arc<dyn Handler> = Arc::new(UnauthorizedHandler);
1000 let mut chain = MiddlewareChain::new(handler);
1001 chain.add_middleware(Arc::new(PassthroughMiddleware));
1002
1003 let request = create_test_request();
1005 let response = chain.handle(request).await.unwrap();
1006
1007 assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
1009 }
1010}