Skip to main content

reinhardt_middleware/
cache.rs

1//! Cache Middleware
2//!
3//! Provides caching for HTTP responses.
4//! Supports various cache backends (memory, Redis, file).
5
6use async_trait::async_trait;
7use hyper::StatusCode;
8use reinhardt_http::{Handler, Middleware, Request, Response, Result};
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13use std::time::{Duration, Instant};
14
15/// Cache Entry
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CacheEntry {
18	/// Status code
19	status: u16,
20	/// Headers
21	headers: HashMap<String, String>,
22	/// Body
23	body: Vec<u8>,
24	/// Cached timestamp
25	#[serde(skip)]
26	cached_at: Option<Instant>,
27	/// TTL (seconds)
28	ttl_secs: u64,
29}
30
31impl CacheEntry {
32	/// Create a new entry
33	fn new(response: &Response, ttl: Duration) -> Self {
34		let mut headers = HashMap::new();
35		for (key, value) in response.headers.iter() {
36			if let Ok(value_str) = value.to_str() {
37				headers.insert(key.to_string(), value_str.to_string());
38			}
39		}
40
41		Self {
42			status: response.status.as_u16(),
43			headers,
44			body: response.body.to_vec(),
45			cached_at: Some(Instant::now()),
46			ttl_secs: ttl.as_secs(),
47		}
48	}
49
50	/// Check if expired
51	fn is_expired(&self) -> bool {
52		if let Some(cached_at) = self.cached_at {
53			cached_at.elapsed().as_secs() >= self.ttl_secs
54		} else {
55			true
56		}
57	}
58
59	/// Convert to response
60	fn to_response(&self) -> Response {
61		let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::OK);
62		let mut response = Response::new(status).with_body(self.body.clone());
63
64		for (key, value) in &self.headers {
65			if let (Ok(header_name), Ok(header_value)) =
66				(hyper::header::HeaderName::try_from(key), value.parse())
67			{
68				response.headers.insert(header_name, header_value);
69			}
70		}
71
72		// Add cache header
73		response.headers.insert(
74			hyper::header::HeaderName::from_static("x-cache"),
75			hyper::header::HeaderValue::from_static("HIT"),
76		);
77
78		response
79	}
80}
81
82/// Cache Storage
83#[derive(Debug, Default)]
84pub struct CacheStore {
85	/// Entries
86	entries: RwLock<HashMap<String, CacheEntry>>,
87}
88
89impl CacheStore {
90	/// Create a new store
91	pub fn new() -> Self {
92		Self::default()
93	}
94
95	/// Get an entry
96	pub fn get(&self, key: &str) -> Option<CacheEntry> {
97		let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
98		entries.get(key).cloned()
99	}
100
101	/// Set an entry
102	pub fn set(&self, key: String, entry: CacheEntry) {
103		let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
104		entries.insert(key, entry);
105	}
106
107	/// Delete an entry
108	pub fn delete(&self, key: &str) {
109		let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
110		entries.remove(key);
111	}
112
113	/// Clean up expired entries
114	pub fn cleanup(&self) {
115		let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
116		entries.retain(|_, entry| !entry.is_expired());
117	}
118
119	/// Clear the store
120	pub fn clear(&self) {
121		let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
122		entries.clear();
123	}
124
125	/// Get the number of entries
126	pub fn len(&self) -> usize {
127		let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
128		entries.len()
129	}
130
131	/// Check if the store is empty
132	pub fn is_empty(&self) -> bool {
133		let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
134		entries.is_empty()
135	}
136}
137
138/// Cache key generation strategy
139#[derive(Debug, Clone, Copy)]
140pub enum CacheKeyStrategy {
141	/// URL only
142	UrlOnly,
143	/// URL and method
144	UrlAndMethod,
145	/// URL and query parameters
146	UrlAndQuery,
147	/// URL and headers
148	UrlAndHeaders,
149}
150
151/// Cache configuration
152#[non_exhaustive]
153#[derive(Debug, Clone)]
154pub struct CacheConfig {
155	/// Default TTL
156	pub default_ttl: Duration,
157	/// Cache key generation strategy
158	pub key_strategy: CacheKeyStrategy,
159	/// Cacheable methods
160	pub cacheable_methods: Vec<String>,
161	/// Cacheable status codes
162	pub cacheable_status_codes: Vec<u16>,
163	/// Paths to exclude
164	pub exclude_paths: Vec<String>,
165	/// Maximum cache size
166	pub max_entries: Option<usize>,
167}
168
169impl CacheConfig {
170	/// Create a new configuration
171	///
172	/// # Examples
173	///
174	/// ```
175	/// use std::time::Duration;
176	/// use reinhardt_middleware::cache::{CacheConfig, CacheKeyStrategy};
177	///
178	/// let config = CacheConfig::new(Duration::from_secs(300), CacheKeyStrategy::UrlOnly);
179	/// assert_eq!(config.default_ttl, Duration::from_secs(300));
180	/// ```
181	pub fn new(default_ttl: Duration, key_strategy: CacheKeyStrategy) -> Self {
182		Self {
183			default_ttl,
184			key_strategy,
185			cacheable_methods: vec!["GET".to_string(), "HEAD".to_string()],
186			cacheable_status_codes: vec![200, 203, 204, 206, 300, 301, 404, 405, 410, 414, 501],
187			exclude_paths: Vec::new(),
188			max_entries: Some(1000),
189		}
190	}
191
192	/// Set cacheable methods
193	///
194	/// # Examples
195	///
196	/// ```
197	/// use std::time::Duration;
198	/// use reinhardt_middleware::cache::{CacheConfig, CacheKeyStrategy};
199	///
200	/// let config = CacheConfig::new(Duration::from_secs(300), CacheKeyStrategy::UrlOnly)
201	///     .with_cacheable_methods(vec!["GET".to_string()]);
202	/// ```
203	pub fn with_cacheable_methods(mut self, methods: Vec<String>) -> Self {
204		self.cacheable_methods = methods;
205		self
206	}
207
208	/// Add paths to exclude
209	///
210	/// # Examples
211	///
212	/// ```
213	/// use std::time::Duration;
214	/// use reinhardt_middleware::cache::{CacheConfig, CacheKeyStrategy};
215	///
216	/// let config = CacheConfig::new(Duration::from_secs(300), CacheKeyStrategy::UrlOnly)
217	///     .with_excluded_paths(vec!["/admin".to_string()]);
218	/// ```
219	pub fn with_excluded_paths(mut self, paths: Vec<String>) -> Self {
220		self.exclude_paths.extend(paths);
221		self
222	}
223
224	/// Set maximum number of entries
225	///
226	/// # Examples
227	///
228	/// ```
229	/// use std::time::Duration;
230	/// use reinhardt_middleware::cache::{CacheConfig, CacheKeyStrategy};
231	///
232	/// let config = CacheConfig::new(Duration::from_secs(300), CacheKeyStrategy::UrlOnly)
233	///     .with_max_entries(5000);
234	/// ```
235	pub fn with_max_entries(mut self, max_entries: usize) -> Self {
236		self.max_entries = Some(max_entries);
237		self
238	}
239}
240
241impl Default for CacheConfig {
242	fn default() -> Self {
243		Self::new(Duration::from_secs(300), CacheKeyStrategy::UrlOnly)
244	}
245}
246
247/// Cache Middleware
248///
249/// # Examples
250///
251/// ```
252/// use std::sync::Arc;
253/// use std::time::Duration;
254/// use reinhardt_middleware::cache::{CacheMiddleware, CacheConfig, CacheKeyStrategy};
255/// use reinhardt_http::{Handler, Middleware, Request, Response};
256/// use hyper::{StatusCode, Method, Version, HeaderMap};
257/// use bytes::Bytes;
258///
259/// struct TestHandler;
260///
261/// #[async_trait::async_trait]
262/// impl Handler for TestHandler {
263///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
264///         Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
265///     }
266/// }
267///
268/// # tokio_test::block_on(async {
269/// let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly);
270/// let middleware = CacheMiddleware::new(config);
271/// let handler = Arc::new(TestHandler);
272///
273/// let request = Request::builder()
274///     .method(Method::GET)
275///     .uri("/api/data")
276///     .version(Version::HTTP_11)
277///     .headers(HeaderMap::new())
278///     .body(Bytes::new())
279///     .build()
280///     .unwrap();
281///
282/// let response = middleware.process(request, handler).await.unwrap();
283/// assert_eq!(response.status, StatusCode::OK);
284/// # });
285/// ```
286pub struct CacheMiddleware {
287	config: CacheConfig,
288	store: Arc<CacheStore>,
289}
290
291impl CacheMiddleware {
292	/// Create a new cache middleware
293	///
294	/// # Examples
295	///
296	/// ```
297	/// use std::time::Duration;
298	/// use reinhardt_middleware::cache::{CacheMiddleware, CacheConfig, CacheKeyStrategy};
299	///
300	/// let config = CacheConfig::new(Duration::from_secs(300), CacheKeyStrategy::UrlOnly);
301	/// let middleware = CacheMiddleware::new(config);
302	/// ```
303	pub fn new(config: CacheConfig) -> Self {
304		Self {
305			config,
306			store: Arc::new(CacheStore::new()),
307		}
308	}
309
310	/// Create with default configuration
311	pub fn with_defaults() -> Self {
312		Self::new(CacheConfig::default())
313	}
314
315	/// Create from an existing Arc-wrapped cache store
316	///
317	/// This is provided for cases where you already have an `Arc<CacheStore>`.
318	/// In most cases, you should use `new()` instead, which creates the store internally.
319	pub fn from_arc(config: CacheConfig, store: Arc<CacheStore>) -> Self {
320		Self { config, store }
321	}
322
323	/// Get a reference to the cache store
324	///
325	/// # Examples
326	///
327	/// ```
328	/// use std::time::Duration;
329	/// use reinhardt_middleware::cache::{CacheMiddleware, CacheConfig, CacheKeyStrategy};
330	///
331	/// let middleware = CacheMiddleware::new(
332	///     CacheConfig::new(Duration::from_secs(300), CacheKeyStrategy::UrlOnly)
333	/// );
334	///
335	/// // Access the store
336	/// let store = middleware.store();
337	/// assert_eq!(store.len(), 0);
338	/// ```
339	pub fn store(&self) -> &CacheStore {
340		&self.store
341	}
342
343	/// Get a cloned Arc of the store (for cases where you need ownership)
344	///
345	/// In most cases, you should use `store()` instead to get a reference.
346	pub fn store_arc(&self) -> Arc<CacheStore> {
347		Arc::clone(&self.store)
348	}
349
350	/// Check if path should be excluded
351	fn should_exclude(&self, path: &str) -> bool {
352		self.config
353			.exclude_paths
354			.iter()
355			.any(|p| path.starts_with(p))
356	}
357
358	/// Check if method is cacheable
359	fn is_cacheable_method(&self, method: &str) -> bool {
360		self.config.cacheable_methods.iter().any(|m| m == method)
361	}
362
363	/// Check if status code is cacheable
364	fn is_cacheable_status(&self, status: u16) -> bool {
365		self.config.cacheable_status_codes.contains(&status)
366	}
367
368	/// Generate cache key
369	fn generate_cache_key(&self, request: &Request) -> String {
370		let base = match self.config.key_strategy {
371			CacheKeyStrategy::UrlOnly => request.uri.path().to_string(),
372			CacheKeyStrategy::UrlAndMethod => {
373				format!("{}:{}", request.method.as_str(), request.uri.path())
374			}
375			CacheKeyStrategy::UrlAndQuery => {
376				let query = request.uri.query().unwrap_or("");
377				format!(
378					"{}:{}?{}",
379					request.method.as_str(),
380					request.uri.path(),
381					query
382				)
383			}
384			CacheKeyStrategy::UrlAndHeaders => {
385				let headers_str = request
386					.headers
387					.iter()
388					.map(|(k, v)| format!("{}={}", k, v.to_str().unwrap_or("")))
389					.collect::<Vec<_>>()
390					.join("&");
391				format!(
392					"{}:{}:{}",
393					request.method.as_str(),
394					request.uri.path(),
395					headers_str
396				)
397			}
398		};
399
400		// Hash with SHA256
401		let mut hasher = Sha256::new();
402		hasher.update(base.as_bytes());
403		let result = hasher.finalize();
404		hex::encode(result)
405	}
406}
407
408impl Default for CacheMiddleware {
409	fn default() -> Self {
410		Self::with_defaults()
411	}
412}
413
414#[async_trait]
415impl Middleware for CacheMiddleware {
416	async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
417		let path = request.uri.path().to_string();
418		let method = request.method.as_str().to_string();
419
420		// Skip excluded paths
421		if self.should_exclude(&path) {
422			return handler.handle(request).await;
423		}
424
425		// Skip non-cacheable methods
426		if !self.is_cacheable_method(&method) {
427			return handler.handle(request).await;
428		}
429
430		// Generate cache key
431		let cache_key = self.generate_cache_key(&request);
432
433		// Check cache
434		if let Some(entry) = self.store.get(&cache_key) {
435			if !entry.is_expired() {
436				// Cache hit
437				return Ok(entry.to_response());
438			} else {
439				// Delete expired entry
440				self.store.delete(&cache_key);
441			}
442		}
443
444		// Call handler
445		let response = handler.handle(request).await?;
446
447		// Save to cache if status code is cacheable
448		if self.is_cacheable_status(response.status.as_u16()) {
449			let entry = CacheEntry::new(&response, self.config.default_ttl);
450			self.store.set(cache_key, entry);
451
452			// Clean up expired entries if max entries exceeded
453			if let Some(max_entries) = self.config.max_entries
454				&& self.store.len() > max_entries
455			{
456				self.store.cleanup();
457			}
458		}
459
460		// Add X-Cache header
461		let mut response = response;
462		response.headers.insert(
463			hyper::header::HeaderName::from_static("x-cache"),
464			hyper::header::HeaderValue::from_static("MISS"),
465		);
466
467		Ok(response)
468	}
469}
470
471#[cfg(test)]
472mod tests {
473	use super::*;
474	use bytes::Bytes;
475	use hyper::{HeaderMap, Method, StatusCode, Version};
476
477	struct TestHandler {
478		status: StatusCode,
479		call_count: Arc<RwLock<usize>>,
480	}
481
482	impl TestHandler {
483		fn new(status: StatusCode) -> Self {
484			Self {
485				status,
486				call_count: Arc::new(RwLock::new(0)),
487			}
488		}
489
490		fn get_call_count(&self) -> usize {
491			*self.call_count.read().unwrap()
492		}
493	}
494
495	#[async_trait]
496	impl Handler for TestHandler {
497		async fn handle(&self, _request: Request) -> Result<Response> {
498			*self.call_count.write().unwrap() += 1;
499			Ok(Response::new(self.status).with_body(Bytes::from("OK")))
500		}
501	}
502
503	#[tokio::test]
504	async fn test_cache_miss() {
505		let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly);
506		let middleware = CacheMiddleware::new(config);
507		let handler = Arc::new(TestHandler::new(StatusCode::OK));
508
509		let request = Request::builder()
510			.method(Method::GET)
511			.uri("/test")
512			.version(Version::HTTP_11)
513			.headers(HeaderMap::new())
514			.body(Bytes::new())
515			.build()
516			.unwrap();
517
518		let response = middleware.process(request, handler).await.unwrap();
519
520		assert_eq!(response.status, StatusCode::OK);
521		assert_eq!(response.headers.get("x-cache").unwrap(), "MISS");
522	}
523
524	#[tokio::test]
525	async fn test_cache_hit() {
526		let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly);
527		let middleware = Arc::new(CacheMiddleware::new(config));
528		let handler = Arc::new(TestHandler::new(StatusCode::OK));
529
530		// First request (cache miss)
531		let request1 = Request::builder()
532			.method(Method::GET)
533			.uri("/test")
534			.version(Version::HTTP_11)
535			.headers(HeaderMap::new())
536			.body(Bytes::new())
537			.build()
538			.unwrap();
539		let response1 = middleware.process(request1, handler.clone()).await.unwrap();
540		assert_eq!(response1.headers.get("x-cache").unwrap(), "MISS");
541		assert_eq!(handler.get_call_count(), 1);
542
543		// Second request (cache hit)
544		let request2 = Request::builder()
545			.method(Method::GET)
546			.uri("/test")
547			.version(Version::HTTP_11)
548			.headers(HeaderMap::new())
549			.body(Bytes::new())
550			.build()
551			.unwrap();
552		let response2 = middleware.process(request2, handler.clone()).await.unwrap();
553		assert_eq!(response2.headers.get("x-cache").unwrap(), "HIT");
554		assert_eq!(handler.get_call_count(), 1); // Handler is not called
555	}
556
557	#[tokio::test]
558	async fn test_cache_expiration() {
559		let config = CacheConfig::new(Duration::from_millis(100), CacheKeyStrategy::UrlOnly);
560		let middleware = Arc::new(CacheMiddleware::new(config));
561		let handler = Arc::new(TestHandler::new(StatusCode::OK));
562
563		// First request
564		let request1 = Request::builder()
565			.method(Method::GET)
566			.uri("/test")
567			.version(Version::HTTP_11)
568			.headers(HeaderMap::new())
569			.body(Bytes::new())
570			.build()
571			.unwrap();
572		let _response1 = middleware.process(request1, handler.clone()).await.unwrap();
573
574		// Wait for expiration
575		std::thread::sleep(Duration::from_millis(150));
576
577		// Request after expiration (cache miss)
578		let request2 = Request::builder()
579			.method(Method::GET)
580			.uri("/test")
581			.version(Version::HTTP_11)
582			.headers(HeaderMap::new())
583			.body(Bytes::new())
584			.build()
585			.unwrap();
586		let response2 = middleware.process(request2, handler.clone()).await.unwrap();
587		assert_eq!(response2.headers.get("x-cache").unwrap(), "MISS");
588		assert_eq!(handler.get_call_count(), 2);
589	}
590
591	#[tokio::test]
592	async fn test_non_cacheable_method() {
593		let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly);
594		let middleware = CacheMiddleware::new(config);
595		let handler = Arc::new(TestHandler::new(StatusCode::OK));
596
597		let request = Request::builder()
598			.method(Method::POST)
599			.uri("/test")
600			.version(Version::HTTP_11)
601			.headers(HeaderMap::new())
602			.body(Bytes::new())
603			.build()
604			.unwrap();
605
606		let response = middleware.process(request, handler).await.unwrap();
607
608		assert_eq!(response.status, StatusCode::OK);
609		assert!(!response.headers.contains_key("x-cache"));
610	}
611
612	#[tokio::test]
613	async fn test_exclude_paths() {
614		let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly)
615			.with_excluded_paths(vec!["/admin".to_string()]);
616		let middleware = CacheMiddleware::new(config);
617		let handler = Arc::new(TestHandler::new(StatusCode::OK));
618
619		let request = Request::builder()
620			.method(Method::GET)
621			.uri("/admin/users")
622			.version(Version::HTTP_11)
623			.headers(HeaderMap::new())
624			.body(Bytes::new())
625			.build()
626			.unwrap();
627
628		let response = middleware.process(request, handler).await.unwrap();
629
630		assert_eq!(response.status, StatusCode::OK);
631		assert!(!response.headers.contains_key("x-cache"));
632	}
633
634	#[tokio::test]
635	async fn test_different_urls() {
636		let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly);
637		let middleware = Arc::new(CacheMiddleware::new(config));
638		let handler = Arc::new(TestHandler::new(StatusCode::OK));
639
640		// Request to /test1
641		let request1 = Request::builder()
642			.method(Method::GET)
643			.uri("/test1")
644			.version(Version::HTTP_11)
645			.headers(HeaderMap::new())
646			.body(Bytes::new())
647			.build()
648			.unwrap();
649		let _response1 = middleware.process(request1, handler.clone()).await.unwrap();
650
651		// Request to /test2 (different cache entry)
652		let request2 = Request::builder()
653			.method(Method::GET)
654			.uri("/test2")
655			.version(Version::HTTP_11)
656			.headers(HeaderMap::new())
657			.body(Bytes::new())
658			.build()
659			.unwrap();
660		let response2 = middleware.process(request2, handler.clone()).await.unwrap();
661
662		assert_eq!(response2.headers.get("x-cache").unwrap(), "MISS");
663		assert_eq!(handler.get_call_count(), 2);
664	}
665
666	#[tokio::test]
667	async fn test_cache_store() {
668		let store = CacheStore::new();
669
670		let response = Response::new(StatusCode::OK).with_body(Bytes::from("test"));
671		let entry = CacheEntry::new(&response, Duration::from_secs(60));
672
673		store.set("key1".to_string(), entry.clone());
674
675		assert_eq!(store.len(), 1);
676		assert!(!store.is_empty());
677
678		let retrieved = store.get("key1").unwrap();
679		assert_eq!(retrieved.status, 200);
680		assert_eq!(retrieved.body, b"test");
681	}
682
683	#[tokio::test]
684	async fn test_cache_cleanup() {
685		let store = CacheStore::new();
686
687		let response = Response::new(StatusCode::OK).with_body(Bytes::from("test"));
688		let mut entry = CacheEntry::new(&response, Duration::from_millis(10));
689		entry.cached_at = Some(Instant::now() - Duration::from_millis(20));
690
691		store.set("key1".to_string(), entry);
692
693		store.cleanup();
694
695		assert_eq!(store.len(), 0);
696		assert!(store.is_empty());
697	}
698
699	#[tokio::test]
700	async fn test_multiple_status_codes_cached() {
701		let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlOnly);
702		let middleware = Arc::new(CacheMiddleware::new(config));
703
704		// Test with 404 status (cached by default)
705		let handler_404 = Arc::new(TestHandler::new(StatusCode::NOT_FOUND));
706		let request1 = Request::builder()
707			.method(Method::GET)
708			.uri("/not-found")
709			.version(Version::HTTP_11)
710			.headers(HeaderMap::new())
711			.body(Bytes::new())
712			.build()
713			.unwrap();
714		let response1 = middleware
715			.process(request1, handler_404.clone())
716			.await
717			.unwrap();
718		assert_eq!(response1.status, StatusCode::NOT_FOUND);
719		assert_eq!(response1.headers.get("x-cache").unwrap(), "MISS");
720		assert_eq!(handler_404.get_call_count(), 1);
721
722		// Second request to same 404 URL (cache hit)
723		let request1b = Request::builder()
724			.method(Method::GET)
725			.uri("/not-found")
726			.version(Version::HTTP_11)
727			.headers(HeaderMap::new())
728			.body(Bytes::new())
729			.build()
730			.unwrap();
731		let response1b = middleware
732			.process(request1b, handler_404.clone())
733			.await
734			.unwrap();
735		assert_eq!(response1b.status, StatusCode::NOT_FOUND);
736		assert_eq!(response1b.headers.get("x-cache").unwrap(), "HIT");
737		assert_eq!(handler_404.get_call_count(), 1); // Not called again
738
739		// Test with 500 status (also cached by default)
740		let handler_500 = Arc::new(TestHandler::new(StatusCode::INTERNAL_SERVER_ERROR));
741		let request2 = Request::builder()
742			.method(Method::GET)
743			.uri("/error")
744			.version(Version::HTTP_11)
745			.headers(HeaderMap::new())
746			.body(Bytes::new())
747			.build()
748			.unwrap();
749		let response2 = middleware
750			.process(request2, handler_500.clone())
751			.await
752			.unwrap();
753		assert_eq!(response2.status, StatusCode::INTERNAL_SERVER_ERROR);
754		assert_eq!(response2.headers.get("x-cache").unwrap(), "MISS");
755	}
756
757	#[tokio::test]
758	async fn test_cache_key_strategy_url_and_method() {
759		let config = CacheConfig::new(Duration::from_secs(60), CacheKeyStrategy::UrlAndMethod);
760		let middleware = Arc::new(CacheMiddleware::new(config));
761		let handler = Arc::new(TestHandler::new(StatusCode::OK));
762
763		// GET request to /api
764		let request1 = Request::builder()
765			.method(Method::GET)
766			.uri("/api")
767			.version(Version::HTTP_11)
768			.headers(HeaderMap::new())
769			.body(Bytes::new())
770			.build()
771			.unwrap();
772		let response1 = middleware.process(request1, handler.clone()).await.unwrap();
773		assert_eq!(response1.headers.get("x-cache").unwrap(), "MISS");
774		assert_eq!(handler.get_call_count(), 1);
775
776		// HEAD request to same URL (different cache key due to method)
777		let handler2 = Arc::new(TestHandler::new(StatusCode::OK));
778		let request2 = Request::builder()
779			.method(Method::HEAD)
780			.uri("/api")
781			.version(Version::HTTP_11)
782			.headers(HeaderMap::new())
783			.body(Bytes::new())
784			.build()
785			.unwrap();
786		let response2 = middleware
787			.process(request2, handler2.clone())
788			.await
789			.unwrap();
790		// Different method should result in cache miss
791		assert_eq!(response2.headers.get("x-cache").unwrap(), "MISS");
792		assert_eq!(handler2.get_call_count(), 1);
793	}
794
795	#[rstest::rstest]
796	fn test_rwlock_poison_recovery_cache_store() {
797		// Arrange
798		let store = Arc::new(CacheStore::new());
799
800		// Act - poison the RwLock by panicking while holding a write guard
801		let store_clone = Arc::clone(&store);
802		let _ = std::thread::spawn(move || {
803			let _guard = store_clone.entries.write().unwrap();
804			panic!("intentional panic to poison lock");
805		})
806		.join();
807
808		// Assert - operations still work after poison recovery
809		let response = Response::new(StatusCode::OK).with_body(Bytes::from("test"));
810		let entry = CacheEntry::new(&response, Duration::from_secs(60));
811		store.set("key1".to_string(), entry);
812		assert_eq!(store.len(), 1);
813		assert!(!store.is_empty());
814		assert!(store.get("key1").is_some());
815		store.delete("key1");
816		assert_eq!(store.len(), 0);
817	}
818}