use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
use crate::error::{InferenceError, Result};
use crate::models::{EmbeddingModel, ModelConfig};
use ort::inputs;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use ort::value::Tensor;
use parking_lot::Mutex;
use std::io::Read;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokenizers::Tokenizer;
use tracing::{debug, info, instrument, warn};
pub struct EmbeddingEngine {
session: Arc<Mutex<Session>>,
processor: Arc<BatchProcessor>,
config: ModelConfig,
dimension: usize,
}
impl EmbeddingEngine {
#[instrument(skip_all, fields(model = %config.model))]
pub async fn new(config: ModelConfig) -> Result<Self> {
info!(
"Initializing ONNX embedding engine with model: {}",
config.model
);
let (tokenizer_path, onnx_path) = Self::download_model_files(&config).await?;
info!("Loading tokenizer from {:?}", tokenizer_path);
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
info!("Loading ONNX model from {:?}", onnx_path);
let num_threads = config.num_threads.unwrap_or(1);
let session = Session::builder()
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.with_intra_threads(num_threads)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.commit_from_file(&onnx_path)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let dimension = config.model.dimension();
let processor = Arc::new(BatchProcessor::new(
tokenizer,
config.model,
config.max_batch_size,
));
info!(
"ONNX embedding engine ready: model={}, dimension={}, threads={}",
config.model, dimension, num_threads
);
Ok(Self {
session: Arc::new(Mutex::new(session)),
processor,
config,
dimension,
})
}
#[instrument(skip_all, fields(model = %config.model))]
async fn download_model_files(config: &ModelConfig) -> Result<(PathBuf, PathBuf)> {
let model_id = config.model.model_id();
let onnx_repo_id = config.model.onnx_repo_id();
let onnx_filename = config.model.onnx_filename();
info!(
"Resolving model files: tokenizer={}, onnx={}@{}",
model_id, onnx_filename, onnx_repo_id
);
let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
let onnx_subdir = onnx_cache_dir.join("onnx");
std::fs::create_dir_all(&onnx_subdir)?;
let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
let onnx_basename = Path::new(onnx_filename)
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("model_quantized.onnx");
let local_onnx = onnx_subdir.join(onnx_basename);
let tokenizer_needs_download = !local_tokenizer.exists();
let onnx_needs_download = !local_onnx.exists();
if tokenizer_needs_download || onnx_needs_download {
let model_id_owned = model_id.to_string();
let onnx_repo_id_owned = onnx_repo_id.to_string();
let onnx_filename_owned = onnx_filename.to_string();
let tokenizer_cache = tokenizer_cache_dir.clone();
let onnx_cache = onnx_cache_dir.clone();
tokio::task::spawn_blocking(move || {
if !tokenizer_cache.join("tokenizer.json").exists() {
Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
.map_err(|e| {
InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
})?;
}
if !onnx_cache.join(&onnx_filename_owned).exists() {
Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
.map_err(|e| {
InferenceError::HubError(format!(
"Failed to download ONNX model: {}",
e
))
})?;
}
Ok::<_, InferenceError>(())
})
.await
.map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
} else {
info!("All model files found in local cache");
}
let final_onnx = onnx_cache_dir.join(onnx_filename);
info!(
"Model files ready: tokenizer={:?}, onnx={:?}",
local_tokenizer, final_onnx
);
Ok((local_tokenizer, final_onnx))
}
fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
let base = std::env::var("HF_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| {
let home = std::env::var("HOME").unwrap_or_else(|_| {
warn!("HOME environment variable not set, using /tmp for model cache");
"/tmp".to_string()
});
PathBuf::from(home).join(".cache").join("huggingface")
});
let dir = base.join("dakera").join(model_id.replace('/', "--"));
std::fs::create_dir_all(&dir)?;
Ok(dir)
}
pub fn download_hf_file_pub(
model_id: &str,
filename: &str,
cache_dir: &Path,
) -> std::result::Result<PathBuf, String> {
Self::download_hf_file(model_id, filename, cache_dir)
}
fn download_hf_file(
model_id: &str,
filename: &str,
cache_dir: &Path,
) -> std::result::Result<PathBuf, String> {
let file_path = cache_dir.join(filename);
if file_path.exists() {
info!("Cached: {}/{}", model_id, filename);
return Ok(file_path);
}
if let Some(parent) = file_path.parent() {
std::fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
}
let url = format!(
"https://huggingface.co/{}/resolve/main/{}",
model_id, filename
);
info!("Downloading: {}", url);
let agent = ureq::AgentBuilder::new()
.redirects(0)
.timeout(std::time::Duration::from_secs(300))
.build();
let mut current_url = url.clone();
let mut redirects = 0;
let max_redirects = 10;
let response = loop {
let resp = agent.get(¤t_url).call();
let r = match resp {
Ok(r) => r,
Err(ureq::Error::Status(_status, r)) => r,
Err(e) => return Err(format!("{}: {}", filename, e)),
};
let status = r.status();
if (200..300).contains(&status) {
break r;
} else if (300..400).contains(&status) {
redirects += 1;
if redirects > max_redirects {
return Err(format!("{}: too many redirects", filename));
}
let location = r
.header("location")
.ok_or_else(|| format!("{}: redirect without Location header", filename))?
.to_string();
current_url = if location.starts_with('/') {
let parsed = url::Url::parse(¤t_url)
.map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
let host = parsed.host_str().ok_or_else(|| {
format!("{}: redirect URL missing host: {}", filename, current_url)
})?;
format!("{}://{}{}", parsed.scheme(), host, location)
} else {
location
};
info!("Redirect {} → {}", redirects, current_url);
} else {
return Err(format!("{}: HTTP {}", filename, status));
}
};
let mut bytes = Vec::new();
response
.into_reader()
.take(500_000_000) .read_to_end(&mut bytes)
.map_err(|e| format!("Failed to read {}: {}", filename, e))?;
std::fs::write(&file_path, &bytes)
.map_err(|e| format!("Failed to write {}: {}", filename, e))?;
info!("Downloaded {} ({} bytes)", filename, bytes.len());
Ok(file_path)
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn model(&self) -> EmbeddingModel {
self.config.model
}
#[instrument(skip(self, text), fields(text_len = text.len()))]
pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
let texts = vec![text.to_string()];
let prepared = self.processor.prepare_texts(&texts, true);
let embeddings = self.embed_batch_internal(&prepared).await?;
embeddings.into_iter().next().ok_or_else(|| {
InferenceError::InferenceError("No embedding returned for query".to_string())
})
}
#[instrument(skip(self, texts), fields(count = texts.len()))]
pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let prepared = self.processor.prepare_texts(texts, true);
self.embed_batch_internal(&prepared).await
}
#[instrument(skip(self, text), fields(text_len = text.len()))]
pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
let texts = vec![text.to_string()];
let prepared = self.processor.prepare_texts(&texts, false);
let embeddings = self.embed_batch_internal(&prepared).await?;
embeddings.into_iter().next().ok_or_else(|| {
InferenceError::InferenceError("No embedding returned for document".to_string())
})
}
#[instrument(skip(self, texts), fields(count = texts.len()))]
pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let prepared = self.processor.prepare_texts(texts, false);
self.embed_batch_internal(&prepared).await
}
#[instrument(skip(self, texts), fields(count = texts.len()))]
pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
self.embed_batch_internal(texts).await
}
async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let batches = self.processor.split_into_batches(texts);
let mut all_embeddings = Vec::with_capacity(texts.len());
for batch in batches {
let batch_owned: Vec<String> = batch.to_vec();
let session = Arc::clone(&self.session);
let processor = Arc::clone(&self.processor);
let normalize = self.config.model.normalize_embeddings();
let batch_embeddings = tokio::task::spawn_blocking(move || {
let mut session_guard = session.lock();
Self::process_batch_blocking(
&batch_owned,
&mut session_guard,
&processor,
normalize,
)
})
.await
.map_err(|e| {
InferenceError::InferenceError(format!("Inference task panicked: {}", e))
})??;
all_embeddings.extend(batch_embeddings);
}
Ok(all_embeddings)
}
fn process_batch_blocking(
texts: &[String],
session: &mut Session,
processor: &BatchProcessor,
normalize: bool,
) -> Result<Vec<Vec<f32>>> {
let prepared = processor.tokenize_batch(texts)?;
let batch_size = prepared.batch_size;
let seq_len = prepared.seq_len;
let attention_mask_flat = prepared.attention_mask.clone();
let input_ids_tensor =
Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let attention_mask_tensor =
Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let token_type_ids_tensor =
Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let outputs = session
.run(inputs![
"input_ids" => input_ids_tensor,
"attention_mask" => attention_mask_tensor,
"token_type_ids" => token_type_ids_tensor
])
.map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
let (ort_shape, lhs_slice) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
if ort_shape.len() != 3 {
return Err(InferenceError::InferenceError(format!(
"Expected 3D last_hidden_state, got {} dims",
ort_shape.len()
)));
}
let hidden_size = ort_shape[2] as usize;
let mut embeddings = mean_pooling(
lhs_slice,
batch_size,
seq_len,
hidden_size,
&attention_mask_flat,
);
if normalize {
normalize_embeddings(&mut embeddings);
}
debug!(
"Generated {} embeddings of dimension {}",
embeddings.len(),
embeddings.first().map(|e| e.len()).unwrap_or(0)
);
Ok(embeddings)
}
pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
let tokens_per_text =
(avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
let total_tokens = tokens_per_text * text_count as f64;
let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
(total_tokens / tokens_per_second) * 1000.0
}
}
impl std::fmt::Debug for EmbeddingEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EmbeddingEngine")
.field("model", &self.config.model)
.field("dimension", &self.dimension)
.field("max_batch_size", &self.config.max_batch_size)
.finish()
}
}
pub struct EmbeddingEngineBuilder {
config: ModelConfig,
}
impl EmbeddingEngineBuilder {
pub fn new() -> Self {
Self {
config: ModelConfig::default(),
}
}
pub fn model(mut self, model: EmbeddingModel) -> Self {
self.config.model = model;
self
}
pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
self.config.cache_dir = Some(dir.into());
self
}
pub fn max_batch_size(mut self, size: usize) -> Self {
self.config.max_batch_size = size;
self
}
pub fn use_gpu(mut self, enable: bool) -> Self {
self.config.use_gpu = enable;
self
}
pub fn num_threads(mut self, threads: usize) -> Self {
self.config.num_threads = Some(threads);
self
}
pub async fn build(self) -> Result<EmbeddingEngine> {
EmbeddingEngine::new(self.config).await
}
}
impl Default for EmbeddingEngineBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_time() {
let config = ModelConfig::new(EmbeddingModel::MiniLM);
let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
assert!(tokens_per_second > 0.0);
}
#[test]
fn test_builder() {
let builder = EmbeddingEngineBuilder::new()
.model(EmbeddingModel::BgeSmall)
.max_batch_size(64)
.use_gpu(false);
assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
assert_eq!(builder.config.max_batch_size, 64);
assert!(!builder.config.use_gpu);
}
#[test]
fn test_model_cache_dir_with_hf_home() {
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
let _guard = ENV_LOCK.lock().unwrap();
let tmp = std::env::temp_dir().join("dakera_test_hf_home");
std::env::set_var("HF_HOME", &tmp);
let result = EmbeddingEngine::model_cache_dir("org/my-model");
std::env::remove_var("HF_HOME");
let path = result.unwrap();
assert!(
path.starts_with(&tmp),
"expected path under {tmp:?}, got {path:?}"
);
assert!(
path.to_str().unwrap().contains("org--my-model"),
"model_id separator not applied: {path:?}"
);
}
#[test]
fn test_model_cache_dir_contains_dakera_subdir() {
let path =
EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
let s = path.to_str().unwrap();
assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
assert!(
s.contains("sentence-transformers--all-MiniLM-L6-v2"),
"expected transformed model id in path: {s}"
);
}
#[test]
fn test_model_cache_dir_creates_directory() {
let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
assert!(dir.exists(), "model_cache_dir should create the directory");
}
#[test]
fn test_download_hf_file_returns_path_when_already_cached() {
let tmp = std::env::temp_dir().join("dakera_test_cached_file");
std::fs::create_dir_all(&tmp).unwrap();
let file_path = tmp.join("config.json");
std::fs::write(&file_path, b"{}").unwrap();
let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
assert!(result.is_ok());
assert_eq!(result.unwrap(), file_path);
}
#[test]
fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
let onnx_dir = tmp.join("onnx");
std::fs::create_dir_all(&onnx_dir).unwrap();
let onnx_path = onnx_dir.join("model_quantized.onnx");
std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
let result = EmbeddingEngine::download_hf_file(
"Xenova/all-MiniLM-L6-v2",
"onnx/model_quantized.onnx",
&tmp,
);
assert!(result.is_ok());
assert_eq!(result.unwrap(), onnx_path);
}
#[test]
fn test_builder_default_impl() {
let b1 = EmbeddingEngineBuilder::new();
let b2 = EmbeddingEngineBuilder::default();
assert_eq!(b1.config.model, b2.config.model);
assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
}
#[test]
fn test_builder_model_field() {
let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
assert_eq!(builder.config.model, EmbeddingModel::E5Small);
}
#[test]
fn test_builder_cache_dir() {
let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
}
#[test]
fn test_builder_max_batch_size() {
let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
assert_eq!(builder.config.max_batch_size, 128);
}
#[test]
fn test_builder_use_gpu_true() {
let builder = EmbeddingEngineBuilder::new().use_gpu(true);
assert!(builder.config.use_gpu);
}
#[test]
fn test_builder_use_gpu_false() {
let builder = EmbeddingEngineBuilder::new().use_gpu(false);
assert!(!builder.config.use_gpu);
}
#[test]
fn test_builder_num_threads() {
let builder = EmbeddingEngineBuilder::new().num_threads(4);
assert_eq!(builder.config.num_threads, Some(4));
}
#[test]
fn test_builder_chain_all_fields() {
let builder = EmbeddingEngineBuilder::new()
.model(EmbeddingModel::BgeSmall)
.cache_dir("/cache")
.max_batch_size(16)
.use_gpu(false)
.num_threads(2);
assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
assert_eq!(builder.config.max_batch_size, 16);
assert!(!builder.config.use_gpu);
assert_eq!(builder.config.num_threads, Some(2));
}
#[test]
fn test_estimate_time_zero_count() {
let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
let estimate = (0.0 / tps) * 1000.0;
assert_eq!(estimate, 0.0);
}
#[test]
fn test_estimate_time_formula_cpu() {
let model = EmbeddingModel::MiniLM;
let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
let total_tokens = tokens_per_text * 10.0;
let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
assert!(
(estimate - 50.0).abs() < 1e-6,
"expected 50.0ms, got {estimate}"
);
}
#[test]
fn test_estimate_time_capped_at_max_seq_length() {
let model = EmbeddingModel::MiniLM;
let avg_len = 100_000;
let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
assert_eq!(tokens_per_text, 256.0);
}
#[test]
fn test_model_config_new() {
let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
assert_eq!(cfg.max_batch_size, 32);
assert!(!cfg.use_gpu);
assert!(cfg.cache_dir.is_none());
assert!(cfg.num_threads.is_none());
}
#[test]
fn test_model_config_default() {
let cfg = ModelConfig::default();
assert_eq!(cfg.model, EmbeddingModel::MiniLM);
assert_eq!(cfg.max_batch_size, 32);
assert!(!cfg.use_gpu);
}
#[test]
fn test_model_config_with_cache_dir() {
let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
}
#[test]
fn test_model_config_with_max_batch_size() {
let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
assert_eq!(cfg.max_batch_size, 64);
}
#[test]
fn test_model_config_with_gpu() {
let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
assert!(cfg.use_gpu);
}
#[test]
fn test_model_config_with_num_threads() {
let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
assert_eq!(cfg.num_threads, Some(8));
}
#[test]
fn test_model_config_chained_builder() {
let cfg = ModelConfig::new(EmbeddingModel::E5Small)
.with_cache_dir("/cache")
.with_max_batch_size(16)
.with_gpu(false)
.with_num_threads(4);
assert_eq!(cfg.model, EmbeddingModel::E5Small);
assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
assert_eq!(cfg.max_batch_size, 16);
assert!(!cfg.use_gpu);
assert_eq!(cfg.num_threads, Some(4));
}
#[test]
fn test_model_cache_dir_no_home_fallback() {
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
let _guard = ENV_LOCK.lock().unwrap();
let saved_home = std::env::var("HOME").ok();
let saved_hf = std::env::var("HF_HOME").ok();
unsafe {
std::env::remove_var("HOME");
std::env::remove_var("HF_HOME");
}
let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
if let Some(h) = saved_home {
unsafe { std::env::set_var("HOME", h) };
}
if let Some(h) = saved_hf {
unsafe { std::env::set_var("HF_HOME", h) };
}
let path = result.unwrap();
assert!(
path.starts_with("/tmp"),
"expected path under /tmp, got {path:?}"
);
}
#[test]
fn test_model_cache_dir_deep_model_id() {
let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
let s = path.to_str().unwrap();
assert!(
s.contains("org--sub--model-name-with-dashes"),
"expected transformed path, got: {s}"
);
}
#[test]
fn test_model_cache_dir_minilm_model_id() {
let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
let s = path.to_str().unwrap();
assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
}
#[test]
fn test_model_cache_dir_bge_model_id() {
let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
let s = path.to_str().unwrap();
assert!(s.contains("BAAI--bge-small-en-v1.5"));
}
#[test]
fn test_model_cache_dir_e5_model_id() {
let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
let s = path.to_str().unwrap();
assert!(s.contains("intfloat--e5-small-v2"));
}
#[test]
fn test_download_hf_file_pytorch_bin_cached() {
let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
std::fs::create_dir_all(&tmp).unwrap();
let model_path = tmp.join("pytorch_model.bin");
std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
assert!(result.is_ok());
assert_eq!(result.unwrap(), model_path);
}
#[test]
fn test_download_hf_file_tokenizer_cached() {
let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
std::fs::create_dir_all(&tmp).unwrap();
let tok_path = tmp.join("tokenizer.json");
std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
assert!(result.is_ok());
assert_eq!(result.unwrap(), tok_path);
}
#[test]
fn test_download_hf_file_config_json_cached() {
let tmp = std::env::temp_dir().join("dakera_test_config_cached");
std::fs::create_dir_all(&tmp).unwrap();
let cfg_path = tmp.join("config.json");
std::fs::write(&cfg_path, b"{}").unwrap();
let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
assert!(result.is_ok());
assert_eq!(result.unwrap(), cfg_path);
}
#[tokio::test]
async fn test_new_fails_with_invalid_tokenizer_json() {
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
let _guard = ENV_LOCK.lock().unwrap();
let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
let model_dir = tmp
.join("dakera")
.join("sentence-transformers--all-MiniLM-L6-v2");
std::fs::create_dir_all(&model_dir).unwrap();
std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
unsafe { std::env::set_var("HF_HOME", &tmp) };
let config = ModelConfig::new(EmbeddingModel::MiniLM);
let result = EmbeddingEngine::new(config).await;
unsafe { std::env::remove_var("HF_HOME") };
assert!(
result.is_err(),
"expected Err from new() with invalid tokenizer, got Ok"
);
}
#[test]
fn test_builder_with_all_models() {
for model in [
EmbeddingModel::MiniLM,
EmbeddingModel::BgeSmall,
EmbeddingModel::E5Small,
] {
let builder = EmbeddingEngineBuilder::new().model(model);
assert_eq!(builder.config.model, model);
}
}
#[test]
fn test_builder_max_batch_size_one() {
let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
assert_eq!(builder.config.max_batch_size, 1);
}
#[test]
fn test_builder_num_threads_zero() {
let builder = EmbeddingEngineBuilder::new().num_threads(0);
assert_eq!(builder.config.num_threads, Some(0));
}
#[tokio::test]
async fn test_engine_getters_when_model_cached() {
let config = ModelConfig::new(EmbeddingModel::MiniLM);
match EmbeddingEngine::new(config).await {
Ok(engine) => {
assert_eq!(engine.dimension(), 384);
assert_eq!(engine.model(), EmbeddingModel::MiniLM);
let _ = format!("{:?}", engine);
let ms = engine.estimate_time_ms(10, 50);
assert!(ms >= 0.0);
}
Err(_) => {
}
}
}
#[tokio::test]
async fn test_engine_embed_empty_batch_when_cached() {
let config = ModelConfig::new(EmbeddingModel::MiniLM);
match EmbeddingEngine::new(config).await {
Ok(engine) => {
let result = engine.embed_raw(&[]).await;
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
Err(_) => {}
}
}
}