1use async_trait::async_trait;
45use reinhardt_core::exception::Result;
46use std::sync::Arc;
47
48use crate::{Request, Response};
49
50#[async_trait]
55pub trait Handler: Send + Sync {
56 async fn handle(&self, request: Request) -> Result<Response>;
62}
63
64#[async_trait]
69impl<T: Handler + ?Sized> Handler for Arc<T> {
70 async fn handle(&self, request: Request) -> Result<Response> {
71 (**self).handle(request).await
72 }
73}
74
75#[async_trait]
81pub trait Middleware: Send + Sync {
82 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response>;
93
94 fn should_continue(&self, _request: &Request) -> bool {
116 true
117 }
118}
119
120pub struct MiddlewareChain {
125 middlewares: Vec<Arc<dyn Middleware>>,
126 handler: Arc<dyn Handler>,
127}
128
129impl MiddlewareChain {
130 pub fn new(handler: Arc<dyn Handler>) -> Self {
151 Self {
152 middlewares: Vec::new(),
153 handler,
154 }
155 }
156
157 pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
185 self.middlewares.push(middleware);
186 self
187 }
188
189 pub fn add_middleware(&mut self, middleware: Arc<dyn Middleware>) {
217 self.middlewares.push(middleware);
218 }
219}
220
221#[async_trait]
222impl Handler for MiddlewareChain {
223 async fn handle(&self, request: Request) -> Result<Response> {
224 if self.middlewares.is_empty() {
225 return self.handler.handle(request).await;
226 }
227
228 let mut current_handler = self.handler.clone();
237
238 let active_middlewares: Vec<_> = self
241 .middlewares
242 .iter()
243 .rev()
244 .filter(|mw| mw.should_continue(&request))
245 .collect();
246
247 for middleware in active_middlewares {
248 let mw = middleware.clone();
249 let handler = current_handler.clone();
250
251 current_handler = Arc::new(ConditionalComposedHandler {
252 middleware: mw,
253 next: handler,
254 });
255 }
256
257 current_handler.handle(request).await
258 }
259}
260
261pub struct ExcludeMiddleware {
301 inner: Arc<dyn Middleware>,
302 exclusions: Vec<String>,
303}
304
305impl ExcludeMiddleware {
306 pub fn new(inner: Arc<dyn Middleware>) -> Self {
308 Self {
309 inner,
310 exclusions: Vec::new(),
311 }
312 }
313
314 pub fn add_exclusion(mut self, pattern: &str) -> Self {
318 self.exclusions.push(pattern.to_string());
319 self
320 }
321
322 pub fn add_exclusion_mut(&mut self, pattern: &str) {
326 self.exclusions.push(pattern.to_string());
327 }
328
329 fn is_excluded(&self, path: &str) -> bool {
331 self.exclusions.iter().any(|pattern| {
332 if pattern.ends_with('/') {
333 path.starts_with(pattern.as_str())
335 } else {
336 path == pattern
338 }
339 })
340 }
341}
342
343#[async_trait]
344impl Middleware for ExcludeMiddleware {
345 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
346 self.inner.process(request, next).await
347 }
348
349 fn should_continue(&self, request: &Request) -> bool {
350 if self.is_excluded(request.uri.path()) {
351 return false;
352 }
353 self.inner.should_continue(request)
354 }
355}
356
357struct ConditionalComposedHandler {
361 middleware: Arc<dyn Middleware>,
362 next: Arc<dyn Handler>,
363}
364
365#[async_trait]
366impl Handler for ConditionalComposedHandler {
367 async fn handle(&self, request: Request) -> Result<Response> {
368 let response = self.middleware.process(request, self.next.clone()).await?;
370
371 if response.should_stop_chain() {
374 return Ok(response);
375 }
376
377 Ok(response)
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use bytes::Bytes;
385 use hyper::{HeaderMap, Method, Version};
386
387 struct MockHandler {
389 response_body: String,
390 }
391
392 #[async_trait]
393 impl Handler for MockHandler {
394 async fn handle(&self, _request: Request) -> Result<Response> {
395 Ok(Response::ok().with_body(self.response_body.clone()))
396 }
397 }
398
399 struct MockMiddleware {
401 prefix: String,
402 }
403
404 #[async_trait]
405 impl Middleware for MockMiddleware {
406 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
407 let response = next.handle(request).await?;
409
410 let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
412 let new_body = format!("{}{}", self.prefix, current_body);
413
414 Ok(Response::ok().with_body(new_body))
415 }
416 }
417
418 fn create_test_request() -> Request {
419 Request::builder()
420 .method(Method::GET)
421 .uri("/")
422 .version(Version::HTTP_11)
423 .headers(HeaderMap::new())
424 .body(Bytes::new())
425 .build()
426 .unwrap()
427 }
428
429 #[tokio::test]
430 async fn test_handler_basic() {
431 let handler = MockHandler {
432 response_body: "Hello".to_string(),
433 };
434
435 let request = create_test_request();
436 let response = handler.handle(request).await.unwrap();
437
438 let body = String::from_utf8(response.body.to_vec()).unwrap();
439 assert_eq!(body, "Hello");
440 }
441
442 #[tokio::test]
443 async fn test_middleware_basic() {
444 let handler = Arc::new(MockHandler {
445 response_body: "World".to_string(),
446 });
447
448 let middleware = MockMiddleware {
449 prefix: "Hello, ".to_string(),
450 };
451
452 let request = create_test_request();
453 let response = middleware.process(request, handler).await.unwrap();
454
455 let body = String::from_utf8(response.body.to_vec()).unwrap();
456 assert_eq!(body, "Hello, World");
457 }
458
459 #[tokio::test]
460 async fn test_middleware_chain_empty() {
461 let handler = Arc::new(MockHandler {
462 response_body: "Test".to_string(),
463 });
464
465 let chain = MiddlewareChain::new(handler);
466
467 let request = create_test_request();
468 let response = chain.handle(request).await.unwrap();
469
470 let body = String::from_utf8(response.body.to_vec()).unwrap();
471 assert_eq!(body, "Test");
472 }
473
474 #[tokio::test]
475 async fn test_middleware_chain_single() {
476 let handler = Arc::new(MockHandler {
477 response_body: "Handler".to_string(),
478 });
479
480 let middleware1 = Arc::new(MockMiddleware {
481 prefix: "MW1:".to_string(),
482 });
483
484 let chain = MiddlewareChain::new(handler).with_middleware(middleware1);
485
486 let request = create_test_request();
487 let response = chain.handle(request).await.unwrap();
488
489 let body = String::from_utf8(response.body.to_vec()).unwrap();
490 assert_eq!(body, "MW1:Handler");
491 }
492
493 #[tokio::test]
494 async fn test_middleware_chain_multiple() {
495 let handler = Arc::new(MockHandler {
496 response_body: "Data".to_string(),
497 });
498
499 let middleware1 = Arc::new(MockMiddleware {
500 prefix: "M1:".to_string(),
501 });
502
503 let middleware2 = Arc::new(MockMiddleware {
504 prefix: "M2:".to_string(),
505 });
506
507 let chain = MiddlewareChain::new(handler)
508 .with_middleware(middleware1)
509 .with_middleware(middleware2);
510
511 let request = create_test_request();
512 let response = chain.handle(request).await.unwrap();
513
514 let body = String::from_utf8(response.body.to_vec()).unwrap();
515 assert_eq!(body, "M1:M2:Data");
517 }
518
519 #[tokio::test]
520 async fn test_middleware_chain_add_middleware() {
521 let handler = Arc::new(MockHandler {
522 response_body: "Result".to_string(),
523 });
524
525 let middleware = Arc::new(MockMiddleware {
526 prefix: "Prefix:".to_string(),
527 });
528
529 let mut chain = MiddlewareChain::new(handler);
530 chain.add_middleware(middleware);
531
532 let request = create_test_request();
533 let response = chain.handle(request).await.unwrap();
534
535 let body = String::from_utf8(response.body.to_vec()).unwrap();
536 assert_eq!(body, "Prefix:Result");
537 }
538
539 struct ConditionalMiddleware {
541 prefix: String,
542 }
543
544 #[async_trait]
545 impl Middleware for ConditionalMiddleware {
546 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
547 let response = next.handle(request).await?;
548 let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
549 let new_body = format!("{}{}", self.prefix, current_body);
550 Ok(Response::ok().with_body(new_body))
551 }
552
553 fn should_continue(&self, request: &Request) -> bool {
554 request.uri.path().starts_with("/api/")
555 }
556 }
557
558 #[tokio::test]
559 async fn test_middleware_conditional_skip() {
560 let handler = Arc::new(MockHandler {
561 response_body: "Response".to_string(),
562 });
563
564 let conditional_mw = Arc::new(ConditionalMiddleware {
565 prefix: "API:".to_string(),
566 });
567
568 let chain = MiddlewareChain::new(handler).with_middleware(conditional_mw);
569
570 let api_request = Request::builder()
572 .method(Method::GET)
573 .uri("/api/users")
574 .version(Version::HTTP_11)
575 .headers(HeaderMap::new())
576 .body(Bytes::new())
577 .build()
578 .unwrap();
579 let response = chain.handle(api_request).await.unwrap();
580 let body = String::from_utf8(response.body.to_vec()).unwrap();
581 assert_eq!(body, "API:Response");
582
583 let non_api_request = Request::builder()
585 .method(Method::GET)
586 .uri("/public")
587 .version(Version::HTTP_11)
588 .headers(HeaderMap::new())
589 .body(Bytes::new())
590 .build()
591 .unwrap();
592 let response = chain.handle(non_api_request).await.unwrap();
593 let body = String::from_utf8(response.body.to_vec()).unwrap();
594 assert_eq!(body, "Response"); }
596
597 struct ShortCircuitMiddleware {
599 should_stop: bool,
600 }
601
602 #[async_trait]
603 impl Middleware for ShortCircuitMiddleware {
604 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
605 if self.should_stop {
606 return Ok(Response::unauthorized()
608 .with_body("Auth required")
609 .with_stop_chain(true));
610 }
611 next.handle(request).await
612 }
613 }
614
615 #[tokio::test]
616 async fn test_middleware_short_circuit() {
617 let handler = Arc::new(MockHandler {
618 response_body: "Handler Response".to_string(),
619 });
620
621 let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: true });
622 let normal_mw = Arc::new(MockMiddleware {
623 prefix: "Normal:".to_string(),
624 });
625
626 let chain = MiddlewareChain::new(handler)
627 .with_middleware(short_circuit_mw)
628 .with_middleware(normal_mw);
629
630 let request = create_test_request();
631 let response = chain.handle(request).await.unwrap();
632
633 assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
635 let body = String::from_utf8(response.body.to_vec()).unwrap();
636 assert_eq!(body, "Auth required");
637 }
638
639 #[tokio::test]
640 async fn test_middleware_no_short_circuit() {
641 let handler = Arc::new(MockHandler {
642 response_body: "Handler Response".to_string(),
643 });
644
645 let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: false });
646 let normal_mw = Arc::new(MockMiddleware {
647 prefix: "Normal:".to_string(),
648 });
649
650 let chain = MiddlewareChain::new(handler)
651 .with_middleware(short_circuit_mw)
652 .with_middleware(normal_mw);
653
654 let request = create_test_request();
655 let response = chain.handle(request).await.unwrap();
656
657 assert_eq!(response.status, hyper::StatusCode::OK);
659 let body = String::from_utf8(response.body.to_vec()).unwrap();
660 assert_eq!(body, "Normal:Handler Response");
661 }
662
663 #[tokio::test]
664 async fn test_middleware_multiple_conditions() {
665 let handler = Arc::new(MockHandler {
666 response_body: "Base".to_string(),
667 });
668
669 let api_mw = Arc::new(ConditionalMiddleware {
671 prefix: "API:".to_string(),
672 });
673
674 let always_mw = Arc::new(MockMiddleware {
676 prefix: "Always:".to_string(),
677 });
678
679 let chain = MiddlewareChain::new(handler)
680 .with_middleware(api_mw)
681 .with_middleware(always_mw);
682
683 let api_request = Request::builder()
685 .method(Method::GET)
686 .uri("/api/test")
687 .version(Version::HTTP_11)
688 .headers(HeaderMap::new())
689 .body(Bytes::new())
690 .build()
691 .unwrap();
692 let response = chain.handle(api_request).await.unwrap();
693 let body = String::from_utf8(response.body.to_vec()).unwrap();
694 assert_eq!(body, "API:Always:Base");
695
696 let non_api_request = Request::builder()
698 .method(Method::GET)
699 .uri("/public")
700 .version(Version::HTTP_11)
701 .headers(HeaderMap::new())
702 .body(Bytes::new())
703 .build()
704 .unwrap();
705 let response = chain.handle(non_api_request).await.unwrap();
706 let body = String::from_utf8(response.body.to_vec()).unwrap();
707 assert_eq!(body, "Always:Base"); }
709
710 #[tokio::test]
711 async fn test_response_should_stop_chain() {
712 let response = Response::ok();
713 assert!(!response.should_stop_chain());
714
715 let stopping_response = Response::unauthorized().with_stop_chain(true);
716 assert!(stopping_response.should_stop_chain());
717 }
718
719 fn create_request_with_path(path: &str) -> Request {
722 Request::builder()
723 .method(Method::GET)
724 .uri(path)
725 .version(Version::HTTP_11)
726 .headers(HeaderMap::new())
727 .body(Bytes::new())
728 .build()
729 .unwrap()
730 }
731
732 #[rstest::rstest]
733 #[case("/api/auth/login", true)]
734 #[case("/api/auth/register", true)]
735 #[case("/api/auth/", true)]
736 #[case("/api/users", false)]
737 #[case("/public", false)]
738 fn test_exclude_middleware_prefix_match(#[case] path: &str, #[case] should_exclude: bool) {
739 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
741 prefix: "MW:".to_string(),
742 });
743 let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/api/auth/");
744
745 let request = create_request_with_path(path);
747 let result = exclude_mw.should_continue(&request);
748
749 assert_eq!(result, !should_exclude);
751 }
752
753 #[rstest::rstest]
754 #[case("/health", true)]
755 #[case("/health/check", false)]
756 #[case("/healthz", false)]
757 #[case("/api/health", false)]
758 fn test_exclude_middleware_exact_match(#[case] path: &str, #[case] should_exclude: bool) {
759 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
761 prefix: "MW:".to_string(),
762 });
763 let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/health");
764
765 let request = create_request_with_path(path);
767 let result = exclude_mw.should_continue(&request);
768
769 assert_eq!(result, !should_exclude);
771 }
772
773 #[rstest::rstest]
774 fn test_exclude_middleware_no_match_passes_through() {
775 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
777 prefix: "MW:".to_string(),
778 });
779 let exclude_mw = ExcludeMiddleware::new(inner)
780 .add_exclusion("/api/auth/")
781 .add_exclusion("/health");
782
783 let request = create_request_with_path("/api/users");
785 let result = exclude_mw.should_continue(&request);
786
787 assert!(result);
789 }
790
791 #[rstest::rstest]
792 #[tokio::test]
793 async fn test_exclude_middleware_delegates_process() {
794 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
796 prefix: "INNER:".to_string(),
797 });
798 let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/excluded/");
799
800 let handler = Arc::new(MockHandler {
801 response_body: "Response".to_string(),
802 });
803
804 let request = create_request_with_path("/api/test");
806 let response = exclude_mw.process(request, handler).await.unwrap();
807
808 let body = String::from_utf8(response.body.to_vec()).unwrap();
810 assert_eq!(body, "INNER:Response");
811 }
812
813 #[rstest::rstest]
814 fn test_exclude_middleware_multiple_exclusions() {
815 let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
817 prefix: "MW:".to_string(),
818 });
819 let mut exclude_mw = ExcludeMiddleware::new(inner);
820 exclude_mw.add_exclusion_mut("/api/auth/");
821 exclude_mw.add_exclusion_mut("/admin/");
822 exclude_mw.add_exclusion_mut("/health");
823
824 assert!(!exclude_mw.should_continue(&create_request_with_path("/api/auth/login")));
826 assert!(!exclude_mw.should_continue(&create_request_with_path("/admin/dashboard")));
827 assert!(!exclude_mw.should_continue(&create_request_with_path("/health")));
828 assert!(exclude_mw.should_continue(&create_request_with_path("/api/users")));
829 }
830
831 #[rstest::rstest]
832 fn test_exclude_middleware_respects_inner_should_continue() {
833 let inner: Arc<dyn Middleware> = Arc::new(ConditionalMiddleware {
835 prefix: "API:".to_string(),
836 });
837 let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/api/auth/");
838
839 assert!(!exclude_mw.should_continue(&create_request_with_path("/api/auth/login")));
842 assert!(!exclude_mw.should_continue(&create_request_with_path("/public")));
844 assert!(exclude_mw.should_continue(&create_request_with_path("/api/users")));
846 }
847}