ferro_cache/
invalidator.rs1use crate::cache::Cache;
51use ferro_events::{global_dispatcher, Event, EventDispatcher};
52use std::sync::Arc;
53
54pub fn register_invalidator_on<E, F>(dispatcher: &EventDispatcher, cache: Arc<Cache>, key_fn: F)
80where
81 E: Event,
82 F: Fn(&E) -> Vec<String> + Send + Sync + 'static,
83{
84 let key_fn = Arc::new(key_fn);
85 dispatcher.on::<E, _, _>(move |event: E| {
86 let cache = cache.clone();
87 let key_fn = Arc::clone(&key_fn);
88 async move {
89 let tags = key_fn(&event);
90 for tag in tags {
91 if let Err(e) = cache.tags(&[tag.as_str()]).flush().await {
92 tracing::warn!(
93 error = %e,
94 tag = %tag,
95 "ferro-cache invalidator: tag flush failed"
96 );
97 }
98 }
99 Ok(())
102 }
103 });
104}
105
106pub fn register_invalidator<E, F>(cache: Arc<Cache>, key_fn: F)
118where
119 E: Event,
120 F: Fn(&E) -> Vec<String> + Send + Sync + 'static,
121{
122 register_invalidator_on::<E, F>(global_dispatcher(), cache, key_fn);
123}
124
125#[cfg(all(test, feature = "memory"))]
126mod tests {
127 use super::*;
128 use crate::Cache;
129 use ferro_events::Event;
130 use std::sync::atomic::{AtomicUsize, Ordering};
131 use std::time::Duration;
132
133 #[derive(Clone)]
137 struct EvtFlushSingle {
138 product: i64,
139 }
140 impl Event for EvtFlushSingle {
141 fn name(&self) -> &'static str {
142 "EvtFlushSingle"
143 }
144 }
145
146 #[tokio::test]
147 async fn flushes_matching_tag() {
148 let cache = Arc::new(Cache::memory());
149
150 cache
151 .tags(&["business:1:product:7"])
152 .put(
153 "availability:foo",
154 &"slot-grid-blob",
155 Duration::from_secs(60),
156 )
157 .await
158 .unwrap();
159 assert!(
160 cache
161 .tags(&["business:1:product:7"])
162 .has("availability:foo")
163 .await
164 .unwrap(),
165 "precondition: entry exists before invalidator runs"
166 );
167
168 register_invalidator::<EvtFlushSingle, _>(cache.clone(), |e| {
169 vec![format!("business:1:product:{}", e.product)]
170 });
171
172 EvtFlushSingle { product: 7 }.dispatch().await.unwrap();
173
174 assert!(
175 !cache
176 .tags(&["business:1:product:7"])
177 .has("availability:foo")
178 .await
179 .unwrap(),
180 "entry should be evicted after matching event"
181 );
182 }
183
184 #[derive(Clone)]
185 struct EvtFlushNonMatching {
186 product: i64,
187 }
188 impl Event for EvtFlushNonMatching {
189 fn name(&self) -> &'static str {
190 "EvtFlushNonMatching"
191 }
192 }
193
194 #[tokio::test]
195 async fn does_not_flush_unrelated_tags() {
196 let cache = Arc::new(Cache::memory());
197
198 cache
199 .tags(&["business:1:product:7"])
200 .put("a", &"kept", Duration::from_secs(60))
201 .await
202 .unwrap();
203 cache
204 .tags(&["business:1:product:99"])
205 .put("b", &"evicted", Duration::from_secs(60))
206 .await
207 .unwrap();
208
209 register_invalidator::<EvtFlushNonMatching, _>(cache.clone(), |e| {
210 vec![format!("business:1:product:{}", e.product)]
211 });
212
213 EvtFlushNonMatching { product: 99 }
214 .dispatch()
215 .await
216 .unwrap();
217
218 assert!(
219 cache
220 .tags(&["business:1:product:7"])
221 .has("a")
222 .await
223 .unwrap(),
224 "unrelated tag must survive"
225 );
226 assert!(
227 !cache
228 .tags(&["business:1:product:99"])
229 .has("b")
230 .await
231 .unwrap(),
232 "matching tag must be evicted"
233 );
234 }
235
236 #[derive(Clone)]
237 struct EvtMultiInvalidator;
238 impl Event for EvtMultiInvalidator {
239 fn name(&self) -> &'static str {
240 "EvtMultiInvalidator"
241 }
242 }
243
244 #[tokio::test]
245 async fn all_registered_invalidators_run() {
246 let cache = Arc::new(Cache::memory());
247
248 cache
250 .tags(&["scope:a"])
251 .put("k", &"va", Duration::from_secs(60))
252 .await
253 .unwrap();
254 cache
255 .tags(&["scope:b"])
256 .put("k", &"vb", Duration::from_secs(60))
257 .await
258 .unwrap();
259
260 let calls = Arc::new(AtomicUsize::new(0));
261 let calls_a = Arc::clone(&calls);
262 let calls_b = Arc::clone(&calls);
263
264 register_invalidator::<EvtMultiInvalidator, _>(cache.clone(), move |_e| {
265 calls_a.fetch_add(1, Ordering::SeqCst);
266 vec!["scope:a".to_string()]
267 });
268 register_invalidator::<EvtMultiInvalidator, _>(cache.clone(), move |_e| {
269 calls_b.fetch_add(1, Ordering::SeqCst);
270 vec!["scope:b".to_string()]
271 });
272
273 EvtMultiInvalidator.dispatch().await.unwrap();
274
275 assert_eq!(calls.load(Ordering::SeqCst), 2, "both key_fns should run");
276 assert!(!cache.tags(&["scope:a"]).has("k").await.unwrap());
277 assert!(!cache.tags(&["scope:b"]).has("k").await.unwrap());
278 }
279
280 #[derive(Clone)]
281 struct EvtEmptyTags;
282 impl Event for EvtEmptyTags {
283 fn name(&self) -> &'static str {
284 "EvtEmptyTags"
285 }
286 }
287
288 #[tokio::test]
289 async fn empty_tag_set_is_a_noop() {
290 let cache = Arc::new(Cache::memory());
291 cache
292 .tags(&["t"])
293 .put("k", &"v", Duration::from_secs(60))
294 .await
295 .unwrap();
296
297 register_invalidator::<EvtEmptyTags, _>(cache.clone(), |_e| Vec::new());
298
299 EvtEmptyTags.dispatch().await.unwrap();
300
301 assert!(
302 cache.tags(&["t"]).has("k").await.unwrap(),
303 "empty tag list must not flush anything"
304 );
305 }
306
307 #[derive(Clone)]
308 struct EvtLocalDispatcher {
309 product: i64,
310 }
311 impl Event for EvtLocalDispatcher {
312 fn name(&self) -> &'static str {
313 "EvtLocalDispatcher"
314 }
315 }
316
317 #[tokio::test]
318 async fn register_invalidator_on_arbitrary_dispatcher() {
319 use ferro_events::EventDispatcher;
320
321 let wired_dispatcher = EventDispatcher::new();
324 let untouched_dispatcher = EventDispatcher::new();
325
326 let cache = Arc::new(Cache::memory());
327 cache
328 .tags(&["business:1:product:7"])
329 .put("k", &"v", Duration::from_secs(60))
330 .await
331 .unwrap();
332
333 register_invalidator_on::<EvtLocalDispatcher, _>(&wired_dispatcher, cache.clone(), |e| {
334 vec![format!("business:1:product:{}", e.product)]
335 });
336
337 untouched_dispatcher
339 .dispatch(EvtLocalDispatcher { product: 7 })
340 .await
341 .unwrap();
342 assert!(
343 cache
344 .tags(&["business:1:product:7"])
345 .has("k")
346 .await
347 .unwrap(),
348 "untouched dispatcher must not trigger the invalidator"
349 );
350
351 wired_dispatcher
353 .dispatch(EvtLocalDispatcher { product: 7 })
354 .await
355 .unwrap();
356 assert!(
357 !cache
358 .tags(&["business:1:product:7"])
359 .has("k")
360 .await
361 .unwrap(),
362 "wired dispatcher must trigger the invalidator"
363 );
364 }
365}