use std::sync::Arc;
use async_trait::async_trait;
use crate::commit::TenantId;
#[async_trait]
pub trait TombstoneProvider: Send + Sync {
async fn is_tombstoned(&self, tenant_id: TenantId, rid: &str) -> bool;
async fn any_tombstoned(&self, tenant_id: TenantId, rids: &[String]) -> bool {
for rid in rids {
if self.is_tombstoned(tenant_id, rid).await {
return true;
}
}
false
}
}
pub struct NoopTombstoneProvider;
#[async_trait]
impl TombstoneProvider for NoopTombstoneProvider {
async fn is_tombstoned(&self, _tenant_id: TenantId, _rid: &str) -> bool {
false
}
}
#[async_trait]
pub trait Cache<K, V>: Send + Sync
where
K: Send + Sync,
V: Send + Sync + Clone,
{
async fn get(&self, key: &K) -> Option<V>;
async fn put(&self, key: K, value: V);
async fn invalidate(&self, key: &K);
async fn clear(&self);
async fn len(&self) -> usize;
}
pub trait RidKeyed {
fn tenant_id(&self) -> TenantId;
fn rids(&self) -> Vec<String>;
}
pub struct TombstoneAwareCache<K, V, C>
where
K: Send + Sync,
V: Send + Sync + Clone + RidKeyed,
C: Cache<K, V>,
{
inner: C,
tombstones: Arc<dyn TombstoneProvider>,
_phantom: std::marker::PhantomData<(K, V)>,
}
impl<K, V, C> TombstoneAwareCache<K, V, C>
where
K: Send + Sync,
V: Send + Sync + Clone + RidKeyed,
C: Cache<K, V>,
{
pub fn new(inner: C, tombstones: Arc<dyn TombstoneProvider>) -> Self {
Self {
inner,
tombstones,
_phantom: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<K, V, C> Cache<K, V> for TombstoneAwareCache<K, V, C>
where
K: Send + Sync,
V: Send + Sync + Clone + RidKeyed,
C: Cache<K, V>,
{
async fn get(&self, key: &K) -> Option<V> {
let value = self.inner.get(key).await?;
let rids = value.rids();
if rids.is_empty() {
return Some(value);
}
let tenant = value.tenant_id();
if self.tombstones.any_tombstoned(tenant, &rids).await {
self.inner.invalidate(key).await;
return None;
}
Some(value)
}
async fn put(&self, key: K, value: V) {
self.inner.put(key, value).await;
}
async fn invalidate(&self, key: &K) {
self.inner.invalidate(key).await;
}
async fn clear(&self) {
self.inner.clear().await;
}
async fn len(&self) -> usize {
self.inner.len().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use parking_lot::RwLock;
use std::collections::HashMap;
struct MemCache<K: Eq + std::hash::Hash + Clone + Send + Sync, V: Send + Sync + Clone> {
inner: RwLock<HashMap<K, V>>,
}
impl<K: Eq + std::hash::Hash + Clone + Send + Sync, V: Send + Sync + Clone> MemCache<K, V> {
fn new() -> Self {
Self {
inner: RwLock::new(HashMap::new()),
}
}
}
#[async_trait]
impl<K, V> Cache<K, V> for MemCache<K, V>
where
K: Eq + std::hash::Hash + Clone + Send + Sync,
V: Send + Sync + Clone,
{
async fn get(&self, key: &K) -> Option<V> {
self.inner.read().get(key).cloned()
}
async fn put(&self, key: K, value: V) {
self.inner.write().insert(key, value);
}
async fn invalidate(&self, key: &K) {
self.inner.write().remove(key);
}
async fn clear(&self) {
self.inner.write().clear();
}
async fn len(&self) -> usize {
self.inner.read().len()
}
}
#[derive(Clone)]
struct CachedQueryResult {
tenant: TenantId,
rids: Vec<String>,
}
impl RidKeyed for CachedQueryResult {
fn tenant_id(&self) -> TenantId {
self.tenant
}
fn rids(&self) -> Vec<String> {
self.rids.clone()
}
}
struct StaticTombstones {
tombstoned: Vec<(TenantId, String)>,
}
#[async_trait]
impl TombstoneProvider for StaticTombstones {
async fn is_tombstoned(&self, tenant_id: TenantId, rid: &str) -> bool {
self.tombstoned
.iter()
.any(|(t, r)| *t == tenant_id && r == rid)
}
}
#[tokio::test]
async fn noop_provider_says_nothing_is_tombstoned() {
let p = NoopTombstoneProvider;
for rid in ["a", "b", "c"] {
assert!(!p.is_tombstoned(TenantId::new(1), rid).await);
}
}
#[tokio::test]
async fn noop_provider_any_tombstoned_returns_false_for_any_input() {
let p = NoopTombstoneProvider;
assert!(
!p.any_tombstoned(TenantId::new(1), &["a".into(), "b".into(), "c".into()])
.await
);
assert!(!p.any_tombstoned(TenantId::new(1), &[]).await);
}
#[tokio::test]
async fn tombstone_aware_cache_serves_when_no_tombstones() {
let inner = MemCache::<String, CachedQueryResult>::new();
let provider: Arc<dyn TombstoneProvider> = Arc::new(NoopTombstoneProvider);
let cache = TombstoneAwareCache::new(inner, provider);
let key = "query_42".to_string();
let value = CachedQueryResult {
tenant: TenantId::new(1),
rids: vec!["mem_a".into(), "mem_b".into()],
};
cache.put(key.clone(), value).await;
let got = cache.get(&key).await;
assert!(got.is_some());
assert_eq!(got.unwrap().rids, vec!["mem_a", "mem_b"]);
}
#[tokio::test]
async fn tombstone_aware_cache_invalidates_when_a_rid_is_tombstoned() {
let inner = MemCache::<String, CachedQueryResult>::new();
let provider: Arc<dyn TombstoneProvider> = Arc::new(StaticTombstones {
tombstoned: vec![(TenantId::new(1), "mem_b".into())],
});
let cache = TombstoneAwareCache::new(inner, provider);
let key = "query_with_b".to_string();
let value = CachedQueryResult {
tenant: TenantId::new(1),
rids: vec!["mem_a".into(), "mem_b".into()],
};
cache.put(key.clone(), value).await;
let got = cache.get(&key).await;
assert!(got.is_none());
let again = cache.get(&key).await;
assert!(again.is_none());
assert_eq!(cache.len().await, 0);
}
#[tokio::test]
async fn tombstone_aware_cache_serves_when_tombstone_is_for_other_tenant() {
let inner = MemCache::<String, CachedQueryResult>::new();
let provider: Arc<dyn TombstoneProvider> = Arc::new(StaticTombstones {
tombstoned: vec![(TenantId::new(2), "mem_b".into())],
});
let cache = TombstoneAwareCache::new(inner, provider);
let value = CachedQueryResult {
tenant: TenantId::new(1),
rids: vec!["mem_a".into(), "mem_b".into()],
};
cache.put("k".to_string(), value).await;
assert!(cache.get(&"k".to_string()).await.is_some());
}
#[tokio::test]
async fn tombstone_aware_cache_skips_check_for_empty_rids() {
let inner = MemCache::<String, CachedQueryResult>::new();
let provider: Arc<dyn TombstoneProvider> = Arc::new(StaticTombstones {
tombstoned: vec![(TenantId::new(1), "anything".into())],
});
let cache = TombstoneAwareCache::new(inner, provider);
let value = CachedQueryResult {
tenant: TenantId::new(1),
rids: vec![],
};
cache.put("k".to_string(), value).await;
assert!(cache.get(&"k".to_string()).await.is_some());
}
#[tokio::test]
async fn tombstone_aware_cache_passes_through_put_invalidate_clear() {
let inner = MemCache::<String, CachedQueryResult>::new();
let provider: Arc<dyn TombstoneProvider> = Arc::new(NoopTombstoneProvider);
let cache = TombstoneAwareCache::new(inner, provider);
let v = CachedQueryResult {
tenant: TenantId::new(1),
rids: vec!["x".into()],
};
cache.put("k1".to_string(), v.clone()).await;
cache.put("k2".to_string(), v.clone()).await;
assert_eq!(cache.len().await, 2);
cache.invalidate(&"k1".to_string()).await;
assert_eq!(cache.len().await, 1);
cache.clear().await;
assert_eq!(cache.len().await, 0);
}
#[tokio::test]
async fn dyn_dispatch_works() {
let provider: Arc<dyn TombstoneProvider> = Arc::new(NoopTombstoneProvider);
assert!(!provider.is_tombstoned(TenantId::new(1), "x").await);
}
}