use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use tract_onnx::prelude::*;
use crate::options::{ModelInfo, Pooling, TractOptions, lookup};
#[derive(Debug, thiserror::Error)]
pub enum TractError {
#[error("unknown tract embed model: {0}")]
UnknownModel(String),
#[error("tract model init failed: {0}")]
Init(String),
#[error("tract embed failed: {0}")]
Embed(String),
#[error("mutex poisoned: {0}")]
MutexPoisoned(String),
#[error("blocking task panicked: {0}")]
TaskPanicked(String),
}
#[derive(Debug, Clone)]
pub struct TractResponse {
pub embeddings: Vec<Vec<f32>>,
pub model: String,
}
type TractModel = SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>;
pub struct TractEmbedModel {
model: Arc<Mutex<TractModel>>,
tokenizer: Arc<tokenizers::Tokenizer>,
model_id: String,
dims: usize,
pooling: Pooling,
max_length: usize,
batch_size: Option<usize>,
input_count: usize,
}
impl std::fmt::Debug for TractEmbedModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TractEmbedModel")
.field("model_id", &self.model_id)
.field("dims", &self.dims)
.field("pooling", &self.pooling)
.field("max_length", &self.max_length)
.field("batch_size", &self.batch_size)
.field("input_count", &self.input_count)
.finish_non_exhaustive()
}
}
impl TractEmbedModel {
pub fn from_options(opts: TractOptions) -> Result<Self, TractError> {
let TractOptions {
model_name,
cache_dir,
max_batch_size,
show_download_progress: _,
} = opts;
let name = model_name.as_deref();
let info = lookup(name)
.ok_or_else(|| TractError::UnknownModel(name.unwrap_or("<none>").to_string()))?;
let cache = if let Some(dir) = cache_dir {
blazen_model_cache::ModelCache::with_dir(dir)
} else {
blazen_model_cache::ModelCache::new()
.map_err(|e| TractError::Init(format!("cache init failed: {e}")))?
};
let (onnx_path, tokenizer_path) = block_on_downloads(&cache, info)?;
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
.map_err(|e| TractError::Init(format!("tokenizer load failed: {e}")))?;
let model = tract_onnx::onnx()
.model_for_path(&onnx_path)
.map_err(|e| TractError::Init(format!("onnx parse failed: {e}")))?
.into_optimized()
.map_err(|e| TractError::Init(format!("onnx optimize failed: {e}")))?
.into_runnable()
.map_err(|e| TractError::Init(format!("onnx runnable failed: {e}")))?;
let input_count = model.model().inputs.len();
Ok(Self {
model: Arc::new(Mutex::new(model)),
tokenizer: Arc::new(tokenizer),
model_id: info.model_code.to_string(),
dims: info.dim,
pooling: info.pooling,
max_length: 512,
batch_size: max_batch_size,
input_count,
})
}
#[must_use]
pub fn model_id(&self) -> &str {
&self.model_id
}
#[must_use]
pub fn dimensions(&self) -> usize {
self.dims
}
pub async fn embed(&self, texts: &[String]) -> Result<TractResponse, TractError> {
if texts.is_empty() {
return Ok(TractResponse {
embeddings: Vec::new(),
model: self.model_id.clone(),
});
}
let texts_owned: Vec<String> = texts.to_vec();
let model_handle = Arc::clone(&self.model);
let tokenizer = Arc::clone(&self.tokenizer);
let dims = self.dims;
let pooling = self.pooling;
let max_length = self.max_length;
let batch_size = self.batch_size;
let input_count = self.input_count;
let model_id = self.model_id.clone();
let embeddings = tokio::task::spawn_blocking(move || {
embed_blocking(
&model_handle,
&tokenizer,
&texts_owned,
dims,
pooling,
max_length,
batch_size,
input_count,
)
})
.await
.map_err(|e| TractError::TaskPanicked(e.to_string()))??;
Ok(TractResponse {
embeddings,
model: model_id,
})
}
}
fn block_on_downloads(
cache: &blazen_model_cache::ModelCache,
info: &ModelInfo,
) -> Result<(PathBuf, PathBuf), TractError> {
let do_downloads = async {
let onnx = cache
.download(info.model_code, info.model_file, None)
.await
.map_err(|e| {
TractError::Init(format!("failed to download {}: {}", info.model_file, e))
})?;
let tokenizer = cache
.download(info.model_code, "tokenizer.json", None)
.await
.map_err(|e| TractError::Init(format!("failed to download tokenizer.json: {e}")))?;
for extra in info.additional_files {
cache
.download(info.model_code, extra, None)
.await
.map_err(|e| TractError::Init(format!("failed to download {extra}: {e}")))?;
}
Ok::<_, TractError>((onnx, tokenizer))
};
if let Ok(handle) = tokio::runtime::Handle::try_current() {
tokio::task::block_in_place(|| handle.block_on(do_downloads))
} else {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| TractError::Init(format!("runtime build failed: {e}")))?;
rt.block_on(do_downloads)
}
}
#[allow(clippy::too_many_arguments)] fn embed_blocking(
model_handle: &Mutex<TractModel>,
tokenizer: &tokenizers::Tokenizer,
texts: &[String],
dims: usize,
pooling: Pooling,
max_length: usize,
batch_size: Option<usize>,
input_count: usize,
) -> Result<Vec<Vec<f32>>, TractError> {
let chunk_size = batch_size.unwrap_or(texts.len()).max(1);
let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
for chunk in texts.chunks(chunk_size) {
let chunk_vec: Vec<String> = chunk.to_vec();
let encodings = tokenizer
.encode_batch(chunk_vec, true)
.map_err(|e| TractError::Embed(format!("tokenize failed: {e}")))?;
let batch = encodings.len();
let seq_len = encodings
.iter()
.map(tokenizers::Encoding::len)
.max()
.unwrap_or(0)
.min(max_length);
if seq_len == 0 {
for _ in 0..batch {
all_embeddings.push(vec![0.0; dims]);
}
continue;
}
let mut input_ids = vec![0_i64; batch * seq_len];
let mut attention_mask = vec![0_i64; batch * seq_len];
let mut token_type_ids = vec![0_i64; batch * seq_len];
for (row, enc) in encodings.iter().enumerate() {
let ids = enc.get_ids();
let mask = enc.get_attention_mask();
let types = enc.get_type_ids();
let take = ids.len().min(seq_len);
let base = row * seq_len;
for i in 0..take {
input_ids[base + i] = i64::from(ids[i]);
attention_mask[base + i] = i64::from(mask[i]);
token_type_ids[base + i] = i64::from(types[i]);
}
}
let ids_arr = tract_ndarray::Array2::from_shape_vec((batch, seq_len), input_ids)
.map_err(|e| TractError::Embed(format!("ids reshape failed: {e}")))?;
let mask_arr = tract_ndarray::Array2::from_shape_vec((batch, seq_len), attention_mask)
.map_err(|e| TractError::Embed(format!("mask reshape failed: {e}")))?;
let types_arr = tract_ndarray::Array2::from_shape_vec((batch, seq_len), token_type_ids)
.map_err(|e| TractError::Embed(format!("types reshape failed: {e}")))?;
let ids_tensor: Tensor = ids_arr.clone().into();
let mask_tensor: Tensor = mask_arr.clone().into();
let types_tensor: Tensor = types_arr.into();
let inputs: TVec<TValue> = if input_count >= 3 {
tvec!(ids_tensor.into(), mask_tensor.into(), types_tensor.into())
} else {
tvec!(ids_tensor.into(), mask_tensor.into())
};
let outputs = {
let locked = model_handle
.lock()
.map_err(|e| TractError::MutexPoisoned(e.to_string()))?;
locked
.run(inputs)
.map_err(|e| TractError::Embed(format!("tract run failed: {e}")))?
};
let hidden = outputs
.first()
.ok_or_else(|| TractError::Embed("no outputs from tract graph".to_string()))?;
let view = hidden
.to_array_view::<f32>()
.map_err(|e| TractError::Embed(format!("output view failed: {e}")))?;
let pooled: Vec<Vec<f32>> = match view.ndim() {
3 => {
let array = view
.view()
.into_dimensionality::<tract_ndarray::Ix3>()
.map_err(|e| TractError::Embed(format!("output ndim coerce failed: {e}")))?;
pool_hidden_states(array, &mask_arr, pooling, dims)?
}
2 => {
let array = view
.view()
.into_dimensionality::<tract_ndarray::Ix2>()
.map_err(|e| TractError::Embed(format!("output ndim coerce failed: {e}")))?;
array
.outer_iter()
.map(|row| row.iter().copied().collect::<Vec<f32>>())
.collect()
}
other => {
return Err(TractError::Embed(format!(
"unexpected output rank {other}, expected 2 or 3"
)));
}
};
for mut row in pooled {
l2_normalize(&mut row);
all_embeddings.push(row);
}
}
Ok(all_embeddings)
}
fn pool_hidden_states(
hidden: tract_ndarray::ArrayView3<f32>,
mask: &tract_ndarray::Array2<i64>,
pooling: Pooling,
dims: usize,
) -> Result<Vec<Vec<f32>>, TractError> {
let (batch, seq_len, hidden_dim) = hidden.dim();
if hidden_dim != dims {
return Err(TractError::Embed(format!(
"model output hidden size {hidden_dim} != expected dim {dims}"
)));
}
let mut out: Vec<Vec<f32>> = Vec::with_capacity(batch);
match pooling {
Pooling::Cls => {
for b in 0..batch {
let slice = hidden.slice(tract_ndarray::s![b, 0_usize, ..]);
out.push(slice.iter().copied().collect());
}
}
Pooling::Mean => {
for b in 0..batch {
let mut acc = vec![0.0_f32; hidden_dim];
let mut weight_sum: f32 = 0.0;
for t in 0..seq_len {
#[allow(clippy::cast_precision_loss)]
let w = mask[[b, t]] as f32;
if w == 0.0 {
continue;
}
weight_sum += w;
for h in 0..hidden_dim {
acc[h] += hidden[[b, t, h]] * w;
}
}
let denom = weight_sum.max(1e-12);
for v in &mut acc {
*v /= denom;
}
out.push(acc);
}
}
}
Ok(out)
}
fn l2_normalize(v: &mut [f32]) {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
for x in v.iter_mut() {
*x /= norm;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn embed_empty_input_returns_empty() {
let Ok(model) = TractEmbedModel::from_options(TractOptions::default()) else {
eprintln!("skipping embed_empty_input_returns_empty: model not available");
return;
};
let response = model.embed(&[]).await.expect("empty embed should succeed");
assert!(response.embeddings.is_empty());
assert_eq!(response.model, model.model_id());
}
#[test]
fn unknown_model_name_is_rejected() {
let opts = TractOptions {
model_name: Some("NotARealModel".to_string()),
..TractOptions::default()
};
let err = TractEmbedModel::from_options(opts).unwrap_err();
assert!(matches!(err, TractError::UnknownModel(_)));
}
#[tokio::test]
#[ignore = "requires model download from HuggingFace"]
async fn embed_returns_correct_count_and_dims() {
let model = TractEmbedModel::from_options(TractOptions::default())
.expect("should create model with default options");
let response = model
.embed(&["hello".into(), "world".into()])
.await
.expect("embedding should succeed");
assert_eq!(response.embeddings.len(), 2);
assert_eq!(response.embeddings[0].len(), model.dimensions());
let norm: f32 = response.embeddings[0]
.iter()
.map(|x| x * x)
.sum::<f32>()
.sqrt();
assert!((norm - 1.0).abs() < 1e-3, "expected ~1.0 norm, got {norm}");
}
}