use dashmap::DashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticCacheConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_threshold")]
pub threshold: f32,
#[serde(default)]
pub embedding_model: String,
#[serde(default = "default_max_search")]
pub max_search: usize,
}
fn default_threshold() -> f32 {
0.92
}
fn default_max_search() -> usize {
100
}
impl Default for SemanticCacheConfig {
fn default() -> Self {
Self {
enabled: false,
threshold: default_threshold(),
embedding_model: String::new(),
max_search: default_max_search(),
}
}
}
struct EmbeddingEntry {
cache_key: String,
embedding: Vec<f32>,
}
pub struct SemanticCache {
entries: DashMap<String, EmbeddingEntry>,
config: SemanticCacheConfig,
}
impl SemanticCache {
#[must_use]
pub fn new(config: SemanticCacheConfig) -> Self {
Self {
entries: DashMap::new(),
config,
}
}
#[must_use]
pub fn find_similar(&self, query_embedding: &[f32]) -> Option<(String, f32)> {
if !self.config.enabled || self.entries.is_empty() {
return None;
}
let mut best_key: Option<String> = None;
let mut best_score: f32 = self.config.threshold;
for (searched, entry) in self.entries.iter().enumerate() {
if self.config.max_search > 0 && searched >= self.config.max_search {
break;
}
let score = cosine_similarity(query_embedding, &entry.embedding);
if score > best_score {
best_score = score;
best_key = Some(entry.cache_key.clone());
}
}
best_key.map(|k| (k, best_score))
}
pub fn insert(&self, cache_key: String, embedding: Vec<f32>) {
if !self.config.enabled {
return;
}
self.entries.insert(
cache_key.clone(),
EmbeddingEntry {
cache_key,
embedding,
},
);
}
pub fn remove(&self, cache_key: &str) {
self.entries.remove(cache_key);
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
#[must_use]
#[inline]
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
#[must_use]
pub fn embedding_model(&self) -> &str {
&self.config.embedding_model
}
}
#[must_use]
#[inline]
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom < f32::EPSILON {
return 0.0;
}
dot / denom
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_identical_vectors() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 0.001);
}
#[test]
fn cosine_orthogonal_vectors() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 0.001);
}
#[test]
fn cosine_opposite_vectors() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 0.001);
}
#[test]
fn cosine_different_lengths() {
let a = vec![1.0, 2.0];
let b = vec![1.0];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn cosine_empty() {
assert_eq!(cosine_similarity(&[], &[]), 0.0);
}
#[test]
fn semantic_cache_disabled() {
let cache = SemanticCache::new(SemanticCacheConfig::default());
assert!(!cache.is_enabled());
assert!(cache.find_similar(&[1.0, 2.0]).is_none());
}
#[test]
fn semantic_cache_insert_and_find() {
let cache = SemanticCache::new(SemanticCacheConfig {
enabled: true,
threshold: 0.9,
max_search: 100,
..Default::default()
});
cache.insert("key1".into(), vec![1.0, 0.0, 0.0]);
cache.insert("key2".into(), vec![0.0, 1.0, 0.0]);
let result = cache.find_similar(&[0.99, 0.01, 0.0]);
assert!(result.is_some());
let (key, score) = result.unwrap();
assert_eq!(key, "key1");
assert!(score > 0.9);
}
#[test]
fn semantic_cache_below_threshold() {
let cache = SemanticCache::new(SemanticCacheConfig {
enabled: true,
threshold: 0.99, max_search: 100,
..Default::default()
});
cache.insert("key1".into(), vec![1.0, 0.0, 0.0]);
let result = cache.find_similar(&[0.7, 0.7, 0.0]);
assert!(result.is_none());
}
#[test]
fn semantic_cache_remove() {
let cache = SemanticCache::new(SemanticCacheConfig {
enabled: true,
..Default::default()
});
cache.insert("key1".into(), vec![1.0]);
assert_eq!(cache.len(), 1);
cache.remove("key1");
assert!(cache.is_empty());
}
#[test]
fn semantic_cache_max_search_limit() {
let cache = SemanticCache::new(SemanticCacheConfig {
enabled: true,
threshold: 0.5,
max_search: 2, ..Default::default()
});
for i in 0..5 {
let mut v = vec![0.0; 3];
v[i % 3] = 1.0;
cache.insert(format!("key{i}"), v);
}
let _ = cache.find_similar(&[1.0, 0.0, 0.0]);
assert_eq!(cache.len(), 5);
}
}