Skip to main content

mentedb_embedding/
candle_provider.rs

1//! Local embedding provider using Candle (pure Rust ML framework).
2//!
3//! Downloads and caches a small transformer model (all-MiniLM-L6-v2) from
4//! Hugging Face on first use. Generates 384-dimensional embeddings locally
5//! with no API key required.
6
7use std::path::PathBuf;
8
9use candle_core::{Device, Tensor};
10use candle_nn::VarBuilder;
11use candle_transformers::models::bert::{BertModel, Config as BertConfig};
12use hf_hub::{Repo, RepoType, api::sync::Api};
13use mentedb_core::MenteError;
14use mentedb_core::error::MenteResult;
15use tokenizers::Tokenizer;
16
17use crate::provider::EmbeddingProvider;
18
19/// Default model for local embeddings.
20const DEFAULT_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
21
22/// Local embedding provider powered by Candle.
23///
24/// Uses a small BERT-based model to generate embeddings entirely on CPU,
25/// with no external API calls. The model is downloaded from Hugging Face
26/// on first use and cached in the HF cache directory.
27pub struct CandleEmbeddingProvider {
28    model: BertModel,
29    tokenizer: Tokenizer,
30    device: Device,
31    dimensions: usize,
32    model_id: String,
33}
34
35impl CandleEmbeddingProvider {
36    /// Create a new provider with the default model (all-MiniLM-L6-v2).
37    ///
38    /// Downloads the model on first use. Subsequent calls load from cache.
39    pub fn new() -> MenteResult<Self> {
40        Self::with_model(DEFAULT_MODEL_ID)
41    }
42
43    /// Create a new provider with a specific Hugging Face model ID.
44    pub fn with_model(model_id: &str) -> MenteResult<Self> {
45        Self::load(model_id, None)
46    }
47
48    /// Create a new provider with a custom cache directory.
49    pub fn with_cache_dir(cache_dir: PathBuf) -> MenteResult<Self> {
50        Self::load(DEFAULT_MODEL_ID, Some(cache_dir))
51    }
52
53    fn load(model_id: &str, cache_dir: Option<PathBuf>) -> MenteResult<Self> {
54        let device = Device::Cpu;
55
56        let api = match cache_dir {
57            Some(dir) => {
58                let cache = hf_hub::Cache::new(dir);
59                hf_hub::api::sync::ApiBuilder::from_cache(cache)
60                    .build()
61                    .map_err(|e| {
62                        MenteError::Storage(format!("Failed to create HF API with cache: {e}"))
63                    })?
64            }
65            None => Api::new()
66                .map_err(|e| MenteError::Storage(format!("Failed to create HF API: {e}")))?,
67        };
68
69        let repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));
70
71        tracing::info!(model = model_id, "Loading local embedding model");
72
73        // Download model files
74        let config_path = repo
75            .get("config.json")
76            .map_err(|e| MenteError::Storage(format!("Failed to download config.json: {e}")))?;
77        let tokenizer_path = repo
78            .get("tokenizer.json")
79            .map_err(|e| MenteError::Storage(format!("Failed to download tokenizer.json: {e}")))?;
80        let weights_path = repo.get("model.safetensors").map_err(|e| {
81            MenteError::Storage(format!("Failed to download model.safetensors: {e}"))
82        })?;
83
84        // Load config
85        let config_str = std::fs::read_to_string(&config_path)
86            .map_err(|e| MenteError::Storage(format!("Failed to read config: {e}")))?;
87        let config: BertConfig = serde_json::from_str(&config_str)
88            .map_err(|e| MenteError::Storage(format!("Failed to parse config: {e}")))?;
89
90        // Load tokenizer
91        let tokenizer = Tokenizer::from_file(&tokenizer_path)
92            .map_err(|e| MenteError::Storage(format!("Failed to load tokenizer: {e}")))?;
93
94        // Load model weights
95        let vb = unsafe {
96            VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
97                .map_err(|e| MenteError::Storage(format!("Failed to load weights: {e}")))?
98        };
99
100        let model = BertModel::load(vb, &config)
101            .map_err(|e| MenteError::Storage(format!("Failed to load model: {e}")))?;
102
103        let dimensions = config.hidden_size;
104
105        tracing::info!(
106            model = model_id,
107            dimensions = dimensions,
108            "Local embedding model loaded"
109        );
110
111        Ok(Self {
112            model,
113            tokenizer,
114            device,
115            dimensions,
116            model_id: model_id.to_string(),
117        })
118    }
119
120    /// Encode texts into embeddings using mean pooling over token embeddings.
121    fn encode(&self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
122        if texts.is_empty() {
123            return Ok(Vec::new());
124        }
125
126        let encodings = self
127            .tokenizer
128            .encode_batch(texts.to_vec(), true)
129            .map_err(|e| MenteError::Storage(format!("Tokenization failed: {e}")))?;
130
131        let max_len = encodings
132            .iter()
133            .map(|e| e.get_ids().len())
134            .max()
135            .unwrap_or(0);
136
137        let mut all_input_ids: Vec<u32> = Vec::new();
138        let mut all_attention_mask: Vec<u32> = Vec::new();
139        let mut all_token_type_ids: Vec<u32> = Vec::new();
140
141        for encoding in &encodings {
142            let ids = encoding.get_ids();
143            let mask = encoding.get_attention_mask();
144            let type_ids = encoding.get_type_ids();
145
146            let pad_len = max_len - ids.len();
147
148            all_input_ids.extend_from_slice(ids);
149            all_input_ids.extend(std::iter::repeat_n(0u32, pad_len));
150
151            all_attention_mask.extend_from_slice(mask);
152            all_attention_mask.extend(std::iter::repeat_n(0u32, pad_len));
153
154            all_token_type_ids.extend_from_slice(type_ids);
155            all_token_type_ids.extend(std::iter::repeat_n(0u32, pad_len));
156        }
157
158        let batch_size = texts.len();
159
160        let input_ids = Tensor::from_vec(all_input_ids, (batch_size, max_len), &self.device)
161            .map_err(|e| MenteError::Storage(format!("Tensor creation failed: {e}")))?;
162
163        let attention_mask = Tensor::from_vec(
164            all_attention_mask.clone(),
165            (batch_size, max_len),
166            &self.device,
167        )
168        .map_err(|e| MenteError::Storage(format!("Tensor creation failed: {e}")))?;
169
170        let token_type_ids =
171            Tensor::from_vec(all_token_type_ids, (batch_size, max_len), &self.device)
172                .map_err(|e| MenteError::Storage(format!("Tensor creation failed: {e}")))?;
173
174        // Forward pass
175        let output = self
176            .model
177            .forward(&input_ids, &token_type_ids, Some(&attention_mask))
178            .map_err(|e| MenteError::Storage(format!("Model forward pass failed: {e}")))?;
179
180        // Mean pooling: average token embeddings, masked by attention
181        let mask_f32 = Tensor::from_vec(
182            all_attention_mask
183                .iter()
184                .map(|&v| v as f32)
185                .collect::<Vec<_>>(),
186            (batch_size, max_len),
187            &self.device,
188        )
189        .map_err(|e| MenteError::Storage(format!("Mask tensor failed: {e}")))?;
190
191        let mask_expanded = mask_f32
192            .unsqueeze(2)
193            .map_err(|e| MenteError::Storage(format!("Unsqueeze failed: {e}")))?
194            .broadcast_as(output.shape())
195            .map_err(|e| MenteError::Storage(format!("Broadcast failed: {e}")))?;
196
197        let masked = output
198            .mul(&mask_expanded)
199            .map_err(|e| MenteError::Storage(format!("Mul failed: {e}")))?;
200
201        let summed = masked
202            .sum(1)
203            .map_err(|e| MenteError::Storage(format!("Sum failed: {e}")))?;
204
205        let counts = mask_expanded
206            .sum(1)
207            .map_err(|e| MenteError::Storage(format!("Count sum failed: {e}")))?
208            .clamp(1e-9, f64::MAX)
209            .map_err(|e| MenteError::Storage(format!("Clamp failed: {e}")))?;
210
211        let mean_pooled = summed
212            .div(&counts)
213            .map_err(|e| MenteError::Storage(format!("Div failed: {e}")))?;
214
215        // L2 normalize each embedding
216        let norms = mean_pooled
217            .sqr()
218            .map_err(|e| MenteError::Storage(format!("Sqr failed: {e}")))?
219            .sum(1)
220            .map_err(|e| MenteError::Storage(format!("Norm sum failed: {e}")))?
221            .sqrt()
222            .map_err(|e| MenteError::Storage(format!("Sqrt failed: {e}")))?
223            .clamp(1e-12, f64::MAX)
224            .map_err(|e| MenteError::Storage(format!("Norm clamp failed: {e}")))?
225            .unsqueeze(1)
226            .map_err(|e| MenteError::Storage(format!("Norm unsqueeze failed: {e}")))?
227            .broadcast_as(mean_pooled.shape())
228            .map_err(|e| MenteError::Storage(format!("Norm broadcast failed: {e}")))?;
229
230        let normalized = mean_pooled
231            .div(&norms)
232            .map_err(|e| MenteError::Storage(format!("Normalize failed: {e}")))?;
233
234        // Extract embeddings
235        let mut results = Vec::with_capacity(batch_size);
236        for i in 0..batch_size {
237            let emb = normalized
238                .get(i)
239                .map_err(|e| MenteError::Storage(format!("Get embedding failed: {e}")))?
240                .to_vec1::<f32>()
241                .map_err(|e| MenteError::Storage(format!("To vec failed: {e}")))?;
242            results.push(emb);
243        }
244
245        Ok(results)
246    }
247}
248
249impl EmbeddingProvider for CandleEmbeddingProvider {
250    fn embed(&self, text: &str) -> MenteResult<Vec<f32>> {
251        let results = self.encode(&[text])?;
252        results
253            .into_iter()
254            .next()
255            .ok_or_else(|| MenteError::Storage("Empty embedding result".to_string()))
256    }
257
258    fn embed_batch(&self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
259        self.encode(texts)
260    }
261
262    fn dimensions(&self) -> usize {
263        self.dimensions
264    }
265
266    fn model_name(&self) -> &str {
267        &self.model_id
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_candle_provider_loads() {
277        let provider = CandleEmbeddingProvider::new();
278        assert!(
279            provider.is_ok(),
280            "Failed to load model: {:?}",
281            provider.err()
282        );
283    }
284
285    #[test]
286    fn test_candle_embed_single() {
287        let provider = CandleEmbeddingProvider::new().unwrap();
288        let emb = provider.embed("hello world").unwrap();
289        assert_eq!(emb.len(), provider.dimensions());
290
291        // Check normalized
292        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
293        assert!((norm - 1.0).abs() < 1e-4, "Not normalized: {norm}");
294    }
295
296    #[test]
297    fn test_candle_embed_batch() {
298        let provider = CandleEmbeddingProvider::new().unwrap();
299        let results = provider.embed_batch(&["hello", "world", "test"]).unwrap();
300        assert_eq!(results.len(), 3);
301        for emb in &results {
302            assert_eq!(emb.len(), provider.dimensions());
303        }
304    }
305
306    #[test]
307    fn test_candle_semantic_similarity() {
308        let provider = CandleEmbeddingProvider::new().unwrap();
309        let e1 = provider.embed("PostgreSQL database").unwrap();
310        let e2 = provider.embed("relational database system").unwrap();
311        let e3 = provider.embed("chocolate cake recipe").unwrap();
312
313        let sim_related: f32 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
314        let sim_unrelated: f32 = e1.iter().zip(e3.iter()).map(|(a, b)| a * b).sum();
315
316        assert!(
317            sim_related > sim_unrelated,
318            "Related texts should be more similar: related={sim_related}, unrelated={sim_unrelated}"
319        );
320    }
321}