use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use async_trait::async_trait;
use kovan_map::HashMap;
use crate::embedding::Embedder;
use crate::error::Result;
pub struct CachedEmbedder {
inner: Arc<dyn Embedder>,
cache: HashMap<u64, Vec<f32>>,
size: AtomicUsize,
max_size: usize,
}
impl CachedEmbedder {
pub fn new(inner: Arc<dyn Embedder>) -> Self {
Self::with_max_size(inner, 10_000)
}
pub fn with_max_size(inner: Arc<dyn Embedder>, max_size: usize) -> Self {
Self {
inner,
cache: HashMap::new(),
size: AtomicUsize::new(0),
max_size,
}
}
fn hash_text(text: &str) -> u64 {
ahash::RandomState::with_seeds(0, 0, 0, 0).hash_one(text)
}
fn maybe_evict(&self) {
let current = self.size.load(Ordering::Relaxed);
if current > self.max_size {
self.cache.clear();
self.size.store(0, Ordering::Relaxed);
}
}
}
#[async_trait]
impl Embedder for CachedEmbedder {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let hash = Self::hash_text(text);
if let Some(cached) = self.cache.get(&hash) {
return Ok(cached);
}
let embedding = self.inner.embed(text).await?;
if self.cache.insert(hash, embedding.clone()).is_none() {
self.size.fetch_add(1, Ordering::Relaxed);
}
self.maybe_evict();
Ok(embedding)
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
let mut uncached_indices = Vec::new();
let mut uncached_texts = Vec::new();
for (i, text) in texts.iter().enumerate() {
let hash = Self::hash_text(text);
if let Some(cached) = self.cache.get(&hash) {
results.push(Some(cached));
} else {
results.push(None);
uncached_indices.push(i);
uncached_texts.push(text.clone());
}
}
if !uncached_texts.is_empty() {
let embeddings = self.inner.embed_batch(&uncached_texts).await?;
for (idx, embedding) in uncached_indices.into_iter().zip(embeddings) {
let hash = Self::hash_text(&texts[idx]);
if self.cache.insert(hash, embedding.clone()).is_none() {
self.size.fetch_add(1, Ordering::Relaxed);
}
results[idx] = Some(embedding);
}
}
self.maybe_evict();
Ok(results.into_iter().map(|r| r.unwrap()).collect())
}
fn dimensions(&self) -> usize {
self.inner.dimensions()
}
fn model_name(&self) -> &str {
self.inner.model_name()
}
}