use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use web_time::SystemTime;
use async_trait::async_trait;
use thiserror::Error;
use crate::data::DataTable;
#[derive(Debug, Clone)]
pub struct CachedEntry {
pub data: DataTable,
pub fetched_at: SystemTime,
pub ttl: Duration,
pub tags: Vec<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl CachedEntry {
pub fn is_expired(&self) -> bool {
SystemTime::now()
.duration_since(self.fetched_at)
.map(|age| age > self.ttl)
.unwrap_or(true)
}
pub fn age(&self) -> Duration {
SystemTime::now()
.duration_since(self.fetched_at)
.unwrap_or(Duration::ZERO)
}
}
#[derive(Debug, Error, Clone)]
pub enum CacheError {
#[error("cache backend error: {0}")]
Backend(String),
}
#[cfg(not(target_arch = "wasm32"))]
#[async_trait]
pub trait CacheBackend: Send + Sync {
async fn get(&self, key: u64) -> Option<CachedEntry>;
async fn put(&self, key: u64, entry: CachedEntry) -> Result<(), CacheError>;
async fn invalidate(&self, key: u64) -> Result<(), CacheError>;
async fn invalidate_by_tag(&self, tag: &str) -> Result<(), CacheError>;
async fn clear(&self) -> Result<(), CacheError>;
async fn shutdown(&self) {}
}
#[cfg(target_arch = "wasm32")]
#[async_trait(?Send)]
pub trait CacheBackend {
async fn get(&self, key: u64) -> Option<CachedEntry>;
async fn put(&self, key: u64, entry: CachedEntry) -> Result<(), CacheError>;
async fn invalidate(&self, key: u64) -> Result<(), CacheError>;
async fn invalidate_by_tag(&self, tag: &str) -> Result<(), CacheError>;
async fn clear(&self) -> Result<(), CacheError>;
async fn shutdown(&self) {}
}
#[derive(Debug, Default, Clone)]
pub struct MemoryBackend {
inner: Arc<Mutex<HashMap<u64, CachedEntry>>>,
}
impl MemoryBackend {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn len(&self) -> usize {
self.inner
.lock()
.expect("memory cache lock poisoned")
.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl CacheBackend for MemoryBackend {
async fn get(&self, key: u64) -> Option<CachedEntry> {
let guard = self.inner.lock().expect("memory cache lock poisoned");
guard.get(&key).cloned()
}
async fn put(&self, key: u64, entry: CachedEntry) -> Result<(), CacheError> {
let mut guard = self.inner.lock().expect("memory cache lock poisoned");
guard.insert(key, entry);
Ok(())
}
async fn invalidate(&self, key: u64) -> Result<(), CacheError> {
let mut guard = self.inner.lock().expect("memory cache lock poisoned");
guard.remove(&key);
Ok(())
}
async fn invalidate_by_tag(&self, tag: &str) -> Result<(), CacheError> {
let mut guard = self.inner.lock().expect("memory cache lock poisoned");
guard.retain(|_, entry| !entry.tags.iter().any(|t| t == tag));
Ok(())
}
async fn clear(&self) -> Result<(), CacheError> {
let mut guard = self.inner.lock().expect("memory cache lock poisoned");
guard.clear();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::Row;
use serde_json::json;
fn make_entry(tags: Vec<&str>) -> CachedEntry {
let row: Row = [("x".to_string(), json!(1.0))].into_iter().collect();
CachedEntry {
data: DataTable::from_rows(&[row]).unwrap(),
fetched_at: SystemTime::now(),
ttl: Duration::from_secs(60),
tags: tags.into_iter().map(String::from).collect(),
metadata: HashMap::new(),
}
}
#[tokio::test]
async fn memory_backend_get_put_roundtrip() {
let backend = MemoryBackend::new();
backend.put(1, make_entry(vec![])).await.unwrap();
let got = backend.get(1).await;
assert!(got.is_some());
assert_eq!(backend.len(), 1);
}
#[tokio::test]
async fn memory_backend_invalidate_single() {
let backend = MemoryBackend::new();
backend.put(1, make_entry(vec![])).await.unwrap();
backend.put(2, make_entry(vec![])).await.unwrap();
backend.invalidate(1).await.unwrap();
assert!(backend.get(1).await.is_none());
assert!(backend.get(2).await.is_some());
}
#[tokio::test]
async fn memory_backend_invalidate_by_tag() {
let backend = MemoryBackend::new();
backend.put(1, make_entry(vec!["slug:foo"])).await.unwrap();
backend.put(2, make_entry(vec!["slug:foo"])).await.unwrap();
backend.put(3, make_entry(vec!["slug:bar"])).await.unwrap();
backend.invalidate_by_tag("slug:foo").await.unwrap();
assert!(backend.get(1).await.is_none());
assert!(backend.get(2).await.is_none());
assert!(backend.get(3).await.is_some());
}
#[tokio::test]
async fn memory_backend_clear() {
let backend = MemoryBackend::new();
backend.put(1, make_entry(vec![])).await.unwrap();
backend.put(2, make_entry(vec![])).await.unwrap();
backend.clear().await.unwrap();
assert_eq!(backend.len(), 0);
}
#[tokio::test]
async fn cached_entry_expiry() {
let mut entry = make_entry(vec![]);
entry.ttl = Duration::from_millis(0);
std::thread::sleep(Duration::from_millis(2));
assert!(entry.is_expired());
}
}