#[cfg(feature = "onnx")]
use lru::LruCache;
#[cfg(feature = "onnx")]
use std::num::NonZeroUsize;
#[cfg(feature = "onnx")]
use std::sync::Mutex;
use crate::{Result, RagError};
use super::OnnxEmbedder;
#[cfg(feature = "onnx")]
#[derive(Debug)]
pub struct CachedEmbedder {
embedder: Mutex<OnnxEmbedder>,
cache: Mutex<LruCache<String, Vec<f32>>>,
hits: Mutex<usize>,
misses: Mutex<usize>,
}
#[cfg(feature = "onnx")]
impl CachedEmbedder {
pub fn new(embedder: OnnxEmbedder, max_cache_size: usize) -> Self {
let cache_size = NonZeroUsize::new(max_cache_size)
.unwrap_or_else(|| NonZeroUsize::new(1).unwrap());
Self {
embedder: Mutex::new(embedder),
cache: Mutex::new(LruCache::new(cache_size)),
hits: Mutex::new(0),
misses: Mutex::new(0),
}
}
pub fn with_default_size(embedder: OnnxEmbedder) -> Self {
Self::new(embedder, 10_000)
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
{
let mut cache = self.cache.lock()
.map_err(|e| RagError::EmbeddingError(format!("Cache lock poisoned: {}", e)))?;
if let Some(embedding) = cache.get(text) {
let mut hits = self.hits.lock()
.map_err(|e| RagError::EmbeddingError(format!("Hits lock poisoned: {}", e)))?;
*hits += 1;
return Ok(embedding.clone());
}
}
let embedding = {
let mut embedder = self.embedder.lock()
.map_err(|e| RagError::EmbeddingError(format!("Embedder lock poisoned: {}", e)))?;
embedder.embed(text)?
};
{
let mut cache = self.cache.lock()
.map_err(|e| RagError::EmbeddingError(format!("Cache lock poisoned: {}", e)))?;
cache.put(text.to_string(), embedding.clone());
let mut misses = self.misses.lock()
.map_err(|e| RagError::EmbeddingError(format!("Misses lock poisoned: {}", e)))?;
*misses += 1;
}
Ok(embedding)
}
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
let mut to_compute = Vec::new();
let mut to_compute_indices = Vec::new();
{
let mut cache = self.cache.lock()
.map_err(|e| RagError::EmbeddingError(format!("Cache lock poisoned: {}", e)))?;
let mut hits = self.hits.lock()
.map_err(|e| RagError::EmbeddingError(format!("Hits lock poisoned: {}", e)))?;
for (idx, &text) in texts.iter().enumerate() {
if let Some(embedding) = cache.get(text) {
results.push((idx, embedding.clone()));
*hits += 1;
} else {
to_compute.push(text);
to_compute_indices.push(idx);
}
}
}
if !to_compute.is_empty() {
let computed = {
let mut embedder = self.embedder.lock()
.map_err(|e| RagError::EmbeddingError(format!("Embedder lock poisoned: {}", e)))?;
embedder.embed_batch(&to_compute)?
};
{
let mut cache = self.cache.lock()
.map_err(|e| RagError::EmbeddingError(format!("Cache lock poisoned: {}", e)))?;
let mut misses = self.misses.lock()
.map_err(|e| RagError::EmbeddingError(format!("Misses lock poisoned: {}", e)))?;
for (text, embedding) in to_compute.iter().zip(computed.iter()) {
cache.put(text.to_string(), embedding.clone());
*misses += 1;
}
}
for (idx, embedding) in to_compute_indices.into_iter().zip(computed.into_iter()) {
results.push((idx, embedding));
}
}
results.sort_by_key(|(idx, _)| *idx);
Ok(results.into_iter().map(|(_, emb)| emb).collect())
}
pub fn clear_cache(&self) {
if let Ok(mut cache) = self.cache.lock() {
cache.clear();
}
if let Ok(mut hits) = self.hits.lock() {
*hits = 0;
}
if let Ok(mut misses) = self.misses.lock() {
*misses = 0;
}
}
pub fn cache_size(&self) -> usize {
self.cache.lock()
.map(|cache| cache.len())
.unwrap_or(0)
}
pub fn cache_stats(&self) -> (usize, usize) {
let hits = self.hits.lock().map(|h| *h).unwrap_or(0);
let misses = self.misses.lock().map(|m| *m).unwrap_or(0);
(hits, misses)
}
pub fn max_cache_size(&self) -> usize {
self.cache.lock()
.map(|cache| cache.cap().get())
.unwrap_or(0)
}
pub fn hit_rate(&self) -> Option<f64> {
let (hits, misses) = self.cache_stats();
let total = hits + misses;
if total == 0 {
None
} else {
Some(hits as f64 / total as f64)
}
}
}
#[cfg(all(test, feature = "onnx"))]
mod tests {
use super::*;
#[test]
#[ignore]
fn test_cache_creation() {
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = CachedEmbedder::new(embedder, 42);
assert_eq!(cached.max_cache_size(), 42);
assert_eq!(cached.cache_size(), 0);
assert_eq!(cached.cache_stats(), (0, 0));
assert_eq!(cached.hit_rate(), None);
}
#[test]
#[ignore]
fn test_with_default_size() {
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = CachedEmbedder::with_default_size(embedder);
assert_eq!(cached.max_cache_size(), 10_000);
}
#[test]
#[ignore]
fn test_zero_cache_size() {
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = CachedEmbedder::new(embedder, 0);
assert_eq!(cached.max_cache_size(), 1);
}
#[test]
#[ignore]
fn test_cache_hit() {
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = CachedEmbedder::new(embedder, 100);
let embedding1 = cached.embed("Hello, world!").unwrap();
let (hits, misses) = cached.cache_stats();
assert_eq!(hits, 0);
assert_eq!(misses, 1);
assert_eq!(cached.cache_size(), 1);
let embedding2 = cached.embed("Hello, world!").unwrap();
let (hits, misses) = cached.cache_stats();
assert_eq!(hits, 1);
assert_eq!(misses, 1);
assert_eq!(cached.cache_size(), 1);
assert_eq!(embedding1, embedding2);
assert_eq!(cached.hit_rate(), Some(0.5));
}
#[test]
#[ignore]
fn test_cache_miss() {
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = CachedEmbedder::new(embedder, 100);
let emb1 = cached.embed("Hello").unwrap();
let emb2 = cached.embed("World").unwrap();
let emb3 = cached.embed("RAG").unwrap();
let (hits, misses) = cached.cache_stats();
assert_eq!(hits, 0);
assert_eq!(misses, 3);
assert_eq!(cached.cache_size(), 3);
assert_eq!(cached.hit_rate(), Some(0.0));
assert_ne!(emb1, emb2);
assert_ne!(emb2, emb3);
}
#[test]
#[ignore]
fn test_cache_eviction() {
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = CachedEmbedder::new(embedder, 3);
let _emb1 = cached.embed("text1").unwrap();
let _emb2 = cached.embed("text2").unwrap();
let _emb3 = cached.embed("text3").unwrap();
assert_eq!(cached.cache_size(), 3);
let _emb4 = cached.embed("text4").unwrap();
assert_eq!(cached.cache_size(), 3);
let _ = cached.embed("text2").unwrap();
let _ = cached.embed("text3").unwrap();
let (hits, misses) = cached.cache_stats();
assert_eq!(hits, 2);
assert_eq!(misses, 4);
let _ = cached.embed("text1").unwrap();
let (hits, misses) = cached.cache_stats();
assert_eq!(hits, 2);
assert_eq!(misses, 5);
}
#[test]
#[ignore]
fn test_clear_cache() {
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = CachedEmbedder::new(embedder, 100);
let _emb1 = cached.embed("Hello").unwrap();
let _emb2 = cached.embed("World").unwrap();
assert_eq!(cached.cache_size(), 2);
assert_eq!(cached.cache_stats(), (0, 2));
cached.clear_cache();
assert_eq!(cached.cache_size(), 0);
assert_eq!(cached.cache_stats(), (0, 0));
assert_eq!(cached.hit_rate(), None);
let _emb1_again = cached.embed("Hello").unwrap();
let (hits, misses) = cached.cache_stats();
assert_eq!(hits, 0);
assert_eq!(misses, 1);
}
#[test]
#[ignore]
fn test_cache_statistics() {
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = CachedEmbedder::new(embedder, 100);
assert_eq!(cached.cache_stats(), (0, 0));
assert_eq!(cached.hit_rate(), None);
let _emb = cached.embed("test").unwrap();
assert_eq!(cached.cache_stats(), (0, 1));
assert_eq!(cached.hit_rate(), Some(0.0));
let _emb = cached.embed("test").unwrap();
assert_eq!(cached.cache_stats(), (1, 1));
assert_eq!(cached.hit_rate(), Some(0.5));
let _emb = cached.embed("test").unwrap();
assert_eq!(cached.cache_stats(), (2, 1));
assert!((cached.hit_rate().unwrap() - 0.6666).abs() < 0.01);
}
#[test]
#[ignore]
fn test_batch_embed_all_cached() {
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = CachedEmbedder::new(embedder, 100);
let texts = ["Hello", "World", "RAG"];
let embeddings1 = cached.embed_batch(&texts).unwrap();
assert_eq!(embeddings1.len(), 3);
let (hits, misses) = cached.cache_stats();
assert_eq!(hits, 0);
assert_eq!(misses, 3);
let embeddings2 = cached.embed_batch(&texts).unwrap();
assert_eq!(embeddings2.len(), 3);
let (hits, misses) = cached.cache_stats();
assert_eq!(hits, 3);
assert_eq!(misses, 3);
assert_eq!(embeddings1, embeddings2);
}
#[test]
#[ignore]
fn test_batch_embed_mixed() {
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = CachedEmbedder::new(embedder, 100);
let _emb1 = cached.embed("Hello").unwrap();
let _emb2 = cached.embed("World").unwrap();
let initial_stats = cached.cache_stats();
assert_eq!(initial_stats, (0, 2));
let texts = ["Hello", "World", "RAG", "System"];
let embeddings = cached.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 4);
let (hits, misses) = cached.cache_stats();
assert_eq!(hits, 2); assert_eq!(misses, 4); }
#[test]
#[ignore]
fn test_batch_embed_order_preserved() {
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = CachedEmbedder::new(embedder, 100);
let emb_rag = cached.embed("RAG").unwrap();
let emb_hello = cached.embed("Hello").unwrap();
let texts = ["Hello", "World", "RAG"];
let embeddings = cached.embed_batch(&texts).unwrap();
assert_eq!(embeddings[0], emb_hello);
assert_eq!(embeddings[2], emb_rag);
let emb_world = cached.embed("World").unwrap();
assert_eq!(embeddings[1], emb_world);
}
#[test]
#[ignore]
fn test_concurrent_access() {
use std::sync::Arc;
use std::thread;
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = Arc::new(CachedEmbedder::new(embedder, 100));
let mut handles = vec![];
for i in 0..10 {
let cached_clone = Arc::clone(&cached);
let handle = thread::spawn(move || {
for j in 0..10 {
let text = format!("text_{}", (i + j) % 5);
let _ = cached_clone.embed(&text).unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert!(cached.cache_size() > 0);
let (hits, misses) = cached.cache_stats();
assert_eq!(hits + misses, 100);
assert!(hits > 0, "Expected cache hits with repeated texts");
}
#[test]
#[ignore]
fn test_embedding_dimension() {
let embedder = OnnxEmbedder::new(
"models/model.onnx",
"models/tokenizer.json"
).expect("Failed to create embedder");
let cached = CachedEmbedder::new(embedder, 100);
let embedding = cached.embed("Test text").unwrap();
assert_eq!(embedding.len(), 384, "MiniLM-L6-v2 should produce 384-dim embeddings");
let norm: f32 = embedding.iter().map(|&x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "Embedding should be L2-normalized");
}
}