reinhardt_http/
middleware.rs1use async_trait::async_trait;
45use reinhardt_core::exception::Result;
46use std::sync::Arc;
47
48use crate::{Request, Response};
49
50#[async_trait]
55pub trait Handler: Send + Sync {
56 async fn handle(&self, request: Request) -> Result<Response>;
62}
63
64#[async_trait]
69impl<T: Handler + ?Sized> Handler for Arc<T> {
70 async fn handle(&self, request: Request) -> Result<Response> {
71 (**self).handle(request).await
72 }
73}
74
75#[async_trait]
81pub trait Middleware: Send + Sync {
82 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response>;
93
94 fn should_continue(&self, _request: &Request) -> bool {
116 true
117 }
118}
119
120pub struct MiddlewareChain {
125 middlewares: Vec<Arc<dyn Middleware>>,
126 handler: Arc<dyn Handler>,
127}
128
129impl MiddlewareChain {
130 pub fn new(handler: Arc<dyn Handler>) -> Self {
151 Self {
152 middlewares: Vec::new(),
153 handler,
154 }
155 }
156
157 pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
185 self.middlewares.push(middleware);
186 self
187 }
188
189 pub fn add_middleware(&mut self, middleware: Arc<dyn Middleware>) {
217 self.middlewares.push(middleware);
218 }
219}
220
221#[async_trait]
222impl Handler for MiddlewareChain {
223 async fn handle(&self, request: Request) -> Result<Response> {
224 if self.middlewares.is_empty() {
225 return self.handler.handle(request).await;
226 }
227
228 let mut current_handler = self.handler.clone();
237
238 let active_middlewares: Vec<_> = self
241 .middlewares
242 .iter()
243 .rev()
244 .filter(|mw| mw.should_continue(&request))
245 .collect();
246
247 for middleware in active_middlewares {
248 let mw = middleware.clone();
249 let handler = current_handler.clone();
250
251 current_handler = Arc::new(ConditionalComposedHandler {
252 middleware: mw,
253 next: handler,
254 });
255 }
256
257 current_handler.handle(request).await
258 }
259}
260
261struct ConditionalComposedHandler {
265 middleware: Arc<dyn Middleware>,
266 next: Arc<dyn Handler>,
267}
268
269#[async_trait]
270impl Handler for ConditionalComposedHandler {
271 async fn handle(&self, request: Request) -> Result<Response> {
272 let response = self.middleware.process(request, self.next.clone()).await?;
274
275 if response.should_stop_chain() {
278 return Ok(response);
279 }
280
281 Ok(response)
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use bytes::Bytes;
289 use hyper::{HeaderMap, Method, Version};
290
291 struct MockHandler {
293 response_body: String,
294 }
295
296 #[async_trait]
297 impl Handler for MockHandler {
298 async fn handle(&self, _request: Request) -> Result<Response> {
299 Ok(Response::ok().with_body(self.response_body.clone()))
300 }
301 }
302
303 struct MockMiddleware {
305 prefix: String,
306 }
307
308 #[async_trait]
309 impl Middleware for MockMiddleware {
310 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
311 let response = next.handle(request).await?;
313
314 let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
316 let new_body = format!("{}{}", self.prefix, current_body);
317
318 Ok(Response::ok().with_body(new_body))
319 }
320 }
321
322 fn create_test_request() -> Request {
323 Request::builder()
324 .method(Method::GET)
325 .uri("/")
326 .version(Version::HTTP_11)
327 .headers(HeaderMap::new())
328 .body(Bytes::new())
329 .build()
330 .unwrap()
331 }
332
333 #[tokio::test]
334 async fn test_handler_basic() {
335 let handler = MockHandler {
336 response_body: "Hello".to_string(),
337 };
338
339 let request = create_test_request();
340 let response = handler.handle(request).await.unwrap();
341
342 let body = String::from_utf8(response.body.to_vec()).unwrap();
343 assert_eq!(body, "Hello");
344 }
345
346 #[tokio::test]
347 async fn test_middleware_basic() {
348 let handler = Arc::new(MockHandler {
349 response_body: "World".to_string(),
350 });
351
352 let middleware = MockMiddleware {
353 prefix: "Hello, ".to_string(),
354 };
355
356 let request = create_test_request();
357 let response = middleware.process(request, handler).await.unwrap();
358
359 let body = String::from_utf8(response.body.to_vec()).unwrap();
360 assert_eq!(body, "Hello, World");
361 }
362
363 #[tokio::test]
364 async fn test_middleware_chain_empty() {
365 let handler = Arc::new(MockHandler {
366 response_body: "Test".to_string(),
367 });
368
369 let chain = MiddlewareChain::new(handler);
370
371 let request = create_test_request();
372 let response = chain.handle(request).await.unwrap();
373
374 let body = String::from_utf8(response.body.to_vec()).unwrap();
375 assert_eq!(body, "Test");
376 }
377
378 #[tokio::test]
379 async fn test_middleware_chain_single() {
380 let handler = Arc::new(MockHandler {
381 response_body: "Handler".to_string(),
382 });
383
384 let middleware1 = Arc::new(MockMiddleware {
385 prefix: "MW1:".to_string(),
386 });
387
388 let chain = MiddlewareChain::new(handler).with_middleware(middleware1);
389
390 let request = create_test_request();
391 let response = chain.handle(request).await.unwrap();
392
393 let body = String::from_utf8(response.body.to_vec()).unwrap();
394 assert_eq!(body, "MW1:Handler");
395 }
396
397 #[tokio::test]
398 async fn test_middleware_chain_multiple() {
399 let handler = Arc::new(MockHandler {
400 response_body: "Data".to_string(),
401 });
402
403 let middleware1 = Arc::new(MockMiddleware {
404 prefix: "M1:".to_string(),
405 });
406
407 let middleware2 = Arc::new(MockMiddleware {
408 prefix: "M2:".to_string(),
409 });
410
411 let chain = MiddlewareChain::new(handler)
412 .with_middleware(middleware1)
413 .with_middleware(middleware2);
414
415 let request = create_test_request();
416 let response = chain.handle(request).await.unwrap();
417
418 let body = String::from_utf8(response.body.to_vec()).unwrap();
419 assert_eq!(body, "M1:M2:Data");
421 }
422
423 #[tokio::test]
424 async fn test_middleware_chain_add_middleware() {
425 let handler = Arc::new(MockHandler {
426 response_body: "Result".to_string(),
427 });
428
429 let middleware = Arc::new(MockMiddleware {
430 prefix: "Prefix:".to_string(),
431 });
432
433 let mut chain = MiddlewareChain::new(handler);
434 chain.add_middleware(middleware);
435
436 let request = create_test_request();
437 let response = chain.handle(request).await.unwrap();
438
439 let body = String::from_utf8(response.body.to_vec()).unwrap();
440 assert_eq!(body, "Prefix:Result");
441 }
442
443 struct ConditionalMiddleware {
445 prefix: String,
446 }
447
448 #[async_trait]
449 impl Middleware for ConditionalMiddleware {
450 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
451 let response = next.handle(request).await?;
452 let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
453 let new_body = format!("{}{}", self.prefix, current_body);
454 Ok(Response::ok().with_body(new_body))
455 }
456
457 fn should_continue(&self, request: &Request) -> bool {
458 request.uri.path().starts_with("/api/")
459 }
460 }
461
462 #[tokio::test]
463 async fn test_middleware_conditional_skip() {
464 let handler = Arc::new(MockHandler {
465 response_body: "Response".to_string(),
466 });
467
468 let conditional_mw = Arc::new(ConditionalMiddleware {
469 prefix: "API:".to_string(),
470 });
471
472 let chain = MiddlewareChain::new(handler).with_middleware(conditional_mw);
473
474 let api_request = Request::builder()
476 .method(Method::GET)
477 .uri("/api/users")
478 .version(Version::HTTP_11)
479 .headers(HeaderMap::new())
480 .body(Bytes::new())
481 .build()
482 .unwrap();
483 let response = chain.handle(api_request).await.unwrap();
484 let body = String::from_utf8(response.body.to_vec()).unwrap();
485 assert_eq!(body, "API:Response");
486
487 let non_api_request = Request::builder()
489 .method(Method::GET)
490 .uri("/public")
491 .version(Version::HTTP_11)
492 .headers(HeaderMap::new())
493 .body(Bytes::new())
494 .build()
495 .unwrap();
496 let response = chain.handle(non_api_request).await.unwrap();
497 let body = String::from_utf8(response.body.to_vec()).unwrap();
498 assert_eq!(body, "Response"); }
500
501 struct ShortCircuitMiddleware {
503 should_stop: bool,
504 }
505
506 #[async_trait]
507 impl Middleware for ShortCircuitMiddleware {
508 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
509 if self.should_stop {
510 return Ok(Response::unauthorized()
512 .with_body("Auth required")
513 .with_stop_chain(true));
514 }
515 next.handle(request).await
516 }
517 }
518
519 #[tokio::test]
520 async fn test_middleware_short_circuit() {
521 let handler = Arc::new(MockHandler {
522 response_body: "Handler Response".to_string(),
523 });
524
525 let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: true });
526 let normal_mw = Arc::new(MockMiddleware {
527 prefix: "Normal:".to_string(),
528 });
529
530 let chain = MiddlewareChain::new(handler)
531 .with_middleware(short_circuit_mw)
532 .with_middleware(normal_mw);
533
534 let request = create_test_request();
535 let response = chain.handle(request).await.unwrap();
536
537 assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
539 let body = String::from_utf8(response.body.to_vec()).unwrap();
540 assert_eq!(body, "Auth required");
541 }
542
543 #[tokio::test]
544 async fn test_middleware_no_short_circuit() {
545 let handler = Arc::new(MockHandler {
546 response_body: "Handler Response".to_string(),
547 });
548
549 let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: false });
550 let normal_mw = Arc::new(MockMiddleware {
551 prefix: "Normal:".to_string(),
552 });
553
554 let chain = MiddlewareChain::new(handler)
555 .with_middleware(short_circuit_mw)
556 .with_middleware(normal_mw);
557
558 let request = create_test_request();
559 let response = chain.handle(request).await.unwrap();
560
561 assert_eq!(response.status, hyper::StatusCode::OK);
563 let body = String::from_utf8(response.body.to_vec()).unwrap();
564 assert_eq!(body, "Normal:Handler Response");
565 }
566
567 #[tokio::test]
568 async fn test_middleware_multiple_conditions() {
569 let handler = Arc::new(MockHandler {
570 response_body: "Base".to_string(),
571 });
572
573 let api_mw = Arc::new(ConditionalMiddleware {
575 prefix: "API:".to_string(),
576 });
577
578 let always_mw = Arc::new(MockMiddleware {
580 prefix: "Always:".to_string(),
581 });
582
583 let chain = MiddlewareChain::new(handler)
584 .with_middleware(api_mw)
585 .with_middleware(always_mw);
586
587 let api_request = Request::builder()
589 .method(Method::GET)
590 .uri("/api/test")
591 .version(Version::HTTP_11)
592 .headers(HeaderMap::new())
593 .body(Bytes::new())
594 .build()
595 .unwrap();
596 let response = chain.handle(api_request).await.unwrap();
597 let body = String::from_utf8(response.body.to_vec()).unwrap();
598 assert_eq!(body, "API:Always:Base");
599
600 let non_api_request = Request::builder()
602 .method(Method::GET)
603 .uri("/public")
604 .version(Version::HTTP_11)
605 .headers(HeaderMap::new())
606 .body(Bytes::new())
607 .build()
608 .unwrap();
609 let response = chain.handle(non_api_request).await.unwrap();
610 let body = String::from_utf8(response.body.to_vec()).unwrap();
611 assert_eq!(body, "Always:Base"); }
613
614 #[tokio::test]
615 async fn test_response_should_stop_chain() {
616 let response = Response::ok();
617 assert!(!response.should_stop_chain());
618
619 let stopping_response = Response::unauthorized().with_stop_chain(true);
620 assert!(stopping_response.should_stop_chain());
621 }
622}