vectoria-core 0.1.5

Embedded hybrid search engine core — BM25 + vector + behavioral signals
use super::EmbeddingProvider;
use anyhow::Result;
use async_trait::async_trait;
use foyer::{Cache, CacheBuilder, LruConfig};
use std::sync::{Arc, Mutex};

pub struct CachedEmbedding {
    inner: Arc<dyn EmbeddingProvider>,
    cache: Mutex<Cache<String, Arc<Vec<f32>>>>,
}

impl CachedEmbedding {
    pub fn new(inner: Arc<dyn EmbeddingProvider>, capacity: usize) -> Self {
        let cache = CacheBuilder::new(capacity)
            .with_eviction_config(LruConfig {
                high_priority_pool_ratio: 0.1,
            })
            .build();
        Self { inner, cache: Mutex::new(cache) }
    }
}

#[async_trait]
impl EmbeddingProvider for CachedEmbedding {
    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
        {
            let cache = self.cache.lock().unwrap();
            if let Some(entry) = cache.get(text) {
                return Ok(entry.value().as_ref().clone());
            }
        }
        let vector = self.inner.embed(text).await?;
        {
            let cache = self.cache.lock().unwrap();
            cache.insert(text.to_string(), Arc::new(vector.clone()));
        }
        Ok(vector)
    }

    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
        let mut results = Vec::with_capacity(texts.len());
        let mut miss_indices: Vec<usize> = Vec::new();
        let mut miss_texts: Vec<&str> = Vec::new();

        {
            let cache = self.cache.lock().unwrap();
            for (i, text) in texts.iter().enumerate() {
                if let Some(entry) = cache.get(*text) {
                    results.push((i, entry.value().as_ref().clone()));
                } else {
                    miss_indices.push(i);
                    miss_texts.push(text);
                    results.push((i, vec![]));
                }
            }
        }

        if !miss_texts.is_empty() {
            let embedded = self.inner.embed_batch(&miss_texts).await?;
            let cache = self.cache.lock().unwrap();
            for (j, (orig_idx, vec)) in miss_indices.iter().zip(embedded.into_iter()).enumerate() {
                cache.insert(miss_texts[j].to_string(), Arc::new(vec.clone()));
                results[*orig_idx] = (*orig_idx, vec);
            }
        }

        results.sort_by_key(|(i, _)| *i);
        Ok(results.into_iter().map(|(_, v)| v).collect())
    }

    fn model_id(&self) -> &str { self.inner.model_id() }
    fn dims(&self) -> usize { self.inner.dims() }
}