use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, info};
use scirs2_core::ndarray_ext::Array1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheEntry {
pub query: String,
#[serde(skip)]
pub embedding: Option<Array1<f32>>,
pub result: CachedResult,
#[serde(skip)]
#[serde(default = "Instant::now")]
pub timestamp: Instant,
pub hit_count: usize,
#[serde(skip)]
#[serde(default = "Instant::now")]
pub last_access: Instant,
#[serde(skip)]
#[serde(default = "default_ttl")]
pub ttl: Duration,
}
fn default_ttl() -> Duration {
Duration::from_secs(3600)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedResult {
pub data: String,
pub metadata: HashMap<String, String>,
pub computation_time_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticCacheConfig {
pub similarity_threshold: f32,
pub max_size: usize,
pub default_ttl: Duration,
pub use_embeddings: bool,
pub eviction_strategy: EvictionStrategy,
pub track_statistics: bool,
}
impl Default for SemanticCacheConfig {
fn default() -> Self {
Self {
similarity_threshold: 0.85,
max_size: 1000,
default_ttl: Duration::from_secs(3600), use_embeddings: true,
eviction_strategy: EvictionStrategy::LRU,
track_statistics: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum EvictionStrategy {
LRU,
LFU,
TTL,
FIFO,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CacheStatistics {
pub hits: usize,
pub misses: usize,
pub total_queries: usize,
pub hit_rate: f64,
pub avg_similarity: f32,
pub time_saved: Duration,
pub current_size: usize,
pub evictions: usize,
}
pub struct SemanticCache {
config: SemanticCacheConfig,
entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
statistics: Arc<RwLock<CacheStatistics>>,
}
impl SemanticCache {
pub fn new(config: SemanticCacheConfig) -> Self {
info!(
"Initialized semantic cache with similarity threshold: {}",
config.similarity_threshold
);
Self {
config,
entries: Arc::new(RwLock::new(HashMap::new())),
statistics: Arc::new(RwLock::new(CacheStatistics::default())),
}
}
pub async fn get(&self, query: &str) -> Result<Option<CachedResult>> {
debug!("Checking semantic cache for query: {}", query);
let entries = self.entries.read().await;
if let Some(entry) = entries.get(query) {
if !self.is_expired(entry) {
self.record_hit(entry).await;
return Ok(Some(entry.result.clone()));
}
}
if self.config.use_embeddings {
let query_embedding = self.compute_embedding(query)?;
for entry in entries.values() {
if self.is_expired(entry) {
continue;
}
if let Some(ref entry_embedding) = entry.embedding {
let similarity = self.compute_similarity(&query_embedding, entry_embedding)?;
if similarity >= self.config.similarity_threshold {
debug!("Semantic cache hit with similarity: {:.3}", similarity);
self.record_hit_with_similarity(entry, similarity).await;
return Ok(Some(entry.result.clone()));
}
}
}
}
self.record_miss().await;
Ok(None)
}
pub async fn put(
&self,
query: String,
result: CachedResult,
embedding: Option<Array1<f32>>,
) -> Result<()> {
let mut entries = self.entries.write().await;
if entries.len() >= self.config.max_size {
self.evict(&mut entries).await?;
}
let entry = CacheEntry {
query: query.clone(),
embedding,
result,
timestamp: Instant::now(),
hit_count: 0,
last_access: Instant::now(),
ttl: self.config.default_ttl,
};
entries.insert(query, entry);
self.update_size(entries.len()).await;
Ok(())
}
pub async fn invalidate(&self, query: &str) -> Result<bool> {
let mut entries = self.entries.write().await;
let removed = entries.remove(query).is_some();
self.update_size(entries.len()).await;
Ok(removed)
}
pub async fn clear(&self) -> Result<()> {
let mut entries = self.entries.write().await;
entries.clear();
self.update_size(0).await;
Ok(())
}
pub async fn cleanup_expired(&self) -> Result<usize> {
let mut entries = self.entries.write().await;
let initial_size = entries.len();
entries.retain(|_, entry| !self.is_expired(entry));
let removed = initial_size - entries.len();
self.update_size(entries.len()).await;
if removed > 0 {
info!("Cleaned up {} expired cache entries", removed);
}
Ok(removed)
}
pub async fn statistics(&self) -> CacheStatistics {
self.statistics.read().await.clone()
}
pub async fn reset_statistics(&self) -> Result<()> {
let mut stats = self.statistics.write().await;
*stats = CacheStatistics::default();
Ok(())
}
fn is_expired(&self, entry: &CacheEntry) -> bool {
entry.timestamp.elapsed() > entry.ttl
}
async fn record_hit(&self, entry: &CacheEntry) {
if !self.config.track_statistics {
return;
}
let mut stats = self.statistics.write().await;
stats.hits += 1;
stats.total_queries += 1;
stats.hit_rate = stats.hits as f64 / stats.total_queries as f64;
stats.time_saved += Duration::from_millis(entry.result.computation_time_ms);
}
async fn record_hit_with_similarity(&self, entry: &CacheEntry, similarity: f32) {
if !self.config.track_statistics {
return;
}
let mut stats = self.statistics.write().await;
stats.hits += 1;
stats.total_queries += 1;
stats.hit_rate = stats.hits as f64 / stats.total_queries as f64;
stats.time_saved += Duration::from_millis(entry.result.computation_time_ms);
let total_similarity = stats.avg_similarity * (stats.hits - 1) as f32 + similarity;
stats.avg_similarity = total_similarity / stats.hits as f32;
}
async fn record_miss(&self) {
if !self.config.track_statistics {
return;
}
let mut stats = self.statistics.write().await;
stats.misses += 1;
stats.total_queries += 1;
stats.hit_rate = stats.hits as f64 / stats.total_queries as f64;
}
async fn update_size(&self, size: usize) {
if !self.config.track_statistics {
return;
}
let mut stats = self.statistics.write().await;
stats.current_size = size;
}
fn compute_embedding(&self, _query: &str) -> Result<Array1<f32>> {
Ok(Array1::zeros(384))
}
fn compute_similarity(&self, emb1: &Array1<f32>, emb2: &Array1<f32>) -> Result<f32> {
let dot_product: f32 = emb1.iter().zip(emb2.iter()).map(|(a, b)| a * b).sum();
let norm1: f32 = emb1.iter().map(|a| a * a).sum::<f32>().sqrt();
let norm2: f32 = emb2.iter().map(|a| a * a).sum::<f32>().sqrt();
if norm1 == 0.0 || norm2 == 0.0 {
return Ok(0.0);
}
Ok(dot_product / (norm1 * norm2))
}
async fn evict(&self, entries: &mut HashMap<String, CacheEntry>) -> Result<()> {
if entries.is_empty() {
return Ok(());
}
let key_to_remove = match self.config.eviction_strategy {
EvictionStrategy::LRU => {
entries
.iter()
.min_by_key(|(_, entry)| entry.last_access)
.map(|(key, _)| key.clone())
}
EvictionStrategy::LFU => {
entries
.iter()
.min_by_key(|(_, entry)| entry.hit_count)
.map(|(key, _)| key.clone())
}
EvictionStrategy::TTL => {
entries
.iter()
.min_by_key(|(_, entry)| entry.timestamp)
.map(|(key, _)| key.clone())
}
EvictionStrategy::FIFO => {
entries
.iter()
.min_by_key(|(_, entry)| entry.timestamp)
.map(|(key, _)| key.clone())
}
};
if let Some(key) = key_to_remove {
entries.remove(&key);
if self.config.track_statistics {
let mut stats = self.statistics.write().await;
stats.evictions += 1;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_exact_match_cache() {
let cache = SemanticCache::new(SemanticCacheConfig::default());
let result = CachedResult {
data: "test data".to_string(),
metadata: HashMap::new(),
computation_time_ms: 100,
};
cache
.put("test query".to_string(), result.clone(), None)
.await
.expect("should succeed");
let cached = cache.get("test query").await.expect("should succeed");
assert!(cached.is_some());
assert_eq!(cached.expect("should succeed").data, "test data");
}
#[tokio::test]
async fn test_cache_miss() {
let cache = SemanticCache::new(SemanticCacheConfig::default());
let cached = cache
.get("nonexistent query")
.await
.expect("should succeed");
assert!(cached.is_none());
}
#[tokio::test]
async fn test_cache_invalidation() {
let cache = SemanticCache::new(SemanticCacheConfig::default());
let result = CachedResult {
data: "test data".to_string(),
metadata: HashMap::new(),
computation_time_ms: 100,
};
cache
.put("test query".to_string(), result, None)
.await
.expect("should succeed");
assert!(cache
.get("test query")
.await
.expect("should succeed")
.is_some());
cache
.invalidate("test query")
.await
.expect("should succeed");
assert!(cache
.get("test query")
.await
.expect("should succeed")
.is_none());
}
#[tokio::test]
async fn test_cache_statistics() {
let cache = SemanticCache::new(SemanticCacheConfig::default());
let result = CachedResult {
data: "test data".to_string(),
metadata: HashMap::new(),
computation_time_ms: 100,
};
cache
.put("test query".to_string(), result, None)
.await
.expect("should succeed");
cache.get("test query").await.expect("should succeed"); cache.get("other query").await.expect("should succeed");
let stats = cache.statistics().await;
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.total_queries, 2);
assert_eq!(stats.hit_rate, 0.5);
}
#[tokio::test]
async fn test_cache_eviction() {
let config = SemanticCacheConfig {
max_size: 2,
..Default::default()
};
let cache = SemanticCache::new(config);
let result = CachedResult {
data: "test".to_string(),
metadata: HashMap::new(),
computation_time_ms: 100,
};
cache
.put("query1".to_string(), result.clone(), None)
.await
.expect("should succeed");
cache
.put("query2".to_string(), result.clone(), None)
.await
.expect("should succeed");
cache
.put("query3".to_string(), result.clone(), None)
.await
.expect("should succeed");
let stats = cache.statistics().await;
assert_eq!(stats.current_size, 2);
assert_eq!(stats.evictions, 1);
}
}