Skip to main content

inference/backend/
onnx.rs

1//! ONNX Runtime embedding backend.
2//!
3//! This is the production default backend.  It wraps the existing session-pool
4//! logic that was previously inline in `EmbeddingEngine`.  No functional change —
5//! pure extraction to satisfy the [`EmbeddingBackend`] trait.
6//!
7//! A pool of `N` independent ONNX sessions (`N = DAKERA_ONNX_POOL_SIZE`, default 4)
8//! serves concurrent callers via round-robin dispatch.  Each session runs inside a
9//! `tokio::task::spawn_blocking` call to avoid blocking the async executor.
10
11use crate::backend::{BackendKind, EmbeddingBackend};
12use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
13use crate::error::{InferenceError, Result};
14use crate::models::ModelConfig;
15use async_trait::async_trait;
16use ort::execution_providers::CUDAExecutionProvider;
17use ort::inputs;
18use ort::session::builder::GraphOptimizationLevel;
19use ort::session::Session;
20use ort::value::Tensor;
21use parking_lot::Mutex;
22use std::io::Read;
23use std::path::{Path, PathBuf};
24use std::sync::atomic::{AtomicUsize, Ordering};
25use std::sync::Arc;
26use tokenizers::Tokenizer;
27use tracing::{info, instrument, warn};
28
29/// ONNX Runtime embedding backend with a session pool for concurrent inference.
30pub struct OnnxBackend {
31    sessions: Vec<Arc<Mutex<Session>>>,
32    next_session: AtomicUsize,
33    processor: Arc<BatchProcessor>,
34    config: ModelConfig,
35    dimension: usize,
36}
37
38impl OnnxBackend {
39    /// Build a new `OnnxBackend` by downloading model files and building the session pool.
40    #[instrument(skip_all, fields(model = %config.model))]
41    pub async fn new(config: &ModelConfig) -> Result<Self> {
42        let config = config.clone();
43        let use_gpu = std::env::var("DAKERA_USE_GPU")
44            .map(|v| v == "1")
45            .unwrap_or(config.use_gpu);
46
47        if use_gpu {
48            info!("ONNX backend: CUDA execution provider enabled (DAKERA_USE_GPU=1)");
49        }
50        info!("Initialising ONNX backend: model={}", config.model);
51
52        let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
53
54        info!("Loading tokenizer from {:?}", tokenizer_path);
55        let tokenizer = Tokenizer::from_file(&tokenizer_path)
56            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
57
58        let num_threads = config.num_threads.unwrap_or(4);
59        let pool_size = config.session_pool_size.max(1);
60        let onnx_path_clone = onnx_path.clone();
61
62        let sessions: Vec<Arc<Mutex<Session>>> =
63            tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
64                (0..pool_size)
65                    .map(|_| {
66                        let builder = Session::builder()
67                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
68                            .with_optimization_level(GraphOptimizationLevel::Level3)
69                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
70                            .with_intra_threads(num_threads)
71                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
72
73                        let mut builder = if use_gpu {
74                            builder
75                                .with_execution_providers(
76                                    [CUDAExecutionProvider::default().build()],
77                                )
78                                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
79                        } else {
80                            builder
81                        };
82
83                        let s = builder
84                            .commit_from_file(&onnx_path_clone)
85                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
86                        Ok(Arc::new(Mutex::new(s)))
87                    })
88                    .collect()
89            })
90            .await
91            .map_err(|e| {
92                InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
93            })??;
94
95        let dimension = config.model.dimension();
96        let processor = Arc::new(BatchProcessor::new(
97            tokenizer,
98            config.model,
99            config.max_batch_size,
100        ));
101
102        info!(
103            "ONNX backend ready: model={}, dimension={}, threads={}, pool={}",
104            config.model, dimension, num_threads, pool_size
105        );
106
107        Ok(Self {
108            sessions,
109            next_session: AtomicUsize::new(0),
110            processor,
111            config,
112            dimension,
113        })
114    }
115
116    /// Number of ONNX sessions in the pool.
117    pub fn pool_size(&self) -> usize {
118        self.sessions.len()
119    }
120
121    // ── File download helpers (shared with CandleBackend) ──────────────────────
122
123    /// Resolve tokenizer and ONNX model files, downloading from HuggingFace if needed.
124    #[instrument(skip_all, fields(model = %config.model))]
125    pub async fn download_model_files(
126        config: &ModelConfig,
127        use_gpu: bool,
128    ) -> Result<(PathBuf, PathBuf)> {
129        let model_id = config.model.model_id();
130        let onnx_repo_id = config.model.onnx_repo_id();
131        let onnx_filename = if use_gpu {
132            config.model.onnx_filename_gpu()
133        } else {
134            config.model.onnx_filename()
135        };
136
137        info!(
138            "Resolving model files: tokenizer={}, onnx={}@{}",
139            model_id, onnx_filename, onnx_repo_id
140        );
141
142        let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
143        let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
144
145        let onnx_subdir = onnx_cache_dir.join("onnx");
146        std::fs::create_dir_all(&onnx_subdir)?;
147
148        let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
149        let onnx_basename = Path::new(onnx_filename)
150            .file_name()
151            .and_then(|s| s.to_str())
152            .unwrap_or("model_quantized.onnx");
153        let local_onnx = onnx_subdir.join(onnx_basename);
154
155        // GPU FP32 model truncation guard (DAK-5976)
156        if use_gpu && local_onnx.exists() {
157            let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
158            if cached_size <= 500_000_000 {
159                warn!(
160                    "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated. Deleting.",
161                    local_onnx, cached_size
162                );
163                let _ = std::fs::remove_file(&local_onnx);
164            }
165        }
166
167        if !local_tokenizer.exists() || !local_onnx.exists() {
168            let model_id_owned = model_id.to_string();
169            let onnx_repo_id_owned = onnx_repo_id.to_string();
170            let onnx_filename_owned = onnx_filename.to_string();
171            let tokenizer_cache = tokenizer_cache_dir.clone();
172            let onnx_cache = onnx_cache_dir.clone();
173
174            tokio::task::spawn_blocking(move || {
175                if !tokenizer_cache.join("tokenizer.json").exists() {
176                    Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
177                        .map_err(|e| {
178                            InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
179                        })?;
180                }
181                if !onnx_cache.join(&onnx_filename_owned).exists() {
182                    Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
183                        .map_err(|e| {
184                            InferenceError::HubError(format!(
185                                "Failed to download ONNX model: {}",
186                                e
187                            ))
188                        })?;
189                }
190                Ok::<_, InferenceError>(())
191            })
192            .await
193            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
194        } else {
195            info!("All model files found in local cache");
196        }
197
198        let final_onnx = onnx_cache_dir.join(onnx_filename);
199        Ok((local_tokenizer, final_onnx))
200    }
201
202    /// Get or create the local model cache directory.
203    pub fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
204        let base = std::env::var("HF_HOME")
205            .map(PathBuf::from)
206            .unwrap_or_else(|_| {
207                let home = std::env::var("HOME").unwrap_or_else(|_| {
208                    warn!("HOME environment variable not set, using /tmp for model cache");
209                    "/tmp".to_string()
210                });
211                PathBuf::from(home).join(".cache").join("huggingface")
212            });
213        let dir = base.join("dakera").join(model_id.replace('/', "--"));
214        std::fs::create_dir_all(&dir)?;
215        Ok(dir)
216    }
217
218    /// Download a single file from HuggingFace using ureq (sync, call inside spawn_blocking).
219    pub fn download_hf_file(
220        model_id: &str,
221        filename: &str,
222        cache_dir: &Path,
223    ) -> std::result::Result<PathBuf, String> {
224        let file_path = cache_dir.join(filename);
225        if file_path.exists() {
226            info!("Cached: {}/{}", model_id, filename);
227            return Ok(file_path);
228        }
229
230        if let Some(parent) = file_path.parent() {
231            std::fs::create_dir_all(parent)
232                .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
233        }
234
235        let url = format!(
236            "https://huggingface.co/{}/resolve/main/{}",
237            model_id, filename
238        );
239        info!("Downloading: {}", url);
240
241        let hf_token = std::env::var("HF_TOKEN")
242            .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
243            .ok();
244
245        let agent = ureq::AgentBuilder::new()
246            .redirects(0)
247            .timeout(std::time::Duration::from_secs(300))
248            .build();
249
250        let mut current_url = url;
251        let mut redirects = 0_u32;
252
253        let response = loop {
254            let mut req = agent.get(&current_url);
255            if let Some(ref token) = hf_token {
256                req = req.set("Authorization", &format!("Bearer {}", token));
257            }
258            let resp = req.call();
259
260            let r = match resp {
261                Ok(r) => r,
262                Err(ureq::Error::Status(_status, r)) => r,
263                Err(e) => return Err(format!("{}: {}", filename, e)),
264            };
265
266            let status = r.status();
267            if (200..300).contains(&status) {
268                break r;
269            } else if (300..400).contains(&status) {
270                redirects += 1;
271                if redirects > 10 {
272                    return Err(format!("{}: too many redirects", filename));
273                }
274                let location = r
275                    .header("location")
276                    .ok_or_else(|| format!("{}: redirect without Location header", filename))?
277                    .to_string();
278
279                current_url = if location.starts_with('/') {
280                    let parsed = url::Url::parse(&current_url)
281                        .map_err(|e| format!("{}: bad URL: {}", filename, e))?;
282                    let host = parsed
283                        .host_str()
284                        .ok_or_else(|| format!("{}: missing host", filename))?;
285                    format!("{}://{}{}", parsed.scheme(), host, location)
286                } else {
287                    location
288                };
289            } else {
290                return Err(format!("{}: HTTP {}", filename, status));
291            }
292        };
293
294        let expected_bytes: Option<u64> = response
295            .header("x-linked-size")
296            .or_else(|| response.header("content-length"))
297            .and_then(|v| v.parse::<u64>().ok());
298
299        let mut bytes = Vec::new();
300        response
301            .into_reader()
302            .take(2_147_483_648)
303            .read_to_end(&mut bytes)
304            .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
305
306        if let Some(expected) = expected_bytes {
307            if (bytes.len() as u64) < expected {
308                return Err(format!(
309                    "{}: download incomplete — received {} of {} bytes",
310                    filename,
311                    bytes.len(),
312                    expected
313                ));
314            }
315        }
316
317        std::fs::write(&file_path, &bytes)
318            .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
319
320        info!("Downloaded {} ({} bytes)", filename, bytes.len());
321        Ok(file_path)
322    }
323
324    /// Public alias for use by other inference modules (e.g. GLiNER NER engine).
325    pub fn download_hf_file_pub(
326        model_id: &str,
327        filename: &str,
328        cache_dir: &Path,
329    ) -> std::result::Result<PathBuf, String> {
330        Self::download_hf_file(model_id, filename, cache_dir)
331    }
332
333    // ── Internal embedding logic ───────────────────────────────────────────────
334
335    /// Internal batch embedding: split → distribute across pool → collect.
336    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
337        if texts.is_empty() {
338            return Ok(vec![]);
339        }
340
341        let pool_len = self.sessions.len();
342        let normalize = self.config.model.normalize_embeddings();
343        let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
344        let mut batch_size = self.config.max_batch_size.max(1);
345
346        for attempt in 0_u32..=3 {
347            let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
348
349            let mut handles = Vec::with_capacity(batches.len());
350            for (i, batch_owned) in batches.into_iter().enumerate() {
351                let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
352                let processor = Arc::clone(&self.processor);
353                handles.push(tokio::task::spawn_blocking(move || {
354                    let mut session_guard = session.lock();
355                    Self::process_batch_blocking(
356                        &batch_owned,
357                        &mut session_guard,
358                        &processor,
359                        normalize,
360                    )
361                }));
362            }
363
364            let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
365            let mut oom: Option<InferenceError> = None;
366
367            for handle in handles {
368                match handle.await {
369                    Err(panic_err) => {
370                        return Err(InferenceError::InferenceError(format!(
371                            "Inference task panicked: {panic_err}"
372                        )));
373                    }
374                    Ok(Err(e)) => {
375                        if attempt < 3 && Self::is_gpu_oom(&e) {
376                            oom = Some(e);
377                            break;
378                        }
379                        return Err(e);
380                    }
381                    Ok(Ok(batch_embs)) => {
382                        all_embeddings.extend(batch_embs);
383                    }
384                }
385            }
386
387            if oom.is_some() {
388                let next_batch = (batch_size / 2).max(1);
389                warn!(
390                    "ONNX allocator OOM (attempt {}/3) — retrying with batch_size {} → {}",
391                    attempt + 1,
392                    batch_size,
393                    next_batch,
394                );
395                batch_size = next_batch;
396                continue;
397            }
398
399            return Ok(all_embeddings);
400        }
401
402        Err(InferenceError::InferenceError(format!(
403            "ONNX inference failed: allocator OOM after 3 batch-halving attempts (batch_size={batch_size})"
404        )))
405    }
406
407    fn is_gpu_oom(err: &InferenceError) -> bool {
408        let msg = err.to_string();
409        msg.contains("BFCArena")
410            || msg.contains("Failed to allocate memory")
411            || msg.contains("CUDA_OUT_OF_MEMORY")
412            || msg.contains("CUDA out of memory")
413            || (msg.contains("allocate") && msg.contains("buffer of size"))
414    }
415
416    fn process_batch_blocking(
417        texts: &[String],
418        session: &mut Session,
419        processor: &BatchProcessor,
420        normalize: bool,
421    ) -> Result<Vec<Vec<f32>>> {
422        let prepared = processor.tokenize_batch(texts)?;
423        let batch_size = prepared.batch_size;
424        let seq_len = prepared.seq_len;
425        let attention_mask_flat = prepared.attention_mask.clone();
426
427        let input_ids_tensor =
428            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
429                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
430        let attention_mask_tensor =
431            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
432                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
433        let token_type_ids_tensor =
434            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
435                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
436
437        let outputs = session
438            .run(inputs![
439                "input_ids" => input_ids_tensor,
440                "attention_mask" => attention_mask_tensor,
441                "token_type_ids" => token_type_ids_tensor
442            ])
443            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
444
445        let (ort_shape, lhs_slice) = outputs[0]
446            .try_extract_tensor::<f32>()
447            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
448
449        if ort_shape.len() != 3 {
450            return Err(InferenceError::InferenceError(format!(
451                "Expected 3D last_hidden_state, got {} dims",
452                ort_shape.len()
453            )));
454        }
455        let hidden_size = ort_shape[2] as usize;
456
457        let mut embeddings = mean_pooling(
458            lhs_slice,
459            batch_size,
460            seq_len,
461            hidden_size,
462            &attention_mask_flat,
463        );
464
465        if normalize {
466            normalize_embeddings(&mut embeddings);
467        }
468
469        Ok(embeddings)
470    }
471}
472
473#[async_trait]
474impl EmbeddingBackend for OnnxBackend {
475    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
476        self.embed_batch_internal(texts).await
477    }
478
479    fn dimension(&self) -> usize {
480        self.dimension
481    }
482
483    fn backend_kind(&self) -> BackendKind {
484        BackendKind::Onnx
485    }
486}