Skip to main content

inference/backend/
onnx.rs

1//! ONNX Runtime embedding backend.
2//!
3//! This is the production default backend.  It wraps the existing session-pool
4//! logic that was previously inline in `EmbeddingEngine`.  No functional change —
5//! pure extraction to satisfy the [`EmbeddingBackend`] trait.
6//!
7//! A pool of `N` independent ONNX sessions (`N = DAKERA_ONNX_POOL_SIZE`, default 4)
8//! serves concurrent callers via round-robin dispatch.  Each session runs inside a
9//! `tokio::task::spawn_blocking` call to avoid blocking the async executor.
10
11use crate::backend::{BackendKind, EmbeddingBackend};
12use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
13use crate::error::{InferenceError, Result};
14use crate::models::ModelConfig;
15use async_trait::async_trait;
16use ort::execution_providers::CUDAExecutionProvider;
17use ort::inputs;
18use ort::session::builder::GraphOptimizationLevel;
19use ort::session::Session;
20use ort::value::Tensor;
21use parking_lot::Mutex;
22use std::io::Read;
23use std::path::{Path, PathBuf};
24use std::sync::atomic::{AtomicUsize, Ordering};
25use std::sync::Arc;
26use tokenizers::Tokenizer;
27use tracing::{info, instrument, warn};
28
29/// ONNX Runtime embedding backend with a session pool for concurrent inference.
30pub struct OnnxBackend {
31    sessions: Vec<Arc<Mutex<Session>>>,
32    next_session: AtomicUsize,
33    processor: Arc<BatchProcessor>,
34    config: ModelConfig,
35    dimension: usize,
36    /// Resolved once at construction from DAKERA_USE_GPU env var; controls GPU semaphore
37    /// acquisition in embed_batch_internal to prevent concurrent VRAM exhaustion (DAK-6129).
38    use_gpu: bool,
39}
40
41impl OnnxBackend {
42    /// Build a new `OnnxBackend` by downloading model files and building the session pool.
43    #[instrument(skip_all, fields(model = %config.model))]
44    pub async fn new(config: &ModelConfig) -> Result<Self> {
45        let config = config.clone();
46        let use_gpu = std::env::var("DAKERA_USE_GPU")
47            .map(|v| v == "1")
48            .unwrap_or(config.use_gpu);
49
50        if use_gpu {
51            info!("ONNX backend: CUDA execution provider enabled (DAKERA_USE_GPU=1)");
52        }
53        info!("Initialising ONNX backend: model={}", config.model);
54
55        let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
56
57        info!("Loading tokenizer from {:?}", tokenizer_path);
58        let tokenizer = Tokenizer::from_file(&tokenizer_path)
59            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
60
61        let num_threads = config.num_threads.unwrap_or(4);
62        let pool_size = config.session_pool_size.max(1);
63        let onnx_path_clone = onnx_path.clone();
64
65        let sessions: Vec<Arc<Mutex<Session>>> =
66            tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
67                (0..pool_size)
68                    .map(|_| {
69                        let builder = Session::builder()
70                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
71                            .with_optimization_level(GraphOptimizationLevel::Level3)
72                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
73                            .with_intra_threads(num_threads)
74                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
75
76                        let mut builder = if use_gpu {
77                            builder
78                                .with_execution_providers(
79                                    [CUDAExecutionProvider::default().build()],
80                                )
81                                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
82                        } else {
83                            builder
84                        };
85
86                        let s = builder
87                            .commit_from_file(&onnx_path_clone)
88                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
89                        Ok(Arc::new(Mutex::new(s)))
90                    })
91                    .collect()
92            })
93            .await
94            .map_err(|e| {
95                InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
96            })??;
97
98        let dimension = config.model.dimension();
99        let processor = Arc::new(BatchProcessor::new(
100            tokenizer,
101            config.model,
102            config.max_batch_size,
103        ));
104
105        info!(
106            "ONNX backend ready: model={}, dimension={}, threads={}, pool={}",
107            config.model, dimension, num_threads, pool_size
108        );
109
110        Ok(Self {
111            sessions,
112            next_session: AtomicUsize::new(0),
113            processor,
114            config,
115            dimension,
116            use_gpu,
117        })
118    }
119
120    /// Number of ONNX sessions in the pool.
121    pub fn pool_size(&self) -> usize {
122        self.sessions.len()
123    }
124
125    // ── File download helpers (shared with CandleBackend) ──────────────────────
126
127    /// Resolve tokenizer and ONNX model files, downloading from HuggingFace if needed.
128    #[instrument(skip_all, fields(model = %config.model))]
129    pub async fn download_model_files(
130        config: &ModelConfig,
131        use_gpu: bool,
132    ) -> Result<(PathBuf, PathBuf)> {
133        let model_id = config.model.model_id();
134        let onnx_repo_id = config.model.onnx_repo_id();
135        let onnx_filename = if use_gpu {
136            config.model.onnx_filename_gpu()
137        } else {
138            config.model.onnx_filename()
139        };
140
141        info!(
142            "Resolving model files: tokenizer={}, onnx={}@{}",
143            model_id, onnx_filename, onnx_repo_id
144        );
145
146        let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
147        let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
148
149        let onnx_subdir = onnx_cache_dir.join("onnx");
150        std::fs::create_dir_all(&onnx_subdir)?;
151
152        let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
153        let onnx_basename = Path::new(onnx_filename)
154            .file_name()
155            .and_then(|s| s.to_str())
156            .unwrap_or("model_quantized.onnx");
157        let local_onnx = onnx_subdir.join(onnx_basename);
158
159        // GPU FP32 model truncation guard (DAK-5976)
160        if use_gpu && local_onnx.exists() {
161            let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
162            if cached_size <= 500_000_000 {
163                warn!(
164                    "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated. Deleting.",
165                    local_onnx, cached_size
166                );
167                let _ = std::fs::remove_file(&local_onnx);
168            }
169        }
170
171        if !local_tokenizer.exists() || !local_onnx.exists() {
172            let model_id_owned = model_id.to_string();
173            let onnx_repo_id_owned = onnx_repo_id.to_string();
174            let onnx_filename_owned = onnx_filename.to_string();
175            let tokenizer_cache = tokenizer_cache_dir.clone();
176            let onnx_cache = onnx_cache_dir.clone();
177
178            tokio::task::spawn_blocking(move || {
179                if !tokenizer_cache.join("tokenizer.json").exists() {
180                    Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
181                        .map_err(|e| {
182                            InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
183                        })?;
184                }
185                if !onnx_cache.join(&onnx_filename_owned).exists() {
186                    Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
187                        .map_err(|e| {
188                            InferenceError::HubError(format!(
189                                "Failed to download ONNX model: {}",
190                                e
191                            ))
192                        })?;
193                }
194                Ok::<_, InferenceError>(())
195            })
196            .await
197            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
198        } else {
199            info!("All model files found in local cache");
200        }
201
202        let final_onnx = onnx_cache_dir.join(onnx_filename);
203        Ok((local_tokenizer, final_onnx))
204    }
205
206    /// Get or create the local model cache directory.
207    pub fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
208        let base = std::env::var("HF_HOME")
209            .map(PathBuf::from)
210            .unwrap_or_else(|_| {
211                let home = std::env::var("HOME").unwrap_or_else(|_| {
212                    warn!("HOME environment variable not set, using /tmp for model cache");
213                    "/tmp".to_string()
214                });
215                PathBuf::from(home).join(".cache").join("huggingface")
216            });
217        let dir = base.join("dakera").join(model_id.replace('/', "--"));
218        std::fs::create_dir_all(&dir)?;
219        Ok(dir)
220    }
221
222    /// Download a single file from HuggingFace using ureq (sync, call inside spawn_blocking).
223    pub fn download_hf_file(
224        model_id: &str,
225        filename: &str,
226        cache_dir: &Path,
227    ) -> std::result::Result<PathBuf, String> {
228        let file_path = cache_dir.join(filename);
229        if file_path.exists() {
230            info!("Cached: {}/{}", model_id, filename);
231            return Ok(file_path);
232        }
233
234        if let Some(parent) = file_path.parent() {
235            std::fs::create_dir_all(parent)
236                .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
237        }
238
239        let url = format!(
240            "https://huggingface.co/{}/resolve/main/{}",
241            model_id, filename
242        );
243        info!("Downloading: {}", url);
244
245        let hf_token = std::env::var("HF_TOKEN")
246            .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
247            .ok();
248
249        let agent = ureq::AgentBuilder::new()
250            .redirects(0)
251            .timeout(std::time::Duration::from_secs(300))
252            .build();
253
254        let mut current_url = url;
255        let mut redirects = 0_u32;
256
257        let response = loop {
258            let mut req = agent.get(&current_url);
259            if let Some(ref token) = hf_token {
260                req = req.set("Authorization", &format!("Bearer {}", token));
261            }
262            let resp = req.call();
263
264            let r = match resp {
265                Ok(r) => r,
266                Err(ureq::Error::Status(_status, r)) => r,
267                Err(e) => return Err(format!("{}: {}", filename, e)),
268            };
269
270            let status = r.status();
271            if (200..300).contains(&status) {
272                break r;
273            } else if (300..400).contains(&status) {
274                redirects += 1;
275                if redirects > 10 {
276                    return Err(format!("{}: too many redirects", filename));
277                }
278                let location = r
279                    .header("location")
280                    .ok_or_else(|| format!("{}: redirect without Location header", filename))?
281                    .to_string();
282
283                current_url = if location.starts_with('/') {
284                    let parsed = url::Url::parse(&current_url)
285                        .map_err(|e| format!("{}: bad URL: {}", filename, e))?;
286                    let host = parsed
287                        .host_str()
288                        .ok_or_else(|| format!("{}: missing host", filename))?;
289                    format!("{}://{}{}", parsed.scheme(), host, location)
290                } else {
291                    location
292                };
293            } else {
294                return Err(format!("{}: HTTP {}", filename, status));
295            }
296        };
297
298        let expected_bytes: Option<u64> = response
299            .header("x-linked-size")
300            .or_else(|| response.header("content-length"))
301            .and_then(|v| v.parse::<u64>().ok());
302
303        let mut bytes = Vec::new();
304        response
305            .into_reader()
306            .take(2_147_483_648)
307            .read_to_end(&mut bytes)
308            .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
309
310        if let Some(expected) = expected_bytes {
311            if (bytes.len() as u64) < expected {
312                return Err(format!(
313                    "{}: download incomplete — received {} of {} bytes",
314                    filename,
315                    bytes.len(),
316                    expected
317                ));
318            }
319        }
320
321        std::fs::write(&file_path, &bytes)
322            .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
323
324        info!("Downloaded {} ({} bytes)", filename, bytes.len());
325        Ok(file_path)
326    }
327
328    /// Public alias for use by other inference modules (e.g. GLiNER NER engine).
329    pub fn download_hf_file_pub(
330        model_id: &str,
331        filename: &str,
332        cache_dir: &Path,
333    ) -> std::result::Result<PathBuf, String> {
334        Self::download_hf_file(model_id, filename, cache_dir)
335    }
336
337    // ── Internal embedding logic ───────────────────────────────────────────────
338
339    /// Internal batch embedding: split → distribute across pool → collect.
340    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
341        if texts.is_empty() {
342            return Ok(vec![]);
343        }
344
345        let pool_len = self.sessions.len();
346        let normalize = self.config.model.normalize_embeddings();
347        let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
348        let mut batch_size = self.config.max_batch_size.max(1);
349
350        for attempt in 0_u32..=3 {
351            let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
352
353            let mut handles = Vec::with_capacity(batches.len());
354            for (i, batch_owned) in batches.into_iter().enumerate() {
355                let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
356                let processor = Arc::clone(&self.processor);
357                // DAK-6129: acquire the global GPU semaphore (1 permit) before each CUDA forward
358                // pass to prevent concurrent VRAM allocations from exhausting device memory.
359                // The .await serializes GPU batches; the permit is moved into the closure and
360                // dropped when spawn_blocking completes — matching engine.rs / candle_backend.rs.
361                let gpu_permit = if self.use_gpu {
362                    Some(
363                        std::sync::Arc::clone(&crate::GPU_INFERENCE_SEMAPHORE)
364                            .acquire_owned()
365                            .await
366                            .map_err(|_| {
367                                InferenceError::InferenceError(
368                                    "GPU inference semaphore closed unexpectedly".to_string(),
369                                )
370                            })?,
371                    )
372                } else {
373                    None
374                };
375                handles.push(tokio::task::spawn_blocking(move || {
376                    let _gpu_permit = gpu_permit; // held until blocking task completes
377                    let mut session_guard = session.lock();
378                    Self::process_batch_blocking(
379                        &batch_owned,
380                        &mut session_guard,
381                        &processor,
382                        normalize,
383                    )
384                }));
385            }
386
387            let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
388            let mut oom: Option<InferenceError> = None;
389
390            for handle in handles {
391                match handle.await {
392                    Err(panic_err) => {
393                        return Err(InferenceError::InferenceError(format!(
394                            "Inference task panicked: {panic_err}"
395                        )));
396                    }
397                    Ok(Err(e)) => {
398                        if attempt < 3 && Self::is_gpu_oom(&e) {
399                            oom = Some(e);
400                            break;
401                        }
402                        return Err(e);
403                    }
404                    Ok(Ok(batch_embs)) => {
405                        all_embeddings.extend(batch_embs);
406                    }
407                }
408            }
409
410            if oom.is_some() {
411                let next_batch = (batch_size / 2).max(1);
412                warn!(
413                    "ONNX allocator OOM (attempt {}/3) — retrying with batch_size {} → {}",
414                    attempt + 1,
415                    batch_size,
416                    next_batch,
417                );
418                batch_size = next_batch;
419                continue;
420            }
421
422            return Ok(all_embeddings);
423        }
424
425        Err(InferenceError::InferenceError(format!(
426            "ONNX inference failed: allocator OOM after 3 batch-halving attempts (batch_size={batch_size})"
427        )))
428    }
429
430    fn is_gpu_oom(err: &InferenceError) -> bool {
431        let msg = err.to_string();
432        msg.contains("BFCArena")
433            || msg.contains("Failed to allocate memory")
434            || msg.contains("CUDA_OUT_OF_MEMORY")
435            || msg.contains("CUDA out of memory")
436            || (msg.contains("allocate") && msg.contains("buffer of size"))
437    }
438
439    fn process_batch_blocking(
440        texts: &[String],
441        session: &mut Session,
442        processor: &BatchProcessor,
443        normalize: bool,
444    ) -> Result<Vec<Vec<f32>>> {
445        let prepared = processor.tokenize_batch(texts)?;
446        let batch_size = prepared.batch_size;
447        let seq_len = prepared.seq_len;
448        let attention_mask_flat = prepared.attention_mask.clone();
449
450        let input_ids_tensor =
451            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
452                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
453        let attention_mask_tensor =
454            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
455                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
456        let token_type_ids_tensor =
457            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
458                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
459
460        let outputs = session
461            .run(inputs![
462                "input_ids" => input_ids_tensor,
463                "attention_mask" => attention_mask_tensor,
464                "token_type_ids" => token_type_ids_tensor
465            ])
466            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
467
468        let (ort_shape, lhs_slice) = outputs[0]
469            .try_extract_tensor::<f32>()
470            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
471
472        if ort_shape.len() != 3 {
473            return Err(InferenceError::InferenceError(format!(
474                "Expected 3D last_hidden_state, got {} dims",
475                ort_shape.len()
476            )));
477        }
478        let hidden_size = ort_shape[2] as usize;
479
480        let mut embeddings = mean_pooling(
481            lhs_slice,
482            batch_size,
483            seq_len,
484            hidden_size,
485            &attention_mask_flat,
486        );
487
488        if normalize {
489            normalize_embeddings(&mut embeddings);
490        }
491
492        Ok(embeddings)
493    }
494}
495
496#[async_trait]
497impl EmbeddingBackend for OnnxBackend {
498    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
499        self.embed_batch_internal(texts).await
500    }
501
502    fn dimension(&self) -> usize {
503        self.dimension
504    }
505
506    fn backend_kind(&self) -> BackendKind {
507        BackendKind::Onnx
508    }
509}