use async_trait::async_trait;
use super::types::{CacheKey, CacheStats, ContextFingerprint, KVCacheEntry};
use crate::error::OxiRagError;
#[async_trait]
pub trait PrefixCacheStore: Send + Sync {
async fn get(&self, fingerprint: &ContextFingerprint) -> Option<KVCacheEntry>;
async fn put(&mut self, entry: KVCacheEntry) -> Result<CacheKey, OxiRagError>;
async fn remove(&mut self, key: &CacheKey) -> Option<KVCacheEntry>;
async fn contains(&self, fingerprint: &ContextFingerprint) -> bool;
async fn clear(&mut self);
fn stats(&self) -> CacheStats;
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
async fn find_prefix_match(&self, fingerprint: &ContextFingerprint) -> Option<KVCacheEntry> {
let _ = fingerprint;
None
}
async fn evict_expired(&mut self) -> usize;
fn memory_usage(&self) -> usize;
}
#[async_trait]
pub trait PrefixCacheExt: PrefixCacheStore {
async fn get_or_compute<F>(
&mut self,
fingerprint: &ContextFingerprint,
compute: F,
) -> Result<KVCacheEntry, OxiRagError>
where
F: FnOnce() -> Result<KVCacheEntry, OxiRagError> + Send;
}
#[async_trait]
impl<T: PrefixCacheStore + Send> PrefixCacheExt for T {
async fn get_or_compute<F>(
&mut self,
fingerprint: &ContextFingerprint,
compute: F,
) -> Result<KVCacheEntry, OxiRagError>
where
F: FnOnce() -> Result<KVCacheEntry, OxiRagError> + Send,
{
if let Some(entry) = self.get(fingerprint).await {
return Ok(entry);
}
let entry = compute()?;
self.put(entry.clone()).await?;
Ok(entry)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prefix_cache::InMemoryPrefixCache;
use crate::prefix_cache::types::PrefixCacheConfig;
#[tokio::test]
async fn test_prefix_cache_ext_get_or_compute() {
let mut cache = InMemoryPrefixCache::new(PrefixCacheConfig::default());
let fingerprint = ContextFingerprint::new(12345, 100, "test");
let mut computed = false;
let entry = cache
.get_or_compute(&fingerprint, || {
computed = true;
Ok(KVCacheEntry::new(
"key1",
fingerprint.clone(),
vec![1.0, 2.0],
100,
))
})
.await
.unwrap();
assert!(computed);
assert_eq!(entry.fingerprint, fingerprint);
let mut computed2 = false;
let entry2 = cache
.get_or_compute(&fingerprint, || {
computed2 = true;
Ok(KVCacheEntry::new(
"key2",
fingerprint.clone(),
vec![3.0, 4.0],
100,
))
})
.await
.unwrap();
assert!(!computed2);
assert_eq!(entry2.kv_data, vec![1.0, 2.0]); }
}