Skip to main content

aft/
local_embed.rs

1//! Local ONNX embedding backend (all-MiniLM-L6-v2) driven directly through
2//! `ort`.
3//!
4//! Replaces the `fastembed` crate. We own the ORT session so we can cap
5//! intra-op threads — `fastembed` hardcoded `with_intra_threads(all cores)`,
6//! which pegged every core during indexing (the sustained-CPU complaint). We
7//! cap to `num_cpus / 2`, which an earlier measurement showed is both faster
8//! (1.7x) and far lighter (3.5x less CPU) than oversubscribing all cores.
9//!
10//! The pipeline reproduces fastembed's MiniLM path byte-for-byte (verified:
11//! cosine 1.000000 vs fastembed across code + prose), so existing semantic
12//! indexes remain valid with no re-embed:
13//!   - tokenizer.json, truncation forced to max_length=512 (the Qdrant
14//!     tokenizer ships an embedded max_length=128 that fastembed overrides),
15//!     add_special_tokens=true
16//!   - ONNX inputs input_ids / attention_mask / token_type_ids (i64)
17//!     → output last_hidden_state [batch, seq, dim]
18//!   - mean pool: sum(mask · tok, over seq) / max(sum(mask), 1)
19//!   - L2 normalize: v / (||v|| + 1e-12)
20
21use std::path::PathBuf;
22
23use ort::session::builder::GraphOptimizationLevel;
24use ort::session::Session;
25use ort::value::Tensor;
26use tokenizers::Tokenizer;
27
28use crate::semantic_index::{format_embedding_init_error, pre_validate_onnx_runtime};
29use crate::slog_info;
30
31/// HuggingFace repo fastembed used for all-MiniLM-L6-v2; we reuse the same repo
32/// and on-disk cache layout so already-downloaded models are picked up offline.
33const MINILM_REPO: &str = "Qdrant/all-MiniLM-L6-v2-onnx";
34const MINILM_MODEL_FILE: &str = "model.onnx";
35const MINILM_TOKENIZER_FILE: &str = "tokenizer.json";
36/// fastembed forces truncation to min(512, model_max_length=512). Existing
37/// indexes were built at 512, so we MUST match it — the tokenizer.json itself
38/// ships max_length=128, which would silently shorten long inputs and break
39/// parity with persisted vectors.
40const MINILM_MAX_LENGTH: usize = 512;
41/// Per-inference memory budget, expressed in attention units (`batch × max_len²`).
42///
43/// The transient ONNX attention tensor scales with `batch × heads × seq_len²`,
44/// so peak RSS is governed by the *largest single inference*, not total chunk
45/// count (ORT's arena grows to the high-water mark and stays there). Measured:
46/// `64 × 512² = 16.78M units → ~4.92 GB peak` — too high for 8–16 GB machines.
47///
48/// 4.0M units caps the worst case at roughly half that (~2–2.5 GB, re-measured):
49/// at 512-token chunks it allows ~15 per inference; at ≤250 tokens (the common
50/// case for code symbols) it allows the full 64-chunk batch, so short-chunk
51/// throughput is unaffected and only long-chunk batches are split.
52const MAX_BATCH_ATTENTION_UNITS: usize = 4_000_000;
53
54/// Cap ORT intra-op threads to half the cores (min 1), leaving the rest free
55/// for the agent / editor. Matches the `num_cpus / 2` policy used elsewhere.
56fn intra_thread_cap() -> usize {
57    std::thread::available_parallelism()
58        .map(|p| p.get())
59        .unwrap_or(1)
60        .div_ceil(2)
61        .max(1)
62}
63
64pub struct LocalEmbedder {
65    session: Session,
66    tokenizer: Tokenizer,
67    wants_token_type_ids: bool,
68}
69
70impl LocalEmbedder {
71    /// Build the embedder for the named model. Only `all-MiniLM-L6-v2` is
72    /// supported as the local backend (matches the prior fastembed surface).
73    pub fn new(model: &str) -> Result<Self, String> {
74        match model {
75            "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {}
76            other => {
77                return Err(format!(
78                    "unsupported local embedding model '{other}'. Supported: all-MiniLM-L6-v2"
79                ))
80            }
81        }
82
83        // Fail with an actionable message instead of letting ort panic deep
84        // inside dlopen on an incompatible/absent ONNX Runtime.
85        pre_validate_onnx_runtime()?;
86
87        let (model_path, tokenizer_path) = resolve_model_files()?;
88
89        let threads = intra_thread_cap();
90        let session = Session::builder()
91            .map_err(|e| format!("failed to create ONNX session builder: {e}"))?
92            .with_optimization_level(GraphOptimizationLevel::Level3)
93            .map_err(|e| format!("failed to set ONNX optimization level: {e}"))?
94            .with_intra_threads(threads)
95            .map_err(|e| format!("failed to set ONNX intra-op threads: {e}"))?
96            .commit_from_file(&model_path)
97            // Route through the shared formatter so a missing/incompatible ONNX
98            // Runtime (dlopen failure) yields the actionable install hint rather
99            // than a raw ort error.
100            .map_err(format_embedding_init_error)?;
101
102        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
103            .map_err(|e| format!("failed to load tokenizer {}: {e}", tokenizer_path.display()))?;
104        // Override the tokenizer's embedded truncation (Qdrant ships 128) to 512
105        // for parity with fastembed and existing indexes.
106        tokenizer
107            .with_truncation(Some(tokenizers::TruncationParams {
108                max_length: MINILM_MAX_LENGTH,
109                ..Default::default()
110            }))
111            .map_err(|e| format!("failed to set tokenizer truncation: {e}"))?;
112
113        let wants_token_type_ids = session
114            .inputs()
115            .iter()
116            .any(|input| input.name() == "token_type_ids");
117
118        slog_info!(
119            "local embedder ready: model=all-MiniLM-L6-v2 intra_threads={} token_type_ids={}",
120            threads,
121            wants_token_type_ids
122        );
123
124        Ok(Self {
125            session,
126            tokenizer,
127            wants_token_type_ids,
128        })
129    }
130
131    /// Embed a batch of texts → one L2-normalized 384-dim vector each.
132    ///
133    /// Internally sub-batches by a token budget so a single ONNX inference can
134    /// never balloon peak RSS: the transient attention tensor scales with
135    /// `batch × heads × seq_len²`, so a batch that happens to contain many
136    /// long (512-token) chunks would otherwise spike memory (~5 GB worst case
137    /// at batch=64 × 512 tokens). We cap `batch × max_len²` per inference,
138    /// which keeps short-chunk batches at full size (no throughput loss) while
139    /// splitting long-chunk batches into smaller inferences. Output order and
140    /// vectors are identical to embedding the whole input in one call.
141    pub fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>, String> {
142        if texts.is_empty() {
143            return Ok(Vec::new());
144        }
145
146        let encodings = self
147            .tokenizer
148            .encode_batch(texts.to_vec(), true)
149            .map_err(|e| format!("tokenize batch: {e}"))?;
150
151        // Greedily partition (order-preserving) into sub-batches bounded by the
152        // attention-unit budget. `cost = (count) × max_len²`; flush before
153        // adding a row that would exceed the budget.
154        let mut result = Vec::with_capacity(encodings.len());
155        let mut batch_start = 0usize;
156        let mut batch_max = 0usize;
157        for (i, enc) in encodings.iter().enumerate() {
158            let len = enc.get_ids().len().max(1);
159            let count = i - batch_start; // size BEFORE adding row i
160            let candidate_max = batch_max.max(len);
161            let cost = (count + 1)
162                .saturating_mul(candidate_max)
163                .saturating_mul(candidate_max);
164            if count > 0 && cost > MAX_BATCH_ATTENTION_UNITS {
165                let vecs = self.run_inference(&encodings[batch_start..i])?;
166                result.extend(vecs);
167                batch_start = i;
168                batch_max = len;
169            } else {
170                batch_max = candidate_max;
171            }
172        }
173        // Flush the final sub-batch (encodings is non-empty here).
174        let vecs = self.run_inference(&encodings[batch_start..])?;
175        result.extend(vecs);
176        Ok(result)
177    }
178
179    /// Run one ONNX inference over a single sub-batch of pre-tokenized
180    /// encodings: pad to the sub-batch longest, run the model, mean-pool over
181    /// the attention mask, L2-normalize. Memory here is bounded by the caller
182    /// (`embed`) via the attention-unit budget.
183    fn run_inference(
184        &mut self,
185        encodings: &[tokenizers::Encoding],
186    ) -> Result<Vec<Vec<f32>>, String> {
187        if encodings.is_empty() {
188            return Ok(Vec::new());
189        }
190
191        let batch = encodings.len();
192        let max_len = encodings
193            .iter()
194            .map(|e| e.get_ids().len())
195            .max()
196            .unwrap_or(1)
197            .max(1);
198
199        // Pad to the batch-longest. The attention mask zeroes padding inside the
200        // model's attention and the mean-pool below ignores it, so a padded
201        // batch yields identical vectors to embedding each text alone.
202        let mut ids = vec![0i64; batch * max_len];
203        let mut mask = vec![0i64; batch * max_len];
204        for (row, enc) in encodings.iter().enumerate() {
205            let row_ids = enc.get_ids();
206            let row_mask = enc.get_attention_mask();
207            let base = row * max_len;
208            for col in 0..row_ids.len() {
209                ids[base + col] = row_ids[col] as i64;
210                mask[base + col] = row_mask[col] as i64;
211            }
212        }
213
214        let input_ids = ndarray::Array2::<i64>::from_shape_vec((batch, max_len), ids)
215            .map_err(|e| format!("build input_ids tensor: {e}"))?;
216        let attention_mask = ndarray::Array2::<i64>::from_shape_vec((batch, max_len), mask)
217            .map_err(|e| format!("build attention_mask tensor: {e}"))?;
218
219        let mut inputs = ort::inputs![
220            "input_ids" => Tensor::from_array(input_ids).map_err(|e| format!("input_ids: {e}"))?,
221            "attention_mask" => Tensor::from_array(attention_mask.clone())
222                .map_err(|e| format!("attention_mask: {e}"))?,
223        ];
224        if self.wants_token_type_ids {
225            let token_type_ids = ndarray::Array2::<i64>::zeros((batch, max_len));
226            inputs.push((
227                "token_type_ids".into(),
228                Tensor::from_array(token_type_ids)
229                    .map_err(|e| format!("token_type_ids: {e}"))?
230                    .into(),
231            ));
232        }
233
234        let outputs = self
235            .session
236            .run(inputs)
237            .map_err(|e| format!("ONNX inference failed: {e}"))?;
238        let output = outputs
239            .values()
240            .next()
241            .ok_or_else(|| "ONNX model produced no output".to_string())?;
242
243        // last_hidden_state may be f32 (standard) or f16 (uniform-fp16 exports).
244        let (shape, data): (Vec<i64>, Vec<f32>) = match output.try_extract_tensor::<f32>() {
245            Ok((s, d)) => (s.to_vec(), d.to_vec()),
246            Err(_) => {
247                let (s, d) = output
248                    .try_extract_tensor::<half::f16>()
249                    .map_err(|e| format!("extract output tensor: {e}"))?;
250                (s.to_vec(), d.iter().map(|h| h.to_f32()).collect())
251            }
252        };
253        if shape.len() != 3 {
254            return Err(format!(
255                "unexpected ONNX output rank {} (expected 3: [batch, seq, dim])",
256                shape.len()
257            ));
258        }
259        let seq = shape[1] as usize;
260        let dim = shape[2] as usize;
261
262        let mut result = Vec::with_capacity(batch);
263        for row in 0..batch {
264            let mut emb = vec![0.0f32; dim];
265            let mut valid = 0.0f32;
266            for col in 0..seq {
267                if mask_at(&attention_mask, row, col) == 1 {
268                    valid += 1.0;
269                    let base = (row * seq + col) * dim;
270                    for (d, slot) in emb.iter_mut().enumerate() {
271                        *slot += data[base + d];
272                    }
273                }
274            }
275            let denom = if valid == 0.0 { 1.0 } else { valid };
276            for slot in &mut emb {
277                *slot /= denom;
278            }
279            let norm = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
280            for slot in &mut emb {
281                *slot /= norm + 1e-12;
282            }
283            result.push(emb);
284        }
285        Ok(result)
286    }
287}
288
289#[inline]
290fn mask_at(mask: &ndarray::Array2<i64>, row: usize, col: usize) -> i64 {
291    mask[[row, col]]
292}
293
294/// Resolve the MiniLM model.onnx + tokenizer.json, reusing an existing local
295/// download when present (offline-safe) and falling back to an hf-hub fetch.
296fn resolve_model_files() -> Result<(PathBuf, PathBuf), String> {
297    let cache_dir = embedding_cache_dir();
298
299    if let Some(found) = scan_local_snapshot(&cache_dir) {
300        return Ok(found);
301    }
302
303    // Not cached locally — download via hf-hub into the same cache layout so a
304    // subsequent run finds it through the local scan above.
305    download_via_hf_hub(&cache_dir)
306}
307
308/// fastembed read `FASTEMBED_CACHE_DIR`; the bridge/warmup set it to
309/// `<storage>/semantic/models`. Keep the same env + default so existing
310/// downloads are reused.
311fn embedding_cache_dir() -> PathBuf {
312    if let Some(dir) = std::env::var_os("FASTEMBED_CACHE_DIR") {
313        return PathBuf::from(dir);
314    }
315    let home = std::env::var_os("HOME")
316        .or_else(|| std::env::var_os("USERPROFILE"))
317        .map(PathBuf::from)
318        .unwrap_or_else(std::env::temp_dir);
319    home.join(".cache").join("fastembed")
320}
321
322/// hf-hub stores repos at `<cache>/models--<org>--<repo>/snapshots/<rev>/`.
323/// Find the newest snapshot that has both required files.
324fn scan_local_snapshot(cache_dir: &std::path::Path) -> Option<(PathBuf, PathBuf)> {
325    let repo_dir = cache_dir.join("models--Qdrant--all-MiniLM-L6-v2-onnx");
326    let snapshots = repo_dir.join("snapshots");
327    let mut candidates: Vec<PathBuf> = std::fs::read_dir(&snapshots)
328        .ok()?
329        .filter_map(|entry| entry.ok().map(|e| e.path()))
330        .filter(|p| p.is_dir())
331        .collect();
332    // Newest snapshot first (by modified time) so a refreshed revision wins.
333    candidates.sort_by_key(|p| {
334        std::fs::metadata(p)
335            .and_then(|m| m.modified())
336            .unwrap_or(std::time::UNIX_EPOCH)
337    });
338    candidates.reverse();
339    for snap in candidates {
340        let model = snap.join(MINILM_MODEL_FILE);
341        let tokenizer = snap.join(MINILM_TOKENIZER_FILE);
342        if model.is_file() && tokenizer.is_file() {
343            return Some((model, tokenizer));
344        }
345    }
346    None
347}
348
349fn download_via_hf_hub(cache_dir: &std::path::Path) -> Result<(PathBuf, PathBuf), String> {
350    use hf_hub::api::sync::ApiBuilder;
351
352    slog_info!(
353        "downloading all-MiniLM-L6-v2 ({}) to {}",
354        MINILM_REPO,
355        cache_dir.display()
356    );
357    let api = ApiBuilder::new()
358        .with_progress(false)
359        .with_cache_dir(cache_dir.to_path_buf())
360        .build()
361        .map_err(|e| format!("failed to init hf-hub api: {e}"))?;
362    let repo = api.model(MINILM_REPO.to_string());
363    let model = repo
364        .get(MINILM_MODEL_FILE)
365        .map_err(|e| format!("failed to download {MINILM_MODEL_FILE}: {e}"))?;
366    let tokenizer = repo
367        .get(MINILM_TOKENIZER_FILE)
368        .map_err(|e| format!("failed to download {MINILM_TOKENIZER_FILE}: {e}"))?;
369    Ok((model, tokenizer))
370}