use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use sha2::{Digest, Sha256};
use rustvello_proto::config::ClientDataStoreConfig;
use crate::error::{RustvelloError, RustvelloResult};
pub const REFERENCE_PREFIX: &str = "__rustvello__cds__:";
#[async_trait]
pub trait ClientDataStore: Send + Sync {
async fn store(&self, key: &str, value: &str) -> RustvelloResult<()>;
async fn retrieve(&self, key: &str) -> RustvelloResult<String>;
async fn purge(&self) -> RustvelloResult<()>;
fn backend_name(&self) -> &'static str {
"Unknown"
}
async fn usage_stats(&self) -> Vec<(&'static str, String)> {
Vec::new()
}
}
pub fn is_reference(value: &str) -> bool {
value.starts_with(REFERENCE_PREFIX)
}
pub fn generate_reference_key(data: &str) -> String {
let hash = Sha256::digest(data.as_bytes());
format!("{REFERENCE_PREFIX}{hash:x}")
}
pub struct ClientDataStoreManager {
backend: Arc<dyn ClientDataStore>,
config: ClientDataStoreConfig,
cache: Mutex<lru::LruCache<String, String>>,
}
impl ClientDataStoreManager {
pub fn new(backend: Arc<dyn ClientDataStore>, config: ClientDataStoreConfig) -> Self {
let cap = NonZeroUsize::new(config.local_cache_size).unwrap_or(NonZeroUsize::MIN);
Self {
backend,
config,
cache: Mutex::new(lru::LruCache::new(cap)),
}
}
pub async fn store_if_large(&self, serialized: &str) -> RustvelloResult<String> {
if self.config.disabled {
return Ok(serialized.to_owned());
}
if is_reference(serialized) {
return Ok(serialized.to_owned());
}
let size = serialized.len();
if size < self.config.min_size_to_cache {
return Ok(serialized.to_owned());
}
if self.config.max_size_to_cache > 0 && size > self.config.max_size_to_cache {
tracing::warn!(
"Value size ({size} bytes) exceeds max_size_to_cache ({}). Returning inline.",
self.config.max_size_to_cache
);
return Ok(serialized.to_owned());
}
if size > self.config.warn_threshold {
tracing::warn!(
"Value size ({size} bytes) exceeds warn_threshold ({}). Consider restructuring.",
self.config.warn_threshold
);
}
let key = generate_reference_key(serialized);
self.backend.store(&key, serialized).await?;
self.cache
.lock()
.map_err(|e| RustvelloError::Internal {
message: format!("CDS cache lock poisoned: {e}"),
})?
.put(key.clone(), serialized.to_owned());
Ok(key)
}
pub async fn resolve(&self, data: &str) -> RustvelloResult<String> {
if !is_reference(data) {
return Ok(data.to_owned());
}
{
let mut cache = self.cache.lock().map_err(|e| RustvelloError::Internal {
message: format!("CDS cache lock poisoned: {e}"),
})?;
if let Some(cached) = cache.get(data) {
return Ok(cached.clone());
}
}
let value = self.backend.retrieve(data).await?;
self.cache
.lock()
.map_err(|e| RustvelloError::Internal {
message: format!("CDS cache lock poisoned: {e}"),
})?
.put(data.to_owned(), value.clone());
Ok(value)
}
pub async fn store(&self, key: &str, value: &str) -> RustvelloResult<()> {
self.backend.store(key, value).await
}
pub async fn retrieve(&self, key: &str) -> RustvelloResult<String> {
self.backend.retrieve(key).await
}
pub async fn purge(&self) -> RustvelloResult<()> {
self.cache
.lock()
.map_err(|e| RustvelloError::Internal {
message: format!("CDS cache lock poisoned: {e}"),
})?
.clear();
self.backend.purge().await
}
pub fn config(&self) -> &ClientDataStoreConfig {
&self.config
}
pub fn backend_name(&self) -> &'static str {
self.backend.backend_name()
}
pub async fn usage_stats(&self) -> Vec<(&'static str, String)> {
self.backend.usage_stats().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::RustvelloError;
use std::collections::HashMap;
struct FakeBackend {
data: Mutex<HashMap<String, String>>,
}
impl FakeBackend {
fn new() -> Self {
Self {
data: Mutex::new(HashMap::new()),
}
}
}
#[async_trait]
impl ClientDataStore for FakeBackend {
async fn store(&self, key: &str, value: &str) -> RustvelloResult<()> {
self.data
.lock()
.unwrap()
.insert(key.to_owned(), value.to_owned());
Ok(())
}
async fn retrieve(&self, key: &str) -> RustvelloResult<String> {
self.data
.lock()
.unwrap()
.get(key)
.cloned()
.ok_or_else(|| RustvelloError::state_backend(format!("key not found: {key}")))
}
async fn purge(&self) -> RustvelloResult<()> {
self.data.lock().unwrap().clear();
Ok(())
}
}
fn make_manager(config: ClientDataStoreConfig) -> ClientDataStoreManager {
ClientDataStoreManager::new(Arc::new(FakeBackend::new()), config)
}
#[test]
fn reference_key_format() {
let key = generate_reference_key("hello world");
assert!(key.starts_with(REFERENCE_PREFIX));
assert!(is_reference(&key));
assert!(!is_reference("just a normal string"));
}
#[test]
fn reference_key_deterministic() {
let k1 = generate_reference_key("same content");
let k2 = generate_reference_key("same content");
assert_eq!(k1, k2);
}
#[test]
fn reference_key_differs_for_different_content() {
let k1 = generate_reference_key("content A");
let k2 = generate_reference_key("content B");
assert_ne!(k1, k2);
}
#[tokio::test]
async fn inline_below_threshold() {
let mut cds_config = ClientDataStoreConfig::default();
cds_config.min_size_to_cache = 100;
let mgr = make_manager(cds_config);
let small = "short";
let result = mgr.store_if_large(small).await.unwrap();
assert_eq!(result, small);
assert!(!is_reference(&result));
}
#[tokio::test]
async fn externalize_above_threshold() {
let mut cds_config = ClientDataStoreConfig::default();
cds_config.min_size_to_cache = 10;
let mgr = make_manager(cds_config);
let large = "a]".repeat(20); let result = mgr.store_if_large(&large).await.unwrap();
assert!(is_reference(&result));
let resolved = mgr.resolve(&result).await.unwrap();
assert_eq!(resolved, large);
}
#[tokio::test]
async fn inline_when_disabled() {
let mut cds_config = ClientDataStoreConfig::default();
cds_config.disabled = true;
cds_config.min_size_to_cache = 1; let mgr = make_manager(cds_config);
let data = "x".repeat(100);
let result = mgr.store_if_large(&data).await.unwrap();
assert!(!is_reference(&result));
assert_eq!(result, data);
}
#[tokio::test]
async fn inline_above_max_size() {
let mut cds_config = ClientDataStoreConfig::default();
cds_config.min_size_to_cache = 10;
cds_config.max_size_to_cache = 50;
let mgr = make_manager(cds_config);
let huge = "x".repeat(100); let result = mgr.store_if_large(&huge).await.unwrap();
assert!(!is_reference(&result));
assert_eq!(result, huge);
}
#[tokio::test]
async fn resolve_inline_passthrough() {
let mgr = make_manager(ClientDataStoreConfig::default());
let data = "not a reference";
let result = mgr.resolve(data).await.unwrap();
assert_eq!(result, data);
}
#[tokio::test]
async fn content_hash_dedup() {
let mut cds_config = ClientDataStoreConfig::default();
cds_config.min_size_to_cache = 5;
let mgr = make_manager(cds_config);
let data = "deduplicate me";
let k1 = mgr.store_if_large(data).await.unwrap();
let k2 = mgr.store_if_large(data).await.unwrap();
assert_eq!(k1, k2); }
#[tokio::test]
async fn purge_clears_cache_and_backend() {
let mut cds_config = ClientDataStoreConfig::default();
cds_config.min_size_to_cache = 5;
let mgr = make_manager(cds_config);
let data = "some large payload here";
let key = mgr.store_if_large(data).await.unwrap();
assert!(is_reference(&key));
mgr.purge().await.unwrap();
let err = mgr.resolve(&key).await;
assert!(err.is_err());
}
#[tokio::test]
async fn lru_cache_hit() {
let mut cds_config = ClientDataStoreConfig::default();
cds_config.min_size_to_cache = 5;
cds_config.local_cache_size = 10;
let mgr = make_manager(cds_config);
let data = "cached payload data";
let key = mgr.store_if_large(data).await.unwrap();
assert!(is_reference(&key));
let r1 = mgr.resolve(&key).await.unwrap();
assert_eq!(r1, data);
mgr.backend.purge().await.unwrap();
let r2 = mgr.resolve(&key).await.unwrap();
assert_eq!(r2, data);
}
#[tokio::test]
async fn passthrough_existing_reference() {
let mut cds_config = ClientDataStoreConfig::default();
cds_config.min_size_to_cache = 5;
let mgr = make_manager(cds_config);
let ref_key = format!("{REFERENCE_PREFIX}abc123");
let result = mgr.store_if_large(&ref_key).await.unwrap();
assert_eq!(result, ref_key);
}
#[test]
fn lru_cache_eviction() {
let cap = NonZeroUsize::new(3).unwrap();
let mut cache = lru::LruCache::<String, String>::new(cap);
cache.put("a".into(), "1".into());
cache.put("b".into(), "2".into());
cache.put("c".into(), "3".into());
assert_eq!(cache.len(), 3);
cache.put("d".into(), "4".into());
assert_eq!(cache.len(), 3);
assert!(cache.get("a").is_none());
assert!(cache.get("d").is_some());
}
#[test]
fn lru_cache_access_refreshes() {
let cap = NonZeroUsize::new(3).unwrap();
let mut cache = lru::LruCache::<String, String>::new(cap);
cache.put("a".into(), "1".into());
cache.put("b".into(), "2".into());
cache.put("c".into(), "3".into());
cache.get("a");
cache.put("d".into(), "4".into());
assert!(cache.get("a").is_some());
assert!(cache.get("b").is_none());
}
}