1use std::path::PathBuf;
22
23use ort::session::builder::GraphOptimizationLevel;
24use ort::session::Session;
25use ort::value::Tensor;
26use tokenizers::Tokenizer;
27
28use crate::semantic_index::{format_embedding_init_error, pre_validate_onnx_runtime};
29use crate::slog_info;
30
31const MINILM_REPO: &str = "Qdrant/all-MiniLM-L6-v2-onnx";
34const MINILM_MODEL_FILE: &str = "model.onnx";
35const MINILM_TOKENIZER_FILE: &str = "tokenizer.json";
36const MINILM_MAX_LENGTH: usize = 512;
41const MAX_BATCH_ATTENTION_UNITS: usize = 4_000_000;
53
54fn intra_thread_cap() -> usize {
57 std::thread::available_parallelism()
58 .map(|p| p.get())
59 .unwrap_or(1)
60 .div_ceil(2)
61 .max(1)
62}
63
64pub struct LocalEmbedder {
65 session: Session,
66 tokenizer: Tokenizer,
67 wants_token_type_ids: bool,
68}
69
70impl LocalEmbedder {
71 pub fn new(model: &str) -> Result<Self, String> {
74 match model {
75 "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {}
76 other => {
77 return Err(format!(
78 "unsupported local embedding model '{other}'. Supported: all-MiniLM-L6-v2"
79 ))
80 }
81 }
82
83 pre_validate_onnx_runtime()?;
86
87 let (model_path, tokenizer_path) = resolve_model_files()?;
88
89 let threads = intra_thread_cap();
90 let session = Session::builder()
91 .map_err(|e| format!("failed to create ONNX session builder: {e}"))?
92 .with_optimization_level(GraphOptimizationLevel::Level3)
93 .map_err(|e| format!("failed to set ONNX optimization level: {e}"))?
94 .with_intra_threads(threads)
95 .map_err(|e| format!("failed to set ONNX intra-op threads: {e}"))?
96 .commit_from_file(&model_path)
97 .map_err(format_embedding_init_error)?;
101
102 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
103 .map_err(|e| format!("failed to load tokenizer {}: {e}", tokenizer_path.display()))?;
104 tokenizer
107 .with_truncation(Some(tokenizers::TruncationParams {
108 max_length: MINILM_MAX_LENGTH,
109 ..Default::default()
110 }))
111 .map_err(|e| format!("failed to set tokenizer truncation: {e}"))?;
112
113 let wants_token_type_ids = session
114 .inputs()
115 .iter()
116 .any(|input| input.name() == "token_type_ids");
117
118 slog_info!(
119 "local embedder ready: model=all-MiniLM-L6-v2 intra_threads={} token_type_ids={}",
120 threads,
121 wants_token_type_ids
122 );
123
124 Ok(Self {
125 session,
126 tokenizer,
127 wants_token_type_ids,
128 })
129 }
130
131 pub fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>, String> {
142 if texts.is_empty() {
143 return Ok(Vec::new());
144 }
145
146 let text_refs: Vec<&str> = texts.iter().map(String::as_str).collect();
147 let encodings = self
148 .tokenizer
149 .encode_batch(text_refs, true)
150 .map_err(|e| format!("tokenize batch: {e}"))?;
151
152 let mut result = Vec::with_capacity(encodings.len());
156 let mut batch_start = 0usize;
157 let mut batch_max = 0usize;
158 for (i, enc) in encodings.iter().enumerate() {
159 let len = enc.get_ids().len().max(1);
160 let count = i - batch_start; let candidate_max = batch_max.max(len);
162 let cost = (count + 1)
163 .saturating_mul(candidate_max)
164 .saturating_mul(candidate_max);
165 if count > 0 && cost > MAX_BATCH_ATTENTION_UNITS {
166 let vecs = self.run_inference(&encodings[batch_start..i])?;
167 result.extend(vecs);
168 batch_start = i;
169 batch_max = len;
170 } else {
171 batch_max = candidate_max;
172 }
173 }
174 let vecs = self.run_inference(&encodings[batch_start..])?;
176 result.extend(vecs);
177 Ok(result)
178 }
179
180 fn run_inference(
185 &mut self,
186 encodings: &[tokenizers::Encoding],
187 ) -> Result<Vec<Vec<f32>>, String> {
188 if encodings.is_empty() {
189 return Ok(Vec::new());
190 }
191
192 let batch = encodings.len();
193 let max_len = encodings
194 .iter()
195 .map(|e| e.get_ids().len())
196 .max()
197 .unwrap_or(1)
198 .max(1);
199
200 let mut ids = vec![0i64; batch * max_len];
204 let mut mask = vec![0i64; batch * max_len];
205 for (row, enc) in encodings.iter().enumerate() {
206 let row_ids = enc.get_ids();
207 let row_mask = enc.get_attention_mask();
208 let base = row * max_len;
209 for col in 0..row_ids.len() {
210 ids[base + col] = row_ids[col] as i64;
211 mask[base + col] = row_mask[col] as i64;
212 }
213 }
214
215 let input_ids = ndarray::Array2::<i64>::from_shape_vec((batch, max_len), ids)
216 .map_err(|e| format!("build input_ids tensor: {e}"))?;
217 let attention_mask = ndarray::Array2::<i64>::from_shape_vec((batch, max_len), mask)
218 .map_err(|e| format!("build attention_mask tensor: {e}"))?;
219
220 let mut inputs = ort::inputs![
221 "input_ids" => Tensor::from_array(input_ids).map_err(|e| format!("input_ids: {e}"))?,
222 "attention_mask" => Tensor::from_array(attention_mask.clone())
223 .map_err(|e| format!("attention_mask: {e}"))?,
224 ];
225 if self.wants_token_type_ids {
226 let token_type_ids = ndarray::Array2::<i64>::zeros((batch, max_len));
227 inputs.push((
228 "token_type_ids".into(),
229 Tensor::from_array(token_type_ids)
230 .map_err(|e| format!("token_type_ids: {e}"))?
231 .into(),
232 ));
233 }
234
235 let outputs = self
236 .session
237 .run(inputs)
238 .map_err(|e| format!("ONNX inference failed: {e}"))?;
239 let output = outputs
240 .values()
241 .next()
242 .ok_or_else(|| "ONNX model produced no output".to_string())?;
243
244 let (shape, data): (Vec<i64>, Vec<f32>) = match output.try_extract_tensor::<f32>() {
246 Ok((s, d)) => (s.to_vec(), d.to_vec()),
247 Err(_) => {
248 let (s, d) = output
249 .try_extract_tensor::<half::f16>()
250 .map_err(|e| format!("extract output tensor: {e}"))?;
251 (s.to_vec(), d.iter().map(|h| h.to_f32()).collect())
252 }
253 };
254 if shape.len() != 3 {
255 return Err(format!(
256 "unexpected ONNX output rank {} (expected 3: [batch, seq, dim])",
257 shape.len()
258 ));
259 }
260 let seq = shape[1] as usize;
261 let dim = shape[2] as usize;
262
263 let mut result = Vec::with_capacity(batch);
264 for row in 0..batch {
265 let mut emb = vec![0.0f32; dim];
266 let mut valid = 0.0f32;
267 for col in 0..seq {
268 if mask_at(&attention_mask, row, col) == 1 {
269 valid += 1.0;
270 let base = (row * seq + col) * dim;
271 for (d, slot) in emb.iter_mut().enumerate() {
272 *slot += data[base + d];
273 }
274 }
275 }
276 let denom = if valid == 0.0 { 1.0 } else { valid };
277 for slot in &mut emb {
278 *slot /= denom;
279 }
280 let norm = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
281 for slot in &mut emb {
282 *slot /= norm + 1e-12;
283 }
284 result.push(emb);
285 }
286 Ok(result)
287 }
288}
289
290#[inline]
291fn mask_at(mask: &ndarray::Array2<i64>, row: usize, col: usize) -> i64 {
292 mask[[row, col]]
293}
294
295fn resolve_model_files() -> Result<(PathBuf, PathBuf), String> {
298 let cache_dir = embedding_cache_dir();
299
300 if let Some(found) = scan_local_snapshot(&cache_dir) {
301 return Ok(found);
302 }
303
304 download_via_hf_hub(&cache_dir)
307}
308
309fn embedding_cache_dir() -> PathBuf {
313 if let Some(dir) = std::env::var_os("FASTEMBED_CACHE_DIR") {
314 return PathBuf::from(dir);
315 }
316 let home = std::env::var_os("HOME")
317 .or_else(|| std::env::var_os("USERPROFILE"))
318 .map(PathBuf::from)
319 .unwrap_or_else(std::env::temp_dir);
320 home.join(".cache").join("fastembed")
321}
322
323fn scan_local_snapshot(cache_dir: &std::path::Path) -> Option<(PathBuf, PathBuf)> {
326 let repo_dir = cache_dir.join("models--Qdrant--all-MiniLM-L6-v2-onnx");
327 let snapshots = repo_dir.join("snapshots");
328 let mut candidates: Vec<PathBuf> = std::fs::read_dir(&snapshots)
329 .ok()?
330 .filter_map(|entry| entry.ok().map(|e| e.path()))
331 .filter(|p| p.is_dir())
332 .collect();
333 candidates.sort_by_key(|p| {
335 std::fs::metadata(p)
336 .and_then(|m| m.modified())
337 .unwrap_or(std::time::UNIX_EPOCH)
338 });
339 candidates.reverse();
340 for snap in candidates {
341 let model = snap.join(MINILM_MODEL_FILE);
342 let tokenizer = snap.join(MINILM_TOKENIZER_FILE);
343 if model.is_file() && tokenizer.is_file() {
344 return Some((model, tokenizer));
345 }
346 }
347 None
348}
349
350fn download_via_hf_hub(cache_dir: &std::path::Path) -> Result<(PathBuf, PathBuf), String> {
351 use hf_hub::api::sync::ApiBuilder;
352
353 slog_info!(
354 "downloading all-MiniLM-L6-v2 ({}) to {}",
355 MINILM_REPO,
356 cache_dir.display()
357 );
358 let api = ApiBuilder::new()
359 .with_progress(false)
360 .with_cache_dir(cache_dir.to_path_buf())
361 .build()
362 .map_err(|e| format!("failed to init hf-hub api: {e}"))?;
363 let repo = api.model(MINILM_REPO.to_string());
364 let model = repo
365 .get(MINILM_MODEL_FILE)
366 .map_err(|e| format!("failed to download {MINILM_MODEL_FILE}: {e}"))?;
367 let tokenizer = repo
368 .get(MINILM_TOKENIZER_FILE)
369 .map_err(|e| format!("failed to download {MINILM_TOKENIZER_FILE}: {e}"))?;
370 Ok((model, tokenizer))
371}
372
373#[cfg(test)]
374mod tests {
375 use super::MINILM_MAX_LENGTH;
376 use std::io::Write;
377 use tokenizers::Tokenizer;
378
379 fn minilm_like_tokenizer_json() -> Vec<u8> {
380 serde_json::json!({
381 "version": "1.0",
382 "truncation": {
383 "direction": "Right",
384 "max_length": MINILM_MAX_LENGTH,
385 "strategy": "LongestFirst",
386 "stride": 0
387 },
388 "padding": null,
389 "added_tokens": [
390 {"id": 0, "content": "[PAD]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
391 {"id": 1, "content": "[CLS]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
392 {"id": 2, "content": "[SEP]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
393 {"id": 3, "content": "[UNK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}
394 ],
395 "normalizer": {
396 "type": "BertNormalizer",
397 "clean_text": true,
398 "handle_chinese_chars": true,
399 "strip_accents": null,
400 "lowercase": true
401 },
402 "pre_tokenizer": {"type": "BertPreTokenizer"},
403 "post_processor": {"type": "BertProcessing", "sep": ["[SEP]", 2], "cls": ["[CLS]", 1]},
404 "decoder": null,
405 "model": {
406 "type": "WordPiece",
407 "unk_token": "[UNK]",
408 "continuing_subword_prefix": "##",
409 "max_input_chars_per_word": 100,
410 "vocab": {
411 "[PAD]": 0,
412 "[CLS]": 1,
413 "[SEP]": 2,
414 "[UNK]": 3,
415 "hello": 4,
416 "world": 5,
417 "!": 6,
418 "cafe": 7,
419 "naive": 8,
420 "##ly": 9
421 }
422 }
423 })
424 .to_string()
425 .into_bytes()
426 }
427
428 fn assert_load_encode_parity(tokenizer: Tokenizer) {
429 let ascii = tokenizer.encode("Hello WORLD!", true).unwrap();
430 assert_eq!(ascii.get_ids(), &[1, 4, 5, 6, 2]);
431
432 let unicode = tokenizer.encode("Café naïvely", true).unwrap();
433 assert_eq!(unicode.get_ids(), &[1, 7, 8, 9, 2]);
434
435 let long_text = std::iter::repeat("hello")
436 .take(MINILM_MAX_LENGTH + 20)
437 .collect::<Vec<_>>()
438 .join(" ");
439 let long = tokenizer.encode(long_text.as_str(), true).unwrap();
440 let ids = long.get_ids();
441 assert_eq!(ids.len(), MINILM_MAX_LENGTH);
442 assert_eq!(ids.first(), Some(&1));
443 assert_eq!(ids.last(), Some(&2));
444 assert!(ids[1..MINILM_MAX_LENGTH - 1].iter().all(|id| *id == 4));
445 }
446
447 #[test]
448 fn tokenizers_slim_features_load_and_encode_minilm_wordpiece() {
449 let json = minilm_like_tokenizer_json();
450
451 assert_load_encode_parity(Tokenizer::from_bytes(&json).unwrap());
452
453 let mut file = tempfile::NamedTempFile::new().unwrap();
454 file.write_all(&json).unwrap();
455 file.flush().unwrap();
456 assert_load_encode_parity(Tokenizer::from_file(file.path()).unwrap());
457 }
458}