mentedb_embedding/
candle_provider.rs1use std::path::PathBuf;
8
9use candle_core::{Device, Tensor};
10use candle_nn::VarBuilder;
11use candle_transformers::models::bert::{BertModel, Config as BertConfig};
12use hf_hub::{Repo, RepoType, api::sync::Api};
13use mentedb_core::MenteError;
14use mentedb_core::error::MenteResult;
15use tokenizers::Tokenizer;
16
17use crate::provider::EmbeddingProvider;
18
19const DEFAULT_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
21
22pub struct CandleEmbeddingProvider {
28 model: BertModel,
29 tokenizer: Tokenizer,
30 device: Device,
31 dimensions: usize,
32 model_id: String,
33}
34
35impl CandleEmbeddingProvider {
36 pub fn new() -> MenteResult<Self> {
40 Self::with_model(DEFAULT_MODEL_ID)
41 }
42
43 pub fn with_model(model_id: &str) -> MenteResult<Self> {
45 Self::load(model_id, None)
46 }
47
48 pub fn with_cache_dir(cache_dir: PathBuf) -> MenteResult<Self> {
50 Self::load(DEFAULT_MODEL_ID, Some(cache_dir))
51 }
52
53 fn load(model_id: &str, cache_dir: Option<PathBuf>) -> MenteResult<Self> {
54 let device = Device::Cpu;
55
56 let api = match cache_dir {
57 Some(dir) => {
58 let cache = hf_hub::Cache::new(dir);
59 hf_hub::api::sync::ApiBuilder::from_cache(cache)
60 .build()
61 .map_err(|e| {
62 MenteError::Storage(format!("Failed to create HF API with cache: {e}"))
63 })?
64 }
65 None => Api::new()
66 .map_err(|e| MenteError::Storage(format!("Failed to create HF API: {e}")))?,
67 };
68
69 let repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));
70
71 tracing::info!(model = model_id, "Loading local embedding model");
72
73 let config_path = repo
75 .get("config.json")
76 .map_err(|e| MenteError::Storage(format!("Failed to download config.json: {e}")))?;
77 let tokenizer_path = repo
78 .get("tokenizer.json")
79 .map_err(|e| MenteError::Storage(format!("Failed to download tokenizer.json: {e}")))?;
80 let weights_path = repo.get("model.safetensors").map_err(|e| {
81 MenteError::Storage(format!("Failed to download model.safetensors: {e}"))
82 })?;
83
84 let config_str = std::fs::read_to_string(&config_path)
86 .map_err(|e| MenteError::Storage(format!("Failed to read config: {e}")))?;
87 let config: BertConfig = serde_json::from_str(&config_str)
88 .map_err(|e| MenteError::Storage(format!("Failed to parse config: {e}")))?;
89
90 let tokenizer = Tokenizer::from_file(&tokenizer_path)
92 .map_err(|e| MenteError::Storage(format!("Failed to load tokenizer: {e}")))?;
93
94 let vb = unsafe {
96 VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
97 .map_err(|e| MenteError::Storage(format!("Failed to load weights: {e}")))?
98 };
99
100 let model = BertModel::load(vb, &config)
101 .map_err(|e| MenteError::Storage(format!("Failed to load model: {e}")))?;
102
103 let dimensions = config.hidden_size;
104
105 tracing::info!(
106 model = model_id,
107 dimensions = dimensions,
108 "Local embedding model loaded"
109 );
110
111 Ok(Self {
112 model,
113 tokenizer,
114 device,
115 dimensions,
116 model_id: model_id.to_string(),
117 })
118 }
119
120 fn encode(&self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
122 if texts.is_empty() {
123 return Ok(Vec::new());
124 }
125
126 let encodings = self
127 .tokenizer
128 .encode_batch(texts.to_vec(), true)
129 .map_err(|e| MenteError::Storage(format!("Tokenization failed: {e}")))?;
130
131 let max_len = encodings
132 .iter()
133 .map(|e| e.get_ids().len())
134 .max()
135 .unwrap_or(0);
136
137 let mut all_input_ids: Vec<u32> = Vec::new();
138 let mut all_attention_mask: Vec<u32> = Vec::new();
139 let mut all_token_type_ids: Vec<u32> = Vec::new();
140
141 for encoding in &encodings {
142 let ids = encoding.get_ids();
143 let mask = encoding.get_attention_mask();
144 let type_ids = encoding.get_type_ids();
145
146 let pad_len = max_len - ids.len();
147
148 all_input_ids.extend_from_slice(ids);
149 all_input_ids.extend(std::iter::repeat_n(0u32, pad_len));
150
151 all_attention_mask.extend_from_slice(mask);
152 all_attention_mask.extend(std::iter::repeat_n(0u32, pad_len));
153
154 all_token_type_ids.extend_from_slice(type_ids);
155 all_token_type_ids.extend(std::iter::repeat_n(0u32, pad_len));
156 }
157
158 let batch_size = texts.len();
159
160 let input_ids = Tensor::from_vec(all_input_ids, (batch_size, max_len), &self.device)
161 .map_err(|e| MenteError::Storage(format!("Tensor creation failed: {e}")))?;
162
163 let attention_mask = Tensor::from_vec(
164 all_attention_mask.clone(),
165 (batch_size, max_len),
166 &self.device,
167 )
168 .map_err(|e| MenteError::Storage(format!("Tensor creation failed: {e}")))?;
169
170 let token_type_ids =
171 Tensor::from_vec(all_token_type_ids, (batch_size, max_len), &self.device)
172 .map_err(|e| MenteError::Storage(format!("Tensor creation failed: {e}")))?;
173
174 let output = self
176 .model
177 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
178 .map_err(|e| MenteError::Storage(format!("Model forward pass failed: {e}")))?;
179
180 let mask_f32 = Tensor::from_vec(
182 all_attention_mask
183 .iter()
184 .map(|&v| v as f32)
185 .collect::<Vec<_>>(),
186 (batch_size, max_len),
187 &self.device,
188 )
189 .map_err(|e| MenteError::Storage(format!("Mask tensor failed: {e}")))?;
190
191 let mask_expanded = mask_f32
192 .unsqueeze(2)
193 .map_err(|e| MenteError::Storage(format!("Unsqueeze failed: {e}")))?
194 .broadcast_as(output.shape())
195 .map_err(|e| MenteError::Storage(format!("Broadcast failed: {e}")))?;
196
197 let masked = output
198 .mul(&mask_expanded)
199 .map_err(|e| MenteError::Storage(format!("Mul failed: {e}")))?;
200
201 let summed = masked
202 .sum(1)
203 .map_err(|e| MenteError::Storage(format!("Sum failed: {e}")))?;
204
205 let counts = mask_expanded
206 .sum(1)
207 .map_err(|e| MenteError::Storage(format!("Count sum failed: {e}")))?
208 .clamp(1e-9, f64::MAX)
209 .map_err(|e| MenteError::Storage(format!("Clamp failed: {e}")))?;
210
211 let mean_pooled = summed
212 .div(&counts)
213 .map_err(|e| MenteError::Storage(format!("Div failed: {e}")))?;
214
215 let norms = mean_pooled
217 .sqr()
218 .map_err(|e| MenteError::Storage(format!("Sqr failed: {e}")))?
219 .sum(1)
220 .map_err(|e| MenteError::Storage(format!("Norm sum failed: {e}")))?
221 .sqrt()
222 .map_err(|e| MenteError::Storage(format!("Sqrt failed: {e}")))?
223 .clamp(1e-12, f64::MAX)
224 .map_err(|e| MenteError::Storage(format!("Norm clamp failed: {e}")))?
225 .unsqueeze(1)
226 .map_err(|e| MenteError::Storage(format!("Norm unsqueeze failed: {e}")))?
227 .broadcast_as(mean_pooled.shape())
228 .map_err(|e| MenteError::Storage(format!("Norm broadcast failed: {e}")))?;
229
230 let normalized = mean_pooled
231 .div(&norms)
232 .map_err(|e| MenteError::Storage(format!("Normalize failed: {e}")))?;
233
234 let mut results = Vec::with_capacity(batch_size);
236 for i in 0..batch_size {
237 let emb = normalized
238 .get(i)
239 .map_err(|e| MenteError::Storage(format!("Get embedding failed: {e}")))?
240 .to_vec1::<f32>()
241 .map_err(|e| MenteError::Storage(format!("To vec failed: {e}")))?;
242 results.push(emb);
243 }
244
245 Ok(results)
246 }
247}
248
249impl EmbeddingProvider for CandleEmbeddingProvider {
250 fn embed(&self, text: &str) -> MenteResult<Vec<f32>> {
251 let results = self.encode(&[text])?;
252 results
253 .into_iter()
254 .next()
255 .ok_or_else(|| MenteError::Storage("Empty embedding result".to_string()))
256 }
257
258 fn embed_batch(&self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
259 self.encode(texts)
260 }
261
262 fn dimensions(&self) -> usize {
263 self.dimensions
264 }
265
266 fn model_name(&self) -> &str {
267 &self.model_id
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_candle_provider_loads() {
277 let provider = CandleEmbeddingProvider::new();
278 assert!(
279 provider.is_ok(),
280 "Failed to load model: {:?}",
281 provider.err()
282 );
283 }
284
285 #[test]
286 fn test_candle_embed_single() {
287 let provider = CandleEmbeddingProvider::new().unwrap();
288 let emb = provider.embed("hello world").unwrap();
289 assert_eq!(emb.len(), provider.dimensions());
290
291 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
293 assert!((norm - 1.0).abs() < 1e-4, "Not normalized: {norm}");
294 }
295
296 #[test]
297 fn test_candle_embed_batch() {
298 let provider = CandleEmbeddingProvider::new().unwrap();
299 let results = provider.embed_batch(&["hello", "world", "test"]).unwrap();
300 assert_eq!(results.len(), 3);
301 for emb in &results {
302 assert_eq!(emb.len(), provider.dimensions());
303 }
304 }
305
306 #[test]
307 fn test_candle_semantic_similarity() {
308 let provider = CandleEmbeddingProvider::new().unwrap();
309 let e1 = provider.embed("PostgreSQL database").unwrap();
310 let e2 = provider.embed("relational database system").unwrap();
311 let e3 = provider.embed("chocolate cake recipe").unwrap();
312
313 let sim_related: f32 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
314 let sim_unrelated: f32 = e1.iter().zip(e3.iter()).map(|(a, b)| a * b).sum();
315
316 assert!(
317 sim_related > sim_unrelated,
318 "Related texts should be more similar: related={sim_related}, unrelated={sim_unrelated}"
319 );
320 }
321}