1use anyhow::{anyhow, Context, Result};
26use fastembed::TextEmbedding;
27use ndarray::{s, Array2};
28use ort::session::{builder::GraphOptimizationLevel, Session};
29use ort::value::Value;
30use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
31use tracing::info;
32
33#[cfg(feature = "embedder-hub")]
34use fastembed::{EmbeddingModel, InitOptions};
35#[cfg(feature = "embedder-hub")]
36use std::collections::HashMap;
37
38use crate::config::FastembedEmbedderConfig;
39#[cfg(feature = "embedder-hub")]
40use crate::hf_cache::{fetch_user_defined_files, HfModelFiles};
41
42pub struct FastembedEmbedder {
43 cfg: FastembedEmbedderConfig,
44 backend: Backend,
45 embed_seconds: f64,
49}
50
51enum Backend {
52 #[cfg_attr(not(feature = "embedder-hub"), allow(dead_code))]
57 Stock(TextEmbedding),
58 UserDefined(UserDefinedRunner),
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum Pooling {
65 Cls,
67 Mean,
70}
71
72fn parse_pooling(s: &str) -> Result<Pooling> {
73 match s {
74 "cls" => Ok(Pooling::Cls),
75 "mean" => Ok(Pooling::Mean),
76 other => Err(anyhow!(
77 "embedder.pooling must be 'cls' or 'mean', got {other:?}"
78 )),
79 }
80}
81
82struct UserDefinedRunner {
83 session: Session,
84 tokenizer: Tokenizer,
85 need_token_type_ids: bool,
86 pooling: Pooling,
87}
88
89#[cfg(feature = "embedder-hub")]
92fn user_defined_source(model_name: &str) -> Option<(&'static str, &'static str)> {
93 match model_name {
94 "Xenova/bge-base-en-v1.5-int8" => {
95 Some(("Xenova/bge-base-en-v1.5", "onnx/model_quantized.onnx"))
96 }
97 "Xenova/bge-small-en-v1.5-int8" => {
98 Some(("Xenova/bge-small-en-v1.5", "onnx/model_quantized.onnx"))
99 }
100 _ => None,
101 }
102}
103
104impl FastembedEmbedder {
105 #[cfg(feature = "embedder-hub")]
111 pub fn new(cfg: FastembedEmbedderConfig) -> Result<Self> {
112 if cfg.is_byo() {
119 let repo = cfg.hf_repo.as_deref().expect("BYO repo present");
121 let onnx_path = cfg.onnx_path.as_deref().expect("BYO onnx_path present");
122 let pooling = parse_pooling(&cfg.pooling)?;
123 let intra = cfg.threads.unwrap_or(1);
126 let runner = build_user_defined_runner(repo, onnx_path, pooling, intra)?;
127 info!(
128 "embedder loaded (BYO, YAML-driven): {} (dim={}, repo={}, file={}, pooling={:?})",
129 cfg.model_name, cfg.dim, repo, onnx_path, pooling
130 );
131 return Ok(Self {
132 cfg,
133 backend: Backend::UserDefined(runner),
134 embed_seconds: 0.0,
135 });
136 }
137
138 if let Some((repo, onnx_path)) = user_defined_source(&cfg.model_name) {
139 let intra = cfg.threads.unwrap_or(1);
144 let runner = build_user_defined_runner(repo, onnx_path, Pooling::Cls, intra)?;
145 info!(
146 "embedder loaded (user-defined, bit-exact): {} (dim={}, repo={}, file={})",
147 cfg.model_name, cfg.dim, repo, onnx_path
148 );
149 return Ok(Self {
150 cfg,
151 backend: Backend::UserDefined(runner),
152 embed_seconds: 0.0,
153 });
154 }
155
156 let variant = resolve_model_name(&cfg.model_name)?;
157 let opts = InitOptions::new(variant).with_show_download_progress(true);
158 let model = TextEmbedding::try_new(opts)
159 .with_context(|| format!("initialising fastembed model {:?}", cfg.model_name))?;
160 info!(
161 "embedder loaded (stock variant): {} (dim={})",
162 cfg.model_name, cfg.dim
163 );
164 Ok(Self {
165 cfg,
166 backend: Backend::Stock(model),
167 embed_seconds: 0.0,
168 })
169 }
170
171 pub fn from_user_defined_files(
188 cfg: FastembedEmbedderConfig,
189 onnx: Vec<u8>,
190 tokenizer: Vec<u8>,
191 tokenizer_config: Vec<u8>,
192 model_config: Vec<u8>,
193 ) -> Result<Self> {
194 let pooling = parse_pooling(&cfg.pooling)?;
195 let intra = cfg.threads.unwrap_or(1);
196 let runner = build_user_defined_runner_from_bytes(
197 onnx,
198 tokenizer,
199 tokenizer_config,
200 model_config,
201 pooling,
202 intra,
203 )?;
204 info!(
205 "embedder loaded (bytes-in, no hf-hub): {} (dim={}, pooling={:?})",
206 cfg.model_name, cfg.dim, pooling
207 );
208 Ok(Self {
209 cfg,
210 backend: Backend::UserDefined(runner),
211 embed_seconds: 0.0,
212 })
213 }
214
215 pub fn embed_seconds(&self) -> f64 {
217 self.embed_seconds
218 }
219
220 pub fn dim(&self) -> usize {
221 self.cfg.dim
222 }
223
224 pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
228 if texts.is_empty() {
229 return Ok(Vec::new());
230 }
231 let t0 = std::time::Instant::now();
232 let vecs = match &mut self.backend {
233 Backend::Stock(model) => {
234 let refs: Vec<&str> = texts.iter().map(String::as_str).collect();
235 model
236 .embed(refs, Some(self.cfg.batch_size))
237 .context("fastembed embed call failed")?
238 }
239 Backend::UserDefined(runner) => {
240 let mut out: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
241 for chunk in texts.chunks(self.cfg.batch_size.max(1)) {
242 let refs: Vec<&str> = chunk.iter().map(String::as_str).collect();
243 let batch = runner.embed_batch(&refs)?;
244 out.extend(batch);
245 }
246 out
247 }
248 };
249 self.embed_seconds += t0.elapsed().as_secs_f64();
250 if let Some(first) = vecs.first() {
251 if first.len() != self.cfg.dim {
252 return Err(anyhow!(
253 "model {} produced dim {}, config says dim={}",
254 self.cfg.model_name,
255 first.len(),
256 self.cfg.dim
257 ));
258 }
259 }
260 Ok(vecs)
261 }
262}
263
264#[cfg(feature = "chunkers")]
274impl crate::chunker::BoundaryEmbedder for FastembedEmbedder {
275 fn embed_batch(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
276 let owned: Vec<String> = texts.iter().map(|s| (*s).to_string()).collect();
277 self.embed(owned)
278 }
279}
280
281#[cfg(feature = "embedder-hub")]
284fn build_user_defined_runner(
285 repo: &str,
286 onnx_path: &str,
287 pooling: Pooling,
288 intra_threads: usize,
289) -> Result<UserDefinedRunner> {
290 let HfModelFiles {
291 onnx,
292 tokenizer,
293 tokenizer_config,
294 special_tokens_map: _,
295 config,
296 } = fetch_user_defined_files(repo, onnx_path)
297 .with_context(|| format!("fetching user-defined files for {repo}"))?;
298
299 build_user_defined_runner_from_bytes(
300 onnx,
301 tokenizer,
302 tokenizer_config,
303 config,
304 pooling,
305 intra_threads,
306 )
307 .with_context(|| format!("building user-defined runner for {repo}"))
308}
309
310fn build_user_defined_runner_from_bytes(
314 onnx: Vec<u8>,
315 tokenizer: Vec<u8>,
316 tokenizer_config: Vec<u8>,
317 config: Vec<u8>,
318 pooling: Pooling,
319 intra_threads: usize,
320) -> Result<UserDefinedRunner> {
321 let session = Session::builder()
328 .map_err(|e| anyhow!("ort session builder: {e}"))?
329 .with_optimization_level(GraphOptimizationLevel::Level3)
330 .map_err(|e| anyhow!("ort with_optimization_level: {e}"))?
331 .with_intra_threads(intra_threads)
332 .map_err(|e| anyhow!("ort with_intra_threads({intra_threads}): {e}"))?
333 .commit_from_memory(&onnx)
334 .map_err(|e| anyhow!("commit ONNX from memory: {e}"))?;
335
336 let need_token_type_ids = session
337 .inputs()
338 .iter()
339 .any(|i| i.name() == "token_type_ids");
340
341 let mut tokenizer =
342 Tokenizer::from_bytes(&tokenizer).map_err(|e| anyhow!("tokenizer load failed: {e}"))?;
343
344 let cfg_json: serde_json::Value =
349 serde_json::from_slice(&config).map_err(|e| anyhow!("parse config.json: {e}"))?;
350 let tcfg_json: serde_json::Value = serde_json::from_slice(&tokenizer_config)
351 .map_err(|e| anyhow!("parse tokenizer_config.json: {e}"))?;
352 let pad_id = cfg_json
353 .get("pad_token_id")
354 .and_then(|v| v.as_u64())
355 .unwrap_or(0) as u32;
356 let pad_token = tcfg_json
357 .get("pad_token")
358 .and_then(|v| v.as_str())
359 .unwrap_or("[PAD]")
360 .to_string();
361 let model_max_length = tcfg_json
362 .get("model_max_length")
363 .and_then(|v| v.as_f64())
364 .unwrap_or(512.0)
365 .min(512.0) as usize;
366
367 tokenizer
368 .with_padding(Some(PaddingParams {
369 strategy: PaddingStrategy::BatchLongest,
370 pad_token,
371 pad_id,
372 ..Default::default()
373 }))
374 .with_truncation(Some(TruncationParams {
375 max_length: model_max_length,
376 ..Default::default()
377 }))
378 .map_err(|e| anyhow!("configure tokenizer padding/truncation: {e}"))?;
379
380 Ok(UserDefinedRunner {
381 session,
382 tokenizer,
383 need_token_type_ids,
384 pooling,
385 })
386}
387
388impl UserDefinedRunner {
389 fn embed_batch(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
390 let encodings = self
391 .tokenizer
392 .encode_batch(texts.to_vec(), true)
393 .map_err(|e| anyhow!("tokenize batch: {e}"))?;
394
395 let batch_size = encodings.len();
396 let seq_len = encodings
397 .first()
398 .ok_or_else(|| anyhow!("empty encodings"))?
399 .len();
400
401 let mut ids = Vec::with_capacity(batch_size * seq_len);
402 let mut mask = Vec::with_capacity(batch_size * seq_len);
403 let mut type_ids = Vec::with_capacity(batch_size * seq_len);
404 for enc in &encodings {
405 ids.extend(enc.get_ids().iter().map(|x| *x as i64));
406 mask.extend(enc.get_attention_mask().iter().map(|x| *x as i64));
407 type_ids.extend(enc.get_type_ids().iter().map(|x| *x as i64));
408 }
409
410 let ids_arr: Array2<i64> =
411 Array2::from_shape_vec((batch_size, seq_len), ids).context("ids array shape")?;
412 let mask_arr: Array2<i64> =
413 Array2::from_shape_vec((batch_size, seq_len), mask).context("mask array shape")?;
414 let type_ids_arr: Array2<i64> = Array2::from_shape_vec((batch_size, seq_len), type_ids)
415 .context("type_ids array shape")?;
416
417 let mask_for_ort = mask_arr.clone();
420 let mut session_inputs = ort::inputs![
421 "input_ids" => Value::from_array(ids_arr)?,
422 "attention_mask" => Value::from_array(mask_for_ort)?,
423 ];
424 if self.need_token_type_ids {
425 session_inputs.push((
426 "token_type_ids".into(),
427 Value::from_array(type_ids_arr)?.into(),
428 ));
429 }
430
431 let outputs = self
432 .session
433 .run(session_inputs)
434 .context("ort session.run")?;
435
436 let mut last_hidden: Option<ndarray::ArrayD<f32>> = None;
440 for (_name, val) in outputs.iter() {
441 if let Ok(arr) = val.try_extract_array::<f32>() {
442 last_hidden = Some(arr.to_owned());
443 break;
444 }
445 }
446 let last_hidden =
447 last_hidden.ok_or_else(|| anyhow!("no f32 output tensor found in session outputs"))?;
448
449 if last_hidden.ndim() != 3 {
451 return Err(anyhow!(
452 "expected 3D output (batch, seq, hidden), got ndim={}",
453 last_hidden.ndim()
454 ));
455 }
456 let pooled: ndarray::Array2<f32> = match self.pooling {
457 Pooling::Cls => last_hidden
458 .slice(s![.., 0, ..])
459 .to_owned()
460 .into_dimensionality()
461 .unwrap(),
462 Pooling::Mean => mean_pool(&last_hidden, &mask_arr)?,
463 };
464
465 let mut out = Vec::with_capacity(batch_size);
466 for row in pooled.rows() {
467 let v: Vec<f32> = row.to_vec();
468 let norm_f64: f64 = v.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
472 let denom = (norm_f64 as f32) + 1e-12_f32;
473 let normalized: Vec<f32> = v.iter().map(|x| x / denom).collect();
474 out.push(normalized);
475 }
476 Ok(out)
477 }
478}
479
480fn mean_pool(
488 last_hidden: &ndarray::ArrayD<f32>,
489 mask: &ndarray::Array2<i64>,
490) -> Result<ndarray::Array2<f32>> {
491 let shape = last_hidden.shape();
492 if shape.len() != 3 {
493 return Err(anyhow!("mean_pool expects 3D last_hidden, got {:?}", shape));
494 }
495 let (batch, seq, hidden) = (shape[0], shape[1], shape[2]);
496 if mask.shape() != [batch, seq] {
497 return Err(anyhow!(
498 "mean_pool: mask shape {:?} does not match last_hidden batch/seq ({}, {})",
499 mask.shape(),
500 batch,
501 seq
502 ));
503 }
504 let last3 = last_hidden
505 .view()
506 .into_dimensionality::<ndarray::Ix3>()
507 .map_err(|e| anyhow!("mean_pool: cannot view as Ix3: {e}"))?;
508 let mut out = ndarray::Array2::<f32>::zeros((batch, hidden));
509 for b in 0..batch {
510 let mut acc = vec![0.0_f32; hidden];
511 let mut count: f32 = 0.0;
512 for t in 0..seq {
513 if mask[[b, t]] != 0 {
514 count += 1.0;
515 let row = last3.slice(s![b, t, ..]);
516 for (i, v) in row.iter().enumerate() {
517 acc[i] += *v;
518 }
519 }
520 }
521 if count == 0.0 {
525 let row = last3.slice(s![b, 0, ..]);
526 for (i, v) in row.iter().enumerate() {
527 out[[b, i]] = *v;
528 }
529 } else {
530 for i in 0..hidden {
531 out[[b, i]] = acc[i] / count;
532 }
533 }
534 }
535 Ok(out)
536}
537
538#[cfg(feature = "embedder-hub")]
542fn resolve_model_name(name: &str) -> Result<EmbeddingModel> {
543 let mut table: HashMap<&str, EmbeddingModel> = HashMap::new();
544 table.insert("BAAI/bge-base-en-v1.5", EmbeddingModel::BGEBaseENV15);
545 table.insert("BAAI/bge-small-en-v1.5", EmbeddingModel::BGESmallENV15);
546 table.insert("BAAI/bge-large-en-v1.5", EmbeddingModel::BGELargeENV15);
547 table.insert(
548 "sentence-transformers/all-MiniLM-L6-v2",
549 EmbeddingModel::AllMiniLML6V2,
550 );
551 table.insert(
559 "sentence-transformers/all-MiniLM-L6-v2-int8",
560 EmbeddingModel::AllMiniLML6V2Q,
561 );
562 table.insert(
568 "nomic-ai/nomic-embed-text-v1.5",
569 EmbeddingModel::NomicEmbedTextV15,
570 );
571 table.insert(
572 "nomic-ai/nomic-embed-text-v1.5-Q",
573 EmbeddingModel::NomicEmbedTextV15Q,
574 );
575
576 table.get(name).cloned().ok_or_else(|| {
577 anyhow!(
578 "chunkshop-rs does not map model_name {name:?} to a fastembed-rs variant. \
579 Supported (stock): BAAI/bge-base-en-v1.5, BAAI/bge-small-en-v1.5, \
580 BAAI/bge-large-en-v1.5, sentence-transformers/all-MiniLM-L6-v2, \
581 sentence-transformers/all-MiniLM-L6-v2-int8, \
582 nomic-ai/nomic-embed-text-v1.5, nomic-ai/nomic-embed-text-v1.5-Q. \
583 Bit-exact (user-defined): Xenova/bge-base-en-v1.5-int8, \
584 Xenova/bge-small-en-v1.5-int8."
585 )
586 })
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592
593 #[test]
596 fn mean_pool_masks_padding() {
597 let last_hidden = ndarray::Array3::<f32>::from_shape_vec(
602 (1, 4, 3),
603 vec![
604 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0,
605 ],
606 )
607 .unwrap()
608 .into_dyn();
609 let mask = ndarray::Array2::<i64>::from_shape_vec((1, 4), vec![1, 1, 0, 0]).unwrap();
610
611 let pooled = mean_pool(&last_hidden, &mask).unwrap();
612 assert_eq!(pooled.shape(), &[1, 3]);
613 let row: Vec<f32> = pooled.row(0).to_vec();
614 assert!((row[0] - 2.5).abs() < 1e-6, "got {row:?}");
615 assert!((row[1] - 3.5).abs() < 1e-6, "got {row:?}");
616 assert!((row[2] - 4.5).abs() < 1e-6, "got {row:?}");
617 }
618
619 #[test]
621 fn mean_pool_all_padding_uses_first_token() {
622 let last_hidden =
623 ndarray::Array3::<f32>::from_shape_vec((1, 2, 2), vec![7.0, 8.0, 99.0, 99.0])
624 .unwrap()
625 .into_dyn();
626 let mask = ndarray::Array2::<i64>::from_shape_vec((1, 2), vec![0, 0]).unwrap();
627 let pooled = mean_pool(&last_hidden, &mask).unwrap();
628 let row: Vec<f32> = pooled.row(0).to_vec();
629 assert_eq!(row, vec![7.0, 8.0]);
630 }
631
632 #[test]
634 fn mean_pool_multi_batch_independent_masks() {
635 let last_hidden = ndarray::Array3::<f32>::from_shape_vec(
636 (2, 3, 1),
637 vec![
638 1.0, 2.0, 3.0, 10.0, 20.0, 30.0, ],
641 )
642 .unwrap()
643 .into_dyn();
644 let mask = ndarray::Array2::<i64>::from_shape_vec((2, 3), vec![1, 1, 1, 1, 0, 0]).unwrap();
647 let pooled = mean_pool(&last_hidden, &mask).unwrap();
648 assert!((pooled[[0, 0]] - 2.0).abs() < 1e-6);
649 assert!((pooled[[1, 0]] - 10.0).abs() < 1e-6);
650 }
651
652 #[test]
653 fn parse_pooling_round_trips() {
654 assert_eq!(parse_pooling("cls").unwrap(), Pooling::Cls);
655 assert_eq!(parse_pooling("mean").unwrap(), Pooling::Mean);
656 assert!(parse_pooling("max").is_err());
657 }
658}