pub mod models;
mod provider;
pub use models::{
EmbeddingConfig, InputNames, ModelConfig, ModelInfo, PoolingStrategy, DEFAULT_DIM,
DEFAULT_MODEL_REPO,
};
use provider::ort_err;
pub(crate) use provider::{create_session, select_provider};
use lru::LruCache;
use ndarray::{Array2, Array3, Axis};
use once_cell::sync::OnceCell;
use ort::session::Session;
use std::num::NonZeroUsize;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use thiserror::Error;
pub fn model_repo() -> String {
ModelConfig::resolve(None, None).repo
}
const MODEL_BLAKE3: &str = "";
const TOKENIZER_BLAKE3: &str = "";
#[derive(Error, Debug)]
pub enum EmbedderError {
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Tokenizer error: {0}")]
Tokenizer(String),
#[error("Inference failed: {0}")]
InferenceFailed(String),
#[error("Checksum mismatch for {path}: expected {expected}, got {actual}")]
ChecksumMismatch {
path: String,
expected: String,
actual: String,
},
#[error("Query cannot be empty")]
EmptyQuery,
#[error("HuggingFace Hub error: {0}")]
HfHub(String),
}
#[derive(Debug, Clone)]
pub struct Embedding(Vec<f32>);
pub use crate::EMBEDDING_DIM;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EmbeddingDimensionError {
pub actual: usize,
pub expected: usize,
}
impl std::fmt::Display for EmbeddingDimensionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Invalid embedding dimension: expected {}, got {}",
self.expected, self.actual
)
}
}
impl std::error::Error for EmbeddingDimensionError {}
impl Embedding {
pub fn new(data: Vec<f32>) -> Self {
Self(data)
}
pub fn try_new(data: Vec<f32>) -> Result<Self, EmbeddingDimensionError> {
if data.is_empty() {
return Err(EmbeddingDimensionError {
actual: 0,
expected: 1, });
}
if !data.iter().all(|v| v.is_finite()) {
return Err(EmbeddingDimensionError {
actual: data.len(),
expected: data.len(),
});
}
Ok(Self(data))
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
pub fn as_vec(&self) -> &Vec<f32> {
&self.0
}
pub fn into_inner(self) -> Vec<f32> {
self.0
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
#[derive(Debug, Clone, Copy)]
pub enum ExecutionProvider {
CUDA { device_id: i32 },
TensorRT { device_id: i32 },
CPU,
}
impl std::fmt::Display for ExecutionProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExecutionProvider::CUDA { device_id } => write!(f, "CUDA (device {})", device_id),
ExecutionProvider::TensorRT { device_id } => {
write!(f, "TensorRT (device {})", device_id)
}
ExecutionProvider::CPU => write!(f, "CPU"),
}
}
}
pub struct Embedder {
session: Mutex<Option<Session>>,
tokenizer: Mutex<Option<Arc<tokenizers::Tokenizer>>>,
model_paths: OnceCell<(PathBuf, PathBuf)>,
provider: ExecutionProvider,
max_length: usize,
query_cache: Mutex<LruCache<String, Embedding>>,
disk_query_cache: Option<crate::cache::QueryCache>,
detected_dim: std::sync::OnceLock<usize>,
model_config: ModelConfig,
model_fingerprint: std::sync::OnceLock<String>,
}
const DEFAULT_QUERY_CACHE_SIZE: usize = 128;
impl Embedder {
pub fn new(model_config: ModelConfig) -> Result<Self, EmbedderError> {
Self::new_with_provider(model_config, select_provider())
}
pub fn new_cpu(model_config: ModelConfig) -> Result<Self, EmbedderError> {
Self::new_with_provider(model_config, ExecutionProvider::CPU)
}
fn new_with_provider(
model_config: ModelConfig,
provider: ExecutionProvider,
) -> Result<Self, EmbedderError> {
let max_length = model_config.max_seq_length;
let cache_size = match std::env::var("CQS_QUERY_CACHE_SIZE") {
Ok(val) => match val.parse::<usize>() {
Ok(n) if n > 0 => {
tracing::info!(
size = n,
"Query cache size override from CQS_QUERY_CACHE_SIZE"
);
n
}
_ => {
tracing::warn!(
value = %val,
"Invalid CQS_QUERY_CACHE_SIZE (must be positive integer), using default {DEFAULT_QUERY_CACHE_SIZE}"
);
DEFAULT_QUERY_CACHE_SIZE
}
},
Err(_) => DEFAULT_QUERY_CACHE_SIZE,
};
let query_cache = Mutex::new(LruCache::new(
NonZeroUsize::new(cache_size).expect("cache_size is non-zero"),
));
let disk_query_cache =
match crate::cache::QueryCache::open(&crate::cache::QueryCache::default_path()) {
Ok(c) => {
let _ = c.prune_older_than(7);
Some(c)
}
Err(e) => {
tracing::debug!(error = %e, "Disk query cache unavailable (non-fatal)");
None
}
};
Ok(Self {
session: Mutex::new(None),
tokenizer: Mutex::new(None),
model_paths: OnceCell::new(),
provider,
max_length,
query_cache,
disk_query_cache,
detected_dim: std::sync::OnceLock::new(),
model_config,
model_fingerprint: std::sync::OnceLock::new(),
})
}
pub fn model_config(&self) -> &ModelConfig {
&self.model_config
}
pub fn model_fingerprint(&self) -> &str {
self.model_fingerprint.get_or_init(|| {
let _span = tracing::info_span!("compute_model_fingerprint").entered();
match self.model_paths() {
Ok((model_path, _)) => {
match std::fs::metadata(model_path) {
Ok(meta) if meta.len() > 2 * 1024 * 1024 * 1024 => {
let mtime = meta
.modified()
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
.map(|d| d.as_secs())
.unwrap_or(0);
let fp = format!(
"{}_{}_{}",
self.model_config.repo,
meta.len(),
mtime
);
tracing::info!(size = meta.len(), "Model >2GB, using metadata fingerprint");
fp
}
_ => {
match std::fs::File::open(model_path) {
Ok(file) => {
let mut hasher = blake3::Hasher::new();
match hasher.update_reader(file) {
Ok(_) => {
let hash =
hasher.finalize().to_hex().to_string();
tracing::info!(
hash = &hash[..16],
"Model fingerprint computed (streaming)"
);
hash
}
Err(e) => {
tracing::warn!(error = %e, "Failed to stream-hash model, using repo+timestamp fallback");
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
format!("{}:{}", self.model_config.repo, ts)
}
}
}
Err(e) => {
tracing::warn!(error = %e, "Failed to open model for fingerprint, using repo+timestamp fallback");
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
format!("{}:{}", self.model_config.repo, ts)
}
}
}
}
}
Err(e) => {
tracing::warn!(error = %e, "Failed to get model paths for fingerprint, using repo+timestamp fallback");
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
format!("{}:{}", self.model_config.repo, ts)
}
}
})
}
fn model_paths(&self) -> Result<&(PathBuf, PathBuf), EmbedderError> {
self.model_paths
.get_or_try_init(|| ensure_model(&self.model_config))
}
fn session(&self) -> Result<std::sync::MutexGuard<'_, Option<Session>>, EmbedderError> {
let mut guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
if guard.is_none() {
let _span = tracing::info_span!("embedder_session_init").entered();
let (model_path, _) = self.model_paths()?;
*guard = Some(create_session(model_path, self.provider)?);
tracing::info!("Embedder session initialized");
}
Ok(guard)
}
fn tokenizer(&self) -> Result<Arc<tokenizers::Tokenizer>, EmbedderError> {
{
let guard = self.tokenizer.lock().unwrap_or_else(|p| p.into_inner());
if let Some(t) = guard.as_ref() {
return Ok(Arc::clone(t));
}
}
let (_, tokenizer_path) = self.model_paths()?;
let loaded = Arc::new(
tokenizers::Tokenizer::from_file(tokenizer_path)
.map_err(|e| EmbedderError::Tokenizer(e.to_string()))?,
);
let mut guard = self.tokenizer.lock().unwrap_or_else(|p| p.into_inner());
if let Some(existing) = guard.as_ref() {
return Ok(Arc::clone(existing));
}
*guard = Some(Arc::clone(&loaded));
Ok(loaded)
}
pub fn token_count(&self, text: &str) -> Result<usize, EmbedderError> {
let encoding = self
.tokenizer()?
.encode(text, false)
.map_err(|e| EmbedderError::Tokenizer(e.to_string()))?;
Ok(encoding.get_ids().len())
}
pub fn token_counts_batch(&self, texts: &[&str]) -> Result<Vec<usize>, EmbedderError> {
if texts.is_empty() {
return Ok(vec![]);
}
let encodings = self
.tokenizer()?
.encode_batch(texts.to_vec(), false)
.map_err(|e| EmbedderError::Tokenizer(e.to_string()))?;
Ok(encodings.iter().map(|e| e.get_ids().len()).collect())
}
pub fn split_into_windows(
&self,
text: &str,
max_tokens: usize,
overlap: usize,
) -> Result<Vec<(String, u32)>, EmbedderError> {
if max_tokens == 0 {
return Ok(vec![]);
}
if overlap >= max_tokens / 2 {
return Err(EmbedderError::Tokenizer(format!(
"overlap ({overlap}) must be less than max_tokens/2 ({})",
max_tokens / 2
)));
}
let tokenizer = self.tokenizer()?;
let encoding = tokenizer
.encode(text, false)
.map_err(|e| EmbedderError::Tokenizer(e.to_string()))?;
let ids = encoding.get_ids();
if ids.len() <= max_tokens {
return Ok(vec![(text.to_string(), 0)]);
}
let mut windows = Vec::new();
let step = max_tokens - overlap;
let mut start = 0;
let mut window_idx = 0u32;
while start < ids.len() {
let end = (start + max_tokens).min(ids.len());
let window_ids: Vec<u32> = ids[start..end].to_vec();
let window_text = tokenizer
.decode(&window_ids, true)
.map_err(|e| EmbedderError::Tokenizer(e.to_string()))?;
windows.push((window_text, window_idx));
window_idx += 1;
if end >= ids.len() {
break;
}
start += step;
}
Ok(windows)
}
pub fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedderError> {
let _span = tracing::info_span!("embed_documents", count = texts.len()).entered();
let prefix = &self.model_config.doc_prefix;
let max_batch: usize = std::env::var("CQS_EMBED_BATCH_SIZE")
.ok()
.and_then(|v| v.parse().ok())
.filter(|&n: &usize| n > 0)
.unwrap_or(64);
if texts.len() <= max_batch {
let prefixed: Vec<String> = texts.iter().map(|t| format!("{}{}", prefix, t)).collect();
return self.embed_batch(&prefixed);
}
let mut all = Vec::with_capacity(texts.len());
for chunk in texts.chunks(max_batch) {
let prefixed: Vec<String> = chunk.iter().map(|t| format!("{}{}", prefix, t)).collect();
all.extend(self.embed_batch(&prefixed)?);
}
Ok(all)
}
fn max_query_bytes() -> usize {
std::env::var("CQS_MAX_QUERY_BYTES")
.ok()
.and_then(|v| v.parse().ok())
.filter(|&n: &usize| n > 0)
.unwrap_or(32 * 1024)
}
pub fn embed_query(&self, text: &str) -> Result<Embedding, EmbedderError> {
let _span = tracing::info_span!("embed_query").entered();
let text = text.trim();
if text.is_empty() {
return Err(EmbedderError::EmptyQuery);
}
let max_query_bytes = Self::max_query_bytes();
let text = if text.len() > max_query_bytes {
tracing::warn!(
len = text.len(),
max = max_query_bytes,
"Query text truncated before embedding"
);
let mut end = max_query_bytes;
while !text.is_char_boundary(end) && end > 0 {
end -= 1;
}
&text[..end]
} else {
text
};
{
let mut cache = self.query_cache.lock().unwrap_or_else(|poisoned| {
tracing::warn!("Query cache lock poisoned (prior panic), recovering");
poisoned.into_inner()
});
if let Some(cached) = cache.get(text) {
tracing::trace!(query = text, "Query cache hit (memory)");
return Ok(cached.clone());
}
}
let model_fp = self.model_fingerprint();
if let Some(ref disk) = self.disk_query_cache {
if let Some(cached) = disk.get(text, model_fp) {
tracing::trace!(query = text, "Query cache hit (disk)");
let mut cache = self.query_cache.lock().unwrap_or_else(|p| p.into_inner());
cache.put(text.to_string(), cached.clone());
return Ok(cached);
}
}
tracing::trace!(query = text, "Query cache miss");
let prefixed = format!("{}{}", self.model_config.query_prefix, text);
let results = self.embed_batch(&[prefixed])?;
let base_embedding = results.into_iter().next().ok_or_else(|| {
EmbedderError::InferenceFailed("embed_batch returned empty result".to_string())
})?;
let embedding = base_embedding;
{
let mut cache = self.query_cache.lock().unwrap_or_else(|poisoned| {
tracing::warn!("Query cache lock poisoned (prior panic), recovering");
poisoned.into_inner()
});
cache.put(text.to_string(), embedding.clone());
}
if let Some(ref disk) = self.disk_query_cache {
disk.put(text, model_fp, &embedding);
}
Ok(embedding)
}
pub fn provider(&self) -> ExecutionProvider {
self.provider
}
pub fn clear_session(&self) {
let mut guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
*guard = None;
let mut cache = self.query_cache.lock().unwrap_or_else(|p| p.into_inner());
cache.clear();
let mut tok = self.tokenizer.lock().unwrap_or_else(|p| p.into_inner());
*tok = None;
tracing::info!("Embedder session, query cache, and tokenizer cleared");
}
pub fn warm(&self) -> Result<(), EmbedderError> {
let _ = self.embed_query("warmup")?;
Ok(())
}
pub fn embedding_dim(&self) -> usize {
let dim = *self.detected_dim.get().unwrap_or(&self.model_config.dim);
if dim == 0 {
EMBEDDING_DIM
} else {
dim
}
}
fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, EmbedderError> {
use ort::session::SessionInputValue;
use ort::value::Tensor;
use std::borrow::Cow;
let _span = tracing::info_span!("embed_batch", count = texts.len()).entered();
if texts.is_empty() {
return Ok(vec![]);
}
let encodings = {
let _tokenize = tracing::debug_span!("tokenize").entered();
self.tokenizer()?
.encode_batch(texts.to_vec(), true)
.map_err(|e| EmbedderError::Tokenizer(e.to_string()))?
};
let input_ids: Vec<Vec<i64>> = encodings
.iter()
.map(|e| e.get_ids().iter().map(|&id| id as i64).collect())
.collect();
let attention_mask: Vec<Vec<i64>> = encodings
.iter()
.map(|e| e.get_attention_mask().iter().map(|&m| m as i64).collect())
.collect();
let max_len = input_ids
.iter()
.map(|v| v.len())
.max()
.unwrap_or(0)
.min(self.max_length);
let input_ids_arr = pad_2d_i64(&input_ids, max_len, 0);
let attention_mask_arr = pad_2d_i64(&attention_mask, max_len, 0);
let input_ids_tensor = Tensor::from_array(input_ids_arr).map_err(ort_err)?;
let attention_mask_tensor = Tensor::from_array(attention_mask_arr).map_err(ort_err)?;
let names = &self.model_config.input_names;
let mut inputs: Vec<(Cow<'_, str>, SessionInputValue<'_>)> = Vec::with_capacity(3);
inputs.push((
Cow::Borrowed(names.ids.as_str()),
SessionInputValue::from(input_ids_tensor),
));
inputs.push((
Cow::Borrowed(names.mask.as_str()),
SessionInputValue::from(attention_mask_tensor),
));
if let Some(ref tt_name) = names.token_types {
let token_type_ids_arr = Array2::<i64>::zeros((texts.len(), max_len));
let token_type_ids_tensor = Tensor::from_array(token_type_ids_arr).map_err(ort_err)?;
inputs.push((
Cow::Borrowed(tt_name.as_str()),
SessionInputValue::from(token_type_ids_tensor),
));
}
let mut guard = self.session()?;
let session = guard
.as_mut()
.expect("session() guarantees initialized after Ok return");
let _inference = tracing::debug_span!("inference", max_len).entered();
let outputs = session.run(inputs).map_err(ort_err)?;
let output_name = self.model_config.output_name.as_str();
let output = outputs.get(output_name).ok_or_else(|| {
EmbedderError::InferenceFailed(format!(
"ONNX model has no '{}' output. Available: {:?}",
output_name,
outputs.keys().collect::<Vec<_>>()
))
})?;
let (shape, data) = output.try_extract_tensor::<f32>().map_err(ort_err)?;
let batch_size = texts.len();
let seq_len = max_len;
if shape.len() != 3 {
return Err(EmbedderError::InferenceFailed(format!(
"Unexpected tensor shape: expected 3 dimensions [batch, seq, dim], got {} dimensions",
shape.len()
)));
}
let embedding_dim = shape[2] as usize;
match self.detected_dim.get() {
Some(&expected) if expected != embedding_dim => {
return Err(EmbedderError::InferenceFailed(format!(
"Embedding dimension changed: expected {expected}, got {embedding_dim}"
)));
}
None => {
let _ = self.detected_dim.set(embedding_dim);
tracing::info!(
dim = embedding_dim,
"Detected embedding dimension from model"
);
}
_ => {} }
if shape[0] as usize != batch_size {
return Err(EmbedderError::InferenceFailed(format!(
"Tensor batch size mismatch: expected {}, got {}",
batch_size, shape[0]
)));
}
let hidden = Array3::from_shape_vec((batch_size, seq_len, embedding_dim), data.to_vec())
.map_err(|e| EmbedderError::InferenceFailed(format!("tensor reshape failed: {e}")))?;
let pooled_batch: Vec<Vec<f32>> = match self.model_config.pooling {
PoolingStrategy::Mean => mean_pool(&hidden, &attention_mask, embedding_dim),
PoolingStrategy::Cls => cls_pool(&hidden),
PoolingStrategy::LastToken => last_token_pool(&hidden, &attention_mask),
};
let results = pooled_batch
.into_iter()
.map(|v| Embedding::new(normalize_l2(v)))
.collect();
Ok(results)
}
}
fn ensure_model(config: &ModelConfig) -> Result<(PathBuf, PathBuf), EmbedderError> {
if let Ok(dir) = std::env::var("CQS_ONNX_DIR") {
let dir = dunce::canonicalize(PathBuf::from(&dir)).unwrap_or_else(|_| PathBuf::from(dir));
let model_path = dir.join(&config.onnx_path);
let tokenizer_path = dir.join(&config.tokenizer_path);
for (label, path) in [("model", &model_path), ("tokenizer", &tokenizer_path)] {
if let Ok(canonical) = dunce::canonicalize(path) {
if !canonical.starts_with(&dir) {
return Err(EmbedderError::ModelNotFound(format!(
"SEC-3: {} path escapes CQS_ONNX_DIR: {} resolves to {}",
label,
path.display(),
canonical.display()
)));
}
}
}
if model_path.exists() && tokenizer_path.exists() {
tracing::info!(dir = %dir.display(), "Using local ONNX model directory");
return Ok((model_path, tokenizer_path));
}
let flat_model = dir.join("model.onnx");
let flat_tok = dir.join("tokenizer.json");
if flat_model.exists() && flat_tok.exists() {
tracing::info!(dir = %dir.display(), "Using local ONNX model directory (flat)");
return Ok((flat_model, flat_tok));
}
tracing::warn!(dir = %dir.display(), "CQS_ONNX_DIR set but model files not found, falling back to HF download");
}
use hf_hub::api::sync::Api;
let api = Api::new().map_err(|e| EmbedderError::HfHub(e.to_string()))?;
let repo = api.model(config.repo.clone());
let model_path = repo
.get(&config.onnx_path)
.map_err(|e| EmbedderError::HfHub(e.to_string()))?;
let tokenizer_path = repo
.get(&config.tokenizer_path)
.map_err(|e| EmbedderError::HfHub(e.to_string()))?;
if !MODEL_BLAKE3.is_empty() || !TOKENIZER_BLAKE3.is_empty() {
let marker = model_path
.parent()
.unwrap_or(Path::new("."))
.join(".cqs_verified");
let expected_marker = format!("{}\n{}", MODEL_BLAKE3, TOKENIZER_BLAKE3);
let already_verified = std::fs::read_to_string(&marker)
.map(|s| s == expected_marker)
.unwrap_or(false);
if !already_verified {
if !MODEL_BLAKE3.is_empty() {
verify_checksum(&model_path, MODEL_BLAKE3)?;
}
if !TOKENIZER_BLAKE3.is_empty() {
verify_checksum(&tokenizer_path, TOKENIZER_BLAKE3)?;
}
let _ = std::fs::write(&marker, &expected_marker);
}
}
Ok((model_path, tokenizer_path))
}
fn verify_checksum(path: &Path, expected: &str) -> Result<(), EmbedderError> {
let mut file =
std::fs::File::open(path).map_err(|e| EmbedderError::ModelNotFound(e.to_string()))?;
let mut hasher = blake3::Hasher::new();
std::io::copy(&mut file, &mut hasher)
.map_err(|e| EmbedderError::ModelNotFound(e.to_string()))?;
let actual = hasher.finalize().to_hex().to_string();
if actual != expected {
return Err(EmbedderError::ChecksumMismatch {
path: path.display().to_string(),
expected: expected.to_string(),
actual,
});
}
Ok(())
}
pub(crate) fn pad_2d_i64(inputs: &[Vec<i64>], max_len: usize, pad_value: i64) -> Array2<i64> {
let batch_size = inputs.len();
let mut arr = Array2::from_elem((batch_size, max_len), pad_value);
for (i, seq) in inputs.iter().enumerate() {
for (j, &val) in seq.iter().take(max_len).enumerate() {
arr[[i, j]] = val;
}
}
arr
}
fn normalize_l2(mut v: Vec<f32>) -> Vec<f32> {
let norm_sq: f32 = v.iter().fold(0.0, |acc, &x| acc + x * x);
if norm_sq > 0.0 {
let inv_norm = 1.0 / norm_sq.sqrt();
v.iter_mut().for_each(|x| *x *= inv_norm);
}
v
}
fn mean_pool(
hidden: &Array3<f32>,
attention_mask: &[Vec<i64>],
embedding_dim: usize,
) -> Vec<Vec<f32>> {
let (batch_size, seq_len, _) = hidden.dim();
let mask_2d = Array2::from_shape_fn((batch_size, seq_len), |(i, j)| {
attention_mask[i].get(j).copied().unwrap_or(0) as f32
});
let mask_3d = mask_2d.clone().insert_axis(Axis(2));
let masked = hidden * &mask_3d;
let summed = masked.sum_axis(Axis(1)); let counts = mask_2d.sum_axis(Axis(1)).insert_axis(Axis(1));
(0..batch_size)
.map(|i| {
let count = counts[[i, 0]];
let row = summed.row(i);
if count > 0.0 {
row.iter().map(|v| v / count).collect()
} else {
tracing::warn!(batch_idx = i, "Zero attention mask — producing zero vector");
vec![0.0f32; embedding_dim]
}
})
.collect()
}
fn cls_pool(hidden: &Array3<f32>) -> Vec<Vec<f32>> {
let (batch_size, _, _) = hidden.dim();
(0..batch_size)
.map(|i| hidden.slice(ndarray::s![i, 0usize, ..]).to_vec())
.collect()
}
fn last_token_pool(hidden: &Array3<f32>, attention_mask: &[Vec<i64>]) -> Vec<Vec<f32>> {
let (batch_size, seq_len, _) = hidden.dim();
(0..batch_size)
.map(|i| {
let mask_row = attention_mask.get(i);
let last_idx = mask_row
.and_then(|row| {
row.iter().take(seq_len).rposition(|&m| m != 0).or_else(|| {
tracing::warn!(
batch_idx = i,
"last_token_pool: zero attention mask — using index 0"
);
None
})
})
.unwrap_or(0);
hidden.slice(ndarray::s![i, last_idx, ..]).to_vec()
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_new() {
let data = vec![0.5; EMBEDDING_DIM];
let emb = Embedding::new(data.clone());
assert_eq!(emb.as_slice(), &data);
}
#[test]
fn test_embedding_len() {
let emb = Embedding::new(vec![1.0; EMBEDDING_DIM]);
assert_eq!(emb.len(), EMBEDDING_DIM);
}
#[test]
fn test_embedding_is_empty() {
let empty = Embedding::new(vec![]);
assert!(empty.is_empty());
let non_empty = Embedding::new(vec![1.0; EMBEDDING_DIM]);
assert!(!non_empty.is_empty());
}
#[test]
fn test_embedding_into_inner() {
let data = vec![1.0; EMBEDDING_DIM];
let emb = Embedding::new(data.clone());
assert_eq!(emb.into_inner(), data);
}
#[test]
fn test_embedding_as_vec() {
let data = vec![1.0; EMBEDDING_DIM];
let emb = Embedding::new(data.clone());
assert_eq!(emb.as_vec(), &data);
}
#[test]
fn tc33_try_new_empty_vec_errors() {
let result = Embedding::try_new(vec![]);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.actual, 0);
assert_eq!(err.expected, 1);
}
#[test]
fn tc33_try_new_nan_errors() {
let result = Embedding::try_new(vec![1.0, f32::NAN, 3.0]);
assert!(result.is_err(), "NaN should be rejected by try_new");
}
#[test]
fn tc33_try_new_inf_errors() {
let result = Embedding::try_new(vec![1.0, f32::INFINITY, 3.0]);
assert!(result.is_err(), "Infinity should be rejected by try_new");
let result = Embedding::try_new(vec![f32::NEG_INFINITY]);
assert!(result.is_err(), "Negative infinity should be rejected");
}
#[test]
fn tc33_try_new_valid_ok() {
let data = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let result = Embedding::try_new(data.clone());
assert!(result.is_ok());
assert_eq!(result.unwrap().as_slice(), &data);
}
#[test]
fn test_normalize_l2_unit_vector() {
let v = normalize_l2(vec![1.0, 0.0, 0.0]);
assert!((v[0] - 1.0).abs() < 1e-6);
assert!((v[1] - 0.0).abs() < 1e-6);
assert!((v[2] - 0.0).abs() < 1e-6);
}
#[test]
fn test_normalize_l2_produces_unit_vector() {
let v = normalize_l2(vec![3.0, 4.0]);
assert!((v[0] - 0.6).abs() < 1e-6);
assert!((v[1] - 0.8).abs() < 1e-6);
let magnitude: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 1e-6);
}
#[test]
fn test_normalize_l2_zero_vector() {
let v = normalize_l2(vec![0.0, 0.0, 0.0]);
assert_eq!(v, vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_normalize_l2_empty_vector() {
let v = normalize_l2(vec![]);
assert!(v.is_empty());
}
fn make_hidden(values: Vec<Vec<Vec<f32>>>) -> Array3<f32> {
let batch = values.len();
let seq = values[0].len();
let dim = values[0][0].len();
let flat: Vec<f32> = values.into_iter().flatten().flatten().collect();
Array3::from_shape_vec((batch, seq, dim), flat).expect("synthetic shape mismatch")
}
#[test]
fn mean_pool_respects_mask() {
let hidden = make_hidden(vec![vec![
vec![1.0, 2.0],
vec![3.0, 4.0],
vec![100.0, 200.0], ]]);
let mask = vec![vec![1i64, 1, 0]];
let pooled = mean_pool(&hidden, &mask, 2);
assert_eq!(pooled.len(), 1, "one batch item");
assert!((pooled[0][0] - 2.0).abs() < 1e-6);
assert!((pooled[0][1] - 3.0).abs() < 1e-6);
}
#[test]
fn mean_pool_zero_mask_returns_zero_vector() {
let hidden = make_hidden(vec![vec![vec![5.0, 5.0], vec![6.0, 6.0]]]);
let mask = vec![vec![0i64, 0]];
let pooled = mean_pool(&hidden, &mask, 2);
assert_eq!(pooled[0], vec![0.0, 0.0]);
}
#[test]
fn cls_pool_returns_first_token() {
let hidden = make_hidden(vec![
vec![vec![1.0, 2.0], vec![9.9, 9.9]],
vec![vec![3.0, 4.0], vec![7.7, 7.7]],
]);
let pooled = cls_pool(&hidden);
assert_eq!(pooled.len(), 2);
assert_eq!(pooled[0], vec![1.0, 2.0]);
assert_eq!(pooled[1], vec![3.0, 4.0]);
}
#[test]
fn last_token_pool_picks_last_unmasked() {
let hidden = make_hidden(vec![
vec![
vec![0.0, 0.0],
vec![0.0, 0.0],
vec![42.0, 43.0], vec![9.0, 9.0],
],
vec![
vec![11.0, 12.0], vec![0.0, 0.0],
vec![0.0, 0.0],
vec![0.0, 0.0],
],
]);
let mask = vec![vec![1i64, 1, 1, 0], vec![1i64, 0, 0, 0]];
let pooled = last_token_pool(&hidden, &mask);
assert_eq!(pooled[0], vec![42.0, 43.0]);
assert_eq!(pooled[1], vec![11.0, 12.0]);
}
#[test]
fn last_token_pool_zero_mask_falls_back_to_index_0() {
let hidden = make_hidden(vec![vec![vec![7.0, 8.0], vec![9.0, 10.0]]]);
let mask = vec![vec![0i64, 0]];
let pooled = last_token_pool(&hidden, &mask);
assert_eq!(pooled[0], vec![7.0, 8.0]);
}
#[test]
fn test_execution_provider_display() {
assert_eq!(format!("{}", ExecutionProvider::CPU), "CPU");
assert_eq!(
format!("{}", ExecutionProvider::CUDA { device_id: 0 }),
"CUDA (device 0)"
);
assert_eq!(
format!("{}", ExecutionProvider::TensorRT { device_id: 1 }),
"TensorRT (device 1)"
);
}
#[test]
fn test_model_dimensions() {
assert_eq!(EMBEDDING_DIM, 1024);
}
#[test]
fn test_pad_2d_i64_basic() {
let inputs = vec![vec![1, 2, 3], vec![4, 5]];
let result = pad_2d_i64(&inputs, 4, 0);
assert_eq!(result.shape(), &[2, 4]);
assert_eq!(result[[0, 0]], 1);
assert_eq!(result[[0, 1]], 2);
assert_eq!(result[[0, 2]], 3);
assert_eq!(result[[0, 3]], 0); assert_eq!(result[[1, 0]], 4);
assert_eq!(result[[1, 1]], 5);
assert_eq!(result[[1, 2]], 0); assert_eq!(result[[1, 3]], 0); }
#[test]
fn test_pad_2d_i64_truncates() {
let inputs = vec![vec![1, 2, 3, 4, 5]];
let result = pad_2d_i64(&inputs, 3, 0);
assert_eq!(result.shape(), &[1, 3]);
assert_eq!(result[[0, 0]], 1);
assert_eq!(result[[0, 1]], 2);
assert_eq!(result[[0, 2]], 3);
}
#[test]
fn test_pad_2d_i64_empty_input() {
let inputs: Vec<Vec<i64>> = vec![];
let result = pad_2d_i64(&inputs, 5, 0);
assert_eq!(result.shape(), &[0, 5]);
}
#[test]
fn test_pad_2d_i64_custom_pad_value() {
let inputs = vec![vec![1]];
let result = pad_2d_i64(&inputs, 3, -1);
assert_eq!(result[[0, 0]], 1);
assert_eq!(result[[0, 1]], -1);
assert_eq!(result[[0, 2]], -1);
}
#[test]
fn test_embedder_error_display() {
let err = EmbedderError::EmptyQuery;
assert_eq!(format!("{}", err), "Query cannot be empty");
let err = EmbedderError::ModelNotFound("model.onnx".to_string());
assert!(format!("{}", err).contains("model.onnx"));
let err = EmbedderError::Tokenizer("invalid token".to_string());
assert!(format!("{}", err).contains("invalid token"));
let err = EmbedderError::ChecksumMismatch {
path: "/path/to/file".to_string(),
expected: "abc123".to_string(),
actual: "def456".to_string(),
};
assert!(format!("{}", err).contains("abc123"));
assert!(format!("{}", err).contains("def456"));
}
#[test]
fn test_embedder_error_from_ort() {
let err: EmbedderError = EmbedderError::InferenceFailed("test error".to_string());
assert!(matches!(err, EmbedderError::InferenceFailed(_)));
}
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn prop_normalize_l2_unit_or_zero(v in prop::collection::vec(-1e6f32..1e6f32, 1..100)) {
let normalized = normalize_l2(v.clone());
let magnitude: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
let input_is_zero = v.iter().all(|&x| x == 0.0);
if input_is_zero {
prop_assert!(magnitude < 1e-6, "Zero input should give zero output");
} else {
prop_assert!(
(magnitude - 1.0).abs() < 1e-4,
"Non-zero input should give unit vector, got magnitude {}",
magnitude
);
}
}
#[test]
fn prop_normalize_l2_preserves_direction(v in prop::collection::vec(1.0f32..100.0, 1..50)) {
let normalized = normalize_l2(v.clone());
let dot: f32 = v.iter().zip(normalized.iter()).map(|(a, b)| a * b).sum();
prop_assert!(dot > 0.0, "Direction should be preserved");
}
#[test]
fn prop_embedding_length_preserved(use_model_dim in proptest::bool::ANY) {
let _ = use_model_dim; let emb = Embedding::new(vec![0.5; EMBEDDING_DIM]);
prop_assert_eq!(emb.len(), EMBEDDING_DIM);
prop_assert_eq!(emb.as_slice().len(), EMBEDDING_DIM);
prop_assert_eq!(emb.as_vec().len(), EMBEDDING_DIM);
}
}
}
#[test]
#[ignore] fn test_clear_session_and_reinit() {
let embedder = Embedder::new(ModelConfig::e5_base()).unwrap();
let _ = embedder.embed_query("test");
embedder.clear_session();
let result = embedder.embed_query("test again");
assert!(result.is_ok());
}
#[test]
fn test_clear_session_idempotent() {
let embedder = Embedder::new_cpu(ModelConfig::e5_base()).unwrap();
embedder.clear_session(); embedder.clear_session(); }
mod integration {
use super::*;
#[test]
#[ignore] fn test_token_count_empty() {
let embedder =
Embedder::new(ModelConfig::e5_base()).expect("Failed to create embedder");
let count = embedder.token_count("").expect("token_count failed");
assert_eq!(count, 0);
}
#[test]
#[ignore]
fn test_token_count_simple() {
let embedder =
Embedder::new(ModelConfig::e5_base()).expect("Failed to create embedder");
let count = embedder
.token_count("hello world")
.expect("token_count failed");
assert!(
(2..=4).contains(&count),
"Expected 2-4 tokens, got {}",
count
);
}
#[test]
#[ignore]
fn test_token_count_code() {
let embedder =
Embedder::new(ModelConfig::e5_base()).expect("Failed to create embedder");
let code = "fn main() { println!(\"Hello\"); }";
let count = embedder.token_count(code).expect("token_count failed");
assert!(count > 5, "Expected >5 tokens for code, got {}", count);
}
#[test]
#[ignore]
fn test_token_count_unicode() {
let embedder =
Embedder::new(ModelConfig::e5_base()).expect("Failed to create embedder");
let text = "\u{3053}\u{3093}\u{306b}\u{3061}\u{306f}\u{4e16}\u{754c}"; let count = embedder.token_count(text).expect("token_count failed");
assert!(count > 0, "Expected >0 tokens for unicode, got {}", count);
}
}
mod ensure_model_tests {
use super::*;
use std::sync::Mutex;
static ONNX_DIR_MUTEX: Mutex<()> = Mutex::new(());
fn test_model_config() -> ModelConfig {
ModelConfig {
name: "test".to_string(),
repo: "test/model".to_string(),
onnx_path: "onnx/model.onnx".to_string(),
tokenizer_path: "tokenizer.json".to_string(),
dim: 768,
max_seq_length: 512,
query_prefix: String::new(),
doc_prefix: String::new(),
input_names: crate::embedder::models::InputNames::bert(),
output_name: "last_hidden_state".to_string(),
pooling: crate::embedder::models::PoolingStrategy::Mean,
}
}
#[test]
fn cqs_onnx_dir_structured_layout() {
let _lock = ONNX_DIR_MUTEX.lock().unwrap();
let dir = tempfile::TempDir::new().unwrap();
let onnx_dir = dir.path().join("onnx");
std::fs::create_dir_all(&onnx_dir).unwrap();
std::fs::write(onnx_dir.join("model.onnx"), b"fake").unwrap();
std::fs::write(dir.path().join("tokenizer.json"), b"fake").unwrap();
std::env::set_var("CQS_ONNX_DIR", dir.path().to_str().unwrap());
let result = ensure_model(&test_model_config());
std::env::remove_var("CQS_ONNX_DIR");
let (model, tok) = result.unwrap();
assert!(
model.to_string_lossy().ends_with("model.onnx"),
"Expected model path ending in model.onnx, got {:?}",
model
);
assert!(
tok.to_string_lossy().ends_with("tokenizer.json"),
"Expected tokenizer path ending in tokenizer.json, got {:?}",
tok
);
}
#[test]
fn cqs_onnx_dir_flat_layout() {
let _lock = ONNX_DIR_MUTEX.lock().unwrap();
let dir = tempfile::TempDir::new().unwrap();
std::fs::write(dir.path().join("model.onnx"), b"fake").unwrap();
std::fs::write(dir.path().join("tokenizer.json"), b"fake").unwrap();
std::env::set_var("CQS_ONNX_DIR", dir.path().to_str().unwrap());
let result = ensure_model(&test_model_config());
std::env::remove_var("CQS_ONNX_DIR");
let (model, tok) = result.unwrap();
assert!(
model.to_string_lossy().ends_with("model.onnx"),
"Expected model path ending in model.onnx, got {:?}",
model
);
assert!(
tok.to_string_lossy().ends_with("tokenizer.json"),
"Expected tokenizer path ending in tokenizer.json, got {:?}",
tok
);
}
#[test]
fn cqs_onnx_dir_missing_files_falls_through() {
let _lock = ONNX_DIR_MUTEX.lock().unwrap();
let dir = tempfile::TempDir::new().unwrap();
std::env::set_var("CQS_ONNX_DIR", dir.path().to_str().unwrap());
let result = ensure_model(&test_model_config());
std::env::remove_var("CQS_ONNX_DIR");
assert!(
result.is_err() || !result.as_ref().unwrap().0.starts_with(dir.path()),
"Should not return paths from empty CQS_ONNX_DIR"
);
}
}
mod embedder_init_failure {
use super::*;
use std::sync::Mutex;
static ONNX_DIR_MUTEX: Mutex<()> = Mutex::new(());
#[test]
fn embedder_with_bogus_onnx_path_returns_err_on_embed() {
let _lock = ONNX_DIR_MUTEX.lock().unwrap();
let dir = tempfile::TempDir::new().unwrap();
std::fs::write(dir.path().join("tokenizer.json"), b"{}").unwrap();
std::fs::create_dir_all(dir.path().join("onnx")).unwrap();
let config = ModelConfig {
name: "bogus".to_string(),
repo: "nonexistent/model".to_string(),
onnx_path: "onnx/model.onnx".to_string(),
tokenizer_path: "tokenizer.json".to_string(),
dim: 768,
max_seq_length: 512,
query_prefix: String::new(),
doc_prefix: String::new(),
input_names: crate::embedder::models::InputNames::bert(),
output_name: "last_hidden_state".to_string(),
pooling: crate::embedder::models::PoolingStrategy::Mean,
};
std::env::set_var("CQS_ONNX_DIR", dir.path().to_str().unwrap());
let embedder = Embedder::new_cpu(config);
std::env::remove_var("CQS_ONNX_DIR");
match embedder {
Ok(emb) => {
let result = emb.embed_query("test query");
assert!(
result.is_err(),
"embed_query should return Err with missing model, got Ok"
);
}
Err(_e) => {
}
}
}
#[test]
fn embedder_init_failure_is_not_cached() {
let _lock = ONNX_DIR_MUTEX.lock().unwrap();
let dir = tempfile::TempDir::new().unwrap();
std::env::set_var("CQS_ONNX_DIR", dir.path().to_str().unwrap());
let embedder = Embedder::new_cpu(ModelConfig {
name: "bogus".to_string(),
repo: "nonexistent/model".to_string(),
onnx_path: "model.onnx".to_string(),
tokenizer_path: "tokenizer.json".to_string(),
dim: 768,
max_seq_length: 512,
query_prefix: String::new(),
doc_prefix: String::new(),
input_names: crate::embedder::models::InputNames::bert(),
output_name: "last_hidden_state".to_string(),
pooling: crate::embedder::models::PoolingStrategy::Mean,
});
std::env::remove_var("CQS_ONNX_DIR");
match embedder {
Ok(emb) => {
let first = emb.embed_query("test");
let second = emb.embed_query("test again");
assert!(first.is_err(), "First embed should fail");
assert!(
second.is_err(),
"Second embed should also fail (not cached bad state)"
);
}
Err(_) => {
}
}
}
}
}