tower_resilience_cache/config.rs
1//! Configuration for cache.
2
3use crate::events::CacheEvent;
4use std::hash::Hash;
5use std::sync::Arc;
6use std::time::Duration;
7use tower_resilience_core::{EventListeners, FnListener};
8
9/// Function that extracts a cache key from a request.
10pub type KeyExtractor<Req, K> = Arc<dyn Fn(&Req) -> K + Send + Sync>;
11
12/// Configuration for the cache pattern.
13pub struct CacheConfig<Req, K> {
14 pub(crate) max_size: usize,
15 pub(crate) ttl: Option<Duration>,
16 pub(crate) key_extractor: KeyExtractor<Req, K>,
17 pub(crate) event_listeners: EventListeners<CacheEvent>,
18 pub(crate) name: String,
19}
20
21impl<Req, K> CacheConfig<Req, K>
22where
23 K: Hash + Eq + Clone + Send + 'static,
24{
25 /// Creates a new configuration builder.
26 pub fn builder() -> CacheConfigBuilder<Req, K> {
27 CacheConfigBuilder::new()
28 }
29}
30
31/// Builder for configuring and constructing a cache.
32pub struct CacheConfigBuilder<Req, K> {
33 max_size: usize,
34 ttl: Option<Duration>,
35 key_extractor: Option<KeyExtractor<Req, K>>,
36 event_listeners: EventListeners<CacheEvent>,
37 name: String,
38}
39
40impl<Req, K> CacheConfigBuilder<Req, K>
41where
42 K: Hash + Eq + Clone + Send + 'static,
43{
44 /// Creates a new builder with default values.
45 pub fn new() -> Self {
46 Self {
47 max_size: 100,
48 ttl: None,
49 key_extractor: None,
50 event_listeners: EventListeners::new(),
51 name: String::from("<unnamed>"),
52 }
53 }
54
55 /// Sets the maximum number of entries in the cache.
56 ///
57 /// Default: 100
58 pub fn max_size(mut self, size: usize) -> Self {
59 self.max_size = size;
60 self
61 }
62
63 /// Sets the time-to-live for cached entries.
64 ///
65 /// If set, entries will expire after the specified duration.
66 /// Default: None (no expiration)
67 pub fn ttl(mut self, ttl: Duration) -> Self {
68 self.ttl = Some(ttl);
69 self
70 }
71
72 /// Sets the function that extracts a cache key from a request.
73 ///
74 /// This function must be provided before building.
75 pub fn key_extractor<F>(mut self, f: F) -> Self
76 where
77 F: Fn(&Req) -> K + Send + Sync + 'static,
78 {
79 self.key_extractor = Some(Arc::new(f));
80 self
81 }
82
83 /// Sets the name of this cache instance for observability.
84 ///
85 /// Default: `"<unnamed>"`
86 pub fn name(mut self, name: impl Into<String>) -> Self {
87 self.name = name.into();
88 self
89 }
90
91 /// Registers a callback when a cache hit occurs.
92 ///
93 /// A cache hit occurs when a requested entry is found in the cache and has not expired.
94 ///
95 /// # Callback Signature
96 /// `Fn()` - Called with no parameters when a cache hit is detected.
97 ///
98 /// # Example
99 /// ```rust,no_run
100 /// use tower_resilience_cache::CacheConfig;
101 /// use std::sync::atomic::{AtomicUsize, Ordering};
102 /// use std::sync::Arc;
103 ///
104 /// #[derive(Clone, Hash, Eq, PartialEq)]
105 /// struct Request {
106 /// id: String,
107 /// }
108 ///
109 /// let hit_count = Arc::new(AtomicUsize::new(0));
110 /// let counter = Arc::clone(&hit_count);
111 ///
112 /// let config = CacheConfig::<Request, String>::builder()
113 /// .key_extractor(|req| req.id.clone())
114 /// .on_hit(move || {
115 /// let count = counter.fetch_add(1, Ordering::SeqCst);
116 /// println!("Cache hit #{}", count + 1);
117 /// })
118 /// .build();
119 /// ```
120 pub fn on_hit<F>(mut self, f: F) -> Self
121 where
122 F: Fn() + Send + Sync + 'static,
123 {
124 self.event_listeners.add(FnListener::new(move |event| {
125 if matches!(event, CacheEvent::Hit { .. }) {
126 f();
127 }
128 }));
129 self
130 }
131
132 /// Registers a callback when a cache miss occurs.
133 ///
134 /// A cache miss occurs when a requested entry is not found in the cache or has expired.
135 /// The underlying service will be called to fetch the value, which will then be cached.
136 ///
137 /// # Callback Signature
138 /// `Fn()` - Called with no parameters when a cache miss is detected.
139 ///
140 /// # Example
141 /// ```rust,no_run
142 /// use tower_resilience_cache::CacheConfig;
143 /// use std::sync::atomic::{AtomicUsize, Ordering};
144 /// use std::sync::Arc;
145 ///
146 /// #[derive(Clone, Hash, Eq, PartialEq)]
147 /// struct Request {
148 /// id: String,
149 /// }
150 ///
151 /// let miss_count = Arc::new(AtomicUsize::new(0));
152 /// let counter = Arc::clone(&miss_count);
153 ///
154 /// let config = CacheConfig::<Request, String>::builder()
155 /// .key_extractor(|req| req.id.clone())
156 /// .on_miss(move || {
157 /// let count = counter.fetch_add(1, Ordering::SeqCst);
158 /// println!("Cache miss #{} - fetching from service", count + 1);
159 /// })
160 /// .build();
161 /// ```
162 pub fn on_miss<F>(mut self, f: F) -> Self
163 where
164 F: Fn() + Send + Sync + 'static,
165 {
166 self.event_listeners.add(FnListener::new(move |event| {
167 if matches!(event, CacheEvent::Miss { .. }) {
168 f();
169 }
170 }));
171 self
172 }
173
174 /// Registers a callback when an entry is evicted from the cache.
175 ///
176 /// Eviction occurs when:
177 /// - The cache reaches its maximum size and needs to make room for new entries
178 /// - An entry expires due to TTL (time-to-live) configuration
179 ///
180 /// # Callback Signature
181 /// `Fn()` - Called with no parameters when a cache eviction occurs.
182 ///
183 /// # Example
184 /// ```rust,no_run
185 /// use tower_resilience_cache::CacheConfig;
186 /// use std::sync::atomic::{AtomicUsize, Ordering};
187 /// use std::sync::Arc;
188 /// use std::time::Duration;
189 ///
190 /// #[derive(Clone, Hash, Eq, PartialEq)]
191 /// struct Request {
192 /// id: String,
193 /// }
194 ///
195 /// let eviction_count = Arc::new(AtomicUsize::new(0));
196 /// let counter = Arc::clone(&eviction_count);
197 ///
198 /// let config = CacheConfig::<Request, String>::builder()
199 /// .key_extractor(|req| req.id.clone())
200 /// .max_size(100)
201 /// .ttl(Duration::from_secs(300))
202 /// .on_eviction(move || {
203 /// let count = counter.fetch_add(1, Ordering::SeqCst);
204 /// println!("Entry evicted (total: {})", count + 1);
205 /// })
206 /// .build();
207 /// ```
208 pub fn on_eviction<F>(mut self, f: F) -> Self
209 where
210 F: Fn() + Send + Sync + 'static,
211 {
212 self.event_listeners.add(FnListener::new(move |event| {
213 if matches!(event, CacheEvent::Eviction { .. }) {
214 f();
215 }
216 }));
217 self
218 }
219
220 /// Builds the cache layer.
221 ///
222 /// # Panics
223 ///
224 /// Panics if `key_extractor` was not set.
225 pub fn build(self) -> crate::CacheLayer<Req, K> {
226 let key_extractor = self
227 .key_extractor
228 .expect("key_extractor must be set before building");
229
230 let config = CacheConfig {
231 max_size: self.max_size,
232 ttl: self.ttl,
233 key_extractor,
234 event_listeners: self.event_listeners,
235 name: self.name,
236 };
237
238 crate::CacheLayer::new(config)
239 }
240}
241
242impl<Req, K> Default for CacheConfigBuilder<Req, K>
243where
244 K: Hash + Eq + Clone + Send + 'static,
245{
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[derive(Clone, Hash, Eq, PartialEq)]
256 struct TestRequest {
257 id: String,
258 }
259
260 #[test]
261 fn test_builder_defaults() {
262 let _layer = CacheConfig::<TestRequest, String>::builder()
263 .key_extractor(|req| req.id.clone())
264 .build();
265 // If this compiles and doesn't panic, the builder works
266 }
267
268 #[test]
269 fn test_builder_custom_values() {
270 let _layer = CacheConfig::<TestRequest, String>::builder()
271 .max_size(500)
272 .ttl(Duration::from_secs(60))
273 .key_extractor(|req| req.id.clone())
274 .name("my-cache")
275 .build();
276 // If this compiles and doesn't panic, the builder works
277 }
278
279 #[test]
280 fn test_event_listeners() {
281 let _layer = CacheConfig::<TestRequest, String>::builder()
282 .key_extractor(|req| req.id.clone())
283 .on_hit(|| {})
284 .on_miss(|| {})
285 .on_eviction(|| {})
286 .build();
287 // If this compiles and doesn't panic, the event listener registration works
288 }
289
290 #[test]
291 #[should_panic(expected = "key_extractor must be set")]
292 fn test_builder_panics_without_key_extractor() {
293 let _config = CacheConfig::<TestRequest, String>::builder().build();
294 }
295}