1use crate::backend::{BackendKind, EmbeddingBackend};
27use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
28use crate::error::{InferenceError, Result};
29use crate::models::ModelConfig;
30use async_trait::async_trait;
31use ort::execution_providers::{ArenaExtendStrategy, CUDAExecutionProvider};
32use ort::inputs;
33use ort::session::builder::GraphOptimizationLevel;
34use ort::session::Session;
35use ort::value::Tensor;
36use parking_lot::Mutex;
37use std::io::Read;
38use std::path::{Path, PathBuf};
39use std::sync::atomic::{AtomicUsize, Ordering};
40use std::sync::Arc;
41use tokenizers::Tokenizer;
42use tracing::{info, instrument, warn};
43
44pub struct OnnxBackend {
46 sessions: Vec<Arc<Mutex<Session>>>,
47 next_session: AtomicUsize,
48 processor: Arc<BatchProcessor>,
49 config: ModelConfig,
50 dimension: usize,
51}
52
53fn resolve_pool_size(use_gpu: bool, configured: usize) -> usize {
59 if use_gpu {
60 1
61 } else {
62 configured.max(1)
63 }
64}
65
66impl OnnxBackend {
67 #[instrument(skip_all, fields(model = %config.model))]
69 pub async fn new(config: &ModelConfig) -> Result<Self> {
70 let config = config.clone();
71 let use_gpu = std::env::var("DAKERA_USE_GPU")
72 .map(|v| v == "1")
73 .unwrap_or(config.use_gpu);
74
75 if use_gpu {
76 info!("ONNX backend: CUDA execution provider enabled (DAKERA_USE_GPU=1)");
77 }
78 info!("Initialising ONNX backend: model={}", config.model);
79
80 let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
81
82 info!("Loading tokenizer from {:?}", tokenizer_path);
83 let tokenizer = Tokenizer::from_file(&tokenizer_path)
84 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
85
86 let num_threads = config.num_threads.unwrap_or(4);
87 let pool_size = resolve_pool_size(use_gpu, config.session_pool_size);
88
89 let gpu_mem_limit_bytes: usize = std::env::var("DAKERA_GPU_MEM_LIMIT_GB")
91 .ok()
92 .and_then(|v| v.parse::<usize>().ok())
93 .unwrap_or(15)
94 * 1024
95 * 1024
96 * 1024;
97
98 if use_gpu {
99 info!(
100 "ONNX backend: GPU mode — pool_size=1, gpu_mem_limit={}GB",
101 gpu_mem_limit_bytes / (1024 * 1024 * 1024)
102 );
103 }
104
105 let onnx_path_clone = onnx_path.clone();
106
107 let sessions: Vec<Arc<Mutex<Session>>> =
108 tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
109 (0..pool_size)
110 .map(|_| {
111 let builder = Session::builder()
112 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
113 .with_optimization_level(GraphOptimizationLevel::Level3)
114 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
115 .with_intra_threads(num_threads)
116 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
117
118 let mut builder = if use_gpu {
125 builder
126 .with_execution_providers([CUDAExecutionProvider::default()
127 .with_memory_limit(gpu_mem_limit_bytes)
128 .with_arena_extend_strategy(
129 ArenaExtendStrategy::SameAsRequested,
130 )
131 .build()])
132 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
133 } else {
134 builder
135 .with_memory_pattern(false)
136 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
137 };
138
139 let s = builder
140 .commit_from_file(&onnx_path_clone)
141 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
142 Ok(Arc::new(Mutex::new(s)))
143 })
144 .collect()
145 })
146 .await
147 .map_err(|e| {
148 InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
149 })??;
150
151 let dimension = config.model.dimension();
152 let processor = Arc::new(BatchProcessor::new(
153 tokenizer,
154 config.model,
155 config.max_batch_size,
156 ));
157
158 info!(
159 "ONNX backend ready: model={}, dimension={}, threads={}, pool={}",
160 config.model, dimension, num_threads, pool_size
161 );
162
163 Ok(Self {
164 sessions,
165 next_session: AtomicUsize::new(0),
166 processor,
167 config,
168 dimension,
169 })
170 }
171
172 pub fn pool_size(&self) -> usize {
174 self.sessions.len()
175 }
176
177 #[instrument(skip_all, fields(model = %config.model))]
181 pub async fn download_model_files(
182 config: &ModelConfig,
183 use_gpu: bool,
184 ) -> Result<(PathBuf, PathBuf)> {
185 let model_id = config.model.model_id();
186 let onnx_repo_id = config.model.onnx_repo_id();
187 let onnx_filename = if use_gpu {
188 config.model.onnx_filename_gpu()
189 } else {
190 config.model.onnx_filename()
191 };
192
193 info!(
194 "Resolving model files: tokenizer={}, onnx={}@{}",
195 model_id, onnx_filename, onnx_repo_id
196 );
197
198 let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
199 let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
200
201 let onnx_subdir = onnx_cache_dir.join("onnx");
202 std::fs::create_dir_all(&onnx_subdir)?;
203
204 let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
205 let onnx_basename = Path::new(onnx_filename)
206 .file_name()
207 .and_then(|s| s.to_str())
208 .unwrap_or("model_quantized.onnx");
209 let local_onnx = onnx_subdir.join(onnx_basename);
210
211 if use_gpu && local_onnx.exists() {
213 let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
214 if cached_size <= 500_000_000 {
215 warn!(
216 "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated. Deleting.",
217 local_onnx, cached_size
218 );
219 let _ = std::fs::remove_file(&local_onnx);
220 }
221 }
222
223 if !local_tokenizer.exists() || !local_onnx.exists() {
224 let model_id_owned = model_id.to_string();
225 let onnx_repo_id_owned = onnx_repo_id.to_string();
226 let onnx_filename_owned = onnx_filename.to_string();
227 let tokenizer_cache = tokenizer_cache_dir.clone();
228 let onnx_cache = onnx_cache_dir.clone();
229
230 tokio::task::spawn_blocking(move || {
231 if !tokenizer_cache.join("tokenizer.json").exists() {
232 Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
233 .map_err(|e| {
234 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
235 })?;
236 }
237 if !onnx_cache.join(&onnx_filename_owned).exists() {
238 Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
239 .map_err(|e| {
240 InferenceError::HubError(format!(
241 "Failed to download ONNX model: {}",
242 e
243 ))
244 })?;
245 }
246 Ok::<_, InferenceError>(())
247 })
248 .await
249 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
250 } else {
251 info!("All model files found in local cache");
252 }
253
254 let final_onnx = onnx_cache_dir.join(onnx_filename);
255 Ok((local_tokenizer, final_onnx))
256 }
257
258 pub fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
260 let base = std::env::var("HF_HOME")
261 .map(PathBuf::from)
262 .unwrap_or_else(|_| {
263 let home = std::env::var("HOME").unwrap_or_else(|_| {
264 warn!("HOME environment variable not set, using /tmp for model cache");
265 "/tmp".to_string()
266 });
267 PathBuf::from(home).join(".cache").join("huggingface")
268 });
269 let dir = base.join("dakera").join(model_id.replace('/', "--"));
270 std::fs::create_dir_all(&dir)?;
271 Ok(dir)
272 }
273
274 pub fn download_hf_file(
276 model_id: &str,
277 filename: &str,
278 cache_dir: &Path,
279 ) -> std::result::Result<PathBuf, String> {
280 let file_path = cache_dir.join(filename);
281 if file_path.exists() {
282 info!("Cached: {}/{}", model_id, filename);
283 return Ok(file_path);
284 }
285
286 if let Some(parent) = file_path.parent() {
287 std::fs::create_dir_all(parent)
288 .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
289 }
290
291 let url = format!(
292 "https://huggingface.co/{}/resolve/main/{}",
293 model_id, filename
294 );
295 info!("Downloading: {}", url);
296
297 let hf_token = std::env::var("HF_TOKEN")
298 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
299 .ok();
300
301 let agent = ureq::AgentBuilder::new()
302 .redirects(0)
303 .timeout(std::time::Duration::from_secs(300))
304 .build();
305
306 let mut current_url = url;
307 let mut redirects = 0_u32;
308
309 let response = loop {
310 let mut req = agent.get(¤t_url);
311 if let Some(ref token) = hf_token {
312 req = req.set("Authorization", &format!("Bearer {}", token));
313 }
314 let resp = req.call();
315
316 let r = match resp {
317 Ok(r) => r,
318 Err(ureq::Error::Status(_status, r)) => r,
319 Err(e) => return Err(format!("{}: {}", filename, e)),
320 };
321
322 let status = r.status();
323 if (200..300).contains(&status) {
324 break r;
325 } else if (300..400).contains(&status) {
326 redirects += 1;
327 if redirects > 10 {
328 return Err(format!("{}: too many redirects", filename));
329 }
330 let location = r
331 .header("location")
332 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
333 .to_string();
334
335 current_url = if location.starts_with('/') {
336 let parsed = url::Url::parse(¤t_url)
337 .map_err(|e| format!("{}: bad URL: {}", filename, e))?;
338 let host = parsed
339 .host_str()
340 .ok_or_else(|| format!("{}: missing host", filename))?;
341 format!("{}://{}{}", parsed.scheme(), host, location)
342 } else {
343 location
344 };
345 } else {
346 return Err(format!("{}: HTTP {}", filename, status));
347 }
348 };
349
350 let expected_bytes: Option<u64> = response
351 .header("x-linked-size")
352 .or_else(|| response.header("content-length"))
353 .and_then(|v| v.parse::<u64>().ok());
354
355 let mut bytes = Vec::new();
356 response
357 .into_reader()
358 .take(2_147_483_648)
359 .read_to_end(&mut bytes)
360 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
361
362 if let Some(expected) = expected_bytes {
363 if (bytes.len() as u64) < expected {
364 return Err(format!(
365 "{}: download incomplete — received {} of {} bytes",
366 filename,
367 bytes.len(),
368 expected
369 ));
370 }
371 }
372
373 std::fs::write(&file_path, &bytes)
374 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
375
376 info!("Downloaded {} ({} bytes)", filename, bytes.len());
377 Ok(file_path)
378 }
379
380 pub fn download_hf_file_pub(
382 model_id: &str,
383 filename: &str,
384 cache_dir: &Path,
385 ) -> std::result::Result<PathBuf, String> {
386 Self::download_hf_file(model_id, filename, cache_dir)
387 }
388
389 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
399 if texts.is_empty() {
400 return Ok(vec![]);
401 }
402
403 let pool_len = self.sessions.len();
404 let normalize = self.config.model.normalize_embeddings();
405 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
406 let mut batch_size = self.config.max_batch_size.max(1);
407
408 for attempt in 0_u32..=5 {
414 let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
415
416 let mut handles = Vec::with_capacity(batches.len());
417 for (i, batch_owned) in batches.into_iter().enumerate() {
418 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
419 let processor = Arc::clone(&self.processor);
420 handles.push(tokio::task::spawn_blocking(move || {
424 let mut session_guard = session.lock();
425 Self::process_batch_blocking(
426 &batch_owned,
427 &mut session_guard,
428 &processor,
429 normalize,
430 )
431 }));
432 }
433
434 let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
435 let mut oom: Option<InferenceError> = None;
436
437 for handle in handles {
438 match handle.await {
439 Err(panic_err) => {
440 return Err(InferenceError::InferenceError(format!(
441 "Inference task panicked: {panic_err}"
442 )));
443 }
444 Ok(Err(e)) => {
445 if attempt < 5 && Self::is_gpu_oom(&e) {
446 oom = Some(e);
447 break;
448 }
449 return Err(e);
450 }
451 Ok(Ok(batch_embs)) => {
452 all_embeddings.extend(batch_embs);
453 }
454 }
455 }
456
457 if let Some(_oom_err) = oom {
458 let next_batch = (batch_size / 2).max(1);
459 warn!(
460 "ONNX allocator OOM (attempt {}/5) — retrying with batch_size {} → {}",
461 attempt + 1,
462 batch_size,
463 next_batch,
464 );
465 batch_size = next_batch;
466 continue;
467 }
468
469 return Ok(all_embeddings);
470 }
471
472 Err(InferenceError::InferenceError(format!(
473 "ONNX inference failed: allocator OOM after 5 batch-halving attempts (batch_size={batch_size})"
474 )))
475 }
476
477 fn is_gpu_oom(err: &InferenceError) -> bool {
478 let msg = err.to_string();
479 msg.contains("BFCArena")
480 || msg.contains("Failed to allocate memory")
481 || msg.contains("CUDA_OUT_OF_MEMORY")
482 || msg.contains("CUDA out of memory")
483 || (msg.contains("allocate") && msg.contains("buffer of size"))
484 }
485
486 fn process_batch_blocking(
487 texts: &[String],
488 session: &mut Session,
489 processor: &BatchProcessor,
490 normalize: bool,
491 ) -> Result<Vec<Vec<f32>>> {
492 let prepared = processor.tokenize_batch(texts)?;
493 let batch_size = prepared.batch_size;
494 let seq_len = prepared.seq_len;
495 let attention_mask_flat = prepared.attention_mask.clone();
496
497 let input_ids_tensor =
498 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
499 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
500 let attention_mask_tensor =
501 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
502 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
503 let token_type_ids_tensor =
504 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
505 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
506
507 let outputs = session
508 .run(inputs![
509 "input_ids" => input_ids_tensor,
510 "attention_mask" => attention_mask_tensor,
511 "token_type_ids" => token_type_ids_tensor
512 ])
513 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
514
515 let (ort_shape, lhs_slice) = outputs[0]
516 .try_extract_tensor::<f32>()
517 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
518
519 if ort_shape.len() != 3 {
520 return Err(InferenceError::InferenceError(format!(
521 "Expected 3D last_hidden_state, got {} dims",
522 ort_shape.len()
523 )));
524 }
525 let hidden_size = ort_shape[2] as usize;
526
527 let mut embeddings = mean_pooling(
528 lhs_slice,
529 batch_size,
530 seq_len,
531 hidden_size,
532 &attention_mask_flat,
533 );
534
535 if normalize {
536 normalize_embeddings(&mut embeddings);
537 }
538
539 Ok(embeddings)
540 }
541}
542
543#[async_trait]
544impl EmbeddingBackend for OnnxBackend {
545 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
546 self.embed_batch_internal(texts).await
547 }
548
549 fn dimension(&self) -> usize {
550 self.dimension
551 }
552
553 fn backend_kind(&self) -> BackendKind {
554 BackendKind::Onnx
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::{resolve_pool_size, OnnxBackend};
561 use crate::error::InferenceError;
562
563 #[test]
564 fn gpu_mode_always_pool_size_one() {
565 assert_eq!(resolve_pool_size(true, 1), 1);
566 assert_eq!(
567 resolve_pool_size(true, 4),
568 1,
569 "GPU overrides configured pool_size=4 → 1"
570 );
571 assert_eq!(resolve_pool_size(true, 0), 1, "GPU overrides zero → 1");
572 }
573
574 #[test]
575 fn cpu_mode_respects_configured_pool_size() {
576 assert_eq!(resolve_pool_size(false, 4), 4);
577 assert_eq!(resolve_pool_size(false, 1), 1);
578 assert_eq!(
579 resolve_pool_size(false, 0),
580 1,
581 "CPU clamps zero to minimum 1"
582 );
583 }
584
585 fn oom_err(msg: &str) -> InferenceError {
588 InferenceError::InferenceError(msg.to_string())
589 }
590
591 #[test]
592 fn detects_bfcarena_oom() {
593 let e = oom_err("Non-zero status code returned while running Add node. \
594 Status Message: bfc_arena.cc:358 void *onnxruntime::BFCArena::\
595 AllocateRawInternal(size_t, bool, Stream *) Failed to allocate memory \
596 for requested buffer of size 8241152");
597 assert!(OnnxBackend::is_gpu_oom(&e), "BFCArena OOM must be detected");
598 }
599
600 #[test]
601 fn detects_cuda_out_of_memory() {
602 let e = oom_err("CUDA_OUT_OF_MEMORY: out of memory on device 0");
603 assert!(OnnxBackend::is_gpu_oom(&e));
604 }
605
606 #[test]
607 fn detects_allocate_buffer_pattern() {
608 let e = oom_err("Failed to allocate memory for requested buffer of size 1234");
609 assert!(OnnxBackend::is_gpu_oom(&e));
610 }
611
612 #[test]
613 fn non_oom_error_not_detected() {
614 let e = oom_err("Shape mismatch: expected [4, 512] got [4, 256]");
615 assert!(!OnnxBackend::is_gpu_oom(&e), "shape error must not trigger OOM retry");
616 }
617
618 #[test]
620 fn batch_halving_reaches_one_in_five_steps() {
621 let mut batch_size = 32_usize;
622 let mut halvings = 0_u32;
623 while batch_size > 1 {
624 batch_size = (batch_size / 2).max(1);
625 halvings += 1;
626 }
627 assert_eq!(batch_size, 1);
628 assert!(halvings <= 5, "expected ≤5 halvings, got {halvings}");
629 }
630}