1mod config;
43mod error;
44mod events;
45mod eviction;
46mod layer;
47mod store;
48
49pub use config::{CacheConfig, CacheConfigBuilder, KeyExtractor};
50pub use error::CacheError;
51pub use events::CacheEvent;
52pub use eviction::EvictionPolicy;
53pub use layer::CacheLayer;
54
55use futures::future::BoxFuture;
56use std::hash::Hash;
57use std::sync::{Arc, Mutex};
58use std::task::{Context, Poll};
59use std::time::Instant;
60use store::CacheStore;
61use tower::Service;
62
63#[cfg(feature = "metrics")]
64use metrics::{counter, describe_counter, describe_gauge, gauge};
65
66#[cfg(feature = "tracing")]
67use tracing::{debug, info};
68
69pub struct Cache<S, Req, K, Resp> {
78 inner: S,
79 config: Arc<CacheConfig<Req, K>>,
80 store: Arc<Mutex<CacheStore<K, Resp>>>,
81}
82
83impl<S, Req, K, Resp> Cache<S, Req, K, Resp>
84where
85 K: Hash + Eq + Clone + Send + 'static,
86 Resp: Clone + Send + 'static,
87{
88 pub fn new(inner: S, config: Arc<CacheConfig<Req, K>>) -> Self {
90 #[cfg(feature = "metrics")]
91 {
92 describe_counter!(
93 "cache_requests_total",
94 "Total number of cache requests (hits and misses)"
95 );
96 describe_counter!("cache_evictions_total", "Total number of cache evictions");
97 describe_gauge!("cache_size", "Current number of entries in the cache");
98 }
99
100 let store = Arc::new(Mutex::new(CacheStore::new(
101 config.max_size,
102 config.ttl,
103 config.eviction_policy,
104 )));
105 Self {
106 inner,
107 config,
108 store,
109 }
110 }
111}
112
113impl<S, Req, K, Resp> Clone for Cache<S, Req, K, Resp>
114where
115 S: Clone,
116{
117 fn clone(&self) -> Self {
118 Self {
119 inner: self.inner.clone(),
120 config: Arc::clone(&self.config),
121 store: Arc::clone(&self.store),
122 }
123 }
124}
125
126impl<S, Req, K> Service<Req> for Cache<S, Req, K, S::Response>
127where
128 S: Service<Req>,
129 S::Response: Clone + Send + 'static,
130 K: Hash + Eq + Clone + Send + 'static,
131 Req: Send + 'static,
132 S::Future: Send + 'static,
133{
134 type Response = S::Response;
135 type Error = CacheError<S::Error>;
136 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
137
138 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139 self.inner.poll_ready(cx).map_err(CacheError::Inner)
140 }
141
142 fn call(&mut self, req: Req) -> Self::Future {
143 let key = (self.config.key_extractor)(&req);
144 let cache_name = self.config.name.clone();
145
146 let cached = {
148 let mut store = self.store.lock().unwrap();
149 store.get(&key)
150 };
151
152 if let Some(response) = cached {
153 #[cfg(feature = "metrics")]
155 {
156 counter!("cache_requests_total", "cache" => cache_name.clone(), "result" => "hit")
157 .increment(1);
158 }
159
160 #[cfg(feature = "tracing")]
161 debug!(cache = %cache_name, "Cache hit");
162
163 let event = CacheEvent::Hit {
164 pattern_name: cache_name,
165 timestamp: Instant::now(),
166 };
167 self.config.event_listeners.emit(&event);
168 return Box::pin(async move { Ok(response) });
169 }
170
171 #[cfg(feature = "metrics")]
173 {
174 counter!("cache_requests_total", "cache" => cache_name.clone(), "result" => "miss")
175 .increment(1);
176 }
177
178 #[cfg(feature = "tracing")]
179 debug!(cache = %cache_name, "Cache miss");
180
181 let miss_event = CacheEvent::Miss {
182 pattern_name: cache_name.clone(),
183 timestamp: Instant::now(),
184 };
185 self.config.event_listeners.emit(&miss_event);
186
187 let future = self.inner.call(req);
188 let store = Arc::clone(&self.store);
189 let config = Arc::clone(&self.config);
190
191 Box::pin(async move {
192 let response = future.await.map_err(CacheError::Inner)?;
193
194 let was_evicted = {
196 let mut store = store.lock().unwrap();
197 let was_full = store.len() >= config.max_size;
198 store.insert(key, response.clone());
199
200 #[cfg(feature = "metrics")]
202 {
203 let new_size = store.len();
204 gauge!("cache_size", "cache" => config.name.clone()).set(new_size as f64);
205 }
206
207 was_full
208 };
209
210 if was_evicted {
211 #[cfg(feature = "metrics")]
212 {
213 counter!("cache_evictions_total", "cache" => config.name.clone()).increment(1);
214 }
215
216 #[cfg(feature = "tracing")]
217 info!(cache = %config.name, "Cache eviction occurred");
218
219 let event = CacheEvent::Eviction {
220 pattern_name: config.name.clone(),
221 timestamp: Instant::now(),
222 };
223 config.event_listeners.emit(&event);
224 }
225
226 Ok(response)
227 })
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use std::sync::atomic::{AtomicUsize, Ordering};
235 use std::time::Duration;
236 use tower::service_fn;
237 use tower::Layer;
238 use tower::ServiceExt;
239
240 #[tokio::test]
241 async fn cache_hit_returns_cached_response() {
242 let call_count = Arc::new(AtomicUsize::new(0));
243 let cc = Arc::clone(&call_count);
244
245 let service = service_fn(move |req: String| {
246 let cc = Arc::clone(&cc);
247 async move {
248 cc.fetch_add(1, Ordering::SeqCst);
249 Ok::<_, std::io::Error>(format!("Response: {}", req))
250 }
251 });
252
253 let layer = CacheLayer::builder()
254 .max_size(10)
255 .key_extractor(|req: &String| req.clone())
256 .build();
257
258 let mut service = layer.layer(service);
259
260 let response1 = service
262 .ready()
263 .await
264 .unwrap()
265 .call("test".to_string())
266 .await
267 .unwrap();
268 assert_eq!(response1, "Response: test");
269 assert_eq!(call_count.load(Ordering::SeqCst), 1);
270
271 let response2 = service
273 .ready()
274 .await
275 .unwrap()
276 .call("test".to_string())
277 .await
278 .unwrap();
279 assert_eq!(response2, "Response: test");
280 assert_eq!(call_count.load(Ordering::SeqCst), 1); }
282
283 #[tokio::test]
284 async fn cache_miss_calls_inner_service() {
285 let service = service_fn(|req: String| async move {
286 Ok::<_, std::io::Error>(format!("Response: {}", req))
287 });
288
289 let layer = CacheLayer::builder()
290 .max_size(10)
291 .key_extractor(|req: &String| req.clone())
292 .build();
293
294 let mut service = layer.layer(service);
295
296 let response = service
297 .ready()
298 .await
299 .unwrap()
300 .call("test".to_string())
301 .await
302 .unwrap();
303 assert_eq!(response, "Response: test");
304 }
305
306 #[tokio::test]
307 async fn different_keys_not_cached_together() {
308 let call_count = Arc::new(AtomicUsize::new(0));
309 let cc = Arc::clone(&call_count);
310
311 let service = service_fn(move |req: String| {
312 let cc = Arc::clone(&cc);
313 async move {
314 cc.fetch_add(1, Ordering::SeqCst);
315 Ok::<_, std::io::Error>(format!("Response: {}", req))
316 }
317 });
318
319 let layer = CacheLayer::builder()
320 .max_size(10)
321 .key_extractor(|req: &String| req.clone())
322 .build();
323
324 let mut service = layer.layer(service);
325
326 service
327 .ready()
328 .await
329 .unwrap()
330 .call("test1".to_string())
331 .await
332 .unwrap();
333 service
334 .ready()
335 .await
336 .unwrap()
337 .call("test2".to_string())
338 .await
339 .unwrap();
340
341 assert_eq!(call_count.load(Ordering::SeqCst), 2);
342 }
343
344 #[tokio::test]
345 async fn ttl_expiration_causes_cache_miss() {
346 let call_count = Arc::new(AtomicUsize::new(0));
347 let cc = Arc::clone(&call_count);
348
349 let service = service_fn(move |req: String| {
350 let cc = Arc::clone(&cc);
351 async move {
352 cc.fetch_add(1, Ordering::SeqCst);
353 Ok::<_, std::io::Error>(format!("Response: {}", req))
354 }
355 });
356
357 let layer = CacheLayer::builder()
358 .max_size(10)
359 .ttl(Duration::from_millis(50))
360 .key_extractor(|req: &String| req.clone())
361 .build();
362
363 let mut service = layer.layer(service);
364
365 service
366 .ready()
367 .await
368 .unwrap()
369 .call("test".to_string())
370 .await
371 .unwrap();
372 assert_eq!(call_count.load(Ordering::SeqCst), 1);
373
374 tokio::time::sleep(Duration::from_millis(100)).await;
376
377 service
378 .ready()
379 .await
380 .unwrap()
381 .call("test".to_string())
382 .await
383 .unwrap();
384 assert_eq!(call_count.load(Ordering::SeqCst), 2); }
386
387 #[tokio::test]
388 async fn lru_eviction_removes_least_recently_used() {
389 let service = service_fn(|req: String| async move {
390 Ok::<_, std::io::Error>(format!("Response: {}", req))
391 });
392
393 let layer = CacheLayer::builder()
394 .max_size(2)
395 .key_extractor(|req: &String| req.clone())
396 .build();
397
398 let mut service = layer.layer(service);
399
400 service
402 .ready()
403 .await
404 .unwrap()
405 .call("key1".to_string())
406 .await
407 .unwrap();
408 service
409 .ready()
410 .await
411 .unwrap()
412 .call("key2".to_string())
413 .await
414 .unwrap();
415
416 service
418 .ready()
419 .await
420 .unwrap()
421 .call("key3".to_string())
422 .await
423 .unwrap();
424
425 let call_count = Arc::new(AtomicUsize::new(0));
427 let cc = Arc::clone(&call_count);
428
429 let service2 = service_fn(move |req: String| {
430 let cc = Arc::clone(&cc);
431 async move {
432 cc.fetch_add(1, Ordering::SeqCst);
433 Ok::<_, std::io::Error>(format!("Response: {}", req))
434 }
435 });
436
437 let mut service2 = layer.layer(service2);
438
439 service2
441 .ready()
442 .await
443 .unwrap()
444 .call("key1".to_string())
445 .await
446 .unwrap();
447 assert_eq!(call_count.load(Ordering::SeqCst), 1);
448 }
449
450 #[tokio::test]
451 async fn event_listeners_called() {
452 let hit_count = Arc::new(AtomicUsize::new(0));
453 let miss_count = Arc::new(AtomicUsize::new(0));
454 let eviction_count = Arc::new(AtomicUsize::new(0));
455
456 let hc = Arc::clone(&hit_count);
457 let mc = Arc::clone(&miss_count);
458 let ec = Arc::clone(&eviction_count);
459
460 let service = service_fn(|req: String| async move {
461 Ok::<_, std::io::Error>(format!("Response: {}", req))
462 });
463
464 let layer = CacheLayer::builder()
465 .max_size(1)
466 .key_extractor(|req: &String| req.clone())
467 .on_hit(move || {
468 hc.fetch_add(1, Ordering::SeqCst);
469 })
470 .on_miss(move || {
471 mc.fetch_add(1, Ordering::SeqCst);
472 })
473 .on_eviction(move || {
474 ec.fetch_add(1, Ordering::SeqCst);
475 })
476 .build();
477
478 let mut service = layer.layer(service);
479
480 service
482 .ready()
483 .await
484 .unwrap()
485 .call("test".to_string())
486 .await
487 .unwrap();
488 assert_eq!(miss_count.load(Ordering::SeqCst), 1);
489 assert_eq!(hit_count.load(Ordering::SeqCst), 0);
490
491 service
493 .ready()
494 .await
495 .unwrap()
496 .call("test".to_string())
497 .await
498 .unwrap();
499 assert_eq!(hit_count.load(Ordering::SeqCst), 1);
500 assert_eq!(miss_count.load(Ordering::SeqCst), 1);
501
502 service
504 .ready()
505 .await
506 .unwrap()
507 .call("other".to_string())
508 .await
509 .unwrap();
510 assert_eq!(eviction_count.load(Ordering::SeqCst), 1);
511 }
512
513 #[tokio::test]
514 async fn errors_not_cached() {
515 let call_count = Arc::new(AtomicUsize::new(0));
516 let cc = Arc::clone(&call_count);
517
518 let service = service_fn(move |_req: String| {
519 let cc = Arc::clone(&cc);
520 async move {
521 cc.fetch_add(1, Ordering::SeqCst);
522 Err::<String, _>(std::io::Error::other("error"))
523 }
524 });
525
526 let layer = CacheLayer::builder()
527 .max_size(10)
528 .key_extractor(|req: &String| req.clone())
529 .build();
530
531 let mut service = layer.layer(service);
532
533 let _ = service
535 .ready()
536 .await
537 .unwrap()
538 .call("test".to_string())
539 .await;
540 assert_eq!(call_count.load(Ordering::SeqCst), 1);
541
542 let _ = service
544 .ready()
545 .await
546 .unwrap()
547 .call("test".to_string())
548 .await;
549 assert_eq!(call_count.load(Ordering::SeqCst), 2);
550 }
551}