Skip to main content

mentedb_embedding/
manager.rs

1//! Embedding manager that wraps a provider with caching and statistics.
2
3use mentedb_core::MemoryNode;
4use mentedb_core::error::MenteResult;
5
6use crate::cache::EmbeddingCache;
7use crate::provider::EmbeddingProvider;
8
9/// Statistics about embedding operations.
10#[derive(Debug, Clone, Default)]
11pub struct EmbeddingStats {
12    pub cache_hits: u64,
13    pub cache_misses: u64,
14    pub total_embeddings: u64,
15}
16
17/// Manages embedding generation with caching.
18pub struct EmbeddingManager {
19    provider: Box<dyn EmbeddingProvider>,
20    cache: EmbeddingCache,
21    total_embeddings: u64,
22}
23
24impl EmbeddingManager {
25    /// Create a new embedding manager with the given provider and cache size.
26    pub fn new(provider: Box<dyn EmbeddingProvider>, cache_size: usize) -> Self {
27        Self {
28            provider,
29            cache: EmbeddingCache::new(cache_size),
30            total_embeddings: 0,
31        }
32    }
33
34    /// Generate an embedding for the given text, using the cache when possible.
35    pub fn embed(&mut self, text: &str) -> MenteResult<Vec<f32>> {
36        let model = self.provider.model_name().to_string();
37
38        if let Some(cached) = self.cache.get(text, &model) {
39            return Ok(cached.to_vec());
40        }
41
42        let embedding = self.provider.embed(text)?;
43        self.cache.put(text, &model, embedding.clone());
44        self.total_embeddings += 1;
45        Ok(embedding)
46    }
47
48    /// Generate embeddings for a batch of texts, using the cache for already-computed ones.
49    pub fn embed_batch(&mut self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
50        let model = self.provider.model_name().to_string();
51
52        let mut results: Vec<Option<Vec<f32>>> = Vec::with_capacity(texts.len());
53        let mut missing_indices: Vec<usize> = Vec::new();
54        let mut missing_texts: Vec<&str> = Vec::new();
55
56        for (i, text) in texts.iter().enumerate() {
57            if let Some(cached) = self.cache.get(text, &model) {
58                results.push(Some(cached.to_vec()));
59            } else {
60                results.push(None);
61                missing_indices.push(i);
62                missing_texts.push(text);
63            }
64        }
65
66        if !missing_texts.is_empty() {
67            let missing_refs: Vec<&str> = missing_texts.to_vec();
68            let computed = self.provider.embed_batch(&missing_refs)?;
69
70            for (idx, embedding) in missing_indices.into_iter().zip(computed) {
71                self.cache.put(texts[idx], &model, embedding.clone());
72                self.total_embeddings += 1;
73                results[idx] = Some(embedding);
74            }
75        }
76
77        Ok(results.into_iter().map(|r| r.unwrap()).collect())
78    }
79
80    /// Embed a memory node's content and set its embedding field.
81    pub fn embed_memory(&mut self, node: &mut MemoryNode) -> MenteResult<()> {
82        let embedding = self.embed(&node.content)?;
83        node.embedding = embedding;
84        Ok(())
85    }
86
87    /// Get statistics about embedding operations.
88    pub fn stats(&self) -> EmbeddingStats {
89        let cache_stats = self.cache.stats();
90        EmbeddingStats {
91            cache_hits: cache_stats.hits,
92            cache_misses: cache_stats.misses,
93            total_embeddings: self.total_embeddings,
94        }
95    }
96
97    /// Get the dimensionality of the underlying provider.
98    pub fn dimensions(&self) -> usize {
99        self.provider.dimensions()
100    }
101
102    /// Get the model name of the underlying provider.
103    pub fn model_name(&self) -> &str {
104        self.provider.model_name()
105    }
106
107    /// Clear the embedding cache.
108    pub fn clear_cache(&mut self) {
109        self.cache.clear();
110    }
111}