use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
use crate::error::{InferenceError, Result};
use crate::models::{EmbeddingModel, ModelConfig};
use candle_core::{DType, Device};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use parking_lot::RwLock;
use std::io::Read;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokenizers::Tokenizer;
use tracing::{debug, info, instrument, warn};
pub struct EmbeddingEngine {
model: Arc<RwLock<BertModel>>,
processor: Arc<BatchProcessor>,
device: Device,
config: ModelConfig,
dimension: usize,
}
impl EmbeddingEngine {
#[instrument(skip_all, fields(model = %config.model))]
pub async fn new(config: ModelConfig) -> Result<Self> {
info!("Initializing embedding engine with model: {}", config.model);
let device = Self::select_device(&config)?;
info!("Using device: {:?}", device);
let (model_path, tokenizer_path, config_path) = Self::download_model_files(&config).await?;
info!("Loading tokenizer from {:?}", tokenizer_path);
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
info!("Loading model config from {:?}", config_path);
let model_config: BertConfig = {
let config_str = std::fs::read_to_string(&config_path)?;
serde_json::from_str(&config_str)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
};
info!("Loading model weights from {:?}", model_path);
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)? };
let model = BertModel::load(vb, &model_config)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let dimension = config.model.dimension();
let processor = Arc::new(BatchProcessor::new(
tokenizer,
config.model,
config.max_batch_size,
));
info!(
"Embedding engine initialized: model={}, dimension={}, max_batch={}",
config.model, dimension, config.max_batch_size
);
Ok(Self {
model: Arc::new(RwLock::new(model)),
processor,
device,
config,
dimension,
})
}
fn select_device(config: &ModelConfig) -> Result<Device> {
if config.use_gpu {
#[cfg(feature = "cuda")]
{
if let Ok(device) = Device::new_cuda(0) {
return Ok(device);
}
warn!("CUDA requested but not available, falling back to CPU");
}
#[cfg(feature = "metal")]
{
if let Ok(device) = Device::new_metal(0) {
return Ok(device);
}
warn!("Metal requested but not available, falling back to CPU");
}
#[cfg(not(any(feature = "cuda", feature = "metal")))]
{
warn!("GPU requested but no GPU features enabled, using CPU");
}
}
Ok(Device::Cpu)
}
#[instrument(skip_all, fields(model = %config.model))]
async fn download_model_files(config: &ModelConfig) -> Result<(PathBuf, PathBuf, PathBuf)> {
let model_id = config.model.model_id();
info!("Resolving model files for: {}", model_id);
let model_id_owned = model_id.to_string();
let hf_cache = hf_hub::Cache::default();
let hf_repo = hf_hub::Repo::new(model_id_owned.clone(), hf_hub::RepoType::Model);
let cached_repo = hf_cache.repo(hf_repo);
let cached_model = cached_repo
.get("model.safetensors")
.or_else(|| cached_repo.get("pytorch_model.bin"));
let cached_tokenizer = cached_repo.get("tokenizer.json");
let cached_config = cached_repo.get("config.json");
if let (Some(m), Some(t), Some(c)) = (cached_model, cached_tokenizer, cached_config) {
info!("All model files found in HF cache");
return Ok((m, t, c));
}
let cache_dir = Self::model_cache_dir(model_id)?;
let local_model = cache_dir.join("model.safetensors");
let local_model_bin = cache_dir.join("pytorch_model.bin");
let local_tokenizer = cache_dir.join("tokenizer.json");
let local_config = cache_dir.join("config.json");
let model_exists = local_model.exists() || local_model_bin.exists();
if model_exists && local_tokenizer.exists() && local_config.exists() {
let mp = if local_model.exists() {
local_model
} else {
local_model_bin
};
info!("All model files found in local cache");
return Ok((mp, local_tokenizer, local_config));
}
info!("Downloading model files from HuggingFace...");
let cd = cache_dir.clone();
let mid = model_id_owned.clone();
tokio::task::spawn_blocking(move || {
Self::download_hf_file(&mid, "model.safetensors", &cd)
.or_else(|_| Self::download_hf_file(&mid, "pytorch_model.bin", &cd))
.map_err(|e| {
InferenceError::HubError(format!("Failed to download model weights: {}", e))
})?;
Self::download_hf_file(&mid, "tokenizer.json", &cd).map_err(|e| {
InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
})?;
Self::download_hf_file(&mid, "config.json", &cd).map_err(|e| {
InferenceError::HubError(format!("Failed to download config: {}", e))
})?;
Ok::<_, InferenceError>(())
})
.await
.map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
let final_model = if cache_dir.join("model.safetensors").exists() {
cache_dir.join("model.safetensors")
} else {
cache_dir.join("pytorch_model.bin")
};
info!("Model files downloaded successfully to {:?}", cache_dir);
Ok((final_model, local_tokenizer, local_config))
}
fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
let base = std::env::var("HF_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| {
let home = std::env::var("HOME").unwrap_or_else(|_| {
tracing::warn!("HOME environment variable not set, using /tmp for model cache");
"/tmp".to_string()
});
PathBuf::from(home).join(".cache").join("huggingface")
});
let dir = base.join("dakera").join(model_id.replace('/', "--"));
std::fs::create_dir_all(&dir)?;
Ok(dir)
}
fn download_hf_file(
model_id: &str,
filename: &str,
cache_dir: &Path,
) -> std::result::Result<PathBuf, String> {
let file_path = cache_dir.join(filename);
if file_path.exists() {
info!("Cached: {}", filename);
return Ok(file_path);
}
let url = format!(
"https://huggingface.co/{}/resolve/main/{}",
model_id, filename
);
info!("Downloading: {}", url);
let agent = ureq::AgentBuilder::new()
.redirects(0)
.timeout(std::time::Duration::from_secs(300))
.build();
let mut current_url = url.clone();
let mut redirects = 0;
let max_redirects = 10;
let response = loop {
let resp = agent.get(¤t_url).call();
let r = match resp {
Ok(r) => r,
Err(ureq::Error::Status(_status, r)) => r,
Err(e) => return Err(format!("{}: {}", filename, e)),
};
let status = r.status();
if (200..300).contains(&status) {
break r;
} else if (300..400).contains(&status) {
redirects += 1;
if redirects > max_redirects {
return Err(format!("{}: too many redirects", filename));
}
let location = r
.header("location")
.ok_or_else(|| format!("{}: redirect without Location header", filename))?
.to_string();
current_url = if location.starts_with('/') {
let parsed = url::Url::parse(¤t_url)
.map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
let host = parsed.host_str().ok_or_else(|| {
format!("{}: redirect URL missing host: {}", filename, current_url)
})?;
format!("{}://{}{}", parsed.scheme(), host, location)
} else {
location
};
info!("Redirect {} → {}", redirects, current_url);
} else {
return Err(format!("{}: HTTP {}", filename, status));
}
};
let mut bytes = Vec::new();
response
.into_reader()
.take(500_000_000) .read_to_end(&mut bytes)
.map_err(|e| format!("Failed to read {}: {}", filename, e))?;
std::fs::write(&file_path, &bytes)
.map_err(|e| format!("Failed to write {}: {}", filename, e))?;
info!("Downloaded {} ({} bytes)", filename, bytes.len());
Ok(file_path)
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn model(&self) -> EmbeddingModel {
self.config.model
}
pub fn device(&self) -> &Device {
&self.device
}
#[instrument(skip(self, text), fields(text_len = text.len()))]
pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
let texts = vec![text.to_string()];
let prepared = self.processor.prepare_texts(&texts, true);
let embeddings = self.embed_batch_internal(&prepared).await?;
embeddings.into_iter().next().ok_or_else(|| {
crate::error::InferenceError::InferenceError(
"No embedding returned for query".to_string(),
)
})
}
#[instrument(skip(self, texts), fields(count = texts.len()))]
pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let prepared = self.processor.prepare_texts(texts, true);
self.embed_batch_internal(&prepared).await
}
#[instrument(skip(self, text), fields(text_len = text.len()))]
pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
let texts = vec![text.to_string()];
let prepared = self.processor.prepare_texts(&texts, false);
let embeddings = self.embed_batch_internal(&prepared).await?;
embeddings.into_iter().next().ok_or_else(|| {
crate::error::InferenceError::InferenceError(
"No embedding returned for document".to_string(),
)
})
}
#[instrument(skip(self, texts), fields(count = texts.len()))]
pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let prepared = self.processor.prepare_texts(texts, false);
self.embed_batch_internal(&prepared).await
}
#[instrument(skip(self, texts), fields(count = texts.len()))]
pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
self.embed_batch_internal(texts).await
}
async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let batches = self.processor.split_into_batches(texts);
let mut all_embeddings = Vec::with_capacity(texts.len());
for batch in batches {
let batch_owned: Vec<String> = batch.to_vec();
let model = Arc::clone(&self.model);
let processor = Arc::clone(&self.processor);
let device = self.device.clone();
let normalize = self.config.model.normalize_embeddings();
let batch_embeddings = tokio::task::spawn_blocking(move || {
Self::process_batch_blocking(&batch_owned, &model, &processor, &device, normalize)
})
.await
.map_err(|e| {
InferenceError::InferenceError(format!("Inference task panicked: {}", e))
})??;
all_embeddings.extend(batch_embeddings);
}
Ok(all_embeddings)
}
fn process_batch_blocking(
texts: &[String],
model: &Arc<RwLock<BertModel>>,
processor: &BatchProcessor,
device: &Device,
normalize: bool,
) -> Result<Vec<Vec<f32>>> {
let prepared = processor.tokenize_batch(texts, device)?;
let model_guard = model.read();
let input_ids = prepared.input_ids.to_dtype(DType::U32)?;
let attention_mask = prepared.attention_mask.to_dtype(DType::U32)?;
let token_type_ids = prepared.token_type_ids.to_dtype(DType::U32)?;
let output = model_guard.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
let attention_mask_f32 = prepared.attention_mask.to_dtype(DType::F32)?;
let pooled = mean_pooling(&output, &attention_mask_f32)?;
let normalized = if normalize {
normalize_embeddings(&pooled)?
} else {
pooled
};
drop(model_guard);
let embeddings = normalized.to_vec2::<f32>()?;
debug!(
"Generated {} embeddings of dimension {}",
embeddings.len(),
embeddings.first().map(|e| e.len()).unwrap_or(0)
);
Ok(embeddings)
}
pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
let tokens_per_text =
(avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
let total_tokens = tokens_per_text * text_count as f64;
let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
let speed_multiplier = if matches!(self.device, Device::Cpu) {
1.0
} else {
10.0
};
(total_tokens / (tokens_per_second * speed_multiplier)) * 1000.0
}
}
impl std::fmt::Debug for EmbeddingEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EmbeddingEngine")
.field("model", &self.config.model)
.field("dimension", &self.dimension)
.field("device", &self.device)
.field("max_batch_size", &self.config.max_batch_size)
.finish()
}
}
pub struct EmbeddingEngineBuilder {
config: ModelConfig,
}
impl EmbeddingEngineBuilder {
pub fn new() -> Self {
Self {
config: ModelConfig::default(),
}
}
pub fn model(mut self, model: EmbeddingModel) -> Self {
self.config.model = model;
self
}
pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
self.config.cache_dir = Some(dir.into());
self
}
pub fn max_batch_size(mut self, size: usize) -> Self {
self.config.max_batch_size = size;
self
}
pub fn use_gpu(mut self, enable: bool) -> Self {
self.config.use_gpu = enable;
self
}
pub fn num_threads(mut self, threads: usize) -> Self {
self.config.num_threads = Some(threads);
self
}
pub async fn build(self) -> Result<EmbeddingEngine> {
EmbeddingEngine::new(self.config).await
}
}
impl Default for EmbeddingEngineBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_time() {
let config = ModelConfig::new(EmbeddingModel::MiniLM);
let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
assert!(tokens_per_second > 0.0);
}
#[test]
fn test_builder() {
let builder = EmbeddingEngineBuilder::new()
.model(EmbeddingModel::BgeSmall)
.max_batch_size(64)
.use_gpu(false);
assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
assert_eq!(builder.config.max_batch_size, 64);
assert!(!builder.config.use_gpu);
}
}