use anyhow::{anyhow, Context, Result};
use fastembed::TextEmbedding;
use ndarray::{s, Array2};
use ort::session::{builder::GraphOptimizationLevel, Session};
use ort::value::Value;
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use tracing::info;
#[cfg(feature = "embedder-hub")]
use fastembed::{EmbeddingModel, InitOptions};
#[cfg(feature = "embedder-hub")]
use std::collections::HashMap;
use crate::config::FastembedEmbedderConfig;
#[cfg(feature = "embedder-hub")]
use crate::hf_cache::{fetch_user_defined_files, HfModelFiles};
pub struct FastembedEmbedder {
cfg: FastembedEmbedderConfig,
backend: Backend,
embed_seconds: f64,
}
enum Backend {
#[cfg_attr(not(feature = "embedder-hub"), allow(dead_code))]
Stock(TextEmbedding),
UserDefined(UserDefinedRunner),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Pooling {
Cls,
Mean,
}
fn parse_pooling(s: &str) -> Result<Pooling> {
match s {
"cls" => Ok(Pooling::Cls),
"mean" => Ok(Pooling::Mean),
other => Err(anyhow!(
"embedder.pooling must be 'cls' or 'mean', got {other:?}"
)),
}
}
struct UserDefinedRunner {
session: Session,
tokenizer: Tokenizer,
need_token_type_ids: bool,
pooling: Pooling,
}
#[cfg(feature = "embedder-hub")]
fn user_defined_source(model_name: &str) -> Option<(&'static str, &'static str)> {
match model_name {
"Xenova/bge-base-en-v1.5-int8" => {
Some(("Xenova/bge-base-en-v1.5", "onnx/model_quantized.onnx"))
}
"Xenova/bge-small-en-v1.5-int8" => {
Some(("Xenova/bge-small-en-v1.5", "onnx/model_quantized.onnx"))
}
_ => None,
}
}
impl FastembedEmbedder {
#[cfg(feature = "embedder-hub")]
pub fn new(cfg: FastembedEmbedderConfig) -> Result<Self> {
if cfg.is_byo() {
let repo = cfg.hf_repo.as_deref().expect("BYO repo present");
let onnx_path = cfg.onnx_path.as_deref().expect("BYO onnx_path present");
let pooling = parse_pooling(&cfg.pooling)?;
let intra = cfg.threads.unwrap_or(1);
let runner = build_user_defined_runner(repo, onnx_path, pooling, intra)?;
info!(
"embedder loaded (BYO, YAML-driven): {} (dim={}, repo={}, file={}, pooling={:?})",
cfg.model_name, cfg.dim, repo, onnx_path, pooling
);
return Ok(Self {
cfg,
backend: Backend::UserDefined(runner),
embed_seconds: 0.0,
});
}
if let Some((repo, onnx_path)) = user_defined_source(&cfg.model_name) {
let intra = cfg.threads.unwrap_or(1);
let runner = build_user_defined_runner(repo, onnx_path, Pooling::Cls, intra)?;
info!(
"embedder loaded (user-defined, bit-exact): {} (dim={}, repo={}, file={})",
cfg.model_name, cfg.dim, repo, onnx_path
);
return Ok(Self {
cfg,
backend: Backend::UserDefined(runner),
embed_seconds: 0.0,
});
}
let variant = resolve_model_name(&cfg.model_name)?;
let opts = InitOptions::new(variant).with_show_download_progress(true);
let model = TextEmbedding::try_new(opts)
.with_context(|| format!("initialising fastembed model {:?}", cfg.model_name))?;
info!(
"embedder loaded (stock variant): {} (dim={})",
cfg.model_name, cfg.dim
);
Ok(Self {
cfg,
backend: Backend::Stock(model),
embed_seconds: 0.0,
})
}
pub fn from_user_defined_files(
cfg: FastembedEmbedderConfig,
onnx: Vec<u8>,
tokenizer: Vec<u8>,
tokenizer_config: Vec<u8>,
model_config: Vec<u8>,
) -> Result<Self> {
let pooling = parse_pooling(&cfg.pooling)?;
let intra = cfg.threads.unwrap_or(1);
let runner = build_user_defined_runner_from_bytes(
onnx,
tokenizer,
tokenizer_config,
model_config,
pooling,
intra,
)?;
info!(
"embedder loaded (bytes-in, no hf-hub): {} (dim={}, pooling={:?})",
cfg.model_name, cfg.dim, pooling
);
Ok(Self {
cfg,
backend: Backend::UserDefined(runner),
embed_seconds: 0.0,
})
}
pub fn embed_seconds(&self) -> f64 {
self.embed_seconds
}
pub fn dim(&self) -> usize {
self.cfg.dim
}
pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let t0 = std::time::Instant::now();
let vecs = match &mut self.backend {
Backend::Stock(model) => {
let refs: Vec<&str> = texts.iter().map(String::as_str).collect();
model
.embed(refs, Some(self.cfg.batch_size))
.context("fastembed embed call failed")?
}
Backend::UserDefined(runner) => {
let mut out: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
for chunk in texts.chunks(self.cfg.batch_size.max(1)) {
let refs: Vec<&str> = chunk.iter().map(String::as_str).collect();
let batch = runner.embed_batch(&refs)?;
out.extend(batch);
}
out
}
};
self.embed_seconds += t0.elapsed().as_secs_f64();
if let Some(first) = vecs.first() {
if first.len() != self.cfg.dim {
return Err(anyhow!(
"model {} produced dim {}, config says dim={}",
self.cfg.model_name,
first.len(),
self.cfg.dim
));
}
}
Ok(vecs)
}
}
#[cfg(feature = "chunkers")]
impl crate::chunker::BoundaryEmbedder for FastembedEmbedder {
fn embed_batch(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let owned: Vec<String> = texts.iter().map(|s| (*s).to_string()).collect();
self.embed(owned)
}
}
#[cfg(feature = "embedder-hub")]
fn build_user_defined_runner(
repo: &str,
onnx_path: &str,
pooling: Pooling,
intra_threads: usize,
) -> Result<UserDefinedRunner> {
let HfModelFiles {
onnx,
tokenizer,
tokenizer_config,
special_tokens_map: _,
config,
} = fetch_user_defined_files(repo, onnx_path)
.with_context(|| format!("fetching user-defined files for {repo}"))?;
build_user_defined_runner_from_bytes(
onnx,
tokenizer,
tokenizer_config,
config,
pooling,
intra_threads,
)
.with_context(|| format!("building user-defined runner for {repo}"))
}
fn build_user_defined_runner_from_bytes(
onnx: Vec<u8>,
tokenizer: Vec<u8>,
tokenizer_config: Vec<u8>,
config: Vec<u8>,
pooling: Pooling,
intra_threads: usize,
) -> Result<UserDefinedRunner> {
let session = Session::builder()
.map_err(|e| anyhow!("ort session builder: {e}"))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| anyhow!("ort with_optimization_level: {e}"))?
.with_intra_threads(intra_threads)
.map_err(|e| anyhow!("ort with_intra_threads({intra_threads}): {e}"))?
.commit_from_memory(&onnx)
.map_err(|e| anyhow!("commit ONNX from memory: {e}"))?;
let need_token_type_ids = session
.inputs()
.iter()
.any(|i| i.name() == "token_type_ids");
let mut tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|e| anyhow!("tokenizer load failed: {e}"))?;
let cfg_json: serde_json::Value =
serde_json::from_slice(&config).map_err(|e| anyhow!("parse config.json: {e}"))?;
let tcfg_json: serde_json::Value = serde_json::from_slice(&tokenizer_config)
.map_err(|e| anyhow!("parse tokenizer_config.json: {e}"))?;
let pad_id = cfg_json
.get("pad_token_id")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
let pad_token = tcfg_json
.get("pad_token")
.and_then(|v| v.as_str())
.unwrap_or("[PAD]")
.to_string();
let model_max_length = tcfg_json
.get("model_max_length")
.and_then(|v| v.as_f64())
.unwrap_or(512.0)
.min(512.0) as usize;
tokenizer
.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
pad_token,
pad_id,
..Default::default()
}))
.with_truncation(Some(TruncationParams {
max_length: model_max_length,
..Default::default()
}))
.map_err(|e| anyhow!("configure tokenizer padding/truncation: {e}"))?;
Ok(UserDefinedRunner {
session,
tokenizer,
need_token_type_ids,
pooling,
})
}
impl UserDefinedRunner {
fn embed_batch(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let encodings = self
.tokenizer
.encode_batch(texts.to_vec(), true)
.map_err(|e| anyhow!("tokenize batch: {e}"))?;
let batch_size = encodings.len();
let seq_len = encodings
.first()
.ok_or_else(|| anyhow!("empty encodings"))?
.len();
let mut ids = Vec::with_capacity(batch_size * seq_len);
let mut mask = Vec::with_capacity(batch_size * seq_len);
let mut type_ids = Vec::with_capacity(batch_size * seq_len);
for enc in &encodings {
ids.extend(enc.get_ids().iter().map(|x| *x as i64));
mask.extend(enc.get_attention_mask().iter().map(|x| *x as i64));
type_ids.extend(enc.get_type_ids().iter().map(|x| *x as i64));
}
let ids_arr: Array2<i64> =
Array2::from_shape_vec((batch_size, seq_len), ids).context("ids array shape")?;
let mask_arr: Array2<i64> =
Array2::from_shape_vec((batch_size, seq_len), mask).context("mask array shape")?;
let type_ids_arr: Array2<i64> = Array2::from_shape_vec((batch_size, seq_len), type_ids)
.context("type_ids array shape")?;
let mask_for_ort = mask_arr.clone();
let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(ids_arr)?,
"attention_mask" => Value::from_array(mask_for_ort)?,
];
if self.need_token_type_ids {
session_inputs.push((
"token_type_ids".into(),
Value::from_array(type_ids_arr)?.into(),
));
}
let outputs = self
.session
.run(session_inputs)
.context("ort session.run")?;
let mut last_hidden: Option<ndarray::ArrayD<f32>> = None;
for (_name, val) in outputs.iter() {
if let Ok(arr) = val.try_extract_array::<f32>() {
last_hidden = Some(arr.to_owned());
break;
}
}
let last_hidden =
last_hidden.ok_or_else(|| anyhow!("no f32 output tensor found in session outputs"))?;
if last_hidden.ndim() != 3 {
return Err(anyhow!(
"expected 3D output (batch, seq, hidden), got ndim={}",
last_hidden.ndim()
));
}
let pooled: ndarray::Array2<f32> = match self.pooling {
Pooling::Cls => last_hidden
.slice(s![.., 0, ..])
.to_owned()
.into_dimensionality()
.unwrap(),
Pooling::Mean => mean_pool(&last_hidden, &mask_arr)?,
};
let mut out = Vec::with_capacity(batch_size);
for row in pooled.rows() {
let v: Vec<f32> = row.to_vec();
let norm_f64: f64 = v.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
let denom = (norm_f64 as f32) + 1e-12_f32;
let normalized: Vec<f32> = v.iter().map(|x| x / denom).collect();
out.push(normalized);
}
Ok(out)
}
}
fn mean_pool(
last_hidden: &ndarray::ArrayD<f32>,
mask: &ndarray::Array2<i64>,
) -> Result<ndarray::Array2<f32>> {
let shape = last_hidden.shape();
if shape.len() != 3 {
return Err(anyhow!("mean_pool expects 3D last_hidden, got {:?}", shape));
}
let (batch, seq, hidden) = (shape[0], shape[1], shape[2]);
if mask.shape() != [batch, seq] {
return Err(anyhow!(
"mean_pool: mask shape {:?} does not match last_hidden batch/seq ({}, {})",
mask.shape(),
batch,
seq
));
}
let last3 = last_hidden
.view()
.into_dimensionality::<ndarray::Ix3>()
.map_err(|e| anyhow!("mean_pool: cannot view as Ix3: {e}"))?;
let mut out = ndarray::Array2::<f32>::zeros((batch, hidden));
for b in 0..batch {
let mut acc = vec![0.0_f32; hidden];
let mut count: f32 = 0.0;
for t in 0..seq {
if mask[[b, t]] != 0 {
count += 1.0;
let row = last3.slice(s![b, t, ..]);
for (i, v) in row.iter().enumerate() {
acc[i] += *v;
}
}
}
if count == 0.0 {
let row = last3.slice(s![b, 0, ..]);
for (i, v) in row.iter().enumerate() {
out[[b, i]] = *v;
}
} else {
for i in 0..hidden {
out[[b, i]] = acc[i] / count;
}
}
}
Ok(out)
}
#[cfg(feature = "embedder-hub")]
fn resolve_model_name(name: &str) -> Result<EmbeddingModel> {
let mut table: HashMap<&str, EmbeddingModel> = HashMap::new();
table.insert("BAAI/bge-base-en-v1.5", EmbeddingModel::BGEBaseENV15);
table.insert("BAAI/bge-small-en-v1.5", EmbeddingModel::BGESmallENV15);
table.insert("BAAI/bge-large-en-v1.5", EmbeddingModel::BGELargeENV15);
table.insert(
"sentence-transformers/all-MiniLM-L6-v2",
EmbeddingModel::AllMiniLML6V2,
);
table.insert(
"sentence-transformers/all-MiniLM-L6-v2-int8",
EmbeddingModel::AllMiniLML6V2Q,
);
table.insert(
"nomic-ai/nomic-embed-text-v1.5",
EmbeddingModel::NomicEmbedTextV15,
);
table.insert(
"nomic-ai/nomic-embed-text-v1.5-Q",
EmbeddingModel::NomicEmbedTextV15Q,
);
table.get(name).cloned().ok_or_else(|| {
anyhow!(
"chunkshop-rs does not map model_name {name:?} to a fastembed-rs variant. \
Supported (stock): BAAI/bge-base-en-v1.5, BAAI/bge-small-en-v1.5, \
BAAI/bge-large-en-v1.5, sentence-transformers/all-MiniLM-L6-v2, \
sentence-transformers/all-MiniLM-L6-v2-int8, \
nomic-ai/nomic-embed-text-v1.5, nomic-ai/nomic-embed-text-v1.5-Q. \
Bit-exact (user-defined): Xenova/bge-base-en-v1.5-int8, \
Xenova/bge-small-en-v1.5-int8."
)
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mean_pool_masks_padding() {
let last_hidden = ndarray::Array3::<f32>::from_shape_vec(
(1, 4, 3),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0,
],
)
.unwrap()
.into_dyn();
let mask = ndarray::Array2::<i64>::from_shape_vec((1, 4), vec![1, 1, 0, 0]).unwrap();
let pooled = mean_pool(&last_hidden, &mask).unwrap();
assert_eq!(pooled.shape(), &[1, 3]);
let row: Vec<f32> = pooled.row(0).to_vec();
assert!((row[0] - 2.5).abs() < 1e-6, "got {row:?}");
assert!((row[1] - 3.5).abs() < 1e-6, "got {row:?}");
assert!((row[2] - 4.5).abs() < 1e-6, "got {row:?}");
}
#[test]
fn mean_pool_all_padding_uses_first_token() {
let last_hidden =
ndarray::Array3::<f32>::from_shape_vec((1, 2, 2), vec![7.0, 8.0, 99.0, 99.0])
.unwrap()
.into_dyn();
let mask = ndarray::Array2::<i64>::from_shape_vec((1, 2), vec![0, 0]).unwrap();
let pooled = mean_pool(&last_hidden, &mask).unwrap();
let row: Vec<f32> = pooled.row(0).to_vec();
assert_eq!(row, vec![7.0, 8.0]);
}
#[test]
fn mean_pool_multi_batch_independent_masks() {
let last_hidden = ndarray::Array3::<f32>::from_shape_vec(
(2, 3, 1),
vec![
1.0, 2.0, 3.0, 10.0, 20.0, 30.0, ],
)
.unwrap()
.into_dyn();
let mask = ndarray::Array2::<i64>::from_shape_vec((2, 3), vec![1, 1, 1, 1, 0, 0]).unwrap();
let pooled = mean_pool(&last_hidden, &mask).unwrap();
assert!((pooled[[0, 0]] - 2.0).abs() < 1e-6);
assert!((pooled[[1, 0]] - 10.0).abs() < 1e-6);
}
#[test]
fn parse_pooling_round_trips() {
assert_eq!(parse_pooling("cls").unwrap(), Pooling::Cls);
assert_eq!(parse_pooling("mean").unwrap(), Pooling::Mean);
assert!(parse_pooling("max").is_err());
}
}