use anyhow::{Context, Result};
use tokenizers::Tokenizer;
use tract_onnx::prelude::*;
const MAX_SEQ_LEN: usize = 512;
pub trait EmbeddingProvider: Send + Sync {
fn embed(&self, text: &str) -> Result<Vec<f32>>;
fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
fn dimensions(&self) -> usize;
fn model_id(&self) -> &str;
}
pub struct TractProvider {
plan: TypedRunnableModel<TypedModel>,
tokenizer: Tokenizer,
model_id: String,
dimensions: usize,
}
impl TractProvider {
pub fn new() -> Result<Self> {
let cache_dir = crate::paths::model_cache_dir();
let api = hf_hub::api::sync::ApiBuilder::new()
.with_cache_dir(cache_dir)
.build()
.context("Failed to initialize HF Hub API")?;
let repo = api.model("Xenova/bge-base-en-v1.5".to_string());
let model_path = repo
.get("onnx/model.onnx")
.context("Failed to fetch ONNX model from HF Hub")?;
let tokenizer_path = repo
.get("tokenizer.json")
.context("Failed to fetch tokenizer from HF Hub")?;
let mut model = tract_onnx::onnx()
.model_for_path(&model_path)
.context("Failed to load ONNX model with tract")?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
tokenizer
.with_truncation(Some(tokenizers::TruncationParams::default()))
.map_err(|e| anyhow::anyhow!("Failed to configure truncation: {}", e))?;
let input_fact = InferenceFact::dt_shape(
i64::datum_type(),
tvec![1_usize.to_dim(), MAX_SEQ_LEN.to_dim()],
);
for i in 0..3 {
model
.set_input_fact(i, input_fact.clone())
.with_context(|| format!("Failed to set input fact for input {}", i))?;
}
let model = model.into_optimized().context("Failed to optimize model")?;
let plan = model
.into_runnable()
.context("Failed to make model runnable")?;
Ok(Self {
plan,
tokenizer,
model_id: "BAAI/bge-base-en-v1.5".to_string(),
dimensions: 768,
})
}
}
fn pad_to_max(tokens: &[i64], pad_value: i64) -> Vec<i64> {
let mut padded = vec![pad_value; MAX_SEQ_LEN];
let copy_len = tokens.len().min(MAX_SEQ_LEN);
padded[..copy_len].copy_from_slice(&tokens[..copy_len]);
padded
}
impl EmbeddingProvider for TractProvider {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!("Failed to tokenize: {}", e))?;
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
let attention_mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|&x| x as i64)
.collect();
let token_type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&x| x as i64).collect();
let input_ids = pad_to_max(&input_ids, 0);
let attention_mask = pad_to_max(&attention_mask, 0);
let token_type_ids = pad_to_max(&token_type_ids, 0);
let input_ids_tensor =
tract_ndarray::Array2::from_shape_vec((1, MAX_SEQ_LEN), input_ids)?.into_tensor();
let attention_mask_tensor =
tract_ndarray::Array2::from_shape_vec((1, MAX_SEQ_LEN), attention_mask)?.into_tensor();
let token_type_ids_tensor =
tract_ndarray::Array2::from_shape_vec((1, MAX_SEQ_LEN), token_type_ids)?.into_tensor();
let inputs = tvec![
input_ids_tensor.into(),
attention_mask_tensor.into(),
token_type_ids_tensor.into(),
];
let outputs = self.plan.run(inputs).context("Failed to run inference")?;
let output_tensor = outputs[0]
.to_array_view::<f32>()
.context("Failed to convert output to f32 array")?;
let mut cls_pooled = output_tensor.slice(tract_ndarray::s![0, 0, ..]).to_vec();
let l2: f32 = cls_pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
if l2 > 0.0 {
for v in cls_pooled.iter_mut() {
*v /= l2;
}
}
Ok(cls_pooled)
}
fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed(t.as_str())).collect()
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn model_id(&self) -> &str {
&self.model_id
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
#[serial]
fn test_embed_single() -> Result<()> {
let provider = TractProvider::new()?;
let embedding = provider.embed("Hello, world!")?;
assert_eq!(embedding.len(), 768);
Ok(())
}
#[test]
#[serial]
fn test_embed_batch() -> Result<()> {
let provider = TractProvider::new()?;
let texts = vec!["First text".to_string(), "Second text".to_string()];
let embeddings = provider.embed_batch(&texts)?;
assert_eq!(embeddings.len(), 2);
assert_eq!(embeddings[0].len(), 768);
assert_eq!(embeddings[1].len(), 768);
Ok(())
}
#[test]
#[serial]
fn test_dimensions() -> Result<()> {
let provider = TractProvider::new()?;
assert_eq!(provider.dimensions(), 768);
Ok(())
}
#[test]
#[serial]
fn test_embed_long_input_truncation() -> Result<()> {
let provider = TractProvider::new()?;
let long_text = "the quick brown fox jumps over the lazy dog ".repeat(250);
let embedding = provider.embed(&long_text)?;
assert_eq!(
embedding.len(),
768,
"Long input should produce 768-dim embedding after truncation"
);
let l2: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(l2 - 1.0).abs() < 1e-4,
"Embedding should be L2-normalized, got norm {}",
l2
);
Ok(())
}
}