1use async_trait::async_trait;
2use half::f16;
3use kapsl_engine_api::{
4 BinaryTensorPacket, Engine, EngineError, EngineMetrics, EngineModelInfo, InferenceRequest,
5 TensorDtype,
6};
7use ndarray::ArrayD;
8use ort::execution_providers::ExecutionProvider as OrtExecutionProvider;
9use ort::session::builder::GraphOptimizationLevel;
10use ort::session::{Session, SessionInputValue};
11use ort::tensor::TensorElementType;
12use ort::value::Value;
13use std::borrow::Cow;
14use std::collections::{HashMap, VecDeque};
15use std::path::{Path, PathBuf};
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::{Arc, OnceLock, RwLock};
18use std::time::Instant;
19
20#[derive(Debug, Clone, Copy)]
22pub enum ExecutionProvider {
23 CPU,
24 CUDA,
25 TensorRT,
26 DirectML,
27 ROCm,
28 OpenVINO,
29 CoreML,
30}
31
32#[derive(Debug, Clone)]
33pub struct ModelMetadata {
34 pub input_names: Vec<String>,
35 pub output_names: Vec<String>,
36 pub input_shapes: Vec<Vec<i64>>,
37 pub output_shapes: Vec<Vec<i64>>,
38 pub input_dtypes: Vec<Option<TensorDtype>>,
39 pub output_dtypes: Vec<Option<TensorDtype>>,
40}
41
42pub struct OnnxBackend {
43 session: Arc<RwLock<Option<Session>>>,
44 bucket_sessions: Arc<RwLock<BucketSessionState>>,
45 model_path: Arc<RwLock<Option<PathBuf>>>,
46 provider: ExecutionProvider,
47 optimization_level: u8,
48 device_id: i32,
49 memory_pattern: bool,
50 disable_cpu_mem_arena: bool,
51 max_bucket_sessions: usize,
52 bucket_dim_granularity: usize,
53 bucket_max_dims: usize,
54 peak_concurrency_hint: Option<u32>,
55 metrics: Arc<RwLock<EngineMetrics>>,
56 metadata: Arc<RwLock<Option<ModelMetadata>>>,
57 warmed_up: Arc<AtomicBool>,
58}
59
60#[derive(Default)]
61struct BucketSessionState {
62 primary_bucket_key: Option<String>,
63 sessions: HashMap<String, Session>,
64 lru: VecDeque<String>,
65}
66
67const ORT_MEMORY_PATTERN_ENV: &str = "KAPSL_ORT_MEMORY_PATTERN";
68const ORT_DISABLE_CPU_MEM_ARENA_ENV: &str = "KAPSL_ORT_DISABLE_CPU_MEM_ARENA";
69const ORT_SESSION_BUCKETS_ENV: &str = "KAPSL_ORT_SESSION_BUCKETS";
70const ORT_BUCKET_DIM_GRANULARITY_ENV: &str = "KAPSL_ORT_BUCKET_DIM_GRANULARITY";
71const ORT_BUCKET_MAX_DIMS_ENV: &str = "KAPSL_ORT_BUCKET_MAX_DIMS";
72const MODEL_PEAK_CONCURRENCY_ENV: &str = "KAPSL_MODEL_PEAK_CONCURRENCY";
73const ORT_SESSION_BUCKETS_MAX: usize = 64;
74
75fn read_env_flag(name: &str, default: bool) -> bool {
76 std::env::var(name)
77 .ok()
78 .and_then(|value| match value.trim().to_ascii_lowercase().as_str() {
79 "1" | "true" | "yes" | "on" => Some(true),
80 "0" | "false" | "no" | "off" => Some(false),
81 _ => None,
82 })
83 .unwrap_or(default)
84}
85
86fn read_env_usize(name: &str) -> Option<usize> {
87 std::env::var(name)
88 .ok()
89 .and_then(|value| value.trim().parse::<usize>().ok())
90}
91
92fn read_env_u32(name: &str) -> Option<u32> {
93 std::env::var(name)
94 .ok()
95 .and_then(|value| value.trim().parse::<u32>().ok())
96 .filter(|value| *value > 0)
97}
98
99fn expose_sensitive_ids_in_logs() -> bool {
100 static CACHE: OnceLock<bool> = OnceLock::new();
101 *CACHE.get_or_init(|| {
102 std::env::var("KAPSL_LOG_SENSITIVE_IDS")
103 .or_else(|_| std::env::var("KAPSL_LOG_SENSITIVE_IDS"))
104 .ok()
105 .map(|value| {
106 matches!(
107 value.trim().to_ascii_lowercase().as_str(),
108 "1" | "true" | "yes" | "on"
109 )
110 })
111 .unwrap_or(false)
112 })
113}
114
115fn redact_session_id_for_log(session_id: &str) -> String {
116 if expose_sensitive_ids_in_logs() || session_id.is_empty() {
117 return session_id.to_string();
118 }
119 let prefix: String = session_id.chars().take(4).collect();
120 format!("{}...[redacted]", prefix)
121}
122
123fn map_ort_dtype(dtype: TensorElementType) -> Option<TensorDtype> {
124 match dtype {
125 TensorElementType::Float32 => Some(TensorDtype::Float32),
126 TensorElementType::Float64 => Some(TensorDtype::Float64),
127 TensorElementType::Float16 => Some(TensorDtype::Float16),
128 TensorElementType::Int32 => Some(TensorDtype::Int32),
129 TensorElementType::Int64 => Some(TensorDtype::Int64),
130 TensorElementType::Uint8 => Some(TensorDtype::Uint8),
131 _ => None,
132 }
133}
134
135unsafe impl Send for OnnxBackend {}
139unsafe impl Sync for OnnxBackend {}
140
141#[derive(Debug)]
142pub struct OnnxBackendBuilder {
143 provider: ExecutionProvider,
144 optimization_level: GraphOptimizationLevel,
145 device_id: i32,
146 memory_pattern: bool,
147 disable_cpu_mem_arena: bool,
148 max_bucket_sessions: usize,
149 bucket_dim_granularity: usize,
150 bucket_max_dims: usize,
151 peak_concurrency_hint: Option<u32>,
152}
153
154impl OnnxBackendBuilder {
155 pub fn new() -> Self {
156 let memory_pattern = read_env_flag(ORT_MEMORY_PATTERN_ENV, true);
157 let disable_cpu_mem_arena = read_env_flag(ORT_DISABLE_CPU_MEM_ARENA_ENV, false);
158 let max_bucket_sessions = read_env_usize(ORT_SESSION_BUCKETS_ENV)
159 .unwrap_or(4)
160 .clamp(1, ORT_SESSION_BUCKETS_MAX);
161 let bucket_dim_granularity = read_env_usize(ORT_BUCKET_DIM_GRANULARITY_ENV)
162 .unwrap_or(64)
163 .max(1);
164 let bucket_max_dims = read_env_usize(ORT_BUCKET_MAX_DIMS_ENV).unwrap_or(4).max(1);
165 let peak_concurrency_hint = read_env_u32(MODEL_PEAK_CONCURRENCY_ENV);
166 Self {
167 provider: ExecutionProvider::CPU,
168 optimization_level: GraphOptimizationLevel::Level3,
169 device_id: 0,
170 memory_pattern,
171 disable_cpu_mem_arena,
172 max_bucket_sessions,
173 bucket_dim_granularity,
174 bucket_max_dims,
175 peak_concurrency_hint,
176 }
177 }
178
179 pub fn with_provider(mut self, provider: ExecutionProvider) -> Self {
180 self.provider = provider;
181 self
182 }
183
184 pub fn with_optimization_level(mut self, opt_level: GraphOptimizationLevel) -> Self {
185 self.optimization_level = opt_level;
186 self
187 }
188
189 pub fn with_device_id(mut self, device_id: i32) -> Result<Self, String> {
190 if device_id < 0 {
191 return Err("Device ID must be non-negative".to_string());
192 }
193 self.device_id = device_id;
194 Ok(self)
195 }
196
197 pub fn with_memory_pattern(mut self, enabled: bool) -> Self {
198 self.memory_pattern = enabled;
199 self
200 }
201
202 pub fn with_disable_cpu_mem_arena(mut self, disabled: bool) -> Self {
203 self.disable_cpu_mem_arena = disabled;
204 self
205 }
206
207 pub fn with_max_bucket_sessions(mut self, max_bucket_sessions: usize) -> Self {
208 self.max_bucket_sessions = max_bucket_sessions.clamp(1, ORT_SESSION_BUCKETS_MAX);
209 self
210 }
211
212 pub fn with_bucket_dim_granularity(mut self, bucket_dim_granularity: usize) -> Self {
213 self.bucket_dim_granularity = bucket_dim_granularity.max(1);
214 self
215 }
216
217 pub fn with_bucket_max_dims(mut self, bucket_max_dims: usize) -> Self {
218 self.bucket_max_dims = bucket_max_dims.max(1);
219 self
220 }
221
222 pub fn with_peak_concurrency_hint(mut self, peak_concurrency_hint: u32) -> Self {
223 self.peak_concurrency_hint = Some(peak_concurrency_hint.max(1));
224 self
225 }
226
227 pub fn build(self) -> OnnxBackend {
228 let level_value = match self.optimization_level {
229 GraphOptimizationLevel::Disable => 0,
230 GraphOptimizationLevel::Level1 => 1,
231 GraphOptimizationLevel::Level2 => 2,
232 GraphOptimizationLevel::Level3 => 3,
233 GraphOptimizationLevel::All => 4,
234 };
235 OnnxBackend {
236 session: Arc::new(RwLock::new(None)),
237 bucket_sessions: Arc::new(RwLock::new(BucketSessionState::default())),
238 model_path: Arc::new(RwLock::new(None)),
239 provider: self.provider,
240 optimization_level: level_value,
241 device_id: self.device_id,
242 memory_pattern: self.memory_pattern,
243 disable_cpu_mem_arena: self.disable_cpu_mem_arena,
244 max_bucket_sessions: self.max_bucket_sessions,
245 bucket_dim_granularity: self.bucket_dim_granularity,
246 bucket_max_dims: self.bucket_max_dims,
247 peak_concurrency_hint: self.peak_concurrency_hint,
248 metrics: Arc::new(RwLock::new(EngineMetrics::default())),
249 metadata: Arc::new(RwLock::new(None)),
250 warmed_up: Arc::new(AtomicBool::new(false)),
251 }
252 }
253}
254
255impl Default for OnnxBackendBuilder {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261impl OnnxBackend {
262 fn get_opt_level(&self) -> GraphOptimizationLevel {
264 match self.optimization_level {
265 0 => GraphOptimizationLevel::Disable,
266 1 => GraphOptimizationLevel::Level1,
267 2 => GraphOptimizationLevel::Level2,
268 3 => GraphOptimizationLevel::Level3,
269 _ => GraphOptimizationLevel::All,
270 }
271 }
272
273 pub fn builder() -> OnnxBackendBuilder {
274 OnnxBackendBuilder::new()
275 }
276
277 pub fn new_cpu() -> Self {
278 Self::builder().build()
279 }
280
281 pub fn new_cpu_with_optimization(opt_level: GraphOptimizationLevel) -> Self {
282 Self::builder().with_optimization_level(opt_level).build()
283 }
284
285 pub fn new_cuda(device_id: i32) -> Result<Self, String> {
286 Self::new_cuda_with_optimization(GraphOptimizationLevel::Level3, device_id)
287 }
288
289 pub fn new_cuda_with_optimization(
290 opt_level: GraphOptimizationLevel,
291 device_id: i32,
292 ) -> Result<Self, String> {
293 Ok(Self::builder()
294 .with_provider(ExecutionProvider::CUDA)
295 .with_optimization_level(opt_level)
296 .with_device_id(device_id)?
297 .build())
298 }
299
300 pub fn new_tensorrt(device_id: i32) -> Result<Self, String> {
301 Self::new_tensorrt_with_optimization(GraphOptimizationLevel::Level3, device_id)
302 }
303
304 pub fn new_tensorrt_with_optimization(
305 opt_level: GraphOptimizationLevel,
306 device_id: i32,
307 ) -> Result<Self, String> {
308 Ok(Self::builder()
309 .with_provider(ExecutionProvider::TensorRT)
310 .with_optimization_level(opt_level)
311 .with_device_id(device_id)?
312 .build())
313 }
314
315 pub fn new_directml(device_id: i32) -> Result<Self, String> {
316 Self::new_directml_with_optimization(GraphOptimizationLevel::Level3, device_id)
317 }
318
319 pub fn new_directml_with_optimization(
320 opt_level: GraphOptimizationLevel,
321 device_id: i32,
322 ) -> Result<Self, String> {
323 Ok(Self::builder()
324 .with_provider(ExecutionProvider::DirectML)
325 .with_optimization_level(opt_level)
326 .with_device_id(device_id)?
327 .build())
328 }
329
330 pub fn new_rocm(device_id: i32) -> Result<Self, String> {
331 Self::new_rocm_with_optimization(GraphOptimizationLevel::Level3, device_id)
332 }
333
334 pub fn new_rocm_with_optimization(
335 opt_level: GraphOptimizationLevel,
336 device_id: i32,
337 ) -> Result<Self, String> {
338 Ok(Self::builder()
339 .with_provider(ExecutionProvider::ROCm)
340 .with_optimization_level(opt_level)
341 .with_device_id(device_id)?
342 .build())
343 }
344
345 pub fn new_openvino_with_optimiation(
346 opt_level: GraphOptimizationLevel,
347 device_id: i32,
348 ) -> Result<Self, String> {
349 Ok(Self::builder()
350 .with_provider(ExecutionProvider::OpenVINO)
351 .with_optimization_level(opt_level)
352 .with_device_id(device_id)?
353 .build())
354 }
355
356 pub fn new_coreml_with_optimiation(
357 opt_level: GraphOptimizationLevel,
358 device_id: i32,
359 ) -> Result<Self, String> {
360 Ok(Self::builder()
361 .with_provider(ExecutionProvider::CoreML)
362 .with_optimization_level(opt_level)
363 .with_device_id(device_id)?
364 .build())
365 }
366
367 fn bucket_key_for_request(&self, request: &InferenceRequest) -> Option<String> {
368 if self.max_bucket_sessions <= 1 {
369 return None;
370 }
371
372 let mut key = format!(
373 "{}:r{}",
374 request.input.dtype.as_str(),
375 request.input.shape.len()
376 );
377 for (index, dim) in request
378 .input
379 .shape
380 .iter()
381 .take(self.bucket_max_dims)
382 .copied()
383 .enumerate()
384 {
385 let rounded = if dim <= 0 {
386 -1
387 } else if index == 0 {
388 dim
389 } else {
390 let granularity = self.bucket_dim_granularity as i64;
391 ((dim + granularity - 1) / granularity) * granularity
392 };
393 key.push(':');
394 key.push_str(&rounded.to_string());
395 }
396 if request.input.shape.len() > self.bucket_max_dims {
397 key.push_str(":*");
398 }
399 Some(key)
400 }
401
402 fn touch_bucket_lru(state: &mut BucketSessionState, bucket_key: &str) {
403 if let Some(pos) = state.lru.iter().position(|existing| existing == bucket_key) {
404 state.lru.remove(pos);
405 }
406 state.lru.push_back(bucket_key.to_string());
407 }
408
409 fn get_or_create_bucket_session<'a>(
410 &self,
411 state: &'a mut BucketSessionState,
412 bucket_key: &str,
413 ) -> Result<&'a mut Session, EngineError> {
414 if !state.sessions.contains_key(bucket_key) {
415 let secondary_capacity = self.max_bucket_sessions.saturating_sub(1).max(1);
416 while state.sessions.len() >= secondary_capacity {
417 let Some(evict_key) = state.lru.pop_front() else {
418 break;
419 };
420 state.sessions.remove(&evict_key);
421 }
422
423 let model_path = self
424 .model_path
425 .read()
426 .map_err(|_| EngineError::Backend {
427 message: "Lock poisoned".to_string(),
428 source: None,
429 })?
430 .clone()
431 .ok_or(EngineError::ModelNotLoaded)?;
432 let session = self.create_session(&model_path, self.get_opt_level())?;
433 state.sessions.insert(bucket_key.to_string(), session);
434 }
435
436 Self::touch_bucket_lru(state, bucket_key);
437 state
438 .sessions
439 .get_mut(bucket_key)
440 .ok_or(EngineError::ModelNotLoaded)
441 }
442
443 fn create_session(
444 &self,
445 model_path: &Path,
446 opt_level: GraphOptimizationLevel,
447 ) -> Result<Session, EngineError> {
448 let mut builder = Session::builder()
450 .map_err(|e| EngineError::ModelLoadError {
451 path: model_path.to_string_lossy().into_owned(),
452 source: Box::new(std::io::Error::other(e.to_string())),
453 })?
454 .with_optimization_level(opt_level)
455 .map_err(|e| EngineError::ModelLoadError {
456 path: model_path.to_string_lossy().into_owned(),
457 source: Box::new(std::io::Error::other(e.to_string())),
458 })?
459 .with_memory_pattern(self.memory_pattern)
460 .map_err(|e| EngineError::ModelLoadError {
461 path: model_path.to_string_lossy().into_owned(),
462 source: Box::new(std::io::Error::other(e.to_string())),
463 })?;
464 if self.disable_cpu_mem_arena {
465 builder = builder
466 .with_config_entry("session.disable_cpu_mem_arena", "1")
467 .map_err(|e| EngineError::ModelLoadError {
468 path: model_path.to_string_lossy().into_owned(),
469 source: Box::new(std::io::Error::other(e.to_string())),
470 })?;
471 }
472
473 let builder = match self.provider {
475 ExecutionProvider::CUDA => {
476 if !ort::execution_providers::CUDAExecutionProvider::default()
477 .is_available()
478 .unwrap_or(false)
479 {
480 return Err(EngineError::Backend {
481 message:
482 "CUDA execution provider is not available. Please check your CUDA installation."
483 .to_string(),
484 source: None,
485 });
486 }
487 builder
488 .with_execution_providers([
489 ort::execution_providers::CUDAExecutionProvider::default()
490 .with_device_id(self.device_id)
491 .build(),
492 ])
493 .map_err(|e| EngineError::ModelLoadError {
494 path: model_path.to_string_lossy().into_owned(),
495 source: Box::new(std::io::Error::other(e.to_string())),
496 })?
497 }
498 ExecutionProvider::TensorRT => {
499 if !ort::execution_providers::TensorRTExecutionProvider::default()
500 .is_available()
501 .unwrap_or(false)
502 {
503 return Err(EngineError::Backend {
504 message: "TensorRT execution provider is not available.".to_string(),
505 source: None,
506 });
507 }
508 builder
509 .with_execution_providers([
510 ort::execution_providers::TensorRTExecutionProvider::default()
511 .with_device_id(self.device_id)
512 .build(),
513 ort::execution_providers::CUDAExecutionProvider::default()
515 .with_device_id(self.device_id)
516 .build(),
517 ])
518 .map_err(|e| EngineError::ModelLoadError {
519 path: model_path.to_string_lossy().into_owned(),
520 source: Box::new(std::io::Error::other(e.to_string())),
521 })?
522 }
523 ExecutionProvider::DirectML => {
524 #[cfg(target_os = "windows")]
525 {
526 if !ort::execution_providers::DirectMLExecutionProvider::default()
527 .is_available()
528 .unwrap_or(false)
529 {
530 return Err(EngineError::Backend {
531 message: "DirectML execution provider is not available.".to_string(),
532 source: None,
533 });
534 }
535 builder
536 .with_execution_providers([
537 ort::execution_providers::DirectMLExecutionProvider::default()
538 .with_device_id(self.device_id)
539 .build(),
540 ])
541 .map_err(|e| EngineError::ModelLoadError {
542 path: model_path.to_string_lossy().into_owned(),
543 source: Box::new(std::io::Error::other(e.to_string())),
544 })?
545 }
546 #[cfg(not(target_os = "windows"))]
547 {
548 return Err(EngineError::Backend {
549 message: "DirectML execution provider is only supported on Windows."
550 .to_string(),
551 source: None,
552 });
553 }
554 }
555 ExecutionProvider::ROCm => {
556 if !ort::execution_providers::ROCmExecutionProvider::default()
557 .is_available()
558 .unwrap_or(false)
559 {
560 return Err(EngineError::Backend {
561 message: "ROCm execution provider is not available.".to_string(),
562 source: None,
563 });
564 }
565 builder
566 .with_execution_providers([
567 ort::execution_providers::ROCmExecutionProvider::default()
568 .with_device_id(self.device_id)
569 .build(),
570 ])
571 .map_err(|e| EngineError::ModelLoadError {
572 path: model_path.to_string_lossy().into_owned(),
573 source: Box::new(std::io::Error::other(e.to_string())),
574 })?
575 }
576 ExecutionProvider::OpenVINO => {
577 if !ort::execution_providers::OpenVINOExecutionProvider::default()
578 .is_available()
579 .unwrap_or(false)
580 {
581 return Err(EngineError::Backend {
582 message: "OpenVINO execution provider is not available.".to_string(),
583 source: None,
584 });
585 }
586 builder
587 .with_execution_providers([
588 ort::execution_providers::OpenVINOExecutionProvider::default().build(),
589 ])
590 .map_err(|e| EngineError::ModelLoadError {
591 path: model_path.to_string_lossy().into_owned(),
592 source: Box::new(std::io::Error::other(e.to_string())),
593 })?
594 }
595 ExecutionProvider::CoreML => {
596 if !ort::execution_providers::CoreMLExecutionProvider::default()
597 .is_available()
598 .unwrap_or(false)
599 {
600 return Err(EngineError::Backend {
601 message: "CoreML execution provider is not available.".to_string(),
602 source: None,
603 });
604 }
605 builder
606 .with_execution_providers([
607 ort::execution_providers::CoreMLExecutionProvider::default().build(),
608 ort::execution_providers::CPUExecutionProvider::default().build(),
612 ])
613 .map_err(|e| EngineError::ModelLoadError {
614 path: model_path.to_string_lossy().into_owned(),
615 source: Box::new(std::io::Error::other(e.to_string())),
616 })?
617 }
618 ExecutionProvider::CPU => {
619 builder
621 }
622 };
623
624 builder
626 .commit_from_file(model_path)
627 .map_err(|e| EngineError::ModelLoadError {
628 path: model_path.to_string_lossy().into_owned(),
629 source: Box::new(std::io::Error::other(e.to_string())),
630 })
631 }
632}
633
634#[derive(Debug)]
639enum PreparedInput {
640 F32(Vec<f32>),
641 F64(Vec<f64>),
642 F16(Vec<f16>),
643 I32(Vec<i32>),
644 I64(Vec<i64>),
645 U8(Vec<u8>),
646}
647
648fn validate_byte_len(
649 input: &BinaryTensorPacket,
650 num_elements: usize,
651 elem_size: usize,
652 dtype_label: &str,
653) -> Result<(), EngineError> {
654 let expected =
655 num_elements
656 .checked_mul(elem_size)
657 .ok_or_else(|| EngineError::InvalidInput {
658 message: "Data size overflow".to_string(),
659 source: None,
660 })?;
661 if input.data.len() != expected {
662 return Err(EngineError::InvalidInput {
663 message: format!(
664 "Data length mismatch: expected {} bytes ({} {} values) but got {} bytes",
665 expected,
666 num_elements,
667 dtype_label,
668 input.data.len()
669 ),
670 source: None,
671 });
672 }
673 Ok(())
674}
675
676fn parse_ne_f32(bytes: &[u8], num_elements: usize) -> Vec<f32> {
677 if let Some(values) = try_aligned_copy::<f32>(bytes) {
678 return values;
679 }
680 let mut values = Vec::with_capacity(num_elements);
681 for chunk in bytes.chunks_exact(4) {
682 values.push(f32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
683 }
684 values
685}
686
687fn parse_ne_f64(bytes: &[u8], num_elements: usize) -> Vec<f64> {
688 if let Some(values) = try_aligned_copy::<f64>(bytes) {
689 return values;
690 }
691 let mut values = Vec::with_capacity(num_elements);
692 for chunk in bytes.chunks_exact(8) {
693 values.push(f64::from_ne_bytes([
694 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
695 ]));
696 }
697 values
698}
699
700fn parse_ne_f16(bytes: &[u8], num_elements: usize) -> Vec<f16> {
701 if let Some(values) = try_aligned_copy::<u16>(bytes) {
702 let mut out = Vec::with_capacity(num_elements);
703 out.extend(values.into_iter().map(f16::from_bits));
704 return out;
705 }
706 let mut values = Vec::with_capacity(num_elements);
707 for chunk in bytes.chunks_exact(2) {
708 values.push(f16::from_bits(u16::from_ne_bytes([chunk[0], chunk[1]])));
709 }
710 values
711}
712
713fn parse_ne_i32(bytes: &[u8], num_elements: usize) -> Vec<i32> {
714 if let Some(values) = try_aligned_copy::<i32>(bytes) {
715 return values;
716 }
717 let mut values = Vec::with_capacity(num_elements);
718 for chunk in bytes.chunks_exact(4) {
719 values.push(i32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
720 }
721 values
722}
723
724fn parse_ne_i64(bytes: &[u8], num_elements: usize) -> Vec<i64> {
725 if let Some(values) = try_aligned_copy::<i64>(bytes) {
726 return values;
727 }
728 let mut values = Vec::with_capacity(num_elements);
729 for chunk in bytes.chunks_exact(8) {
730 values.push(i64::from_ne_bytes([
731 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
732 ]));
733 }
734 values
735}
736
737fn find_additional_input_by_name<'a>(
738 additional_inputs: &'a [kapsl_engine_api::NamedTensor],
739 name: &str,
740) -> Option<&'a BinaryTensorPacket> {
741 additional_inputs
742 .iter()
743 .find(|entry| entry.name == name)
744 .map(|entry| &entry.tensor)
745}
746
747fn ensure_unique_additional_input_names(
748 additional_inputs: &[kapsl_engine_api::NamedTensor],
749) -> Result<(), EngineError> {
750 for i in 0..additional_inputs.len() {
751 for j in (i + 1)..additional_inputs.len() {
752 if additional_inputs[i].name == additional_inputs[j].name {
753 return Err(EngineError::InvalidInput {
754 message: format!(
755 "Duplicate additional input name: {}",
756 additional_inputs[i].name
757 ),
758 source: None,
759 });
760 }
761 }
762 }
763 Ok(())
764}
765
766fn try_aligned_copy<T: Copy>(bytes: &[u8]) -> Option<Vec<T>> {
767 let (prefix, aligned, suffix) = unsafe { bytes.align_to::<T>() };
769 if prefix.is_empty() && suffix.is_empty() {
770 Some(aligned.to_vec())
771 } else {
772 None
773 }
774}
775
776fn validate_and_prepare_input(
783 input: &kapsl_engine_api::BinaryTensorPacket,
784) -> Result<(Vec<i64>, PreparedInput), EngineError> {
785 let num_elements: usize = if input.shape.is_empty() {
787 1
788 } else {
789 let mut prod: usize = 1;
791 for &d in &input.shape {
792 if d <= 0 {
793 return Err(EngineError::InvalidInput {
794 message: format!("Invalid shape dimension: {}", d),
795 source: None,
796 });
797 }
798 prod = prod
799 .checked_mul(d as usize)
800 .ok_or_else(|| EngineError::InvalidInput {
801 message: "Shape multiplication overflow".to_string(),
802 source: None,
803 })?;
804 }
805 prod
806 };
807
808 match input.dtype {
810 TensorDtype::Float32 => {
811 validate_byte_len(input, num_elements, 4, "float32")?;
812 let values = parse_ne_f32(&input.data, num_elements);
813 Ok((input.shape.clone(), PreparedInput::F32(values)))
814 }
815 TensorDtype::Float64 => {
816 validate_byte_len(input, num_elements, 8, "float64")?;
817 let values = parse_ne_f64(&input.data, num_elements);
818 Ok((input.shape.clone(), PreparedInput::F64(values)))
819 }
820 TensorDtype::Float16 => {
821 validate_byte_len(input, num_elements, 2, "float16")?;
822 let values = parse_ne_f16(&input.data, num_elements);
823 Ok((input.shape.clone(), PreparedInput::F16(values)))
824 }
825 TensorDtype::Int32 => {
826 validate_byte_len(input, num_elements, 4, "int32")?;
827 let values = parse_ne_i32(&input.data, num_elements);
828 Ok((input.shape.clone(), PreparedInput::I32(values)))
829 }
830 TensorDtype::Int64 => {
831 validate_byte_len(input, num_elements, 8, "int64")?;
832 let values = parse_ne_i64(&input.data, num_elements);
833 Ok((input.shape.clone(), PreparedInput::I64(values)))
834 }
835 TensorDtype::Uint8 => {
836 validate_byte_len(input, num_elements, 1, "uint8")?;
837 Ok((input.shape.clone(), PreparedInput::U8(input.data.clone())))
838 }
839 other => Err(EngineError::InvalidInput {
840 message: format!(
841 "Unsupported dtype: {}. Supported: float32, float64, float16, int32, int64, uint8",
842 other
843 ),
844 source: None,
845 }),
846 }
847}
848
849fn tensor_packet_to_session_input(
850 input: &BinaryTensorPacket,
851) -> Result<(Vec<usize>, SessionInputValue<'_>), EngineError> {
852 let (shape_i64, prepared) = validate_and_prepare_input(input)?;
853 let shape_usize = get_shape_usize(&shape_i64);
854
855 let value: SessionInputValue = match prepared {
856 PreparedInput::F32(v) => Value::from_array((shape_usize.clone(), v)).map(|v| v.into()),
857 PreparedInput::F64(v) => Value::from_array((shape_usize.clone(), v)).map(|v| v.into()),
858 PreparedInput::F16(v) => Value::from_array((shape_usize.clone(), v)).map(|v| v.into()),
859 PreparedInput::I32(v) => Value::from_array((shape_usize.clone(), v)).map(|v| v.into()),
860 PreparedInput::I64(v) => Value::from_array((shape_usize.clone(), v)).map(|v| v.into()),
861 PreparedInput::U8(v) => Value::from_array((shape_usize.clone(), v)).map(|v| v.into()),
862 }
863 .map_err(|e| EngineError::InferenceError {
864 reason: "Failed to create input tensor".to_string(),
865 source: Some(Box::new(e)),
866 })?;
867
868 Ok((shape_usize, value))
869}
870
871fn run_inference_with_session(
872 session: &mut Session,
873 request: &InferenceRequest,
874 metadata: &ModelMetadata,
875 shape_usize: Vec<usize>,
876 main_input_tensor: SessionInputValue<'_>,
877) -> Result<BinaryTensorPacket, EngineError> {
878 if metadata.input_names.is_empty() {
880 return Err(EngineError::InferenceError {
881 reason: "Model has no inputs defined".to_string(),
882 source: None,
883 });
884 }
885
886 let outputs = if metadata.input_names.len() == 1 && request.additional_inputs.is_empty() {
887 session.run([main_input_tensor]).map_err(|e| {
888 log::error!("ONNX Runtime inference error: {:?}", e);
889 EngineError::InferenceError {
890 reason: format!("Inference failed: {}", e),
891 source: Some(Box::new(e)),
892 }
893 })?
894 } else {
895 let mut inputs: Vec<(Cow<'_, str>, SessionInputValue)> =
897 Vec::with_capacity(metadata.input_names.len());
898 inputs.push((
899 Cow::Borrowed(metadata.input_names[0].as_str()),
900 main_input_tensor,
901 ));
902 ensure_unique_additional_input_names(&request.additional_inputs)?;
903
904 let batch_size = if !shape_usize.is_empty() {
906 shape_usize[0]
907 } else {
908 1
909 };
910 let seq_len = if shape_usize.len() > 1 {
911 shape_usize[1]
912 } else {
913 1
914 };
915
916 let workaround_past_len = 0;
921
922 for (i, name) in metadata.input_names.iter().enumerate().skip(1) {
923 let shape_def = &metadata.input_shapes[i];
924 if let Some(named_input) =
925 find_additional_input_by_name(&request.additional_inputs, name)
926 {
927 let (_, value) = tensor_packet_to_session_input(named_input)?;
928 inputs.push((Cow::Borrowed(name.as_str()), value));
929 continue;
930 }
931
932 if name.contains("attention_mask") {
933 let total_len = seq_len + workaround_past_len;
936 let mask_shape = vec![batch_size as i64, total_len as i64];
938
939 let mut mask_data = Vec::with_capacity(batch_size * total_len);
940 for _ in 0..batch_size {
941 mask_data.extend(std::iter::repeat_n(0i64, workaround_past_len));
943 mask_data.extend(std::iter::repeat_n(1i64, seq_len));
945 }
946
947 log::debug!(
948 "Creating attention_mask tensor for {} with shape {:?}",
949 name,
950 mask_shape
951 );
952
953 let mask_tensor = Value::from_array((get_shape_usize(&mask_shape), mask_data))
954 .map_err(|e| EngineError::InferenceError {
955 reason: format!("Failed to create attention_mask tensor for {}", name),
956 source: Some(Box::new(e)),
957 })?;
958 inputs.push((Cow::Borrowed(name.as_str()), mask_tensor.into()));
959 } else if name.contains("position_ids") {
960 let pos_shape = vec![batch_size as i64, seq_len as i64];
961 let mut pos_data = Vec::with_capacity(batch_size * seq_len);
962 for _ in 0..batch_size {
967 for s in 0..seq_len {
968 pos_data.push(s as i64);
969 }
970 }
971 let pos_tensor = Value::from_array((get_shape_usize(&pos_shape), pos_data))
972 .map_err(|e| EngineError::InferenceError {
973 reason: format!("Failed to create position_ids tensor for {}", name),
974 source: Some(Box::new(e)),
975 })?;
976 inputs.push((Cow::Borrowed(name.as_str()), pos_tensor.into()));
977 } else if name.starts_with("past_key_values") {
978 let mut new_shape = Vec::new();
979 new_shape.push(batch_size); if shape_def.len() == 4 {
982 let dim1 = if shape_def[1] > 0 {
983 shape_def[1] as usize
984 } else {
985 1
986 };
987 new_shape.push(dim1);
988
989 new_shape.push(workaround_past_len); let dim3 = if shape_def[3] > 0 {
992 shape_def[3] as usize
993 } else {
994 64
995 };
996 new_shape.push(dim3);
997
998 log::debug!("Creating KV tensor for {} with shape {:?}", name, new_shape);
999
1000 let count: usize = new_shape.iter().product();
1001 let empty_data: Vec<f16> = vec![f16::ZERO; count];
1002
1003 let kv_array = ArrayD::from_shape_vec(new_shape, empty_data).map_err(|e| {
1005 EngineError::InferenceError {
1006 reason: format!("Failed to create ndarray for {}: {:?}", name, e),
1007 source: Some(Box::new(e)),
1008 }
1009 })?;
1010
1011 let kv_tensor =
1012 Value::from_array(kv_array).map_err(|e| EngineError::InferenceError {
1013 reason: format!(
1014 "Failed to create empty KV tensor for {}: {:?}",
1015 name, e
1016 ),
1017 source: Some(Box::new(e)),
1018 })?;
1019 inputs.push((Cow::Borrowed(name.as_str()), kv_tensor.into()));
1020 } else {
1021 log::warn!(
1022 "Skipping input {} due to unknown shape pattern {:?}",
1023 name,
1024 shape_def
1025 );
1026 }
1027 } else {
1028 log::warn!(
1029 "Skipping input {} as it is not recognized as auto-fillable",
1030 name
1031 );
1032 }
1033 }
1034
1035 session.run(inputs).map_err(|e| {
1036 log::error!("ONNX Runtime inference error: {:?}", e);
1037 EngineError::InferenceError {
1038 reason: format!("Inference failed: {}", e),
1039 source: Some(Box::new(e)),
1040 }
1041 })?
1042 };
1043
1044 if outputs.len() > 1 {
1047 log::debug!(
1048 "Backend received {} outputs, using only the first one (logits)",
1049 outputs.len()
1050 );
1051 }
1052
1053 let output_value = &outputs[0];
1055 let output_packet = if let Ok((shape_ref, data)) = output_value.try_extract_tensor::<f32>() {
1056 BinaryTensorPacket {
1057 shape: shape_ref.to_vec(),
1058 dtype: TensorDtype::Float32,
1059 data: data.iter().flat_map(|&x| x.to_ne_bytes()).collect(),
1060 }
1061 } else if let Ok((shape_ref, data)) = output_value.try_extract_tensor::<f64>() {
1062 BinaryTensorPacket {
1063 shape: shape_ref.to_vec(),
1064 dtype: TensorDtype::Float64,
1065 data: data.iter().flat_map(|&x| x.to_ne_bytes()).collect(),
1066 }
1067 } else if let Ok((shape_ref, data)) = output_value.try_extract_tensor::<f16>() {
1068 BinaryTensorPacket {
1069 shape: shape_ref.to_vec(),
1070 dtype: TensorDtype::Float16,
1071 data: data
1072 .iter()
1073 .flat_map(|x| x.to_bits().to_ne_bytes())
1074 .collect(),
1075 }
1076 } else if let Ok((shape_ref, data)) = output_value.try_extract_tensor::<i32>() {
1077 BinaryTensorPacket {
1078 shape: shape_ref.to_vec(),
1079 dtype: TensorDtype::Int32,
1080 data: data.iter().flat_map(|&x| x.to_ne_bytes()).collect(),
1081 }
1082 } else if let Ok((shape_ref, data)) = output_value.try_extract_tensor::<i64>() {
1083 BinaryTensorPacket {
1084 shape: shape_ref.to_vec(),
1085 dtype: TensorDtype::Int64,
1086 data: data.iter().flat_map(|&x| x.to_ne_bytes()).collect(),
1087 }
1088 } else if let Ok((shape_ref, data)) = output_value.try_extract_tensor::<u8>() {
1089 BinaryTensorPacket {
1090 shape: shape_ref.to_vec(),
1091 dtype: TensorDtype::Uint8,
1092 data: data.to_vec(),
1093 }
1094 } else {
1095 return Err(EngineError::InferenceError {
1096 reason: "Failed to extract output tensor. Supported output dtypes: float32, float64, float16, int32, int64, uint8"
1097 .to_string(),
1098 source: None,
1099 });
1100 };
1101
1102 Ok(output_packet)
1103}
1104
1105#[async_trait]
1106impl Engine for OnnxBackend {
1107 async fn load(&mut self, model_path: &Path) -> Result<(), EngineError> {
1113 let opt_level = self.get_opt_level();
1114 log::info!(
1115 "Loading ONNX model with optimization level: {:?} on provider {:?}",
1116 opt_level,
1117 self.provider
1118 );
1119 log::info!(
1120 "ORT memory config: memory_pattern={} disable_cpu_mem_arena={} session_buckets={} bucket_dim_granularity={} bucket_max_dims={} peak_concurrency_hint={:?}",
1121 self.memory_pattern,
1122 self.disable_cpu_mem_arena,
1123 self.max_bucket_sessions,
1124 self.bucket_dim_granularity,
1125 self.bucket_max_dims,
1126 self.peak_concurrency_hint
1127 );
1128
1129 let session = self.create_session(model_path, opt_level)?;
1130
1131 log::info!("Model Inputs:");
1132 for (i, input) in session.inputs().iter().enumerate() {
1133 log::info!(" Input {}: {} ({:?})", i, input.name(), input.dtype());
1134 }
1135
1136 let input_names: Vec<String> = session
1138 .inputs()
1139 .iter()
1140 .map(|i| i.name().to_string())
1141 .collect();
1142 let output_names: Vec<String> = session
1143 .outputs()
1144 .iter()
1145 .map(|o| o.name().to_string())
1146 .collect();
1147
1148 let mut input_shapes = Vec::new();
1149 let mut input_dtypes = Vec::new();
1150 for input in session.inputs() {
1151 let (shape, dtype) = match input.dtype() {
1152 ort::value::ValueType::Tensor { ty, shape, .. } => {
1153 (shape.iter().copied().collect(), map_ort_dtype(*ty))
1154 }
1155 _ => (vec![], None),
1156 };
1157 input_shapes.push(shape);
1158 input_dtypes.push(dtype);
1159 }
1160
1161 let mut output_shapes = Vec::new();
1162 let mut output_dtypes = Vec::new();
1163 for output in session.outputs() {
1164 let (shape, dtype) = match output.dtype() {
1165 ort::value::ValueType::Tensor { ty, shape, .. } => {
1166 (shape.iter().copied().collect(), map_ort_dtype(*ty))
1167 }
1168 _ => (vec![], None),
1169 };
1170 output_shapes.push(shape);
1171 output_dtypes.push(dtype);
1172 }
1173
1174 let metadata = ModelMetadata {
1175 input_names,
1176 output_names,
1177 input_shapes,
1178 output_shapes,
1179 input_dtypes,
1180 output_dtypes,
1181 };
1182
1183 if let Ok(mut meta_guard) = self.metadata.write() {
1185 *meta_guard = Some(metadata);
1186 }
1187
1188 let mut session_guard = self.session.write().map_err(|_| EngineError::Backend {
1189 message: "Lock poisoned".to_string(),
1190 source: None,
1191 })?;
1192 *session_guard = Some(session);
1193 drop(session_guard);
1194
1195 if let Ok(mut model_path_guard) = self.model_path.write() {
1196 *model_path_guard = Some(model_path.to_path_buf());
1197 }
1198 if let Ok(mut bucket_guard) = self.bucket_sessions.write() {
1199 bucket_guard.primary_bucket_key = None;
1200 bucket_guard.sessions.clear();
1201 bucket_guard.lru.clear();
1202 }
1203
1204 self.warmed_up.store(false, Ordering::SeqCst);
1206
1207 Ok(())
1208 }
1209
1210 fn infer(&self, request: &InferenceRequest) -> Result<BinaryTensorPacket, EngineError> {
1211 let start_time = Instant::now();
1212 if let Some(session_id) = &request.session_id {
1213 log::debug!(
1214 "Processing request for session: {}",
1215 redact_session_id_for_log(session_id)
1216 );
1217 }
1218
1219 let metadata = self.metadata.read().map_err(|_| EngineError::Backend {
1220 message: "Lock poisoned".to_string(),
1221 source: None,
1222 })?;
1223 let metadata = metadata
1224 .as_ref()
1225 .cloned()
1226 .ok_or(EngineError::ModelNotLoaded)?;
1227
1228 let (shape_usize, main_input_tensor) = tensor_packet_to_session_input(&request.input)?;
1230 let mut prepared_input = Some((shape_usize, main_input_tensor));
1231
1232 let output_packet = if let Some(bucket_key) = self.bucket_key_for_request(request) {
1233 let use_primary = {
1234 let mut bucket_guard =
1235 self.bucket_sessions
1236 .write()
1237 .map_err(|_| EngineError::Backend {
1238 message: "Lock poisoned".to_string(),
1239 source: None,
1240 })?;
1241 let primary_key = bucket_guard
1242 .primary_bucket_key
1243 .get_or_insert_with(|| bucket_key.clone());
1244 *primary_key == bucket_key
1245 };
1246
1247 if use_primary {
1248 let mut session_guard = self.session.write().map_err(|_| EngineError::Backend {
1249 message: "Lock poisoned".to_string(),
1250 source: None,
1251 })?;
1252 let session = session_guard.as_mut().ok_or(EngineError::ModelNotLoaded)?;
1253 let (shape_usize, main_input_tensor) = prepared_input
1254 .take()
1255 .ok_or_else(|| EngineError::backend("input already consumed".to_string()))?;
1256 run_inference_with_session(
1257 session,
1258 request,
1259 &metadata,
1260 shape_usize,
1261 main_input_tensor,
1262 )?
1263 } else {
1264 let mut bucket_guard =
1265 self.bucket_sessions
1266 .write()
1267 .map_err(|_| EngineError::Backend {
1268 message: "Lock poisoned".to_string(),
1269 source: None,
1270 })?;
1271 let session = self.get_or_create_bucket_session(&mut bucket_guard, &bucket_key)?;
1272 let (shape_usize, main_input_tensor) = prepared_input
1273 .take()
1274 .ok_or_else(|| EngineError::backend("input already consumed".to_string()))?;
1275 run_inference_with_session(
1276 session,
1277 request,
1278 &metadata,
1279 shape_usize,
1280 main_input_tensor,
1281 )?
1282 }
1283 } else {
1284 let mut session_guard = self.session.write().map_err(|_| EngineError::Backend {
1285 message: "Lock poisoned".to_string(),
1286 source: None,
1287 })?;
1288 let session = session_guard.as_mut().ok_or(EngineError::ModelNotLoaded)?;
1289 let (shape_usize, main_input_tensor) = prepared_input
1290 .take()
1291 .ok_or_else(|| EngineError::backend("input already consumed".to_string()))?;
1292 run_inference_with_session(session, request, &metadata, shape_usize, main_input_tensor)?
1293 };
1294
1295 let duration = start_time.elapsed().as_secs_f64();
1297 if let Ok(mut metrics) = self.metrics.write() {
1298 metrics.inference_time = duration;
1299 }
1302
1303 self.warmed_up.store(true, Ordering::SeqCst);
1305
1306 Ok(output_packet)
1307 }
1308
1309 fn infer_stream(
1312 &self,
1313 request: &InferenceRequest,
1314 ) -> std::pin::Pin<
1315 Box<dyn futures::stream::Stream<Item = Result<BinaryTensorPacket, EngineError>> + Send>,
1316 > {
1317 let result = self.infer(request);
1319 Box::pin(futures::stream::once(async move { result }))
1321 }
1322
1323 fn unload(&mut self) {
1326 if let Ok(mut session_guard) = self.session.write() {
1327 *session_guard = None;
1328 }
1329 if let Ok(mut bucket_guard) = self.bucket_sessions.write() {
1330 bucket_guard.primary_bucket_key = None;
1331 bucket_guard.sessions.clear();
1332 bucket_guard.lru.clear();
1333 }
1334 if let Ok(mut model_path_guard) = self.model_path.write() {
1335 *model_path_guard = None;
1336 }
1337 if let Ok(mut meta_guard) = self.metadata.write() {
1338 *meta_guard = None;
1339 }
1340 }
1341
1342 fn metrics(&self) -> kapsl_engine_api::EngineMetrics {
1343 if let Ok(metrics) = self.metrics.read() {
1344 metrics.clone()
1345 } else {
1346 kapsl_engine_api::EngineMetrics::default()
1347 }
1348 }
1349
1350 fn health_check(&self) -> Result<(), EngineError> {
1351 let session_guard = self.session.read().map_err(|_| EngineError::Backend {
1353 message: "Session lock poisoned".to_string(),
1354 source: None,
1355 })?;
1356 let bucket_guard = self
1357 .bucket_sessions
1358 .read()
1359 .map_err(|_| EngineError::Backend {
1360 message: "Session cache lock poisoned".to_string(),
1361 source: None,
1362 })?;
1363
1364 if session_guard.is_some() || !bucket_guard.sessions.is_empty() {
1365 Ok(())
1366 } else {
1367 Err(EngineError::ModelNotLoaded)
1368 }
1369 }
1370
1371 fn model_info(&self) -> Option<EngineModelInfo> {
1372 let metadata_guard = self.metadata.read().ok()?;
1373 let metadata = metadata_guard.as_ref()?;
1374 Some(EngineModelInfo {
1375 input_names: metadata.input_names.clone(),
1376 output_names: metadata.output_names.clone(),
1377 input_shapes: metadata.input_shapes.clone(),
1378 output_shapes: metadata.output_shapes.clone(),
1379 input_dtypes: metadata
1380 .input_dtypes
1381 .iter()
1382 .map(|dtype| {
1383 dtype
1384 .as_ref()
1385 .map(TensorDtype::as_str)
1386 .unwrap_or("unknown")
1387 .to_string()
1388 })
1389 .collect(),
1390 output_dtypes: metadata
1391 .output_dtypes
1392 .iter()
1393 .map(|dtype| {
1394 dtype
1395 .as_ref()
1396 .map(TensorDtype::as_str)
1397 .unwrap_or("unknown")
1398 .to_string()
1399 })
1400 .collect(),
1401 framework: Some("onnx".to_string()),
1402 model_version: None,
1403 peak_concurrency: self.peak_concurrency_hint,
1404 })
1405 }
1406}
1407
1408fn get_shape_usize(shape: &[i64]) -> Vec<usize> {
1409 shape.iter().map(|&v| v as usize).collect()
1410}
1411
1412#[path = "onnx_tests.rs"]
1413mod onnx_tests;