use crate::cache::Cache;
use ferro_events::{global_dispatcher, Event, EventDispatcher};
use std::sync::Arc;
pub fn register_invalidator_on<E, F>(dispatcher: &EventDispatcher, cache: Arc<Cache>, key_fn: F)
where
E: Event,
F: Fn(&E) -> Vec<String> + Send + Sync + 'static,
{
let key_fn = Arc::new(key_fn);
dispatcher.on::<E, _, _>(move |event: E| {
let cache = cache.clone();
let key_fn = Arc::clone(&key_fn);
async move {
let tags = key_fn(&event);
for tag in tags {
if let Err(e) = cache.tags(&[tag.as_str()]).flush().await {
tracing::warn!(
error = %e,
tag = %tag,
"ferro-cache invalidator: tag flush failed"
);
}
}
Ok(())
}
});
}
pub fn register_invalidator<E, F>(cache: Arc<Cache>, key_fn: F)
where
E: Event,
F: Fn(&E) -> Vec<String> + Send + Sync + 'static,
{
register_invalidator_on::<E, F>(global_dispatcher(), cache, key_fn);
}
#[cfg(all(test, feature = "memory"))]
mod tests {
use super::*;
use crate::Cache;
use ferro_events::Event;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
#[derive(Clone)]
struct EvtFlushSingle {
product: i64,
}
impl Event for EvtFlushSingle {
fn name(&self) -> &'static str {
"EvtFlushSingle"
}
}
#[tokio::test]
async fn flushes_matching_tag() {
let cache = Arc::new(Cache::memory());
cache
.tags(&["business:1:product:7"])
.put(
"availability:foo",
&"slot-grid-blob",
Duration::from_secs(60),
)
.await
.unwrap();
assert!(
cache
.tags(&["business:1:product:7"])
.has("availability:foo")
.await
.unwrap(),
"precondition: entry exists before invalidator runs"
);
register_invalidator::<EvtFlushSingle, _>(cache.clone(), |e| {
vec![format!("business:1:product:{}", e.product)]
});
EvtFlushSingle { product: 7 }.dispatch().await.unwrap();
assert!(
!cache
.tags(&["business:1:product:7"])
.has("availability:foo")
.await
.unwrap(),
"entry should be evicted after matching event"
);
}
#[derive(Clone)]
struct EvtFlushNonMatching {
product: i64,
}
impl Event for EvtFlushNonMatching {
fn name(&self) -> &'static str {
"EvtFlushNonMatching"
}
}
#[tokio::test]
async fn does_not_flush_unrelated_tags() {
let cache = Arc::new(Cache::memory());
cache
.tags(&["business:1:product:7"])
.put("a", &"kept", Duration::from_secs(60))
.await
.unwrap();
cache
.tags(&["business:1:product:99"])
.put("b", &"evicted", Duration::from_secs(60))
.await
.unwrap();
register_invalidator::<EvtFlushNonMatching, _>(cache.clone(), |e| {
vec![format!("business:1:product:{}", e.product)]
});
EvtFlushNonMatching { product: 99 }
.dispatch()
.await
.unwrap();
assert!(
cache
.tags(&["business:1:product:7"])
.has("a")
.await
.unwrap(),
"unrelated tag must survive"
);
assert!(
!cache
.tags(&["business:1:product:99"])
.has("b")
.await
.unwrap(),
"matching tag must be evicted"
);
}
#[derive(Clone)]
struct EvtMultiInvalidator;
impl Event for EvtMultiInvalidator {
fn name(&self) -> &'static str {
"EvtMultiInvalidator"
}
}
#[tokio::test]
async fn all_registered_invalidators_run() {
let cache = Arc::new(Cache::memory());
cache
.tags(&["scope:a"])
.put("k", &"va", Duration::from_secs(60))
.await
.unwrap();
cache
.tags(&["scope:b"])
.put("k", &"vb", Duration::from_secs(60))
.await
.unwrap();
let calls = Arc::new(AtomicUsize::new(0));
let calls_a = Arc::clone(&calls);
let calls_b = Arc::clone(&calls);
register_invalidator::<EvtMultiInvalidator, _>(cache.clone(), move |_e| {
calls_a.fetch_add(1, Ordering::SeqCst);
vec!["scope:a".to_string()]
});
register_invalidator::<EvtMultiInvalidator, _>(cache.clone(), move |_e| {
calls_b.fetch_add(1, Ordering::SeqCst);
vec!["scope:b".to_string()]
});
EvtMultiInvalidator.dispatch().await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 2, "both key_fns should run");
assert!(!cache.tags(&["scope:a"]).has("k").await.unwrap());
assert!(!cache.tags(&["scope:b"]).has("k").await.unwrap());
}
#[derive(Clone)]
struct EvtEmptyTags;
impl Event for EvtEmptyTags {
fn name(&self) -> &'static str {
"EvtEmptyTags"
}
}
#[tokio::test]
async fn empty_tag_set_is_a_noop() {
let cache = Arc::new(Cache::memory());
cache
.tags(&["t"])
.put("k", &"v", Duration::from_secs(60))
.await
.unwrap();
register_invalidator::<EvtEmptyTags, _>(cache.clone(), |_e| Vec::new());
EvtEmptyTags.dispatch().await.unwrap();
assert!(
cache.tags(&["t"]).has("k").await.unwrap(),
"empty tag list must not flush anything"
);
}
#[derive(Clone)]
struct EvtLocalDispatcher {
product: i64,
}
impl Event for EvtLocalDispatcher {
fn name(&self) -> &'static str {
"EvtLocalDispatcher"
}
}
#[tokio::test]
async fn register_invalidator_on_arbitrary_dispatcher() {
use ferro_events::EventDispatcher;
let wired_dispatcher = EventDispatcher::new();
let untouched_dispatcher = EventDispatcher::new();
let cache = Arc::new(Cache::memory());
cache
.tags(&["business:1:product:7"])
.put("k", &"v", Duration::from_secs(60))
.await
.unwrap();
register_invalidator_on::<EvtLocalDispatcher, _>(&wired_dispatcher, cache.clone(), |e| {
vec![format!("business:1:product:{}", e.product)]
});
untouched_dispatcher
.dispatch(EvtLocalDispatcher { product: 7 })
.await
.unwrap();
assert!(
cache
.tags(&["business:1:product:7"])
.has("k")
.await
.unwrap(),
"untouched dispatcher must not trigger the invalidator"
);
wired_dispatcher
.dispatch(EvtLocalDispatcher { product: 7 })
.await
.unwrap();
assert!(
!cache
.tags(&["business:1:product:7"])
.has("k")
.await
.unwrap(),
"wired dispatcher must trigger the invalidator"
);
}
}