1use std::path::PathBuf;
22
23use ort::session::builder::GraphOptimizationLevel;
24use ort::session::Session;
25use ort::value::Tensor;
26use tokenizers::Tokenizer;
27
28use crate::semantic_index::{format_embedding_init_error, pre_validate_onnx_runtime};
29use crate::slog_info;
30
31const MINILM_REPO: &str = "Qdrant/all-MiniLM-L6-v2-onnx";
34const MINILM_MODEL_FILE: &str = "model.onnx";
35const MINILM_TOKENIZER_FILE: &str = "tokenizer.json";
36const MINILM_MAX_LENGTH: usize = 512;
41const MAX_BATCH_ATTENTION_UNITS: usize = 4_000_000;
53
54fn intra_thread_cap() -> usize {
57 std::thread::available_parallelism()
58 .map(|p| p.get())
59 .unwrap_or(1)
60 .div_ceil(2)
61 .max(1)
62}
63
64pub struct LocalEmbedder {
65 session: Session,
66 tokenizer: Tokenizer,
67 wants_token_type_ids: bool,
68}
69
70impl LocalEmbedder {
71 pub fn new(model: &str) -> Result<Self, String> {
74 match model {
75 "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {}
76 other => {
77 return Err(format!(
78 "unsupported local embedding model '{other}'. Supported: all-MiniLM-L6-v2"
79 ))
80 }
81 }
82
83 pre_validate_onnx_runtime()?;
86
87 let (model_path, tokenizer_path) = resolve_model_files()?;
88
89 let threads = intra_thread_cap();
90 let session = Session::builder()
91 .map_err(|e| format!("failed to create ONNX session builder: {e}"))?
92 .with_optimization_level(GraphOptimizationLevel::Level3)
93 .map_err(|e| format!("failed to set ONNX optimization level: {e}"))?
94 .with_intra_threads(threads)
95 .map_err(|e| format!("failed to set ONNX intra-op threads: {e}"))?
96 .commit_from_file(&model_path)
97 .map_err(format_embedding_init_error)?;
101
102 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
103 .map_err(|e| format!("failed to load tokenizer {}: {e}", tokenizer_path.display()))?;
104 tokenizer
107 .with_truncation(Some(tokenizers::TruncationParams {
108 max_length: MINILM_MAX_LENGTH,
109 ..Default::default()
110 }))
111 .map_err(|e| format!("failed to set tokenizer truncation: {e}"))?;
112
113 let wants_token_type_ids = session
114 .inputs()
115 .iter()
116 .any(|input| input.name() == "token_type_ids");
117
118 slog_info!(
119 "local embedder ready: model=all-MiniLM-L6-v2 intra_threads={} token_type_ids={}",
120 threads,
121 wants_token_type_ids
122 );
123
124 Ok(Self {
125 session,
126 tokenizer,
127 wants_token_type_ids,
128 })
129 }
130
131 pub fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>, String> {
142 if texts.is_empty() {
143 return Ok(Vec::new());
144 }
145
146 let encodings = self
147 .tokenizer
148 .encode_batch(texts.to_vec(), true)
149 .map_err(|e| format!("tokenize batch: {e}"))?;
150
151 let mut result = Vec::with_capacity(encodings.len());
155 let mut batch_start = 0usize;
156 let mut batch_max = 0usize;
157 for (i, enc) in encodings.iter().enumerate() {
158 let len = enc.get_ids().len().max(1);
159 let count = i - batch_start; let candidate_max = batch_max.max(len);
161 let cost = (count + 1)
162 .saturating_mul(candidate_max)
163 .saturating_mul(candidate_max);
164 if count > 0 && cost > MAX_BATCH_ATTENTION_UNITS {
165 let vecs = self.run_inference(&encodings[batch_start..i])?;
166 result.extend(vecs);
167 batch_start = i;
168 batch_max = len;
169 } else {
170 batch_max = candidate_max;
171 }
172 }
173 let vecs = self.run_inference(&encodings[batch_start..])?;
175 result.extend(vecs);
176 Ok(result)
177 }
178
179 fn run_inference(
184 &mut self,
185 encodings: &[tokenizers::Encoding],
186 ) -> Result<Vec<Vec<f32>>, String> {
187 if encodings.is_empty() {
188 return Ok(Vec::new());
189 }
190
191 let batch = encodings.len();
192 let max_len = encodings
193 .iter()
194 .map(|e| e.get_ids().len())
195 .max()
196 .unwrap_or(1)
197 .max(1);
198
199 let mut ids = vec![0i64; batch * max_len];
203 let mut mask = vec![0i64; batch * max_len];
204 for (row, enc) in encodings.iter().enumerate() {
205 let row_ids = enc.get_ids();
206 let row_mask = enc.get_attention_mask();
207 let base = row * max_len;
208 for col in 0..row_ids.len() {
209 ids[base + col] = row_ids[col] as i64;
210 mask[base + col] = row_mask[col] as i64;
211 }
212 }
213
214 let input_ids = ndarray::Array2::<i64>::from_shape_vec((batch, max_len), ids)
215 .map_err(|e| format!("build input_ids tensor: {e}"))?;
216 let attention_mask = ndarray::Array2::<i64>::from_shape_vec((batch, max_len), mask)
217 .map_err(|e| format!("build attention_mask tensor: {e}"))?;
218
219 let mut inputs = ort::inputs![
220 "input_ids" => Tensor::from_array(input_ids).map_err(|e| format!("input_ids: {e}"))?,
221 "attention_mask" => Tensor::from_array(attention_mask.clone())
222 .map_err(|e| format!("attention_mask: {e}"))?,
223 ];
224 if self.wants_token_type_ids {
225 let token_type_ids = ndarray::Array2::<i64>::zeros((batch, max_len));
226 inputs.push((
227 "token_type_ids".into(),
228 Tensor::from_array(token_type_ids)
229 .map_err(|e| format!("token_type_ids: {e}"))?
230 .into(),
231 ));
232 }
233
234 let outputs = self
235 .session
236 .run(inputs)
237 .map_err(|e| format!("ONNX inference failed: {e}"))?;
238 let output = outputs
239 .values()
240 .next()
241 .ok_or_else(|| "ONNX model produced no output".to_string())?;
242
243 let (shape, data): (Vec<i64>, Vec<f32>) = match output.try_extract_tensor::<f32>() {
245 Ok((s, d)) => (s.to_vec(), d.to_vec()),
246 Err(_) => {
247 let (s, d) = output
248 .try_extract_tensor::<half::f16>()
249 .map_err(|e| format!("extract output tensor: {e}"))?;
250 (s.to_vec(), d.iter().map(|h| h.to_f32()).collect())
251 }
252 };
253 if shape.len() != 3 {
254 return Err(format!(
255 "unexpected ONNX output rank {} (expected 3: [batch, seq, dim])",
256 shape.len()
257 ));
258 }
259 let seq = shape[1] as usize;
260 let dim = shape[2] as usize;
261
262 let mut result = Vec::with_capacity(batch);
263 for row in 0..batch {
264 let mut emb = vec![0.0f32; dim];
265 let mut valid = 0.0f32;
266 for col in 0..seq {
267 if mask_at(&attention_mask, row, col) == 1 {
268 valid += 1.0;
269 let base = (row * seq + col) * dim;
270 for (d, slot) in emb.iter_mut().enumerate() {
271 *slot += data[base + d];
272 }
273 }
274 }
275 let denom = if valid == 0.0 { 1.0 } else { valid };
276 for slot in &mut emb {
277 *slot /= denom;
278 }
279 let norm = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
280 for slot in &mut emb {
281 *slot /= norm + 1e-12;
282 }
283 result.push(emb);
284 }
285 Ok(result)
286 }
287}
288
289#[inline]
290fn mask_at(mask: &ndarray::Array2<i64>, row: usize, col: usize) -> i64 {
291 mask[[row, col]]
292}
293
294fn resolve_model_files() -> Result<(PathBuf, PathBuf), String> {
297 let cache_dir = embedding_cache_dir();
298
299 if let Some(found) = scan_local_snapshot(&cache_dir) {
300 return Ok(found);
301 }
302
303 download_via_hf_hub(&cache_dir)
306}
307
308fn embedding_cache_dir() -> PathBuf {
312 if let Some(dir) = std::env::var_os("FASTEMBED_CACHE_DIR") {
313 return PathBuf::from(dir);
314 }
315 let home = std::env::var_os("HOME")
316 .or_else(|| std::env::var_os("USERPROFILE"))
317 .map(PathBuf::from)
318 .unwrap_or_else(std::env::temp_dir);
319 home.join(".cache").join("fastembed")
320}
321
322fn scan_local_snapshot(cache_dir: &std::path::Path) -> Option<(PathBuf, PathBuf)> {
325 let repo_dir = cache_dir.join("models--Qdrant--all-MiniLM-L6-v2-onnx");
326 let snapshots = repo_dir.join("snapshots");
327 let mut candidates: Vec<PathBuf> = std::fs::read_dir(&snapshots)
328 .ok()?
329 .filter_map(|entry| entry.ok().map(|e| e.path()))
330 .filter(|p| p.is_dir())
331 .collect();
332 candidates.sort_by_key(|p| {
334 std::fs::metadata(p)
335 .and_then(|m| m.modified())
336 .unwrap_or(std::time::UNIX_EPOCH)
337 });
338 candidates.reverse();
339 for snap in candidates {
340 let model = snap.join(MINILM_MODEL_FILE);
341 let tokenizer = snap.join(MINILM_TOKENIZER_FILE);
342 if model.is_file() && tokenizer.is_file() {
343 return Some((model, tokenizer));
344 }
345 }
346 None
347}
348
349fn download_via_hf_hub(cache_dir: &std::path::Path) -> Result<(PathBuf, PathBuf), String> {
350 use hf_hub::api::sync::ApiBuilder;
351
352 slog_info!(
353 "downloading all-MiniLM-L6-v2 ({}) to {}",
354 MINILM_REPO,
355 cache_dir.display()
356 );
357 let api = ApiBuilder::new()
358 .with_progress(false)
359 .with_cache_dir(cache_dir.to_path_buf())
360 .build()
361 .map_err(|e| format!("failed to init hf-hub api: {e}"))?;
362 let repo = api.model(MINILM_REPO.to_string());
363 let model = repo
364 .get(MINILM_MODEL_FILE)
365 .map_err(|e| format!("failed to download {MINILM_MODEL_FILE}: {e}"))?;
366 let tokenizer = repo
367 .get(MINILM_TOKENIZER_FILE)
368 .map_err(|e| format!("failed to download {MINILM_TOKENIZER_FILE}: {e}"))?;
369 Ok((model, tokenizer))
370}