pub mod config;
pub use config::CacheConfig;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use moka::future::Cache as MokaCache;
pub struct Cache {
inner: MokaCache<String, Arc<Vec<u8>>>,
config: CacheConfig,
source_keys: Mutex<HashMap<String, Vec<String>>>,
}
impl Cache {
#[must_use]
pub fn new(config: CacheConfig) -> Self {
let inner = MokaCache::builder()
.max_capacity(config.max_capacity)
.time_to_live(Duration::from_secs(config.default_ttl_secs))
.build();
Self {
inner,
config,
source_keys: Mutex::new(HashMap::new()),
}
}
#[must_use]
pub fn from_cascade() -> Self {
Self::new(CacheConfig::from_cascade())
}
pub async fn get<T: serde::de::DeserializeOwned>(&self, source: &str, key: &str) -> Option<T> {
let full_key = format!("{source}:{key}");
let bytes = self.inner.get(&full_key).await;
#[cfg(feature = "metrics")]
if bytes.is_some() {
metrics::counter!("dfe_cache_hits_total", "source" => source.to_string()).increment(1);
} else {
metrics::counter!("dfe_cache_misses_total", "source" => source.to_string())
.increment(1);
}
let bytes = bytes?;
serde_json::from_slice(&bytes).ok()
}
pub async fn set<T: serde::Serialize>(&self, source: &str, key: &str, value: T) {
let full_key = format!("{source}:{key}");
let bytes = match serde_json::to_vec(&value) {
Ok(b) => Arc::new(b),
Err(_) => return,
};
let _ttl = self.ttl_for_source(source);
self.inner.insert(full_key.clone(), bytes).await;
#[cfg(feature = "metrics")]
metrics::gauge!("dfe_cache_entries").set(self.inner.entry_count() as f64);
if let Ok(mut keys) = self.source_keys.lock() {
keys.entry(source.to_string()).or_default().push(full_key);
}
}
pub async fn invalidate_source(&self, source: &str) {
let keys = {
let Ok(mut source_keys) = self.source_keys.lock() else {
return;
};
source_keys.remove(source).unwrap_or_default()
};
for key in keys {
self.inner.invalidate(&key).await;
}
#[cfg(feature = "metrics")]
metrics::gauge!("dfe_cache_entries").set(self.inner.entry_count() as f64);
}
pub async fn invalidate(&self, source: &str, key: &str) {
let full_key = format!("{source}:{key}");
self.inner.invalidate(&full_key).await;
}
fn ttl_for_source(&self, source: &str) -> Duration {
self.config.source_ttls.get(source).copied().map_or(
Duration::from_secs(self.config.default_ttl_secs),
Duration::from_secs,
)
}
pub fn entry_count(&self) -> u64 {
self.inner.entry_count()
}
#[must_use]
pub fn config(&self) -> &CacheConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> CacheConfig {
CacheConfig {
max_capacity: 100,
default_ttl_secs: 60,
source_ttls: HashMap::from([("http".into(), 3600), ("db".into(), 1800)]),
}
}
#[tokio::test]
async fn set_and_get() {
let cache = Cache::new(test_config());
cache.set("http", "url1", "value1".to_string()).await;
let result: Option<String> = cache.get("http", "url1").await;
assert_eq!(result.as_deref(), Some("value1"));
}
#[tokio::test]
async fn get_missing_returns_none() {
let cache = Cache::new(test_config());
let result: Option<String> = cache.get("http", "nonexistent").await;
assert!(result.is_none());
}
#[tokio::test]
async fn sources_are_isolated() {
let cache = Cache::new(test_config());
cache.set("http", "key1", "http_value".to_string()).await;
cache.set("db", "key1", "db_value".to_string()).await;
let http: Option<String> = cache.get("http", "key1").await;
let db: Option<String> = cache.get("db", "key1").await;
assert_eq!(http.as_deref(), Some("http_value"));
assert_eq!(db.as_deref(), Some("db_value"));
}
#[tokio::test]
async fn invalidate_source_removes_only_that_source() {
let cache = Cache::new(test_config());
cache.set("http", "url1", "v1".to_string()).await;
cache.set("http", "url2", "v2".to_string()).await;
cache.set("db", "query1", "v3".to_string()).await;
cache.invalidate_source("http").await;
cache.inner.run_pending_tasks().await;
let http1: Option<String> = cache.get("http", "url1").await;
let http2: Option<String> = cache.get("http", "url2").await;
let db1: Option<String> = cache.get("db", "query1").await;
assert!(http1.is_none(), "http url1 should be invalidated");
assert!(http2.is_none(), "http url2 should be invalidated");
assert_eq!(db1.as_deref(), Some("v3"), "db should be preserved");
}
#[tokio::test]
async fn invalidate_single_entry() {
let cache = Cache::new(test_config());
cache.set("http", "url1", "v1".to_string()).await;
cache.set("http", "url2", "v2".to_string()).await;
cache.invalidate("http", "url1").await;
cache.inner.run_pending_tasks().await;
let v1: Option<String> = cache.get("http", "url1").await;
let v2: Option<String> = cache.get("http", "url2").await;
assert!(v1.is_none());
assert_eq!(v2.as_deref(), Some("v2"));
}
#[tokio::test]
async fn entry_count() {
let cache = Cache::new(test_config());
assert_eq!(cache.entry_count(), 0);
cache.set("http", "url1", "v1".to_string()).await;
cache.set("http", "url2", "v2".to_string()).await;
cache.inner.run_pending_tasks().await;
assert_eq!(cache.entry_count(), 2);
}
#[tokio::test]
async fn complex_types() {
let cache = Cache::new(test_config());
let data = serde_json::json!({"name": "test", "values": [1, 2, 3]});
cache.set("db", "query1", data.clone()).await;
let result: Option<serde_json::Value> = cache.get("db", "query1").await;
assert_eq!(result, Some(data));
}
}