1use crate::backend::{BackendKind, EmbeddingBackend};
12use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
13use crate::error::{InferenceError, Result};
14use crate::models::ModelConfig;
15use async_trait::async_trait;
16use ort::execution_providers::CUDAExecutionProvider;
17use ort::inputs;
18use ort::session::builder::GraphOptimizationLevel;
19use ort::session::Session;
20use ort::value::Tensor;
21use parking_lot::Mutex;
22use std::io::Read;
23use std::path::{Path, PathBuf};
24use std::sync::atomic::{AtomicUsize, Ordering};
25use std::sync::Arc;
26use tokenizers::Tokenizer;
27use tracing::{info, instrument, warn};
28
29pub struct OnnxBackend {
31 sessions: Vec<Arc<Mutex<Session>>>,
32 next_session: AtomicUsize,
33 processor: Arc<BatchProcessor>,
34 config: ModelConfig,
35 dimension: usize,
36}
37
38impl OnnxBackend {
39 #[instrument(skip_all, fields(model = %config.model))]
41 pub async fn new(config: &ModelConfig) -> Result<Self> {
42 let config = config.clone();
43 let use_gpu = std::env::var("DAKERA_USE_GPU")
44 .map(|v| v == "1")
45 .unwrap_or(config.use_gpu);
46
47 if use_gpu {
48 info!("ONNX backend: CUDA execution provider enabled (DAKERA_USE_GPU=1)");
49 }
50 info!("Initialising ONNX backend: model={}", config.model);
51
52 let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
53
54 info!("Loading tokenizer from {:?}", tokenizer_path);
55 let tokenizer = Tokenizer::from_file(&tokenizer_path)
56 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
57
58 let num_threads = config.num_threads.unwrap_or(4);
59 let pool_size = config.session_pool_size.max(1);
60 let onnx_path_clone = onnx_path.clone();
61
62 let sessions: Vec<Arc<Mutex<Session>>> =
63 tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
64 (0..pool_size)
65 .map(|_| {
66 let builder = Session::builder()
67 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
68 .with_optimization_level(GraphOptimizationLevel::Level3)
69 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
70 .with_intra_threads(num_threads)
71 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
72
73 let mut builder = if use_gpu {
74 builder
75 .with_execution_providers(
76 [CUDAExecutionProvider::default().build()],
77 )
78 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
79 } else {
80 builder
81 };
82
83 let s = builder
84 .commit_from_file(&onnx_path_clone)
85 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
86 Ok(Arc::new(Mutex::new(s)))
87 })
88 .collect()
89 })
90 .await
91 .map_err(|e| {
92 InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
93 })??;
94
95 let dimension = config.model.dimension();
96 let processor = Arc::new(BatchProcessor::new(
97 tokenizer,
98 config.model,
99 config.max_batch_size,
100 ));
101
102 info!(
103 "ONNX backend ready: model={}, dimension={}, threads={}, pool={}",
104 config.model, dimension, num_threads, pool_size
105 );
106
107 Ok(Self {
108 sessions,
109 next_session: AtomicUsize::new(0),
110 processor,
111 config,
112 dimension,
113 })
114 }
115
116 pub fn pool_size(&self) -> usize {
118 self.sessions.len()
119 }
120
121 #[instrument(skip_all, fields(model = %config.model))]
125 pub async fn download_model_files(
126 config: &ModelConfig,
127 use_gpu: bool,
128 ) -> Result<(PathBuf, PathBuf)> {
129 let model_id = config.model.model_id();
130 let onnx_repo_id = config.model.onnx_repo_id();
131 let onnx_filename = if use_gpu {
132 config.model.onnx_filename_gpu()
133 } else {
134 config.model.onnx_filename()
135 };
136
137 info!(
138 "Resolving model files: tokenizer={}, onnx={}@{}",
139 model_id, onnx_filename, onnx_repo_id
140 );
141
142 let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
143 let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
144
145 let onnx_subdir = onnx_cache_dir.join("onnx");
146 std::fs::create_dir_all(&onnx_subdir)?;
147
148 let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
149 let onnx_basename = Path::new(onnx_filename)
150 .file_name()
151 .and_then(|s| s.to_str())
152 .unwrap_or("model_quantized.onnx");
153 let local_onnx = onnx_subdir.join(onnx_basename);
154
155 if use_gpu && local_onnx.exists() {
157 let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
158 if cached_size <= 500_000_000 {
159 warn!(
160 "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated. Deleting.",
161 local_onnx, cached_size
162 );
163 let _ = std::fs::remove_file(&local_onnx);
164 }
165 }
166
167 if !local_tokenizer.exists() || !local_onnx.exists() {
168 let model_id_owned = model_id.to_string();
169 let onnx_repo_id_owned = onnx_repo_id.to_string();
170 let onnx_filename_owned = onnx_filename.to_string();
171 let tokenizer_cache = tokenizer_cache_dir.clone();
172 let onnx_cache = onnx_cache_dir.clone();
173
174 tokio::task::spawn_blocking(move || {
175 if !tokenizer_cache.join("tokenizer.json").exists() {
176 Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
177 .map_err(|e| {
178 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
179 })?;
180 }
181 if !onnx_cache.join(&onnx_filename_owned).exists() {
182 Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
183 .map_err(|e| {
184 InferenceError::HubError(format!(
185 "Failed to download ONNX model: {}",
186 e
187 ))
188 })?;
189 }
190 Ok::<_, InferenceError>(())
191 })
192 .await
193 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
194 } else {
195 info!("All model files found in local cache");
196 }
197
198 let final_onnx = onnx_cache_dir.join(onnx_filename);
199 Ok((local_tokenizer, final_onnx))
200 }
201
202 pub fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
204 let base = std::env::var("HF_HOME")
205 .map(PathBuf::from)
206 .unwrap_or_else(|_| {
207 let home = std::env::var("HOME").unwrap_or_else(|_| {
208 warn!("HOME environment variable not set, using /tmp for model cache");
209 "/tmp".to_string()
210 });
211 PathBuf::from(home).join(".cache").join("huggingface")
212 });
213 let dir = base.join("dakera").join(model_id.replace('/', "--"));
214 std::fs::create_dir_all(&dir)?;
215 Ok(dir)
216 }
217
218 pub fn download_hf_file(
220 model_id: &str,
221 filename: &str,
222 cache_dir: &Path,
223 ) -> std::result::Result<PathBuf, String> {
224 let file_path = cache_dir.join(filename);
225 if file_path.exists() {
226 info!("Cached: {}/{}", model_id, filename);
227 return Ok(file_path);
228 }
229
230 if let Some(parent) = file_path.parent() {
231 std::fs::create_dir_all(parent)
232 .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
233 }
234
235 let url = format!(
236 "https://huggingface.co/{}/resolve/main/{}",
237 model_id, filename
238 );
239 info!("Downloading: {}", url);
240
241 let hf_token = std::env::var("HF_TOKEN")
242 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
243 .ok();
244
245 let agent = ureq::AgentBuilder::new()
246 .redirects(0)
247 .timeout(std::time::Duration::from_secs(300))
248 .build();
249
250 let mut current_url = url;
251 let mut redirects = 0_u32;
252
253 let response = loop {
254 let mut req = agent.get(¤t_url);
255 if let Some(ref token) = hf_token {
256 req = req.set("Authorization", &format!("Bearer {}", token));
257 }
258 let resp = req.call();
259
260 let r = match resp {
261 Ok(r) => r,
262 Err(ureq::Error::Status(_status, r)) => r,
263 Err(e) => return Err(format!("{}: {}", filename, e)),
264 };
265
266 let status = r.status();
267 if (200..300).contains(&status) {
268 break r;
269 } else if (300..400).contains(&status) {
270 redirects += 1;
271 if redirects > 10 {
272 return Err(format!("{}: too many redirects", filename));
273 }
274 let location = r
275 .header("location")
276 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
277 .to_string();
278
279 current_url = if location.starts_with('/') {
280 let parsed = url::Url::parse(¤t_url)
281 .map_err(|e| format!("{}: bad URL: {}", filename, e))?;
282 let host = parsed
283 .host_str()
284 .ok_or_else(|| format!("{}: missing host", filename))?;
285 format!("{}://{}{}", parsed.scheme(), host, location)
286 } else {
287 location
288 };
289 } else {
290 return Err(format!("{}: HTTP {}", filename, status));
291 }
292 };
293
294 let expected_bytes: Option<u64> = response
295 .header("x-linked-size")
296 .or_else(|| response.header("content-length"))
297 .and_then(|v| v.parse::<u64>().ok());
298
299 let mut bytes = Vec::new();
300 response
301 .into_reader()
302 .take(2_147_483_648)
303 .read_to_end(&mut bytes)
304 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
305
306 if let Some(expected) = expected_bytes {
307 if (bytes.len() as u64) < expected {
308 return Err(format!(
309 "{}: download incomplete — received {} of {} bytes",
310 filename,
311 bytes.len(),
312 expected
313 ));
314 }
315 }
316
317 std::fs::write(&file_path, &bytes)
318 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
319
320 info!("Downloaded {} ({} bytes)", filename, bytes.len());
321 Ok(file_path)
322 }
323
324 pub fn download_hf_file_pub(
326 model_id: &str,
327 filename: &str,
328 cache_dir: &Path,
329 ) -> std::result::Result<PathBuf, String> {
330 Self::download_hf_file(model_id, filename, cache_dir)
331 }
332
333 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
337 if texts.is_empty() {
338 return Ok(vec![]);
339 }
340
341 let pool_len = self.sessions.len();
342 let normalize = self.config.model.normalize_embeddings();
343 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
344 let mut batch_size = self.config.max_batch_size.max(1);
345
346 for attempt in 0_u32..=3 {
347 let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
348
349 let mut handles = Vec::with_capacity(batches.len());
350 for (i, batch_owned) in batches.into_iter().enumerate() {
351 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
352 let processor = Arc::clone(&self.processor);
353 handles.push(tokio::task::spawn_blocking(move || {
354 let mut session_guard = session.lock();
355 Self::process_batch_blocking(
356 &batch_owned,
357 &mut session_guard,
358 &processor,
359 normalize,
360 )
361 }));
362 }
363
364 let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
365 let mut oom: Option<InferenceError> = None;
366
367 for handle in handles {
368 match handle.await {
369 Err(panic_err) => {
370 return Err(InferenceError::InferenceError(format!(
371 "Inference task panicked: {panic_err}"
372 )));
373 }
374 Ok(Err(e)) => {
375 if attempt < 3 && Self::is_gpu_oom(&e) {
376 oom = Some(e);
377 break;
378 }
379 return Err(e);
380 }
381 Ok(Ok(batch_embs)) => {
382 all_embeddings.extend(batch_embs);
383 }
384 }
385 }
386
387 if oom.is_some() {
388 let next_batch = (batch_size / 2).max(1);
389 warn!(
390 "ONNX allocator OOM (attempt {}/3) — retrying with batch_size {} → {}",
391 attempt + 1,
392 batch_size,
393 next_batch,
394 );
395 batch_size = next_batch;
396 continue;
397 }
398
399 return Ok(all_embeddings);
400 }
401
402 Err(InferenceError::InferenceError(format!(
403 "ONNX inference failed: allocator OOM after 3 batch-halving attempts (batch_size={batch_size})"
404 )))
405 }
406
407 fn is_gpu_oom(err: &InferenceError) -> bool {
408 let msg = err.to_string();
409 msg.contains("BFCArena")
410 || msg.contains("Failed to allocate memory")
411 || msg.contains("CUDA_OUT_OF_MEMORY")
412 || msg.contains("CUDA out of memory")
413 || (msg.contains("allocate") && msg.contains("buffer of size"))
414 }
415
416 fn process_batch_blocking(
417 texts: &[String],
418 session: &mut Session,
419 processor: &BatchProcessor,
420 normalize: bool,
421 ) -> Result<Vec<Vec<f32>>> {
422 let prepared = processor.tokenize_batch(texts)?;
423 let batch_size = prepared.batch_size;
424 let seq_len = prepared.seq_len;
425 let attention_mask_flat = prepared.attention_mask.clone();
426
427 let input_ids_tensor =
428 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
429 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
430 let attention_mask_tensor =
431 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
432 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
433 let token_type_ids_tensor =
434 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
435 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
436
437 let outputs = session
438 .run(inputs![
439 "input_ids" => input_ids_tensor,
440 "attention_mask" => attention_mask_tensor,
441 "token_type_ids" => token_type_ids_tensor
442 ])
443 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
444
445 let (ort_shape, lhs_slice) = outputs[0]
446 .try_extract_tensor::<f32>()
447 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
448
449 if ort_shape.len() != 3 {
450 return Err(InferenceError::InferenceError(format!(
451 "Expected 3D last_hidden_state, got {} dims",
452 ort_shape.len()
453 )));
454 }
455 let hidden_size = ort_shape[2] as usize;
456
457 let mut embeddings = mean_pooling(
458 lhs_slice,
459 batch_size,
460 seq_len,
461 hidden_size,
462 &attention_mask_flat,
463 );
464
465 if normalize {
466 normalize_embeddings(&mut embeddings);
467 }
468
469 Ok(embeddings)
470 }
471}
472
473#[async_trait]
474impl EmbeddingBackend for OnnxBackend {
475 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
476 self.embed_batch_internal(texts).await
477 }
478
479 fn dimension(&self) -> usize {
480 self.dimension
481 }
482
483 fn backend_kind(&self) -> BackendKind {
484 BackendKind::Onnx
485 }
486}