1use async_trait::async_trait;
7use hyper::StatusCode;
8use reinhardt_http::{Handler, Middleware, Request, Response, Result};
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13use std::time::{Duration, Instant};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CacheEntry {
18 status: u16,
20 headers: HashMap<String, String>,
22 body: Vec<u8>,
24 #[serde(skip)]
26 cached_at: Option<Instant>,
27 ttl_secs: u64,
29}
30
31impl CacheEntry {
32 fn new(response: &Response, ttl: Duration) -> Self {
34 let mut headers = HashMap::new();
35 for (key, value) in response.headers.iter() {
36 if let Ok(value_str) = value.to_str() {
37 headers.insert(key.to_string(), value_str.to_string());
38 }
39 }
40
41 Self {
42 status: response.status.as_u16(),
43 headers,
44 body: response.body.to_vec(),
45 cached_at: Some(Instant::now()),
46 ttl_secs: ttl.as_secs(),
47 }
48 }
49
50 fn is_expired(&self) -> bool {
52 if let Some(cached_at) = self.cached_at {
53 cached_at.elapsed().as_secs() >= self.ttl_secs
54 } else {
55 true
56 }
57 }
58
59 fn to_response(&self) -> Response {
61 let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::OK);
62 let mut response = Response::new(status).with_body(self.body.clone());
63
64 for (key, value) in &self.headers {
65 if let (Ok(header_name), Ok(header_value)) =
66 (hyper::header::HeaderName::try_from(key), value.parse())
67 {
68 response.headers.insert(header_name, header_value);
69 }
70 }
71
72 response.headers.insert(
74 hyper::header::HeaderName::from_static("x-cache"),
75 hyper::header::HeaderValue::from_static("HIT"),
76 );
77
78 response
79 }
80}
81
82#[derive(Debug, Default)]
84pub struct CacheStore {
85 entries: RwLock<HashMap<String, CacheEntry>>,
87}
88
89impl CacheStore {
90 pub fn new() -> Self {
92 Self::default()
93 }
94
95 pub fn get(&self, key: &str) -> Option<CacheEntry> {
97 let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
98 entries.get(key).cloned()
99 }
100
101 pub fn set(&self, key: String, entry: CacheEntry) {
103 let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
104 entries.insert(key, entry);
105 }
106
107 pub fn delete(&self, key: &str) {
109 let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
110 entries.remove(key);
111 }
112
113 pub fn cleanup(&self) {
115 let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
116 entries.retain(|_, entry| !entry.is_expired());
117 }
118
119 pub fn clear(&self) {
121 let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
122 entries.clear();
123 }
124
125 pub fn len(&self) -> usize {
127 let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
128 entries.len()
129 }
130
131 pub fn is_empty(&self) -> bool {
133 let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
134 entries.is_empty()
135 }
136}
137
138#[derive(Debug, Clone, Copy)]
140pub enum CacheKeyStrategy {
141 UrlOnly,
143 UrlAndMethod,
145 UrlAndQuery,
147 UrlAndHeaders,
149}
150
151#[non_exhaustive]
153#[derive(Debug, Clone)]
154pub struct CacheConfig {
155 pub default_ttl: Duration,
157 pub key_strategy: CacheKeyStrategy,
159 pub cacheable_methods: Vec<String>,
161 pub cacheable_status_codes: Vec<u16>,
163 pub exclude_paths: Vec<String>,
165 pub max_entries: Option<usize>,
167}
168
169impl CacheConfig {
170 pub fn new(default_ttl: Duration, key_strategy: CacheKeyStrategy) -> Self {
182 Self {
183 default_ttl,
184 key_strategy,
185 cacheable_methods: vec!["GET".to_string(), "HEAD".to_string()],
186 cacheable_status_codes: vec![200, 203, 204, 206, 300, 301, 404, 405, 410, 414, 501],
187 exclude_paths: Vec::new(),
188 max_entries: Some(1000),
189 }
190 }
191
192 pub fn with_cacheable_methods(mut self, methods: Vec<String>) -> Self {
204 self.cacheable_methods = methods;
205 self
206 }
207
208 pub fn with_excluded_paths(mut self, paths: Vec<String>) -> Self {
220 self.exclude_paths.extend(paths);
221 self
222 }
223
224 pub fn with_max_entries(mut self, max_entries: usize) -> Self {
236 self.max_entries = Some(max_entries);
237 self
238 }
239}
240
241impl Default for CacheConfig {
242 fn default() -> Self {
243 Self::new(Duration::from_secs(300), CacheKeyStrategy::UrlOnly)
244 }
245}
246
247pub struct CacheMiddleware {
287 config: CacheConfig,
288 store: Arc<CacheStore>,
289}
290
291impl CacheMiddleware {
292 pub fn new(config: CacheConfig) -> Self {
304 Self {
305 config,
306 store: Arc::new(CacheStore::new()),
307 }
308 }
309
310 pub fn with_defaults() -> Self {
312 Self::new(CacheConfig::default())
313 }
314
315 pub fn from_arc(config: CacheConfig, store: Arc<CacheStore>) -> Self {
320 Self { config, store }
321 }
322
323 pub fn store(&self) -> &CacheStore {
340 &self.store
341 }
342
343 pub fn store_arc(&self) -> Arc<CacheStore> {
347 Arc::clone(&self.store)
348 }
349
350 fn should_exclude(&self, path: &str) -> bool {
352 self.config
353 .exclude_paths
354 .iter()
355 .any(|p| path.starts_with(p))
356 }
357
358 fn is_cacheable_method(&self, method: &str) -> bool {
360 self.config.cacheable_methods.iter().any(|m| m == method)
361 }
362
363 fn is_cacheable_status(&self, status: u16) -> bool {
365 self.config.cacheable_status_codes.contains(&status)
366 }
367
368 fn generate_cache_key(&self, request: &Request) -> String {
370 let base = match self.config.key_strategy {
371 CacheKeyStrategy::UrlOnly => request.uri.path().to_string(),
372 CacheKeyStrategy::UrlAndMethod => {
373 format!("{}:{}", request.method.as_str(), request.uri.path())
374 }
375 CacheKeyStrategy::UrlAndQuery => {
376 let query = request.uri.query().unwrap_or("");
377 format!(
378 "{}:{}?{}",
379 request.method.as_str(),
380 request.uri.path(),
381 query
382 )
383 }
384 CacheKeyStrategy::UrlAndHeaders => {
385 let headers_str = request
386 .headers
387 .iter()
388 .map(|(k, v)| format!("{}={}", k, v.to_str().unwrap_or("")))
389 .collect::<Vec<_>>()
390 .join("&");
391 format!(
392 "{}:{}:{}",
393 request.method.as_str(),
394 request.uri.path(),
395 headers_str
396 )
397 }
398 };
399
400 let mut hasher = Sha256::new();
402 hasher.update(base.as_bytes());
403 let result = hasher.finalize();
404 hex::encode(result)
405 }
406}
407
408impl Default for CacheMiddleware {
409 fn default() -> Self {
410 Self::with_defaults()
411 }
412}
413
414#[async_trait]
415impl Middleware for CacheMiddleware {
416 async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
417 let path = request.uri.path().to_string();
418 let method = request.method.as_str().to_string();
419
420 if self.should_exclude(&path) {
422 return handler.handle(request).await;
423 }
424
425 if !self.is_cacheable_method(&method) {
427 return handler.handle(request).await;
428 }
429
430 let cache_key = self.generate_cache_key(&request);
432
433 if let Some(entry) = self.store.get(&cache_key) {
435 if !entry.is_expired() {
436 return Ok(entry.to_response());
438 } else {
439 self.store.delete(&cache_key);
441 }
442 }
443
444 let response = handler.handle(request).await?;
446
447 if self.is_cacheable_status(response.status.as_u16()) {
449 let entry = CacheEntry::new(&response, self.config.default_ttl);
450 self.store.set(cache_key, entry);
451
452 if let Some(max_entries) = self.config.max_entries
454 && self.store.len() > max_entries
455 {
456 self.store.cleanup();
457 }
458 }
459
460 let mut response = response;
462 response.headers.insert(
463 hyper::header::HeaderName::from_static("x-cache"),
464 hyper::header::HeaderValue::from_static("MISS"),
465 );
466
467 Ok(response)
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use bytes::Bytes;
475 use hyper::{HeaderMap, Method, StatusCode, Version};
476
477 struct TestHandler {
478 status: StatusCode,
479 call_count: Arc<RwLock<usize>>,
480 }
481
482 impl TestHandler {
483 fn new(status: StatusCode) -> Self {
484 Self {
485 status,
486 call_count: Arc::new(RwLock::new(0)),
487 }
488 }
489
490 fn get_call_count(&self) -> usize {
491 *self.call_count.read().unwrap()
492 }
493 }
494
495 #[async_trait]
496 impl Handler for TestHandler {
497 async fn handle(&self, _request: Request) -> Result<Response> {
498 *self.call_count.write().unwrap() += 1;
499 Ok(Response::new(self.status).with_body(Bytes::from("OK")))
500 }
501 }
502
503 #[tokio::test]
504 async fn test_cache_miss() {
505 let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly);
506 let middleware = CacheMiddleware::new(config);
507 let handler = Arc::new(TestHandler::new(StatusCode::OK));
508
509 let request = Request::builder()
510 .method(Method::GET)
511 .uri("/test")
512 .version(Version::HTTP_11)
513 .headers(HeaderMap::new())
514 .body(Bytes::new())
515 .build()
516 .unwrap();
517
518 let response = middleware.process(request, handler).await.unwrap();
519
520 assert_eq!(response.status, StatusCode::OK);
521 assert_eq!(response.headers.get("x-cache").unwrap(), "MISS");
522 }
523
524 #[tokio::test]
525 async fn test_cache_hit() {
526 let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly);
527 let middleware = Arc::new(CacheMiddleware::new(config));
528 let handler = Arc::new(TestHandler::new(StatusCode::OK));
529
530 let request1 = Request::builder()
532 .method(Method::GET)
533 .uri("/test")
534 .version(Version::HTTP_11)
535 .headers(HeaderMap::new())
536 .body(Bytes::new())
537 .build()
538 .unwrap();
539 let response1 = middleware.process(request1, handler.clone()).await.unwrap();
540 assert_eq!(response1.headers.get("x-cache").unwrap(), "MISS");
541 assert_eq!(handler.get_call_count(), 1);
542
543 let request2 = Request::builder()
545 .method(Method::GET)
546 .uri("/test")
547 .version(Version::HTTP_11)
548 .headers(HeaderMap::new())
549 .body(Bytes::new())
550 .build()
551 .unwrap();
552 let response2 = middleware.process(request2, handler.clone()).await.unwrap();
553 assert_eq!(response2.headers.get("x-cache").unwrap(), "HIT");
554 assert_eq!(handler.get_call_count(), 1); }
556
557 #[tokio::test]
558 async fn test_cache_expiration() {
559 let config = CacheConfig::new(Duration::from_millis(100), CacheKeyStrategy::UrlOnly);
560 let middleware = Arc::new(CacheMiddleware::new(config));
561 let handler = Arc::new(TestHandler::new(StatusCode::OK));
562
563 let request1 = Request::builder()
565 .method(Method::GET)
566 .uri("/test")
567 .version(Version::HTTP_11)
568 .headers(HeaderMap::new())
569 .body(Bytes::new())
570 .build()
571 .unwrap();
572 let _response1 = middleware.process(request1, handler.clone()).await.unwrap();
573
574 std::thread::sleep(Duration::from_millis(150));
576
577 let request2 = Request::builder()
579 .method(Method::GET)
580 .uri("/test")
581 .version(Version::HTTP_11)
582 .headers(HeaderMap::new())
583 .body(Bytes::new())
584 .build()
585 .unwrap();
586 let response2 = middleware.process(request2, handler.clone()).await.unwrap();
587 assert_eq!(response2.headers.get("x-cache").unwrap(), "MISS");
588 assert_eq!(handler.get_call_count(), 2);
589 }
590
591 #[tokio::test]
592 async fn test_non_cacheable_method() {
593 let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly);
594 let middleware = CacheMiddleware::new(config);
595 let handler = Arc::new(TestHandler::new(StatusCode::OK));
596
597 let request = Request::builder()
598 .method(Method::POST)
599 .uri("/test")
600 .version(Version::HTTP_11)
601 .headers(HeaderMap::new())
602 .body(Bytes::new())
603 .build()
604 .unwrap();
605
606 let response = middleware.process(request, handler).await.unwrap();
607
608 assert_eq!(response.status, StatusCode::OK);
609 assert!(!response.headers.contains_key("x-cache"));
610 }
611
612 #[tokio::test]
613 async fn test_exclude_paths() {
614 let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly)
615 .with_excluded_paths(vec!["/admin".to_string()]);
616 let middleware = CacheMiddleware::new(config);
617 let handler = Arc::new(TestHandler::new(StatusCode::OK));
618
619 let request = Request::builder()
620 .method(Method::GET)
621 .uri("/admin/users")
622 .version(Version::HTTP_11)
623 .headers(HeaderMap::new())
624 .body(Bytes::new())
625 .build()
626 .unwrap();
627
628 let response = middleware.process(request, handler).await.unwrap();
629
630 assert_eq!(response.status, StatusCode::OK);
631 assert!(!response.headers.contains_key("x-cache"));
632 }
633
634 #[tokio::test]
635 async fn test_different_urls() {
636 let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly);
637 let middleware = Arc::new(CacheMiddleware::new(config));
638 let handler = Arc::new(TestHandler::new(StatusCode::OK));
639
640 let request1 = Request::builder()
642 .method(Method::GET)
643 .uri("/test1")
644 .version(Version::HTTP_11)
645 .headers(HeaderMap::new())
646 .body(Bytes::new())
647 .build()
648 .unwrap();
649 let _response1 = middleware.process(request1, handler.clone()).await.unwrap();
650
651 let request2 = Request::builder()
653 .method(Method::GET)
654 .uri("/test2")
655 .version(Version::HTTP_11)
656 .headers(HeaderMap::new())
657 .body(Bytes::new())
658 .build()
659 .unwrap();
660 let response2 = middleware.process(request2, handler.clone()).await.unwrap();
661
662 assert_eq!(response2.headers.get("x-cache").unwrap(), "MISS");
663 assert_eq!(handler.get_call_count(), 2);
664 }
665
666 #[tokio::test]
667 async fn test_cache_store() {
668 let store = CacheStore::new();
669
670 let response = Response::new(StatusCode::OK).with_body(Bytes::from("test"));
671 let entry = CacheEntry::new(&response, Duration::from_secs(60));
672
673 store.set("key1".to_string(), entry.clone());
674
675 assert_eq!(store.len(), 1);
676 assert!(!store.is_empty());
677
678 let retrieved = store.get("key1").unwrap();
679 assert_eq!(retrieved.status, 200);
680 assert_eq!(retrieved.body, b"test");
681 }
682
683 #[tokio::test]
684 async fn test_cache_cleanup() {
685 let store = CacheStore::new();
686
687 let response = Response::new(StatusCode::OK).with_body(Bytes::from("test"));
688 let mut entry = CacheEntry::new(&response, Duration::from_millis(10));
689 entry.cached_at = Some(Instant::now() - Duration::from_millis(20));
690
691 store.set("key1".to_string(), entry);
692
693 store.cleanup();
694
695 assert_eq!(store.len(), 0);
696 assert!(store.is_empty());
697 }
698
699 #[tokio::test]
700 async fn test_multiple_status_codes_cached() {
701 let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly);
702 let middleware = Arc::new(CacheMiddleware::new(config));
703
704 let handler_404 = Arc::new(TestHandler::new(StatusCode::NOT_FOUND));
706 let request1 = Request::builder()
707 .method(Method::GET)
708 .uri("/not-found")
709 .version(Version::HTTP_11)
710 .headers(HeaderMap::new())
711 .body(Bytes::new())
712 .build()
713 .unwrap();
714 let response1 = middleware
715 .process(request1, handler_404.clone())
716 .await
717 .unwrap();
718 assert_eq!(response1.status, StatusCode::NOT_FOUND);
719 assert_eq!(response1.headers.get("x-cache").unwrap(), "MISS");
720 assert_eq!(handler_404.get_call_count(), 1);
721
722 let request1b = Request::builder()
724 .method(Method::GET)
725 .uri("/not-found")
726 .version(Version::HTTP_11)
727 .headers(HeaderMap::new())
728 .body(Bytes::new())
729 .build()
730 .unwrap();
731 let response1b = middleware
732 .process(request1b, handler_404.clone())
733 .await
734 .unwrap();
735 assert_eq!(response1b.status, StatusCode::NOT_FOUND);
736 assert_eq!(response1b.headers.get("x-cache").unwrap(), "HIT");
737 assert_eq!(handler_404.get_call_count(), 1); let handler_500 = Arc::new(TestHandler::new(StatusCode::INTERNAL_SERVER_ERROR));
741 let request2 = Request::builder()
742 .method(Method::GET)
743 .uri("/error")
744 .version(Version::HTTP_11)
745 .headers(HeaderMap::new())
746 .body(Bytes::new())
747 .build()
748 .unwrap();
749 let response2 = middleware
750 .process(request2, handler_500.clone())
751 .await
752 .unwrap();
753 assert_eq!(response2.status, StatusCode::INTERNAL_SERVER_ERROR);
754 assert_eq!(response2.headers.get("x-cache").unwrap(), "MISS");
755 }
756
757 #[tokio::test]
758 async fn test_cache_key_strategy_url_and_method() {
759 let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlAndMethod);
760 let middleware = Arc::new(CacheMiddleware::new(config));
761 let handler = Arc::new(TestHandler::new(StatusCode::OK));
762
763 let request1 = Request::builder()
765 .method(Method::GET)
766 .uri("/api")
767 .version(Version::HTTP_11)
768 .headers(HeaderMap::new())
769 .body(Bytes::new())
770 .build()
771 .unwrap();
772 let response1 = middleware.process(request1, handler.clone()).await.unwrap();
773 assert_eq!(response1.headers.get("x-cache").unwrap(), "MISS");
774 assert_eq!(handler.get_call_count(), 1);
775
776 let handler2 = Arc::new(TestHandler::new(StatusCode::OK));
778 let request2 = Request::builder()
779 .method(Method::HEAD)
780 .uri("/api")
781 .version(Version::HTTP_11)
782 .headers(HeaderMap::new())
783 .body(Bytes::new())
784 .build()
785 .unwrap();
786 let response2 = middleware
787 .process(request2, handler2.clone())
788 .await
789 .unwrap();
790 assert_eq!(response2.headers.get("x-cache").unwrap(), "MISS");
792 assert_eq!(handler2.get_call_count(), 1);
793 }
794
795 #[rstest::rstest]
796 fn test_rwlock_poison_recovery_cache_store() {
797 let store = Arc::new(CacheStore::new());
799
800 let store_clone = Arc::clone(&store);
802 let _ = std::thread::spawn(move || {
803 let _guard = store_clone.entries.write().unwrap();
804 panic!("intentional panic to poison lock");
805 })
806 .join();
807
808 let response = Response::new(StatusCode::OK).with_body(Bytes::from("test"));
810 let entry = CacheEntry::new(&response, Duration::from_secs(60));
811 store.set("key1".to_string(), entry);
812 assert_eq!(store.len(), 1);
813 assert!(!store.is_empty());
814 assert!(store.get("key1").is_some());
815 store.delete("key1");
816 assert_eq!(store.len(), 0);
817 }
818}