use crate::backend::{BackendKind, EmbeddingBackend};
use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
use crate::error::{InferenceError, Result};
use crate::models::ModelConfig;
use async_trait::async_trait;
use ort::execution_providers::CUDAExecutionProvider;
use ort::inputs;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use ort::value::Tensor;
use parking_lot::Mutex;
use std::io::Read;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokenizers::Tokenizer;
use tracing::{info, instrument, warn};
pub struct OnnxBackend {
sessions: Vec<Arc<Mutex<Session>>>,
next_session: AtomicUsize,
processor: Arc<BatchProcessor>,
config: ModelConfig,
dimension: usize,
use_gpu: bool,
}
impl OnnxBackend {
#[instrument(skip_all, fields(model = %config.model))]
pub async fn new(config: &ModelConfig) -> Result<Self> {
let config = config.clone();
let use_gpu = std::env::var("DAKERA_USE_GPU")
.map(|v| v == "1")
.unwrap_or(config.use_gpu);
if use_gpu {
info!("ONNX backend: CUDA execution provider enabled (DAKERA_USE_GPU=1)");
}
info!("Initialising ONNX backend: model={}", config.model);
let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
info!("Loading tokenizer from {:?}", tokenizer_path);
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let num_threads = config.num_threads.unwrap_or(4);
let pool_size = config.session_pool_size.max(1);
let onnx_path_clone = onnx_path.clone();
let sessions: Vec<Arc<Mutex<Session>>> =
tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
(0..pool_size)
.map(|_| {
let builder = Session::builder()
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.with_intra_threads(num_threads)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let mut builder = if use_gpu {
builder
.with_execution_providers(
[CUDAExecutionProvider::default().build()],
)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
} else {
builder
};
let s = builder
.commit_from_file(&onnx_path_clone)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
Ok(Arc::new(Mutex::new(s)))
})
.collect()
})
.await
.map_err(|e| {
InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
})??;
let dimension = config.model.dimension();
let processor = Arc::new(BatchProcessor::new(
tokenizer,
config.model,
config.max_batch_size,
));
info!(
"ONNX backend ready: model={}, dimension={}, threads={}, pool={}",
config.model, dimension, num_threads, pool_size
);
Ok(Self {
sessions,
next_session: AtomicUsize::new(0),
processor,
config,
dimension,
use_gpu,
})
}
pub fn pool_size(&self) -> usize {
self.sessions.len()
}
#[instrument(skip_all, fields(model = %config.model))]
pub async fn download_model_files(
config: &ModelConfig,
use_gpu: bool,
) -> Result<(PathBuf, PathBuf)> {
let model_id = config.model.model_id();
let onnx_repo_id = config.model.onnx_repo_id();
let onnx_filename = if use_gpu {
config.model.onnx_filename_gpu()
} else {
config.model.onnx_filename()
};
info!(
"Resolving model files: tokenizer={}, onnx={}@{}",
model_id, onnx_filename, onnx_repo_id
);
let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
let onnx_subdir = onnx_cache_dir.join("onnx");
std::fs::create_dir_all(&onnx_subdir)?;
let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
let onnx_basename = Path::new(onnx_filename)
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("model_quantized.onnx");
let local_onnx = onnx_subdir.join(onnx_basename);
if use_gpu && local_onnx.exists() {
let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
if cached_size <= 500_000_000 {
warn!(
"Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated. Deleting.",
local_onnx, cached_size
);
let _ = std::fs::remove_file(&local_onnx);
}
}
if !local_tokenizer.exists() || !local_onnx.exists() {
let model_id_owned = model_id.to_string();
let onnx_repo_id_owned = onnx_repo_id.to_string();
let onnx_filename_owned = onnx_filename.to_string();
let tokenizer_cache = tokenizer_cache_dir.clone();
let onnx_cache = onnx_cache_dir.clone();
tokio::task::spawn_blocking(move || {
if !tokenizer_cache.join("tokenizer.json").exists() {
Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
.map_err(|e| {
InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
})?;
}
if !onnx_cache.join(&onnx_filename_owned).exists() {
Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
.map_err(|e| {
InferenceError::HubError(format!(
"Failed to download ONNX model: {}",
e
))
})?;
}
Ok::<_, InferenceError>(())
})
.await
.map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
} else {
info!("All model files found in local cache");
}
let final_onnx = onnx_cache_dir.join(onnx_filename);
Ok((local_tokenizer, final_onnx))
}
pub fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
let base = std::env::var("HF_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| {
let home = std::env::var("HOME").unwrap_or_else(|_| {
warn!("HOME environment variable not set, using /tmp for model cache");
"/tmp".to_string()
});
PathBuf::from(home).join(".cache").join("huggingface")
});
let dir = base.join("dakera").join(model_id.replace('/', "--"));
std::fs::create_dir_all(&dir)?;
Ok(dir)
}
pub fn download_hf_file(
model_id: &str,
filename: &str,
cache_dir: &Path,
) -> std::result::Result<PathBuf, String> {
let file_path = cache_dir.join(filename);
if file_path.exists() {
info!("Cached: {}/{}", model_id, filename);
return Ok(file_path);
}
if let Some(parent) = file_path.parent() {
std::fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
}
let url = format!(
"https://huggingface.co/{}/resolve/main/{}",
model_id, filename
);
info!("Downloading: {}", url);
let hf_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
let agent = ureq::AgentBuilder::new()
.redirects(0)
.timeout(std::time::Duration::from_secs(300))
.build();
let mut current_url = url;
let mut redirects = 0_u32;
let response = loop {
let mut req = agent.get(¤t_url);
if let Some(ref token) = hf_token {
req = req.set("Authorization", &format!("Bearer {}", token));
}
let resp = req.call();
let r = match resp {
Ok(r) => r,
Err(ureq::Error::Status(_status, r)) => r,
Err(e) => return Err(format!("{}: {}", filename, e)),
};
let status = r.status();
if (200..300).contains(&status) {
break r;
} else if (300..400).contains(&status) {
redirects += 1;
if redirects > 10 {
return Err(format!("{}: too many redirects", filename));
}
let location = r
.header("location")
.ok_or_else(|| format!("{}: redirect without Location header", filename))?
.to_string();
current_url = if location.starts_with('/') {
let parsed = url::Url::parse(¤t_url)
.map_err(|e| format!("{}: bad URL: {}", filename, e))?;
let host = parsed
.host_str()
.ok_or_else(|| format!("{}: missing host", filename))?;
format!("{}://{}{}", parsed.scheme(), host, location)
} else {
location
};
} else {
return Err(format!("{}: HTTP {}", filename, status));
}
};
let expected_bytes: Option<u64> = response
.header("x-linked-size")
.or_else(|| response.header("content-length"))
.and_then(|v| v.parse::<u64>().ok());
let mut bytes = Vec::new();
response
.into_reader()
.take(2_147_483_648)
.read_to_end(&mut bytes)
.map_err(|e| format!("Failed to read {}: {}", filename, e))?;
if let Some(expected) = expected_bytes {
if (bytes.len() as u64) < expected {
return Err(format!(
"{}: download incomplete — received {} of {} bytes",
filename,
bytes.len(),
expected
));
}
}
std::fs::write(&file_path, &bytes)
.map_err(|e| format!("Failed to write {}: {}", filename, e))?;
info!("Downloaded {} ({} bytes)", filename, bytes.len());
Ok(file_path)
}
pub fn download_hf_file_pub(
model_id: &str,
filename: &str,
cache_dir: &Path,
) -> std::result::Result<PathBuf, String> {
Self::download_hf_file(model_id, filename, cache_dir)
}
async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let pool_len = self.sessions.len();
let normalize = self.config.model.normalize_embeddings();
let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
let mut batch_size = self.config.max_batch_size.max(1);
for attempt in 0_u32..=3 {
let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
let mut handles = Vec::with_capacity(batches.len());
for (i, batch_owned) in batches.into_iter().enumerate() {
let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
let processor = Arc::clone(&self.processor);
let gpu_permit = if self.use_gpu {
Some(
std::sync::Arc::clone(&crate::GPU_INFERENCE_SEMAPHORE)
.acquire_owned()
.await
.map_err(|_| {
InferenceError::InferenceError(
"GPU inference semaphore closed unexpectedly".to_string(),
)
})?,
)
} else {
None
};
handles.push(tokio::task::spawn_blocking(move || {
let _gpu_permit = gpu_permit; let mut session_guard = session.lock();
Self::process_batch_blocking(
&batch_owned,
&mut session_guard,
&processor,
normalize,
)
}));
}
let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
let mut oom: Option<InferenceError> = None;
for handle in handles {
match handle.await {
Err(panic_err) => {
return Err(InferenceError::InferenceError(format!(
"Inference task panicked: {panic_err}"
)));
}
Ok(Err(e)) => {
if attempt < 3 && Self::is_gpu_oom(&e) {
oom = Some(e);
break;
}
return Err(e);
}
Ok(Ok(batch_embs)) => {
all_embeddings.extend(batch_embs);
}
}
}
if oom.is_some() {
let next_batch = (batch_size / 2).max(1);
warn!(
"ONNX allocator OOM (attempt {}/3) — retrying with batch_size {} → {}",
attempt + 1,
batch_size,
next_batch,
);
batch_size = next_batch;
continue;
}
return Ok(all_embeddings);
}
Err(InferenceError::InferenceError(format!(
"ONNX inference failed: allocator OOM after 3 batch-halving attempts (batch_size={batch_size})"
)))
}
fn is_gpu_oom(err: &InferenceError) -> bool {
let msg = err.to_string();
msg.contains("BFCArena")
|| msg.contains("Failed to allocate memory")
|| msg.contains("CUDA_OUT_OF_MEMORY")
|| msg.contains("CUDA out of memory")
|| (msg.contains("allocate") && msg.contains("buffer of size"))
}
fn process_batch_blocking(
texts: &[String],
session: &mut Session,
processor: &BatchProcessor,
normalize: bool,
) -> Result<Vec<Vec<f32>>> {
let prepared = processor.tokenize_batch(texts)?;
let batch_size = prepared.batch_size;
let seq_len = prepared.seq_len;
let attention_mask_flat = prepared.attention_mask.clone();
let input_ids_tensor =
Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let attention_mask_tensor =
Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let token_type_ids_tensor =
Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let outputs = session
.run(inputs![
"input_ids" => input_ids_tensor,
"attention_mask" => attention_mask_tensor,
"token_type_ids" => token_type_ids_tensor
])
.map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
let (ort_shape, lhs_slice) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
if ort_shape.len() != 3 {
return Err(InferenceError::InferenceError(format!(
"Expected 3D last_hidden_state, got {} dims",
ort_shape.len()
)));
}
let hidden_size = ort_shape[2] as usize;
let mut embeddings = mean_pooling(
lhs_slice,
batch_size,
seq_len,
hidden_size,
&attention_mask_flat,
);
if normalize {
normalize_embeddings(&mut embeddings);
}
Ok(embeddings)
}
}
#[async_trait]
impl EmbeddingBackend for OnnxBackend {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
self.embed_batch_internal(texts).await
}
fn dimension(&self) -> usize {
self.dimension
}
fn backend_kind(&self) -> BackendKind {
BackendKind::Onnx
}
}