Skip to main content

inference/backend/
onnx.rs

1//! ONNX Runtime embedding backend.
2//!
3//! This is the production default backend.  It wraps the existing session-pool
4//! logic that was previously inline in `EmbeddingEngine`.  No functional change —
5//! pure extraction to satisfy the [`EmbeddingBackend`] trait.
6//!
7//! ## GPU mode (DAKERA_USE_GPU=1)
8//!
9//! When GPU is enabled, `pool_size` is capped to **1** regardless of
10//! `DAKERA_ONNX_POOL_SIZE`.  A single session with a `parking_lot::Mutex`
11//! naturally serialises CUDA forward passes without a separate semaphore.  The
12//! CUDA execution provider is configured with a hard memory limit
13//! (`DAKERA_GPU_MEM_LIMIT_GB`, default 15 GB) and
14//! `ArenaExtendStrategy::SameAsRequested` to prevent `BFCArena` growth beyond
15//! the working set.  Together these guarantee at most one concurrent CUDA call
16//! and bounded VRAM usage, replacing the application-level
17//! `GPU_INFERENCE_SEMAPHORE` approach from v0.11.79/v0.11.80 (DAK-6134).
18//!
19//! ## CPU mode
20//!
21//! A pool of `N` independent ONNX sessions (`N = DAKERA_ONNX_POOL_SIZE`, default
22//! 4) serves concurrent callers via round-robin dispatch.  Each session runs
23//! inside a `tokio::task::spawn_blocking` call to avoid blocking the async
24//! executor.
25
26use crate::backend::{BackendKind, EmbeddingBackend};
27use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
28use crate::error::{InferenceError, Result};
29use crate::models::ModelConfig;
30use async_trait::async_trait;
31use ort::execution_providers::{ArenaExtendStrategy, CUDAExecutionProvider};
32use ort::inputs;
33use ort::session::builder::GraphOptimizationLevel;
34use ort::session::Session;
35use ort::value::Tensor;
36use parking_lot::Mutex;
37use std::io::Read;
38use std::path::{Path, PathBuf};
39use std::sync::atomic::{AtomicUsize, Ordering};
40use std::sync::Arc;
41use tokenizers::Tokenizer;
42use tracing::{info, instrument, warn};
43
44/// ONNX Runtime embedding backend with a session pool for concurrent inference.
45pub struct OnnxBackend {
46    sessions: Vec<Arc<Mutex<Session>>>,
47    next_session: AtomicUsize,
48    processor: Arc<BatchProcessor>,
49    config: ModelConfig,
50    dimension: usize,
51}
52
53/// Determine the ONNX session pool size based on inference mode.
54///
55/// GPU: always 1 — a single `parking_lot::Mutex`-guarded session serialises CUDA calls
56/// at the allocator level, replacing the application-level `GPU_INFERENCE_SEMAPHORE`.
57/// CPU: the configured pool size (minimum 1) for concurrent inference.
58fn resolve_pool_size(use_gpu: bool, configured: usize) -> usize {
59    if use_gpu {
60        1
61    } else {
62        configured.max(1)
63    }
64}
65
66impl OnnxBackend {
67    /// Build a new `OnnxBackend` by downloading model files and building the session pool.
68    #[instrument(skip_all, fields(model = %config.model))]
69    pub async fn new(config: &ModelConfig) -> Result<Self> {
70        let config = config.clone();
71        let use_gpu = std::env::var("DAKERA_USE_GPU")
72            .map(|v| v == "1")
73            .unwrap_or(config.use_gpu);
74
75        if use_gpu {
76            info!("ONNX backend: CUDA execution provider enabled (DAKERA_USE_GPU=1)");
77        }
78        info!("Initialising ONNX backend: model={}", config.model);
79
80        let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
81
82        info!("Loading tokenizer from {:?}", tokenizer_path);
83        let tokenizer = Tokenizer::from_file(&tokenizer_path)
84            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
85
86        let num_threads = config.num_threads.unwrap_or(4);
87        let pool_size = resolve_pool_size(use_gpu, config.session_pool_size);
88
89        // GPU memory limit: default 15 GB on L4 (24 GB total), overrideable via env.
90        let gpu_mem_limit_bytes: usize = std::env::var("DAKERA_GPU_MEM_LIMIT_GB")
91            .ok()
92            .and_then(|v| v.parse::<usize>().ok())
93            .unwrap_or(15)
94            * 1024
95            * 1024
96            * 1024;
97
98        if use_gpu {
99            info!(
100                "ONNX backend: GPU mode — pool_size=1, gpu_mem_limit={}GB",
101                gpu_mem_limit_bytes / (1024 * 1024 * 1024)
102            );
103        }
104
105        let onnx_path_clone = onnx_path.clone();
106
107        let sessions: Vec<Arc<Mutex<Session>>> =
108            tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
109                (0..pool_size)
110                    .map(|_| {
111                        let builder = Session::builder()
112                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
113                            .with_optimization_level(GraphOptimizationLevel::Level3)
114                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
115                            .with_intra_threads(num_threads)
116                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
117
118                        // DAK-6145: disable memory pattern pre-allocation on CPU.
119                        // On memory-constrained servers, ORT's pattern-based pre-allocation
120                        // exhausts the BFCArena before inference starts, causing batch=4
121                        // (8MB query/Add buffer) to fail even after 3 halvings.  With
122                        // memory_pattern=false, allocation is fresh per-run so headroom
123                        // remains available for per-layer activations.
124                        let mut builder = if use_gpu {
125                            builder
126                                .with_execution_providers([CUDAExecutionProvider::default()
127                                    .with_memory_limit(gpu_mem_limit_bytes)
128                                    .with_arena_extend_strategy(
129                                        ArenaExtendStrategy::SameAsRequested,
130                                    )
131                                    .build()])
132                                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
133                        } else {
134                            builder
135                                .with_memory_pattern(false)
136                                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
137                        };
138
139                        let s = builder
140                            .commit_from_file(&onnx_path_clone)
141                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
142                        Ok(Arc::new(Mutex::new(s)))
143                    })
144                    .collect()
145            })
146            .await
147            .map_err(|e| {
148                InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
149            })??;
150
151        let dimension = config.model.dimension();
152        let processor = Arc::new(BatchProcessor::new(
153            tokenizer,
154            config.model,
155            config.max_batch_size,
156        ));
157
158        info!(
159            "ONNX backend ready: model={}, dimension={}, threads={}, pool={}",
160            config.model, dimension, num_threads, pool_size
161        );
162
163        Ok(Self {
164            sessions,
165            next_session: AtomicUsize::new(0),
166            processor,
167            config,
168            dimension,
169        })
170    }
171
172    /// Number of ONNX sessions in the pool.
173    pub fn pool_size(&self) -> usize {
174        self.sessions.len()
175    }
176
177    // ── File download helpers (shared with CandleBackend) ──────────────────────
178
179    /// Resolve tokenizer and ONNX model files, downloading from HuggingFace if needed.
180    #[instrument(skip_all, fields(model = %config.model))]
181    pub async fn download_model_files(
182        config: &ModelConfig,
183        use_gpu: bool,
184    ) -> Result<(PathBuf, PathBuf)> {
185        let model_id = config.model.model_id();
186        let onnx_repo_id = config.model.onnx_repo_id();
187        let onnx_filename = if use_gpu {
188            config.model.onnx_filename_gpu()
189        } else {
190            config.model.onnx_filename()
191        };
192
193        info!(
194            "Resolving model files: tokenizer={}, onnx={}@{}",
195            model_id, onnx_filename, onnx_repo_id
196        );
197
198        let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
199        let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
200
201        let onnx_subdir = onnx_cache_dir.join("onnx");
202        std::fs::create_dir_all(&onnx_subdir)?;
203
204        let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
205        let onnx_basename = Path::new(onnx_filename)
206            .file_name()
207            .and_then(|s| s.to_str())
208            .unwrap_or("model_quantized.onnx");
209        let local_onnx = onnx_subdir.join(onnx_basename);
210
211        // GPU FP32 model truncation guard (DAK-5976)
212        if use_gpu && local_onnx.exists() {
213            let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
214            if cached_size <= 500_000_000 {
215                warn!(
216                    "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated. Deleting.",
217                    local_onnx, cached_size
218                );
219                let _ = std::fs::remove_file(&local_onnx);
220            }
221        }
222
223        if !local_tokenizer.exists() || !local_onnx.exists() {
224            let model_id_owned = model_id.to_string();
225            let onnx_repo_id_owned = onnx_repo_id.to_string();
226            let onnx_filename_owned = onnx_filename.to_string();
227            let tokenizer_cache = tokenizer_cache_dir.clone();
228            let onnx_cache = onnx_cache_dir.clone();
229
230            tokio::task::spawn_blocking(move || {
231                if !tokenizer_cache.join("tokenizer.json").exists() {
232                    Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
233                        .map_err(|e| {
234                            InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
235                        })?;
236                }
237                if !onnx_cache.join(&onnx_filename_owned).exists() {
238                    Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
239                        .map_err(|e| {
240                            InferenceError::HubError(format!(
241                                "Failed to download ONNX model: {}",
242                                e
243                            ))
244                        })?;
245                }
246                Ok::<_, InferenceError>(())
247            })
248            .await
249            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
250        } else {
251            info!("All model files found in local cache");
252        }
253
254        let final_onnx = onnx_cache_dir.join(onnx_filename);
255        Ok((local_tokenizer, final_onnx))
256    }
257
258    /// Get or create the local model cache directory.
259    pub fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
260        let base = std::env::var("HF_HOME")
261            .map(PathBuf::from)
262            .unwrap_or_else(|_| {
263                let home = std::env::var("HOME").unwrap_or_else(|_| {
264                    warn!("HOME environment variable not set, using /tmp for model cache");
265                    "/tmp".to_string()
266                });
267                PathBuf::from(home).join(".cache").join("huggingface")
268            });
269        let dir = base.join("dakera").join(model_id.replace('/', "--"));
270        std::fs::create_dir_all(&dir)?;
271        Ok(dir)
272    }
273
274    /// Download a single file from HuggingFace using ureq (sync, call inside spawn_blocking).
275    pub fn download_hf_file(
276        model_id: &str,
277        filename: &str,
278        cache_dir: &Path,
279    ) -> std::result::Result<PathBuf, String> {
280        let file_path = cache_dir.join(filename);
281        if file_path.exists() {
282            info!("Cached: {}/{}", model_id, filename);
283            return Ok(file_path);
284        }
285
286        if let Some(parent) = file_path.parent() {
287            std::fs::create_dir_all(parent)
288                .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
289        }
290
291        let url = format!(
292            "https://huggingface.co/{}/resolve/main/{}",
293            model_id, filename
294        );
295        info!("Downloading: {}", url);
296
297        let hf_token = std::env::var("HF_TOKEN")
298            .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
299            .ok();
300
301        let agent = ureq::AgentBuilder::new()
302            .redirects(0)
303            .timeout(std::time::Duration::from_secs(300))
304            .build();
305
306        let mut current_url = url;
307        let mut redirects = 0_u32;
308
309        let response = loop {
310            let mut req = agent.get(&current_url);
311            if let Some(ref token) = hf_token {
312                req = req.set("Authorization", &format!("Bearer {}", token));
313            }
314            let resp = req.call();
315
316            let r = match resp {
317                Ok(r) => r,
318                Err(ureq::Error::Status(_status, r)) => r,
319                Err(e) => return Err(format!("{}: {}", filename, e)),
320            };
321
322            let status = r.status();
323            if (200..300).contains(&status) {
324                break r;
325            } else if (300..400).contains(&status) {
326                redirects += 1;
327                if redirects > 10 {
328                    return Err(format!("{}: too many redirects", filename));
329                }
330                let location = r
331                    .header("location")
332                    .ok_or_else(|| format!("{}: redirect without Location header", filename))?
333                    .to_string();
334
335                current_url = if location.starts_with('/') {
336                    let parsed = url::Url::parse(&current_url)
337                        .map_err(|e| format!("{}: bad URL: {}", filename, e))?;
338                    let host = parsed
339                        .host_str()
340                        .ok_or_else(|| format!("{}: missing host", filename))?;
341                    format!("{}://{}{}", parsed.scheme(), host, location)
342                } else {
343                    location
344                };
345            } else {
346                return Err(format!("{}: HTTP {}", filename, status));
347            }
348        };
349
350        let expected_bytes: Option<u64> = response
351            .header("x-linked-size")
352            .or_else(|| response.header("content-length"))
353            .and_then(|v| v.parse::<u64>().ok());
354
355        let mut bytes = Vec::new();
356        response
357            .into_reader()
358            .take(2_147_483_648)
359            .read_to_end(&mut bytes)
360            .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
361
362        if let Some(expected) = expected_bytes {
363            if (bytes.len() as u64) < expected {
364                return Err(format!(
365                    "{}: download incomplete — received {} of {} bytes",
366                    filename,
367                    bytes.len(),
368                    expected
369                ));
370            }
371        }
372
373        std::fs::write(&file_path, &bytes)
374            .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
375
376        info!("Downloaded {} ({} bytes)", filename, bytes.len());
377        Ok(file_path)
378    }
379
380    /// Public alias for use by other inference modules (e.g. GLiNER NER engine).
381    pub fn download_hf_file_pub(
382        model_id: &str,
383        filename: &str,
384        cache_dir: &Path,
385    ) -> std::result::Result<PathBuf, String> {
386        Self::download_hf_file(model_id, filename, cache_dir)
387    }
388
389    // ── Internal embedding logic ───────────────────────────────────────────────
390
391    /// Internal batch embedding: split → distribute across pool → collect.
392    ///
393    /// On BFCArena / allocator OOM the batch is halved and retried until
394    /// `batch_size == 1`.  Starting from the default `max_batch_size=32` this
395    /// gives up to 5 halvings (32→16→8→4→2→1) before surfacing the error,
396    /// covering the case where even a 4-text batch exceeds available memory
397    /// under concurrent load (DAK-6145).
398    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
399        if texts.is_empty() {
400            return Ok(vec![]);
401        }
402
403        let pool_len = self.sessions.len();
404        let normalize = self.config.model.normalize_embeddings();
405        let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
406        let mut batch_size = self.config.max_batch_size.max(1);
407
408        // DAK-6145: 5 halvings — 32→16→8→4→2→1 — before hard fail.
409        // Previous depth of 3 (stopping at batch=4) was insufficient: BGE-Large at
410        // batch=4 seq≈503 still needs 8MB for the first query/Add buffer, which the
411        // arena cannot provide when concurrent sessions have exhausted system RAM.
412        // batch=1 requires only ~2MB and reliably succeeds under memory pressure.
413        for attempt in 0_u32..=5 {
414            let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
415
416            let mut handles = Vec::with_capacity(batches.len());
417            for (i, batch_owned) in batches.into_iter().enumerate() {
418                let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
419                let processor = Arc::clone(&self.processor);
420                // GPU mode: pool_size=1 so all handles point to the same session. The
421                // parking_lot::Mutex serialises CUDA forward passes implicitly — no
422                // application-level semaphore needed (DAK-6134 deep fix).
423                handles.push(tokio::task::spawn_blocking(move || {
424                    let mut session_guard = session.lock();
425                    Self::process_batch_blocking(
426                        &batch_owned,
427                        &mut session_guard,
428                        &processor,
429                        normalize,
430                    )
431                }));
432            }
433
434            let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
435            let mut oom: Option<InferenceError> = None;
436
437            for handle in handles {
438                match handle.await {
439                    Err(panic_err) => {
440                        return Err(InferenceError::InferenceError(format!(
441                            "Inference task panicked: {panic_err}"
442                        )));
443                    }
444                    Ok(Err(e)) => {
445                        if attempt < 5 && Self::is_gpu_oom(&e) {
446                            oom = Some(e);
447                            break;
448                        }
449                        return Err(e);
450                    }
451                    Ok(Ok(batch_embs)) => {
452                        all_embeddings.extend(batch_embs);
453                    }
454                }
455            }
456
457            if let Some(_oom_err) = oom {
458                let next_batch = (batch_size / 2).max(1);
459                warn!(
460                    "ONNX allocator OOM (attempt {}/5) — retrying with batch_size {} → {}",
461                    attempt + 1,
462                    batch_size,
463                    next_batch,
464                );
465                batch_size = next_batch;
466                continue;
467            }
468
469            return Ok(all_embeddings);
470        }
471
472        Err(InferenceError::InferenceError(format!(
473            "ONNX inference failed: allocator OOM after 5 batch-halving attempts (batch_size={batch_size})"
474        )))
475    }
476
477    fn is_gpu_oom(err: &InferenceError) -> bool {
478        let msg = err.to_string();
479        msg.contains("BFCArena")
480            || msg.contains("Failed to allocate memory")
481            || msg.contains("CUDA_OUT_OF_MEMORY")
482            || msg.contains("CUDA out of memory")
483            || (msg.contains("allocate") && msg.contains("buffer of size"))
484    }
485
486    fn process_batch_blocking(
487        texts: &[String],
488        session: &mut Session,
489        processor: &BatchProcessor,
490        normalize: bool,
491    ) -> Result<Vec<Vec<f32>>> {
492        let prepared = processor.tokenize_batch(texts)?;
493        let batch_size = prepared.batch_size;
494        let seq_len = prepared.seq_len;
495        let attention_mask_flat = prepared.attention_mask.clone();
496
497        let input_ids_tensor =
498            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
499                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
500        let attention_mask_tensor =
501            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
502                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
503        let token_type_ids_tensor =
504            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
505                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
506
507        let outputs = session
508            .run(inputs![
509                "input_ids" => input_ids_tensor,
510                "attention_mask" => attention_mask_tensor,
511                "token_type_ids" => token_type_ids_tensor
512            ])
513            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
514
515        let (ort_shape, lhs_slice) = outputs[0]
516            .try_extract_tensor::<f32>()
517            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
518
519        if ort_shape.len() != 3 {
520            return Err(InferenceError::InferenceError(format!(
521                "Expected 3D last_hidden_state, got {} dims",
522                ort_shape.len()
523            )));
524        }
525        let hidden_size = ort_shape[2] as usize;
526
527        let mut embeddings = mean_pooling(
528            lhs_slice,
529            batch_size,
530            seq_len,
531            hidden_size,
532            &attention_mask_flat,
533        );
534
535        if normalize {
536            normalize_embeddings(&mut embeddings);
537        }
538
539        Ok(embeddings)
540    }
541}
542
543#[async_trait]
544impl EmbeddingBackend for OnnxBackend {
545    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
546        self.embed_batch_internal(texts).await
547    }
548
549    fn dimension(&self) -> usize {
550        self.dimension
551    }
552
553    fn backend_kind(&self) -> BackendKind {
554        BackendKind::Onnx
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::{resolve_pool_size, OnnxBackend};
561    use crate::error::InferenceError;
562
563    #[test]
564    fn gpu_mode_always_pool_size_one() {
565        assert_eq!(resolve_pool_size(true, 1), 1);
566        assert_eq!(
567            resolve_pool_size(true, 4),
568            1,
569            "GPU overrides configured pool_size=4 → 1"
570        );
571        assert_eq!(resolve_pool_size(true, 0), 1, "GPU overrides zero → 1");
572    }
573
574    #[test]
575    fn cpu_mode_respects_configured_pool_size() {
576        assert_eq!(resolve_pool_size(false, 4), 4);
577        assert_eq!(resolve_pool_size(false, 1), 1);
578        assert_eq!(
579            resolve_pool_size(false, 0),
580            1,
581            "CPU clamps zero to minimum 1"
582        );
583    }
584
585    // ── is_gpu_oom detection (DAK-6145) ─────────────────────────────────────
586
587    fn oom_err(msg: &str) -> InferenceError {
588        InferenceError::InferenceError(msg.to_string())
589    }
590
591    #[test]
592    fn detects_bfcarena_oom() {
593        let e = oom_err("Non-zero status code returned while running Add node. \
594            Status Message: bfc_arena.cc:358 void *onnxruntime::BFCArena::\
595            AllocateRawInternal(size_t, bool, Stream *) Failed to allocate memory \
596            for requested buffer of size 8241152");
597        assert!(OnnxBackend::is_gpu_oom(&e), "BFCArena OOM must be detected");
598    }
599
600    #[test]
601    fn detects_cuda_out_of_memory() {
602        let e = oom_err("CUDA_OUT_OF_MEMORY: out of memory on device 0");
603        assert!(OnnxBackend::is_gpu_oom(&e));
604    }
605
606    #[test]
607    fn detects_allocate_buffer_pattern() {
608        let e = oom_err("Failed to allocate memory for requested buffer of size 1234");
609        assert!(OnnxBackend::is_gpu_oom(&e));
610    }
611
612    #[test]
613    fn non_oom_error_not_detected() {
614        let e = oom_err("Shape mismatch: expected [4, 512] got [4, 256]");
615        assert!(!OnnxBackend::is_gpu_oom(&e), "shape error must not trigger OOM retry");
616    }
617
618    /// Verify that halving from max_batch_size=32 reaches batch_size=1 in ≤5 steps.
619    #[test]
620    fn batch_halving_reaches_one_in_five_steps() {
621        let mut batch_size = 32_usize;
622        let mut halvings = 0_u32;
623        while batch_size > 1 {
624            batch_size = (batch_size / 2).max(1);
625            halvings += 1;
626        }
627        assert_eq!(batch_size, 1);
628        assert!(halvings <= 5, "expected ≤5 halvings, got {halvings}");
629    }
630}