use std::sync::Mutex;
use std::time::{Duration, Instant};
use oxibonsai_rag::embedding::{Embedder, TfIdfEmbedder};
use oxibonsai_rag::vector_store::cosine_similarity;
#[derive(Debug, Clone)]
pub struct SemanticCacheConfig {
pub similarity_threshold: f32,
pub max_entries: usize,
pub ttl: Duration,
pub cache_streaming: bool,
pub min_prompt_chars: usize,
}
impl Default for SemanticCacheConfig {
fn default() -> Self {
Self {
similarity_threshold: 0.92,
max_entries: 1000,
ttl: Duration::from_secs(3600),
cache_streaming: false,
min_prompt_chars: 20,
}
}
}
#[derive(Debug, Clone)]
pub struct CachedResponse {
pub response: String,
pub prompt: String,
pub similarity: f32,
pub created_at: Instant,
pub hit_count: u64,
}
impl CachedResponse {
pub fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
}
struct CacheEntry {
prompt: String,
response: String,
vector: Vec<f32>,
created_at: Instant,
last_accessed: u64,
hit_count: u64,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct SemanticCacheStats {
pub total_requests: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub hit_rate: f32,
pub entries: usize,
pub evictions: u64,
pub expired_evictions: u64,
pub avg_similarity_on_hit: f32,
}
impl Default for SemanticCacheStats {
fn default() -> Self {
Self {
total_requests: 0,
cache_hits: 0,
cache_misses: 0,
hit_rate: 0.0,
entries: 0,
evictions: 0,
expired_evictions: 0,
avg_similarity_on_hit: 0.0,
}
}
}
pub struct SemanticCache {
config: SemanticCacheConfig,
entries: Mutex<Vec<CacheEntry>>,
embedder: Mutex<TfIdfEmbedder>,
stats: Mutex<SemanticCacheStats>,
all_prompts: Mutex<Vec<String>>,
access_clock: Mutex<u64>,
similarity_sum: Mutex<f64>,
}
const BOOTSTRAP_DIM: usize = 64;
const REFIT_BATCH_SIZE: usize = 16;
impl SemanticCache {
pub fn new(config: SemanticCacheConfig) -> Self {
let bootstrap_docs = [
"hello world query prompt response cache",
"semantic similarity cosine embedding language model",
"retrieval augmented generation inference rust",
];
let embedder = TfIdfEmbedder::fit(&bootstrap_docs, BOOTSTRAP_DIM);
Self {
config,
entries: Mutex::new(Vec::new()),
embedder: Mutex::new(embedder),
stats: Mutex::new(SemanticCacheStats::default()),
all_prompts: Mutex::new(Vec::new()),
access_clock: Mutex::new(0),
similarity_sum: Mutex::new(0.0),
}
}
pub fn lookup(&self, prompt: &str) -> Option<CachedResponse> {
if !self.is_cacheable(prompt) {
let mut stats = self.stats.lock().expect("stats lock poisoned");
stats.total_requests += 1;
stats.cache_misses += 1;
self.update_hit_rate(&mut stats);
return None;
}
let query_vec = {
let embedder = self.embedder.lock().expect("embedder lock poisoned");
match embedder.embed(prompt) {
Ok(v) => v,
Err(_) => {
let mut stats = self.stats.lock().expect("stats lock poisoned");
stats.total_requests += 1;
stats.cache_misses += 1;
self.update_hit_rate(&mut stats);
return None;
}
}
};
let mut entries = self.entries.lock().expect("entries lock poisoned");
let ttl = self.config.ttl;
let threshold = self.config.similarity_threshold;
let mut best_score = f32::NEG_INFINITY;
let mut best_idx: Option<usize> = None;
for (idx, entry) in entries.iter().enumerate() {
if entry.created_at.elapsed() > ttl {
continue; }
if entry.vector.len() != query_vec.len() {
continue; }
let score = cosine_similarity(&query_vec, &entry.vector);
if score >= threshold && score > best_score {
best_score = score;
best_idx = Some(idx);
}
}
let mut stats = self.stats.lock().expect("stats lock poisoned");
stats.total_requests += 1;
match best_idx {
Some(idx) => {
let clock = {
let mut c = self.access_clock.lock().expect("clock lock poisoned");
*c += 1;
*c
};
let entry = &mut entries[idx];
entry.hit_count += 1;
entry.last_accessed = clock;
let response = CachedResponse {
response: entry.response.clone(),
prompt: entry.prompt.clone(),
similarity: best_score,
created_at: entry.created_at,
hit_count: entry.hit_count,
};
stats.cache_hits += 1;
self.update_hit_rate(&mut stats);
{
let mut sim_sum = self
.similarity_sum
.lock()
.expect("similarity_sum lock poisoned");
*sim_sum += best_score as f64;
stats.avg_similarity_on_hit = (*sim_sum / stats.cache_hits as f64) as f32;
}
Some(response)
}
None => {
stats.cache_misses += 1;
self.update_hit_rate(&mut stats);
None
}
}
}
pub fn insert(&self, prompt: &str, response: &str) {
if !self.is_cacheable(prompt) {
return;
}
{
let mut all_prompts = self.all_prompts.lock().expect("all_prompts lock poisoned");
all_prompts.push(prompt.to_string());
let should_refit = all_prompts.len() == 1 || all_prompts.len() % REFIT_BATCH_SIZE == 0;
drop(all_prompts);
if should_refit {
self.refit_embedder();
}
}
let vector = {
let embedder = self.embedder.lock().expect("embedder lock poisoned");
match embedder.embed(prompt) {
Ok(v) => v,
Err(_) => return, }
};
let clock = {
let mut c = self.access_clock.lock().expect("clock lock poisoned");
*c += 1;
*c
};
let mut entries = self.entries.lock().expect("entries lock poisoned");
if entries.len() >= self.config.max_entries {
let lru_idx = entries
.iter()
.enumerate()
.min_by_key(|(_, e)| e.last_accessed)
.map(|(i, _)| i)
.expect("entries is non-empty");
entries.swap_remove(lru_idx);
let mut stats = self.stats.lock().expect("stats lock poisoned");
stats.evictions += 1;
}
entries.push(CacheEntry {
prompt: prompt.to_string(),
response: response.to_string(),
vector,
created_at: Instant::now(),
last_accessed: clock,
hit_count: 0,
});
let mut stats = self.stats.lock().expect("stats lock poisoned");
stats.entries = entries.len();
}
pub fn evict_expired(&self) -> usize {
let ttl = self.config.ttl;
let mut entries = self.entries.lock().expect("entries lock poisoned");
let before = entries.len();
entries.retain(|e| e.created_at.elapsed() <= ttl);
let removed = before - entries.len();
let mut stats = self.stats.lock().expect("stats lock poisoned");
stats.expired_evictions += removed as u64;
stats.entries = entries.len();
removed
}
pub fn clear(&self) {
self.entries.lock().expect("entries lock poisoned").clear();
self.all_prompts
.lock()
.expect("all_prompts lock poisoned")
.clear();
*self
.similarity_sum
.lock()
.expect("similarity_sum lock poisoned") = 0.0;
*self.stats.lock().expect("stats lock poisoned") = SemanticCacheStats::default();
}
pub fn len(&self) -> usize {
self.entries.lock().expect("entries lock poisoned").len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn stats(&self) -> SemanticCacheStats {
self.stats.lock().expect("stats lock poisoned").clone()
}
fn is_cacheable(&self, prompt: &str) -> bool {
prompt.len() >= self.config.min_prompt_chars
}
fn refit_embedder(&self) {
let all_prompts = self.all_prompts.lock().expect("all_prompts lock poisoned");
if all_prompts.is_empty() {
return;
}
let max_features = BOOTSTRAP_DIM.max(all_prompts.len() * 4).min(4096);
let doc_refs: Vec<&str> = all_prompts.iter().map(|s| s.as_str()).collect();
let new_embedder = TfIdfEmbedder::fit(&doc_refs, max_features);
drop(all_prompts);
let mut embedder = self.embedder.lock().expect("embedder lock poisoned");
*embedder = new_embedder;
}
fn update_hit_rate(&self, stats: &mut SemanticCacheStats) {
stats.hit_rate = if stats.total_requests == 0 {
0.0
} else {
stats.cache_hits as f32 / stats.total_requests as f32
};
}
}
pub struct CachedInference {
pub cache: SemanticCache,
}
impl CachedInference {
pub fn new(config: SemanticCacheConfig) -> Self {
Self {
cache: SemanticCache::new(config),
}
}
pub fn run_or_cache<F>(&self, prompt: &str, run_inference: F) -> (String, bool)
where
F: FnOnce() -> String,
{
if let Some(cached) = self.cache.lookup(prompt) {
return (cached.response, true);
}
let response = run_inference();
self.cache.insert(prompt, &response);
(response, false)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn short_ttl_config() -> SemanticCacheConfig {
SemanticCacheConfig {
ttl: Duration::from_millis(50),
..Default::default()
}
}
fn low_threshold_config() -> SemanticCacheConfig {
SemanticCacheConfig {
similarity_threshold: 0.1,
..Default::default()
}
}
#[test]
fn test_semantic_cache_miss_on_empty() {
let cache = SemanticCache::new(SemanticCacheConfig::default());
assert!(cache.lookup("What is the meaning of life?").is_none());
}
#[test]
fn test_semantic_cache_exact_match() {
let cache = SemanticCache::new(low_threshold_config());
let prompt = "What is the capital of France and why is it important?";
cache.insert(prompt, "Paris is the capital of France.");
let result = cache.lookup(prompt);
assert!(result.is_some(), "exact prompt should hit the cache");
let cached = result.expect("just asserted Some");
assert_eq!(cached.response, "Paris is the capital of France.");
assert!(cached.similarity > 0.9, "similarity={}", cached.similarity);
}
#[test]
fn test_semantic_cache_insert_and_lookup() {
let config = SemanticCacheConfig {
similarity_threshold: 0.5,
..Default::default()
};
let cache = SemanticCache::new(config);
let prompt = "Explain the concept of machine learning in detail";
cache.insert(prompt, "Machine learning is a branch of AI.");
assert_eq!(cache.len(), 1);
let hit = cache.lookup(prompt);
assert!(hit.is_some());
}
#[test]
fn test_semantic_cache_ttl_expiry() {
let config = short_ttl_config();
let cache = SemanticCache::new(config);
let prompt = "Tell me everything about neural networks and deep learning";
cache.insert(prompt, "Neural networks are computational graphs.");
assert!(
cache.lookup(prompt).is_some(),
"should hit before TTL expires"
);
std::thread::sleep(Duration::from_millis(100));
assert!(
cache.lookup(prompt).is_none(),
"should miss after TTL expires"
);
}
#[test]
fn test_semantic_cache_min_prompt_length() {
let cache = SemanticCache::new(SemanticCacheConfig::default());
let short = "Hi";
cache.insert(short, "Hello!");
assert_eq!(cache.len(), 0, "short prompt should not be cached");
assert!(cache.lookup(short).is_none());
}
#[test]
fn test_semantic_cache_evict_expired() {
let config = short_ttl_config();
let cache = SemanticCache::new(config);
for i in 0..5 {
let prompt = format!(
"This is a sufficiently long prompt number {} for caching purposes",
i
);
cache.insert(&prompt, "response");
}
assert_eq!(cache.len(), 5);
std::thread::sleep(Duration::from_millis(100));
let removed = cache.evict_expired();
assert_eq!(removed, 5, "all entries should have expired");
assert_eq!(cache.len(), 0);
let stats = cache.stats();
assert_eq!(stats.expired_evictions, 5);
}
#[test]
fn test_semantic_cache_stats_hit_rate() {
let config = low_threshold_config();
let cache = SemanticCache::new(config);
let prompt = "Describe the architecture of transformer neural networks in depth";
cache.insert(prompt, "Transformers use attention mechanisms.");
let _ = cache.lookup(prompt);
let _ = cache.lookup("Completely unrelated gibberish zzzzzzzz that matches nothing");
let stats = cache.stats();
assert_eq!(stats.cache_hits, 1);
assert_eq!(stats.cache_misses, 1);
assert_eq!(stats.total_requests, 2);
assert!(
(stats.hit_rate - 0.5).abs() < 1e-5,
"hit_rate={}",
stats.hit_rate
);
}
#[test]
fn test_semantic_cache_clear() {
let config = low_threshold_config();
let cache = SemanticCache::new(config);
for i in 0..10 {
let prompt = format!(
"This is prompt number {} that is long enough to be cached by the system",
i
);
cache.insert(&prompt, "some response");
}
assert!(!cache.is_empty());
cache.clear();
assert!(cache.is_empty());
assert_eq!(cache.stats().total_requests, 0);
}
#[test]
fn test_cached_inference_returns_cached() {
let config = low_threshold_config();
let ci = CachedInference::new(config);
let prompt = "What is Rust and why is it used for systems programming?";
let (r1, hit1) = ci.run_or_cache(prompt, || "Rust is a systems language.".to_string());
assert!(!hit1, "first call must be a miss");
assert_eq!(r1, "Rust is a systems language.");
let (r2, hit2) = ci.run_or_cache(prompt, || panic!("should not be called"));
assert!(hit2, "second identical call must be a hit");
assert_eq!(r2, "Rust is a systems language.");
}
#[test]
fn test_cached_inference_calls_fn_on_miss() {
let ci = CachedInference::new(SemanticCacheConfig::default());
let mut called = false;
let (resp, hit) = ci.run_or_cache(
"Explain quantum entanglement in detail for a physics student",
|| {
called = true;
"Quantum entanglement is a phenomenon…".to_string()
},
);
assert!(!hit);
assert!(called);
assert!(!resp.is_empty());
}
#[test]
fn test_cache_config_defaults() {
let cfg = SemanticCacheConfig::default();
assert!((cfg.similarity_threshold - 0.92).abs() < 1e-6);
assert_eq!(cfg.max_entries, 1000);
assert_eq!(cfg.ttl, Duration::from_secs(3600));
assert!(!cfg.cache_streaming);
assert_eq!(cfg.min_prompt_chars, 20);
}
#[test]
fn test_cached_response_is_expired() {
let resp = CachedResponse {
response: "answer".to_string(),
prompt: "question".to_string(),
similarity: 0.95,
created_at: Instant::now(),
hit_count: 1,
};
assert!(!resp.is_expired(Duration::from_secs(60)));
std::thread::sleep(Duration::from_millis(1));
assert!(resp.is_expired(Duration::ZERO));
}
}