mentedb_embedding/
manager.rs1use mentedb_core::MemoryNode;
4use mentedb_core::error::MenteResult;
5
6use crate::cache::EmbeddingCache;
7use crate::provider::EmbeddingProvider;
8
9#[derive(Debug, Clone, Default)]
11pub struct EmbeddingStats {
12 pub cache_hits: u64,
13 pub cache_misses: u64,
14 pub total_embeddings: u64,
15}
16
17pub struct EmbeddingManager {
19 provider: Box<dyn EmbeddingProvider>,
20 cache: EmbeddingCache,
21 total_embeddings: u64,
22}
23
24impl EmbeddingManager {
25 pub fn new(provider: Box<dyn EmbeddingProvider>, cache_size: usize) -> Self {
27 Self {
28 provider,
29 cache: EmbeddingCache::new(cache_size),
30 total_embeddings: 0,
31 }
32 }
33
34 pub fn embed(&mut self, text: &str) -> MenteResult<Vec<f32>> {
36 let model = self.provider.model_name().to_string();
37
38 if let Some(cached) = self.cache.get(text, &model) {
39 return Ok(cached.to_vec());
40 }
41
42 let embedding = self.provider.embed(text)?;
43 self.cache.put(text, &model, embedding.clone());
44 self.total_embeddings += 1;
45 Ok(embedding)
46 }
47
48 pub fn embed_batch(&mut self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
50 let model = self.provider.model_name().to_string();
51
52 let mut results: Vec<Option<Vec<f32>>> = Vec::with_capacity(texts.len());
53 let mut missing_indices: Vec<usize> = Vec::new();
54 let mut missing_texts: Vec<&str> = Vec::new();
55
56 for (i, text) in texts.iter().enumerate() {
57 if let Some(cached) = self.cache.get(text, &model) {
58 results.push(Some(cached.to_vec()));
59 } else {
60 results.push(None);
61 missing_indices.push(i);
62 missing_texts.push(text);
63 }
64 }
65
66 if !missing_texts.is_empty() {
67 let missing_refs: Vec<&str> = missing_texts.to_vec();
68 let computed = self.provider.embed_batch(&missing_refs)?;
69
70 for (idx, embedding) in missing_indices.into_iter().zip(computed) {
71 self.cache.put(texts[idx], &model, embedding.clone());
72 self.total_embeddings += 1;
73 results[idx] = Some(embedding);
74 }
75 }
76
77 Ok(results.into_iter().map(|r| r.unwrap()).collect())
78 }
79
80 pub fn embed_memory(&mut self, node: &mut MemoryNode) -> MenteResult<()> {
82 let embedding = self.embed(&node.content)?;
83 node.embedding = embedding;
84 Ok(())
85 }
86
87 pub fn stats(&self) -> EmbeddingStats {
89 let cache_stats = self.cache.stats();
90 EmbeddingStats {
91 cache_hits: cache_stats.hits,
92 cache_misses: cache_stats.misses,
93 total_embeddings: self.total_embeddings,
94 }
95 }
96
97 pub fn dimensions(&self) -> usize {
99 self.provider.dimensions()
100 }
101
102 pub fn model_name(&self) -> &str {
104 self.provider.model_name()
105 }
106
107 pub fn clear_cache(&mut self) {
109 self.cache.clear();
110 }
111}