1mod config;
43mod error;
44mod events;
45mod eviction;
46mod layer;
47mod shared_layer;
48mod store;
49
50pub use config::{CacheConfig, CacheConfigBuilder, KeyExtractor};
51pub use error::CacheError;
52pub use events::CacheEvent;
53pub use eviction::EvictionPolicy;
54pub use layer::CacheLayer;
55pub use shared_layer::{SharedCacheConfigBuilder, SharedCacheLayer};
56
57use futures::future::BoxFuture;
58use std::hash::Hash;
59use std::sync::{Arc, Mutex};
60use std::task::{Context, Poll};
61use std::time::Instant;
62use store::CacheStore;
63use tower::Service;
64
65#[cfg(feature = "metrics")]
66use metrics::{counter, describe_counter, describe_gauge, gauge};
67
68#[cfg(feature = "tracing")]
69use tracing::{debug, info};
70
71pub struct Cache<S, Req, K, Resp> {
80 inner: S,
81 config: Arc<CacheConfig<Req, K>>,
82 store: Arc<Mutex<CacheStore<K, Resp>>>,
83}
84
85impl<S, Req, K, Resp> Cache<S, Req, K, Resp>
86where
87 K: Hash + Eq + Clone + Send + 'static,
88 Resp: Clone + Send + 'static,
89{
90 pub fn new(inner: S, config: Arc<CacheConfig<Req, K>>) -> Self {
92 #[cfg(feature = "metrics")]
93 {
94 describe_counter!(
95 "cache_requests_total",
96 "Total number of cache requests (hits and misses)"
97 );
98 describe_counter!("cache_evictions_total", "Total number of cache evictions");
99 describe_gauge!("cache_size", "Current number of entries in the cache");
100 }
101
102 let store = Arc::new(Mutex::new(CacheStore::new(
103 config.max_size,
104 config.ttl,
105 config.eviction_policy,
106 )));
107 Self {
108 inner,
109 config,
110 store,
111 }
112 }
113
114 pub(crate) fn with_store(
119 inner: S,
120 config: Arc<CacheConfig<Req, K>>,
121 store: Arc<Mutex<CacheStore<K, Resp>>>,
122 ) -> Self {
123 #[cfg(feature = "metrics")]
124 {
125 describe_counter!(
126 "cache_requests_total",
127 "Total number of cache requests (hits and misses)"
128 );
129 describe_counter!("cache_evictions_total", "Total number of cache evictions");
130 describe_gauge!("cache_size", "Current number of entries in the cache");
131 }
132
133 Self {
134 inner,
135 config,
136 store,
137 }
138 }
139}
140
141impl<S, Req, K, Resp> Clone for Cache<S, Req, K, Resp>
142where
143 S: Clone,
144{
145 fn clone(&self) -> Self {
146 Self {
147 inner: self.inner.clone(),
148 config: Arc::clone(&self.config),
149 store: Arc::clone(&self.store),
150 }
151 }
152}
153
154impl<S, Req, K> Service<Req> for Cache<S, Req, K, S::Response>
155where
156 S: Service<Req>,
157 S::Response: Clone + Send + 'static,
158 K: Hash + Eq + Clone + Send + 'static,
159 Req: Send + 'static,
160 S::Future: Send + 'static,
161{
162 type Response = S::Response;
163 type Error = CacheError<S::Error>;
164 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
165
166 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
167 self.inner.poll_ready(cx).map_err(CacheError::Inner)
168 }
169
170 fn call(&mut self, req: Req) -> Self::Future {
171 let key = (self.config.key_extractor)(&req);
172 let cache_name = self.config.name.clone();
173
174 let cached = {
176 let mut store = self.store.lock().unwrap();
177 store.get(&key)
178 };
179
180 if let Some(response) = cached {
181 #[cfg(feature = "metrics")]
183 {
184 counter!("cache_requests_total", "cache" => cache_name.clone(), "result" => "hit")
185 .increment(1);
186 }
187
188 #[cfg(feature = "tracing")]
189 debug!(cache = %cache_name, "Cache hit");
190
191 let event = CacheEvent::Hit {
192 pattern_name: cache_name,
193 timestamp: Instant::now(),
194 };
195 self.config.event_listeners.emit(&event);
196 return Box::pin(async move { Ok(response) });
197 }
198
199 #[cfg(feature = "metrics")]
201 {
202 counter!("cache_requests_total", "cache" => cache_name.clone(), "result" => "miss")
203 .increment(1);
204 }
205
206 #[cfg(feature = "tracing")]
207 debug!(cache = %cache_name, "Cache miss");
208
209 let miss_event = CacheEvent::Miss {
210 pattern_name: cache_name.clone(),
211 timestamp: Instant::now(),
212 };
213 self.config.event_listeners.emit(&miss_event);
214
215 let future = self.inner.call(req);
216 let store = Arc::clone(&self.store);
217 let config = Arc::clone(&self.config);
218
219 Box::pin(async move {
220 let response = future.await.map_err(CacheError::Inner)?;
221
222 let was_evicted = {
224 let mut store = store.lock().unwrap();
225 let was_full = store.len() >= config.max_size;
226 store.insert(key, response.clone());
227
228 #[cfg(feature = "metrics")]
230 {
231 let new_size = store.len();
232 gauge!("cache_size", "cache" => config.name.clone()).set(new_size as f64);
233 }
234
235 was_full
236 };
237
238 if was_evicted {
239 #[cfg(feature = "metrics")]
240 {
241 counter!("cache_evictions_total", "cache" => config.name.clone()).increment(1);
242 }
243
244 #[cfg(feature = "tracing")]
245 info!(cache = %config.name, "Cache eviction occurred");
246
247 let event = CacheEvent::Eviction {
248 pattern_name: config.name.clone(),
249 timestamp: Instant::now(),
250 };
251 config.event_listeners.emit(&event);
252 }
253
254 Ok(response)
255 })
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use std::sync::atomic::{AtomicUsize, Ordering};
263 use std::time::Duration;
264 use tower::service_fn;
265 use tower::Layer;
266 use tower::ServiceExt;
267
268 #[tokio::test]
269 async fn cache_hit_returns_cached_response() {
270 let call_count = Arc::new(AtomicUsize::new(0));
271 let cc = Arc::clone(&call_count);
272
273 let service = service_fn(move |req: String| {
274 let cc = Arc::clone(&cc);
275 async move {
276 cc.fetch_add(1, Ordering::SeqCst);
277 Ok::<_, std::io::Error>(format!("Response: {}", req))
278 }
279 });
280
281 let layer = CacheLayer::builder()
282 .max_size(10)
283 .key_extractor(|req: &String| req.clone())
284 .build();
285
286 let mut service = layer.layer(service);
287
288 let response1 = service
290 .ready()
291 .await
292 .unwrap()
293 .call("test".to_string())
294 .await
295 .unwrap();
296 assert_eq!(response1, "Response: test");
297 assert_eq!(call_count.load(Ordering::SeqCst), 1);
298
299 let response2 = service
301 .ready()
302 .await
303 .unwrap()
304 .call("test".to_string())
305 .await
306 .unwrap();
307 assert_eq!(response2, "Response: test");
308 assert_eq!(call_count.load(Ordering::SeqCst), 1); }
310
311 #[tokio::test]
312 async fn cache_miss_calls_inner_service() {
313 let service = service_fn(|req: String| async move {
314 Ok::<_, std::io::Error>(format!("Response: {}", req))
315 });
316
317 let layer = CacheLayer::builder()
318 .max_size(10)
319 .key_extractor(|req: &String| req.clone())
320 .build();
321
322 let mut service = layer.layer(service);
323
324 let response = service
325 .ready()
326 .await
327 .unwrap()
328 .call("test".to_string())
329 .await
330 .unwrap();
331 assert_eq!(response, "Response: test");
332 }
333
334 #[tokio::test]
335 async fn different_keys_not_cached_together() {
336 let call_count = Arc::new(AtomicUsize::new(0));
337 let cc = Arc::clone(&call_count);
338
339 let service = service_fn(move |req: String| {
340 let cc = Arc::clone(&cc);
341 async move {
342 cc.fetch_add(1, Ordering::SeqCst);
343 Ok::<_, std::io::Error>(format!("Response: {}", req))
344 }
345 });
346
347 let layer = CacheLayer::builder()
348 .max_size(10)
349 .key_extractor(|req: &String| req.clone())
350 .build();
351
352 let mut service = layer.layer(service);
353
354 service
355 .ready()
356 .await
357 .unwrap()
358 .call("test1".to_string())
359 .await
360 .unwrap();
361 service
362 .ready()
363 .await
364 .unwrap()
365 .call("test2".to_string())
366 .await
367 .unwrap();
368
369 assert_eq!(call_count.load(Ordering::SeqCst), 2);
370 }
371
372 #[tokio::test]
373 async fn ttl_expiration_causes_cache_miss() {
374 let call_count = Arc::new(AtomicUsize::new(0));
375 let cc = Arc::clone(&call_count);
376
377 let service = service_fn(move |req: String| {
378 let cc = Arc::clone(&cc);
379 async move {
380 cc.fetch_add(1, Ordering::SeqCst);
381 Ok::<_, std::io::Error>(format!("Response: {}", req))
382 }
383 });
384
385 let layer = CacheLayer::builder()
386 .max_size(10)
387 .ttl(Duration::from_millis(50))
388 .key_extractor(|req: &String| req.clone())
389 .build();
390
391 let mut service = layer.layer(service);
392
393 service
394 .ready()
395 .await
396 .unwrap()
397 .call("test".to_string())
398 .await
399 .unwrap();
400 assert_eq!(call_count.load(Ordering::SeqCst), 1);
401
402 tokio::time::sleep(Duration::from_millis(100)).await;
404
405 service
406 .ready()
407 .await
408 .unwrap()
409 .call("test".to_string())
410 .await
411 .unwrap();
412 assert_eq!(call_count.load(Ordering::SeqCst), 2); }
414
415 #[tokio::test]
416 async fn lru_eviction_removes_least_recently_used() {
417 let service = service_fn(|req: String| async move {
418 Ok::<_, std::io::Error>(format!("Response: {}", req))
419 });
420
421 let layer = CacheLayer::builder()
422 .max_size(2)
423 .key_extractor(|req: &String| req.clone())
424 .build();
425
426 let mut service = layer.layer(service);
427
428 service
430 .ready()
431 .await
432 .unwrap()
433 .call("key1".to_string())
434 .await
435 .unwrap();
436 service
437 .ready()
438 .await
439 .unwrap()
440 .call("key2".to_string())
441 .await
442 .unwrap();
443
444 service
446 .ready()
447 .await
448 .unwrap()
449 .call("key3".to_string())
450 .await
451 .unwrap();
452
453 let call_count = Arc::new(AtomicUsize::new(0));
455 let cc = Arc::clone(&call_count);
456
457 let service2 = service_fn(move |req: String| {
458 let cc = Arc::clone(&cc);
459 async move {
460 cc.fetch_add(1, Ordering::SeqCst);
461 Ok::<_, std::io::Error>(format!("Response: {}", req))
462 }
463 });
464
465 let mut service2 = layer.layer(service2);
466
467 service2
469 .ready()
470 .await
471 .unwrap()
472 .call("key1".to_string())
473 .await
474 .unwrap();
475 assert_eq!(call_count.load(Ordering::SeqCst), 1);
476 }
477
478 #[tokio::test]
479 async fn event_listeners_called() {
480 let hit_count = Arc::new(AtomicUsize::new(0));
481 let miss_count = Arc::new(AtomicUsize::new(0));
482 let eviction_count = Arc::new(AtomicUsize::new(0));
483
484 let hc = Arc::clone(&hit_count);
485 let mc = Arc::clone(&miss_count);
486 let ec = Arc::clone(&eviction_count);
487
488 let service = service_fn(|req: String| async move {
489 Ok::<_, std::io::Error>(format!("Response: {}", req))
490 });
491
492 let layer = CacheLayer::builder()
493 .max_size(1)
494 .key_extractor(|req: &String| req.clone())
495 .on_hit(move || {
496 hc.fetch_add(1, Ordering::SeqCst);
497 })
498 .on_miss(move || {
499 mc.fetch_add(1, Ordering::SeqCst);
500 })
501 .on_eviction(move || {
502 ec.fetch_add(1, Ordering::SeqCst);
503 })
504 .build();
505
506 let mut service = layer.layer(service);
507
508 service
510 .ready()
511 .await
512 .unwrap()
513 .call("test".to_string())
514 .await
515 .unwrap();
516 assert_eq!(miss_count.load(Ordering::SeqCst), 1);
517 assert_eq!(hit_count.load(Ordering::SeqCst), 0);
518
519 service
521 .ready()
522 .await
523 .unwrap()
524 .call("test".to_string())
525 .await
526 .unwrap();
527 assert_eq!(hit_count.load(Ordering::SeqCst), 1);
528 assert_eq!(miss_count.load(Ordering::SeqCst), 1);
529
530 service
532 .ready()
533 .await
534 .unwrap()
535 .call("other".to_string())
536 .await
537 .unwrap();
538 assert_eq!(eviction_count.load(Ordering::SeqCst), 1);
539 }
540
541 #[tokio::test]
542 async fn errors_not_cached() {
543 let call_count = Arc::new(AtomicUsize::new(0));
544 let cc = Arc::clone(&call_count);
545
546 let service = service_fn(move |_req: String| {
547 let cc = Arc::clone(&cc);
548 async move {
549 cc.fetch_add(1, Ordering::SeqCst);
550 Err::<String, _>(std::io::Error::other("error"))
551 }
552 });
553
554 let layer = CacheLayer::builder()
555 .max_size(10)
556 .key_extractor(|req: &String| req.clone())
557 .build();
558
559 let mut service = layer.layer(service);
560
561 let _ = service
563 .ready()
564 .await
565 .unwrap()
566 .call("test".to_string())
567 .await;
568 assert_eq!(call_count.load(Ordering::SeqCst), 1);
569
570 let _ = service
572 .ready()
573 .await
574 .unwrap()
575 .call("test".to_string())
576 .await;
577 assert_eq!(call_count.load(Ordering::SeqCst), 2);
578 }
579}