1use crate::backend::{BackendKind, EmbeddingBackend};
12use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
13use crate::error::{InferenceError, Result};
14use crate::models::ModelConfig;
15use async_trait::async_trait;
16use ort::execution_providers::CUDAExecutionProvider;
17use ort::inputs;
18use ort::session::builder::GraphOptimizationLevel;
19use ort::session::Session;
20use ort::value::Tensor;
21use parking_lot::Mutex;
22use std::io::Read;
23use std::path::{Path, PathBuf};
24use std::sync::atomic::{AtomicUsize, Ordering};
25use std::sync::Arc;
26use tokenizers::Tokenizer;
27use tracing::{info, instrument, warn};
28
29pub struct OnnxBackend {
31 sessions: Vec<Arc<Mutex<Session>>>,
32 next_session: AtomicUsize,
33 processor: Arc<BatchProcessor>,
34 config: ModelConfig,
35 dimension: usize,
36 use_gpu: bool,
39}
40
41impl OnnxBackend {
42 #[instrument(skip_all, fields(model = %config.model))]
44 pub async fn new(config: &ModelConfig) -> Result<Self> {
45 let config = config.clone();
46 let use_gpu = std::env::var("DAKERA_USE_GPU")
47 .map(|v| v == "1")
48 .unwrap_or(config.use_gpu);
49
50 if use_gpu {
51 info!("ONNX backend: CUDA execution provider enabled (DAKERA_USE_GPU=1)");
52 }
53 info!("Initialising ONNX backend: model={}", config.model);
54
55 let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
56
57 info!("Loading tokenizer from {:?}", tokenizer_path);
58 let tokenizer = Tokenizer::from_file(&tokenizer_path)
59 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
60
61 let num_threads = config.num_threads.unwrap_or(4);
62 let pool_size = config.session_pool_size.max(1);
63 let onnx_path_clone = onnx_path.clone();
64
65 let sessions: Vec<Arc<Mutex<Session>>> =
66 tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
67 (0..pool_size)
68 .map(|_| {
69 let builder = Session::builder()
70 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
71 .with_optimization_level(GraphOptimizationLevel::Level3)
72 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
73 .with_intra_threads(num_threads)
74 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
75
76 let mut builder = if use_gpu {
77 builder
78 .with_execution_providers(
79 [CUDAExecutionProvider::default().build()],
80 )
81 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
82 } else {
83 builder
84 };
85
86 let s = builder
87 .commit_from_file(&onnx_path_clone)
88 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
89 Ok(Arc::new(Mutex::new(s)))
90 })
91 .collect()
92 })
93 .await
94 .map_err(|e| {
95 InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
96 })??;
97
98 let dimension = config.model.dimension();
99 let processor = Arc::new(BatchProcessor::new(
100 tokenizer,
101 config.model,
102 config.max_batch_size,
103 ));
104
105 info!(
106 "ONNX backend ready: model={}, dimension={}, threads={}, pool={}",
107 config.model, dimension, num_threads, pool_size
108 );
109
110 Ok(Self {
111 sessions,
112 next_session: AtomicUsize::new(0),
113 processor,
114 config,
115 dimension,
116 use_gpu,
117 })
118 }
119
120 pub fn pool_size(&self) -> usize {
122 self.sessions.len()
123 }
124
125 #[instrument(skip_all, fields(model = %config.model))]
129 pub async fn download_model_files(
130 config: &ModelConfig,
131 use_gpu: bool,
132 ) -> Result<(PathBuf, PathBuf)> {
133 let model_id = config.model.model_id();
134 let onnx_repo_id = config.model.onnx_repo_id();
135 let onnx_filename = if use_gpu {
136 config.model.onnx_filename_gpu()
137 } else {
138 config.model.onnx_filename()
139 };
140
141 info!(
142 "Resolving model files: tokenizer={}, onnx={}@{}",
143 model_id, onnx_filename, onnx_repo_id
144 );
145
146 let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
147 let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
148
149 let onnx_subdir = onnx_cache_dir.join("onnx");
150 std::fs::create_dir_all(&onnx_subdir)?;
151
152 let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
153 let onnx_basename = Path::new(onnx_filename)
154 .file_name()
155 .and_then(|s| s.to_str())
156 .unwrap_or("model_quantized.onnx");
157 let local_onnx = onnx_subdir.join(onnx_basename);
158
159 if use_gpu && local_onnx.exists() {
161 let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
162 if cached_size <= 500_000_000 {
163 warn!(
164 "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated. Deleting.",
165 local_onnx, cached_size
166 );
167 let _ = std::fs::remove_file(&local_onnx);
168 }
169 }
170
171 if !local_tokenizer.exists() || !local_onnx.exists() {
172 let model_id_owned = model_id.to_string();
173 let onnx_repo_id_owned = onnx_repo_id.to_string();
174 let onnx_filename_owned = onnx_filename.to_string();
175 let tokenizer_cache = tokenizer_cache_dir.clone();
176 let onnx_cache = onnx_cache_dir.clone();
177
178 tokio::task::spawn_blocking(move || {
179 if !tokenizer_cache.join("tokenizer.json").exists() {
180 Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
181 .map_err(|e| {
182 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
183 })?;
184 }
185 if !onnx_cache.join(&onnx_filename_owned).exists() {
186 Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
187 .map_err(|e| {
188 InferenceError::HubError(format!(
189 "Failed to download ONNX model: {}",
190 e
191 ))
192 })?;
193 }
194 Ok::<_, InferenceError>(())
195 })
196 .await
197 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
198 } else {
199 info!("All model files found in local cache");
200 }
201
202 let final_onnx = onnx_cache_dir.join(onnx_filename);
203 Ok((local_tokenizer, final_onnx))
204 }
205
206 pub fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
208 let base = std::env::var("HF_HOME")
209 .map(PathBuf::from)
210 .unwrap_or_else(|_| {
211 let home = std::env::var("HOME").unwrap_or_else(|_| {
212 warn!("HOME environment variable not set, using /tmp for model cache");
213 "/tmp".to_string()
214 });
215 PathBuf::from(home).join(".cache").join("huggingface")
216 });
217 let dir = base.join("dakera").join(model_id.replace('/', "--"));
218 std::fs::create_dir_all(&dir)?;
219 Ok(dir)
220 }
221
222 pub fn download_hf_file(
224 model_id: &str,
225 filename: &str,
226 cache_dir: &Path,
227 ) -> std::result::Result<PathBuf, String> {
228 let file_path = cache_dir.join(filename);
229 if file_path.exists() {
230 info!("Cached: {}/{}", model_id, filename);
231 return Ok(file_path);
232 }
233
234 if let Some(parent) = file_path.parent() {
235 std::fs::create_dir_all(parent)
236 .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
237 }
238
239 let url = format!(
240 "https://huggingface.co/{}/resolve/main/{}",
241 model_id, filename
242 );
243 info!("Downloading: {}", url);
244
245 let hf_token = std::env::var("HF_TOKEN")
246 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
247 .ok();
248
249 let agent = ureq::AgentBuilder::new()
250 .redirects(0)
251 .timeout(std::time::Duration::from_secs(300))
252 .build();
253
254 let mut current_url = url;
255 let mut redirects = 0_u32;
256
257 let response = loop {
258 let mut req = agent.get(¤t_url);
259 if let Some(ref token) = hf_token {
260 req = req.set("Authorization", &format!("Bearer {}", token));
261 }
262 let resp = req.call();
263
264 let r = match resp {
265 Ok(r) => r,
266 Err(ureq::Error::Status(_status, r)) => r,
267 Err(e) => return Err(format!("{}: {}", filename, e)),
268 };
269
270 let status = r.status();
271 if (200..300).contains(&status) {
272 break r;
273 } else if (300..400).contains(&status) {
274 redirects += 1;
275 if redirects > 10 {
276 return Err(format!("{}: too many redirects", filename));
277 }
278 let location = r
279 .header("location")
280 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
281 .to_string();
282
283 current_url = if location.starts_with('/') {
284 let parsed = url::Url::parse(¤t_url)
285 .map_err(|e| format!("{}: bad URL: {}", filename, e))?;
286 let host = parsed
287 .host_str()
288 .ok_or_else(|| format!("{}: missing host", filename))?;
289 format!("{}://{}{}", parsed.scheme(), host, location)
290 } else {
291 location
292 };
293 } else {
294 return Err(format!("{}: HTTP {}", filename, status));
295 }
296 };
297
298 let expected_bytes: Option<u64> = response
299 .header("x-linked-size")
300 .or_else(|| response.header("content-length"))
301 .and_then(|v| v.parse::<u64>().ok());
302
303 let mut bytes = Vec::new();
304 response
305 .into_reader()
306 .take(2_147_483_648)
307 .read_to_end(&mut bytes)
308 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
309
310 if let Some(expected) = expected_bytes {
311 if (bytes.len() as u64) < expected {
312 return Err(format!(
313 "{}: download incomplete — received {} of {} bytes",
314 filename,
315 bytes.len(),
316 expected
317 ));
318 }
319 }
320
321 std::fs::write(&file_path, &bytes)
322 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
323
324 info!("Downloaded {} ({} bytes)", filename, bytes.len());
325 Ok(file_path)
326 }
327
328 pub fn download_hf_file_pub(
330 model_id: &str,
331 filename: &str,
332 cache_dir: &Path,
333 ) -> std::result::Result<PathBuf, String> {
334 Self::download_hf_file(model_id, filename, cache_dir)
335 }
336
337 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
341 if texts.is_empty() {
342 return Ok(vec![]);
343 }
344
345 let pool_len = self.sessions.len();
346 let normalize = self.config.model.normalize_embeddings();
347 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
348 let mut batch_size = self.config.max_batch_size.max(1);
349
350 for attempt in 0_u32..=3 {
351 let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
352
353 let mut handles = Vec::with_capacity(batches.len());
354 for (i, batch_owned) in batches.into_iter().enumerate() {
355 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
356 let processor = Arc::clone(&self.processor);
357 let gpu_permit = if self.use_gpu {
362 Some(
363 std::sync::Arc::clone(&crate::GPU_INFERENCE_SEMAPHORE)
364 .acquire_owned()
365 .await
366 .map_err(|_| {
367 InferenceError::InferenceError(
368 "GPU inference semaphore closed unexpectedly".to_string(),
369 )
370 })?,
371 )
372 } else {
373 None
374 };
375 handles.push(tokio::task::spawn_blocking(move || {
376 let _gpu_permit = gpu_permit; let mut session_guard = session.lock();
378 Self::process_batch_blocking(
379 &batch_owned,
380 &mut session_guard,
381 &processor,
382 normalize,
383 )
384 }));
385 }
386
387 let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
388 let mut oom: Option<InferenceError> = None;
389
390 for handle in handles {
391 match handle.await {
392 Err(panic_err) => {
393 return Err(InferenceError::InferenceError(format!(
394 "Inference task panicked: {panic_err}"
395 )));
396 }
397 Ok(Err(e)) => {
398 if attempt < 3 && Self::is_gpu_oom(&e) {
399 oom = Some(e);
400 break;
401 }
402 return Err(e);
403 }
404 Ok(Ok(batch_embs)) => {
405 all_embeddings.extend(batch_embs);
406 }
407 }
408 }
409
410 if oom.is_some() {
411 let next_batch = (batch_size / 2).max(1);
412 warn!(
413 "ONNX allocator OOM (attempt {}/3) — retrying with batch_size {} → {}",
414 attempt + 1,
415 batch_size,
416 next_batch,
417 );
418 batch_size = next_batch;
419 continue;
420 }
421
422 return Ok(all_embeddings);
423 }
424
425 Err(InferenceError::InferenceError(format!(
426 "ONNX inference failed: allocator OOM after 3 batch-halving attempts (batch_size={batch_size})"
427 )))
428 }
429
430 fn is_gpu_oom(err: &InferenceError) -> bool {
431 let msg = err.to_string();
432 msg.contains("BFCArena")
433 || msg.contains("Failed to allocate memory")
434 || msg.contains("CUDA_OUT_OF_MEMORY")
435 || msg.contains("CUDA out of memory")
436 || (msg.contains("allocate") && msg.contains("buffer of size"))
437 }
438
439 fn process_batch_blocking(
440 texts: &[String],
441 session: &mut Session,
442 processor: &BatchProcessor,
443 normalize: bool,
444 ) -> Result<Vec<Vec<f32>>> {
445 let prepared = processor.tokenize_batch(texts)?;
446 let batch_size = prepared.batch_size;
447 let seq_len = prepared.seq_len;
448 let attention_mask_flat = prepared.attention_mask.clone();
449
450 let input_ids_tensor =
451 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
452 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
453 let attention_mask_tensor =
454 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
455 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
456 let token_type_ids_tensor =
457 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
458 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
459
460 let outputs = session
461 .run(inputs![
462 "input_ids" => input_ids_tensor,
463 "attention_mask" => attention_mask_tensor,
464 "token_type_ids" => token_type_ids_tensor
465 ])
466 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
467
468 let (ort_shape, lhs_slice) = outputs[0]
469 .try_extract_tensor::<f32>()
470 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
471
472 if ort_shape.len() != 3 {
473 return Err(InferenceError::InferenceError(format!(
474 "Expected 3D last_hidden_state, got {} dims",
475 ort_shape.len()
476 )));
477 }
478 let hidden_size = ort_shape[2] as usize;
479
480 let mut embeddings = mean_pooling(
481 lhs_slice,
482 batch_size,
483 seq_len,
484 hidden_size,
485 &attention_mask_flat,
486 );
487
488 if normalize {
489 normalize_embeddings(&mut embeddings);
490 }
491
492 Ok(embeddings)
493 }
494}
495
496#[async_trait]
497impl EmbeddingBackend for OnnxBackend {
498 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
499 self.embed_batch_internal(texts).await
500 }
501
502 fn dimension(&self) -> usize {
503 self.dimension
504 }
505
506 fn backend_kind(&self) -> BackendKind {
507 BackendKind::Onnx
508 }
509}