Skip to main content

chunkshop/
embedder.rs

1//! Fastembed-backed embedder.
2//!
3//! Two paths:
4//!
5//! 1. **Stock-variant path** — for models where fastembed-rs's built-in
6//!    registry already matches what we want (BGE non-quantized, MiniLM, etc.).
7//!    Resolves through `resolve_model_name` and uses `TextEmbedding::try_new`.
8//!    Requires `embedder-hub` (hf-hub auto-download).
9//!
10//! 2. **User-defined path (bit-exact)** — for `Xenova/bge-base-en-v1.5-int8`
11//!    and `Xenova/bge-small-en-v1.5-int8`, where the goal is byte-identical
12//!    output vs Python. Hand-rolls the ORT session because fastembed-rs's
13//!    `try_new_from_user_defined` hardcodes `with_intra_threads(available_parallelism())`,
14//!    which makes the reduction order CPU-count-dependent and breaks bit-
15//!    exactness across machines. We pin `with_intra_threads(1)` for these two
16//!    int8 models and replicate fastembed's tokenize → infer → CLS-pool →
17//!    L2-normalize pipeline. Requires `embedder-hub` for the HF download.
18//!
19//! 3. **Bytes-in path** — `from_user_defined_files`. Caller hands in the ONNX
20//!    and tokenizer bytes already loaded from disk / Postgres / their own
21//!    storage. No `hf-hub` involvement, available under `embedder-core`.
22//!    Used by embedded library consumers who manage model artifacts
23//!    out-of-band.
24
25use anyhow::{anyhow, Context, Result};
26use fastembed::TextEmbedding;
27use ndarray::{s, Array2};
28use ort::session::{builder::GraphOptimizationLevel, Session};
29use ort::value::Value;
30use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
31use tracing::info;
32
33#[cfg(feature = "embedder-hub")]
34use fastembed::{EmbeddingModel, InitOptions};
35#[cfg(feature = "embedder-hub")]
36use std::collections::HashMap;
37
38use crate::config::FastembedEmbedderConfig;
39#[cfg(feature = "embedder-hub")]
40use crate::hf_cache::{fetch_user_defined_files, HfModelFiles};
41
42pub struct FastembedEmbedder {
43    cfg: FastembedEmbedderConfig,
44    backend: Backend,
45    /// Cumulative wall time spent inside `embed()`. Mirrors Python's
46    /// `FastembedProvider.embed_seconds`. Used by the bakeoff to break out
47    /// the embedder's portion of total cell wall time.
48    embed_seconds: f64,
49}
50
51enum Backend {
52    /// fastembed's stock `TextEmbedding` (registry variant). Constructed
53    /// only by `FastembedEmbedder::new`, which is gated under `embedder-hub`
54    /// because `TextEmbedding::try_new` itself requires fastembed's hf-hub
55    /// feature. Allow dead_code under `embedder-core`-only builds.
56    #[cfg_attr(not(feature = "embedder-hub"), allow(dead_code))]
57    Stock(TextEmbedding),
58    UserDefined(UserDefinedRunner),
59}
60
61/// Pooling strategy for the user-defined ONNX path. Stock fastembed-rs
62/// variants pool internally; this enum only governs the hand-rolled forward.
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum Pooling {
65    /// Take the [CLS] token's hidden state. Used by BGE family + most BERT-derived retrieval models.
66    Cls,
67    /// Mean of token-level hidden states, masked to non-padding tokens.
68    /// Used by sentence-transformers / e5 / etc.
69    Mean,
70}
71
72fn parse_pooling(s: &str) -> Result<Pooling> {
73    match s {
74        "cls" => Ok(Pooling::Cls),
75        "mean" => Ok(Pooling::Mean),
76        other => Err(anyhow!(
77            "embedder.pooling must be 'cls' or 'mean', got {other:?}"
78        )),
79    }
80}
81
82struct UserDefinedRunner {
83    session: Session,
84    tokenizer: Tokenizer,
85    need_token_type_ids: bool,
86    pooling: Pooling,
87}
88
89/// Returns `Some((repo, onnx_path))` when `model_name` is a Xenova int8 variant
90/// we have a bit-exact path for. Otherwise `None`.
91#[cfg(feature = "embedder-hub")]
92fn user_defined_source(model_name: &str) -> Option<(&'static str, &'static str)> {
93    match model_name {
94        "Xenova/bge-base-en-v1.5-int8" => {
95            Some(("Xenova/bge-base-en-v1.5", "onnx/model_quantized.onnx"))
96        }
97        "Xenova/bge-small-en-v1.5-int8" => {
98            Some(("Xenova/bge-small-en-v1.5", "onnx/model_quantized.onnx"))
99        }
100        _ => None,
101    }
102}
103
104impl FastembedEmbedder {
105    /// Constructor that fetches model files from HuggingFace at runtime.
106    /// Requires the `embedder-hub` Cargo feature (pulls `hf-hub` + native-tls).
107    /// Embedded library consumers that load model bytes from their own
108    /// storage should use [`FastembedEmbedder::from_user_defined_files`]
109    /// instead — it works under `embedder-core` alone.
110    #[cfg(feature = "embedder-hub")]
111    pub fn new(cfg: FastembedEmbedderConfig) -> Result<Self> {
112        // Priority order for embedder dispatch:
113        //   1. BYO mode (cfg.hf_repo set): user-defined ONNX path, runtime
114        //      pooling per cfg.pooling. No registry lookups.
115        //   2. Hardcoded user_defined_source (Xenova int8 BGE): bit-near-exact
116        //      hand-rolled CLS-pooled path.
117        //   3. fastembed-rs stock variants (resolve_model_name).
118        if cfg.is_byo() {
119            // Already validated by FastembedEmbedderConfig::validate() at config-load.
120            let repo = cfg.hf_repo.as_deref().expect("BYO repo present");
121            let onnx_path = cfg.onnx_path.as_deref().expect("BYO onnx_path present");
122            let pooling = parse_pooling(&cfg.pooling)?;
123            // Honor cfg.threads for BYO. Default 1 if unset — conservative
124            // for shared boxes, but users can opt into multi-thread.
125            let intra = cfg.threads.unwrap_or(1);
126            let runner = build_user_defined_runner(repo, onnx_path, pooling, intra)?;
127            info!(
128                "embedder loaded (BYO, YAML-driven): {} (dim={}, repo={}, file={}, pooling={:?})",
129                cfg.model_name, cfg.dim, repo, onnx_path, pooling
130            );
131            return Ok(Self {
132                cfg,
133                backend: Backend::UserDefined(runner),
134                embed_seconds: 0.0,
135            });
136        }
137
138        if let Some((repo, onnx_path)) = user_defined_source(&cfg.model_name) {
139            // Hardcoded Xenova int8 path stays at intra_threads=1 by default
140            // for bit-near-exact parity vs Python (parity tests depend on
141            // this). YAML can override via `threads:` if user prioritizes
142            // speed and accepts the parity drift.
143            let intra = cfg.threads.unwrap_or(1);
144            let runner = build_user_defined_runner(repo, onnx_path, Pooling::Cls, intra)?;
145            info!(
146                "embedder loaded (user-defined, bit-exact): {} (dim={}, repo={}, file={})",
147                cfg.model_name, cfg.dim, repo, onnx_path
148            );
149            return Ok(Self {
150                cfg,
151                backend: Backend::UserDefined(runner),
152                embed_seconds: 0.0,
153            });
154        }
155
156        let variant = resolve_model_name(&cfg.model_name)?;
157        let opts = InitOptions::new(variant).with_show_download_progress(true);
158        let model = TextEmbedding::try_new(opts)
159            .with_context(|| format!("initialising fastembed model {:?}", cfg.model_name))?;
160        info!(
161            "embedder loaded (stock variant): {} (dim={})",
162            cfg.model_name, cfg.dim
163        );
164        Ok(Self {
165            cfg,
166            backend: Backend::Stock(model),
167            embed_seconds: 0.0,
168        })
169    }
170
171    /// Bytes-in constructor: caller supplies the model files directly. No
172    /// HuggingFace fetch, no `hf-hub` dep. Available under `embedder-core`.
173    ///
174    /// `onnx` is the raw ONNX model bytes. `tokenizer` is the raw
175    /// `tokenizer.json` bytes. `tokenizer_config` and `model_config` carry
176    /// the JSON bytes from `tokenizer_config.json` and `config.json` (used to
177    /// pin pad-token / max-length tokenizer settings — same shape the
178    /// HF-fetch path uses).
179    ///
180    /// Pooling is read from `cfg.pooling` (`"cls"` or `"mean"`). `cfg.threads`
181    /// controls ORT intra-threads (default 1). `cfg.is_byo()` is not required
182    /// — this constructor never consults `cfg.hf_repo` / `cfg.onnx_path`.
183    ///
184    /// Used by embedded library consumers (e.g. AIDB pgrx extension) that
185    /// load model bytes from extension-managed storage rather than fetching
186    /// at runtime.
187    pub fn from_user_defined_files(
188        cfg: FastembedEmbedderConfig,
189        onnx: Vec<u8>,
190        tokenizer: Vec<u8>,
191        tokenizer_config: Vec<u8>,
192        model_config: Vec<u8>,
193    ) -> Result<Self> {
194        let pooling = parse_pooling(&cfg.pooling)?;
195        let intra = cfg.threads.unwrap_or(1);
196        let runner = build_user_defined_runner_from_bytes(
197            onnx,
198            tokenizer,
199            tokenizer_config,
200            model_config,
201            pooling,
202            intra,
203        )?;
204        info!(
205            "embedder loaded (bytes-in, no hf-hub): {} (dim={}, pooling={:?})",
206            cfg.model_name, cfg.dim, pooling
207        );
208        Ok(Self {
209            cfg,
210            backend: Backend::UserDefined(runner),
211            embed_seconds: 0.0,
212        })
213    }
214
215    /// Cumulative wall time spent in `embed()` calls so far.
216    pub fn embed_seconds(&self) -> f64 {
217        self.embed_seconds
218    }
219
220    pub fn dim(&self) -> usize {
221        self.cfg.dim
222    }
223
224    /// Embed a batch of texts. Returns a flat `Vec<Vec<f32>>` ordered to match
225    /// the input. Verifies the output dim matches the config `dim`. Tracks
226    /// cumulative wall time in `self.embed_seconds`.
227    pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
228        if texts.is_empty() {
229            return Ok(Vec::new());
230        }
231        let t0 = std::time::Instant::now();
232        let vecs = match &mut self.backend {
233            Backend::Stock(model) => {
234                let refs: Vec<&str> = texts.iter().map(String::as_str).collect();
235                model
236                    .embed(refs, Some(self.cfg.batch_size))
237                    .context("fastembed embed call failed")?
238            }
239            Backend::UserDefined(runner) => {
240                let mut out: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
241                for chunk in texts.chunks(self.cfg.batch_size.max(1)) {
242                    let refs: Vec<&str> = chunk.iter().map(String::as_str).collect();
243                    let batch = runner.embed_batch(&refs)?;
244                    out.extend(batch);
245                }
246                out
247            }
248        };
249        self.embed_seconds += t0.elapsed().as_secs_f64();
250        if let Some(first) = vecs.first() {
251            if first.len() != self.cfg.dim {
252                return Err(anyhow!(
253                    "model {} produced dim {}, config says dim={}",
254                    self.cfg.model_name,
255                    first.len(),
256                    self.cfg.dim
257                ));
258            }
259        }
260        Ok(vecs)
261    }
262}
263
264/// Default in-tree implementation of [`crate::chunker::BoundaryEmbedder`]
265/// for [`FastembedEmbedder`]. Routes through [`FastembedEmbedder::embed`]
266/// (which already takes `Vec<String>` and returns `Vec<Vec<f32>>`) by
267/// allocating owned Strings from the input slice. Owns one allocation per
268/// call; negligible vs. the ORT inference cost.
269///
270/// Present when `chunkers` is enabled (this module is itself `embedder-core`-
271/// gated, so `FastembedEmbedder` is always available here). The trait lives
272/// in `crate::chunker` and is always available.
273#[cfg(feature = "chunkers")]
274impl crate::chunker::BoundaryEmbedder for FastembedEmbedder {
275    fn embed_batch(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
276        let owned: Vec<String> = texts.iter().map(|s| (*s).to_string()).collect();
277        self.embed(owned)
278    }
279}
280
281/// HF-fetch path: download files via `hf-hub` then delegate to the bytes-only
282/// builder. Only available under `embedder-hub`.
283#[cfg(feature = "embedder-hub")]
284fn build_user_defined_runner(
285    repo: &str,
286    onnx_path: &str,
287    pooling: Pooling,
288    intra_threads: usize,
289) -> Result<UserDefinedRunner> {
290    let HfModelFiles {
291        onnx,
292        tokenizer,
293        tokenizer_config,
294        special_tokens_map: _,
295        config,
296    } = fetch_user_defined_files(repo, onnx_path)
297        .with_context(|| format!("fetching user-defined files for {repo}"))?;
298
299    build_user_defined_runner_from_bytes(
300        onnx,
301        tokenizer,
302        tokenizer_config,
303        config,
304        pooling,
305        intra_threads,
306    )
307    .with_context(|| format!("building user-defined runner for {repo}"))
308}
309
310/// Bytes-in builder: shared by `build_user_defined_runner` (HF path) and
311/// [`FastembedEmbedder::from_user_defined_files`] (bytes-in API). No HF
312/// dependency — available under `embedder-core`.
313fn build_user_defined_runner_from_bytes(
314    onnx: Vec<u8>,
315    tokenizer: Vec<u8>,
316    tokenizer_config: Vec<u8>,
317    config: Vec<u8>,
318    pooling: Pooling,
319    intra_threads: usize,
320) -> Result<UserDefinedRunner> {
321    // intra_threads = 1 is the bit-exactness setting. Caller passes 1 by
322    // default for the Xenova int8 BGE bit-near-exact path (parity tests
323    // depend on it). For BYO mode with `threads: 4` in YAML, the caller
324    // passes 4 — bit-exactness isn't promised for arbitrary BYO models
325    // anyway, and multi-threaded inference is 2-4× faster on big batches.
326    // ORT optimization level Level3 stays the same regardless.
327    let session = Session::builder()
328        .map_err(|e| anyhow!("ort session builder: {e}"))?
329        .with_optimization_level(GraphOptimizationLevel::Level3)
330        .map_err(|e| anyhow!("ort with_optimization_level: {e}"))?
331        .with_intra_threads(intra_threads)
332        .map_err(|e| anyhow!("ort with_intra_threads({intra_threads}): {e}"))?
333        .commit_from_memory(&onnx)
334        .map_err(|e| anyhow!("commit ONNX from memory: {e}"))?;
335
336    let need_token_type_ids = session
337        .inputs()
338        .iter()
339        .any(|i| i.name() == "token_type_ids");
340
341    let mut tokenizer =
342        Tokenizer::from_bytes(&tokenizer).map_err(|e| anyhow!("tokenizer load failed: {e}"))?;
343
344    // Mirror fastembed-py's tokenizer configuration: read pad token / id from
345    // config.json + tokenizer_config.json, set BatchLongest padding + 512
346    // truncation. Without this, our tokenizer pads per its bundled defaults
347    // which can differ from Python's resulting attention_mask shape.
348    let cfg_json: serde_json::Value =
349        serde_json::from_slice(&config).map_err(|e| anyhow!("parse config.json: {e}"))?;
350    let tcfg_json: serde_json::Value = serde_json::from_slice(&tokenizer_config)
351        .map_err(|e| anyhow!("parse tokenizer_config.json: {e}"))?;
352    let pad_id = cfg_json
353        .get("pad_token_id")
354        .and_then(|v| v.as_u64())
355        .unwrap_or(0) as u32;
356    let pad_token = tcfg_json
357        .get("pad_token")
358        .and_then(|v| v.as_str())
359        .unwrap_or("[PAD]")
360        .to_string();
361    let model_max_length = tcfg_json
362        .get("model_max_length")
363        .and_then(|v| v.as_f64())
364        .unwrap_or(512.0)
365        .min(512.0) as usize;
366
367    tokenizer
368        .with_padding(Some(PaddingParams {
369            strategy: PaddingStrategy::BatchLongest,
370            pad_token,
371            pad_id,
372            ..Default::default()
373        }))
374        .with_truncation(Some(TruncationParams {
375            max_length: model_max_length,
376            ..Default::default()
377        }))
378        .map_err(|e| anyhow!("configure tokenizer padding/truncation: {e}"))?;
379
380    Ok(UserDefinedRunner {
381        session,
382        tokenizer,
383        need_token_type_ids,
384        pooling,
385    })
386}
387
388impl UserDefinedRunner {
389    fn embed_batch(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
390        let encodings = self
391            .tokenizer
392            .encode_batch(texts.to_vec(), true)
393            .map_err(|e| anyhow!("tokenize batch: {e}"))?;
394
395        let batch_size = encodings.len();
396        let seq_len = encodings
397            .first()
398            .ok_or_else(|| anyhow!("empty encodings"))?
399            .len();
400
401        let mut ids = Vec::with_capacity(batch_size * seq_len);
402        let mut mask = Vec::with_capacity(batch_size * seq_len);
403        let mut type_ids = Vec::with_capacity(batch_size * seq_len);
404        for enc in &encodings {
405            ids.extend(enc.get_ids().iter().map(|x| *x as i64));
406            mask.extend(enc.get_attention_mask().iter().map(|x| *x as i64));
407            type_ids.extend(enc.get_type_ids().iter().map(|x| *x as i64));
408        }
409
410        let ids_arr: Array2<i64> =
411            Array2::from_shape_vec((batch_size, seq_len), ids).context("ids array shape")?;
412        let mask_arr: Array2<i64> =
413            Array2::from_shape_vec((batch_size, seq_len), mask).context("mask array shape")?;
414        let type_ids_arr: Array2<i64> = Array2::from_shape_vec((batch_size, seq_len), type_ids)
415            .context("type_ids array shape")?;
416
417        // Clone mask for ORT input — we need to keep a copy for mean_pool below.
418        // Keeping the clone close to the move site so the diff explains itself.
419        let mask_for_ort = mask_arr.clone();
420        let mut session_inputs = ort::inputs![
421            "input_ids" => Value::from_array(ids_arr)?,
422            "attention_mask" => Value::from_array(mask_for_ort)?,
423        ];
424        if self.need_token_type_ids {
425            session_inputs.push((
426                "token_type_ids".into(),
427                Value::from_array(type_ids_arr)?.into(),
428            ));
429        }
430
431        let outputs = self
432            .session
433            .run(session_inputs)
434            .context("ort session.run")?;
435
436        // Output is the model's last_hidden_state (BERT-style). Find the
437        // first f32 tensor in the outputs map — for the Xenova int8 BGE
438        // models there's one output ("last_hidden_state").
439        let mut last_hidden: Option<ndarray::ArrayD<f32>> = None;
440        for (_name, val) in outputs.iter() {
441            if let Ok(arr) = val.try_extract_array::<f32>() {
442                last_hidden = Some(arr.to_owned());
443                break;
444            }
445        }
446        let last_hidden =
447            last_hidden.ok_or_else(|| anyhow!("no f32 output tensor found in session outputs"))?;
448
449        // Expect shape (batch, seq, hidden). Pool per `self.pooling`.
450        if last_hidden.ndim() != 3 {
451            return Err(anyhow!(
452                "expected 3D output (batch, seq, hidden), got ndim={}",
453                last_hidden.ndim()
454            ));
455        }
456        let pooled: ndarray::Array2<f32> = match self.pooling {
457            Pooling::Cls => last_hidden
458                .slice(s![.., 0, ..])
459                .to_owned()
460                .into_dimensionality()
461                .unwrap(),
462            Pooling::Mean => mean_pool(&last_hidden, &mask_arr)?,
463        };
464
465        let mut out = Vec::with_capacity(batch_size);
466        for row in pooled.rows() {
467            let v: Vec<f32> = row.to_vec();
468            // Numpy's np.linalg.norm on f32 promotes to f64 internally for
469            // the sum-of-squares accumulation; mirror that to maximize
470            // cross-language parity. Final result is still f32.
471            let norm_f64: f64 = v.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
472            let denom = (norm_f64 as f32) + 1e-12_f32;
473            let normalized: Vec<f32> = v.iter().map(|x| x / denom).collect();
474            out.push(normalized);
475        }
476        Ok(out)
477    }
478}
479
480/// Mean-pool the last_hidden output across non-padding tokens.
481///
482/// `last_hidden` shape: (batch, seq, hidden). `mask` shape: (batch, seq) with
483/// 1 = real token, 0 = padding. Result shape: (batch, hidden). Mirrors what
484/// sentence-transformers / fastembed do for mean-pooled models like e5 /
485/// MiniLM. Without masking, padding tokens contribute zero-ish but real
486/// values to the mean — for short inputs this materially distorts the vector.
487fn mean_pool(
488    last_hidden: &ndarray::ArrayD<f32>,
489    mask: &ndarray::Array2<i64>,
490) -> Result<ndarray::Array2<f32>> {
491    let shape = last_hidden.shape();
492    if shape.len() != 3 {
493        return Err(anyhow!("mean_pool expects 3D last_hidden, got {:?}", shape));
494    }
495    let (batch, seq, hidden) = (shape[0], shape[1], shape[2]);
496    if mask.shape() != [batch, seq] {
497        return Err(anyhow!(
498            "mean_pool: mask shape {:?} does not match last_hidden batch/seq ({}, {})",
499            mask.shape(),
500            batch,
501            seq
502        ));
503    }
504    let last3 = last_hidden
505        .view()
506        .into_dimensionality::<ndarray::Ix3>()
507        .map_err(|e| anyhow!("mean_pool: cannot view as Ix3: {e}"))?;
508    let mut out = ndarray::Array2::<f32>::zeros((batch, hidden));
509    for b in 0..batch {
510        let mut acc = vec![0.0_f32; hidden];
511        let mut count: f32 = 0.0;
512        for t in 0..seq {
513            if mask[[b, t]] != 0 {
514                count += 1.0;
515                let row = last3.slice(s![b, t, ..]);
516                for (i, v) in row.iter().enumerate() {
517                    acc[i] += *v;
518                }
519            }
520        }
521        // If a row has zero unmasked tokens (shouldn't happen — tokenizers
522        // emit at least the [CLS]/<s> token even for empty input), fall back
523        // to the first token to avoid NaN. Otherwise divide.
524        if count == 0.0 {
525            let row = last3.slice(s![b, 0, ..]);
526            for (i, v) in row.iter().enumerate() {
527                out[[b, i]] = *v;
528            }
529        } else {
530            for i in 0..hidden {
531                out[[b, i]] = acc[i] / count;
532            }
533        }
534    }
535    Ok(out)
536}
537
538/// Map a Python-style `model_name` to a fastembed-rs `EmbeddingModel`. Only
539/// reached for names that are NOT in `user_defined_source` — the int8 names
540/// are handled by the user-defined path.
541#[cfg(feature = "embedder-hub")]
542fn resolve_model_name(name: &str) -> Result<EmbeddingModel> {
543    let mut table: HashMap<&str, EmbeddingModel> = HashMap::new();
544    table.insert("BAAI/bge-base-en-v1.5", EmbeddingModel::BGEBaseENV15);
545    table.insert("BAAI/bge-small-en-v1.5", EmbeddingModel::BGESmallENV15);
546    table.insert("BAAI/bge-large-en-v1.5", EmbeddingModel::BGELargeENV15);
547    table.insert(
548        "sentence-transformers/all-MiniLM-L6-v2",
549        EmbeddingModel::AllMiniLML6V2,
550    );
551    // The semantic chunker's default boundary model is the int8 MiniLM. We map
552    // it to fastembed-rs's stock quantized AllMiniLML6V2Q (Qdrant fp32-optimized
553    // ONNX, mean-pooled) — close enough for boundary detection. Full bit-near-
554    // exact parity (Xenova int8 ONNX with mean pooling in our hand-rolled path)
555    // would require extending the user-defined embedder code from MB-1 with a
556    // mean-pooling branch — out of scope for the semantic-chunker brief because
557    // semantic chunks are not promised byte-identical to Python anyway.
558    table.insert(
559        "sentence-transformers/all-MiniLM-L6-v2-int8",
560        EmbeddingModel::AllMiniLML6V2Q,
561    );
562    // Nomic v1.5 long-context (8k tokens, 768 dim, mean-pooled internally by
563    // fastembed-rs). The `-Q` suffix routes to the int8-quantized ONNX file
564    // (`onnx/model_quantized.onnx` in the upstream HF repo) — same model_name
565    // Python's fastembed accepts. Stock fastembed-rs handles pooling +
566    // normalization, so no user-defined branch is needed.
567    table.insert(
568        "nomic-ai/nomic-embed-text-v1.5",
569        EmbeddingModel::NomicEmbedTextV15,
570    );
571    table.insert(
572        "nomic-ai/nomic-embed-text-v1.5-Q",
573        EmbeddingModel::NomicEmbedTextV15Q,
574    );
575
576    table.get(name).cloned().ok_or_else(|| {
577        anyhow!(
578            "chunkshop-rs does not map model_name {name:?} to a fastembed-rs variant. \
579             Supported (stock): BAAI/bge-base-en-v1.5, BAAI/bge-small-en-v1.5, \
580             BAAI/bge-large-en-v1.5, sentence-transformers/all-MiniLM-L6-v2, \
581             sentence-transformers/all-MiniLM-L6-v2-int8, \
582             nomic-ai/nomic-embed-text-v1.5, nomic-ai/nomic-embed-text-v1.5-Q. \
583             Bit-exact (user-defined): Xenova/bge-base-en-v1.5-int8, \
584             Xenova/bge-small-en-v1.5-int8."
585        )
586    })
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592
593    /// `mean_pool` should average the unmasked tokens and ignore padding.
594    /// Crafted so the right answer is hand-checkable.
595    #[test]
596    fn mean_pool_masks_padding() {
597        // batch=1, seq=4, hidden=3. First 2 tokens are real, last 2 are padding.
598        // Real values: [1,2,3], [4,5,6] → mean = [2.5, 3.5, 4.5].
599        // Padding values would be [99,99,99] each — if mask is wrong we'd see
600        // them dragging the mean toward 99.
601        let last_hidden = ndarray::Array3::<f32>::from_shape_vec(
602            (1, 4, 3),
603            vec![
604                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0,
605            ],
606        )
607        .unwrap()
608        .into_dyn();
609        let mask = ndarray::Array2::<i64>::from_shape_vec((1, 4), vec![1, 1, 0, 0]).unwrap();
610
611        let pooled = mean_pool(&last_hidden, &mask).unwrap();
612        assert_eq!(pooled.shape(), &[1, 3]);
613        let row: Vec<f32> = pooled.row(0).to_vec();
614        assert!((row[0] - 2.5).abs() < 1e-6, "got {row:?}");
615        assert!((row[1] - 3.5).abs() < 1e-6, "got {row:?}");
616        assert!((row[2] - 4.5).abs() < 1e-6, "got {row:?}");
617    }
618
619    /// All-padding row falls back to first-token (no NaN). Defensive path.
620    #[test]
621    fn mean_pool_all_padding_uses_first_token() {
622        let last_hidden =
623            ndarray::Array3::<f32>::from_shape_vec((1, 2, 2), vec![7.0, 8.0, 99.0, 99.0])
624                .unwrap()
625                .into_dyn();
626        let mask = ndarray::Array2::<i64>::from_shape_vec((1, 2), vec![0, 0]).unwrap();
627        let pooled = mean_pool(&last_hidden, &mask).unwrap();
628        let row: Vec<f32> = pooled.row(0).to_vec();
629        assert_eq!(row, vec![7.0, 8.0]);
630    }
631
632    /// Multi-batch: each row pools independently against its own mask.
633    #[test]
634    fn mean_pool_multi_batch_independent_masks() {
635        let last_hidden = ndarray::Array3::<f32>::from_shape_vec(
636            (2, 3, 1),
637            vec![
638                1.0, 2.0, 3.0, // batch 0
639                10.0, 20.0, 30.0, // batch 1
640            ],
641        )
642        .unwrap()
643        .into_dyn();
644        // batch 0: all real → mean = 2.0
645        // batch 1: only first → mean = 10.0
646        let mask = ndarray::Array2::<i64>::from_shape_vec((2, 3), vec![1, 1, 1, 1, 0, 0]).unwrap();
647        let pooled = mean_pool(&last_hidden, &mask).unwrap();
648        assert!((pooled[[0, 0]] - 2.0).abs() < 1e-6);
649        assert!((pooled[[1, 0]] - 10.0).abs() < 1e-6);
650    }
651
652    #[test]
653    fn parse_pooling_round_trips() {
654        assert_eq!(parse_pooling("cls").unwrap(), Pooling::Cls);
655        assert_eq!(parse_pooling("mean").unwrap(), Pooling::Mean);
656        assert!(parse_pooling("max").is_err());
657    }
658}