Skip to main content

kapsl_backends/
onnx.rs

1use async_trait::async_trait;
2use half::f16;
3use kapsl_engine_api::{
4    BinaryTensorPacket, Engine, EngineError, EngineMetrics, EngineModelInfo, InferenceRequest,
5    TensorDtype,
6};
7use ndarray::ArrayD;
8use ort::execution_providers::ExecutionProvider as OrtExecutionProvider;
9use ort::session::builder::GraphOptimizationLevel;
10use ort::session::{Session, SessionInputValue};
11use ort::tensor::TensorElementType;
12use ort::value::Value;
13use std::borrow::Cow;
14use std::collections::{HashMap, VecDeque};
15use std::path::{Path, PathBuf};
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::{Arc, OnceLock, RwLock};
18use std::time::Instant;
19
20// TODO: Consider adding support for OpenVINO and other backends
21#[derive(Debug, Clone, Copy)]
22pub enum ExecutionProvider {
23    CPU,
24    CUDA,
25    TensorRT,
26    DirectML,
27    ROCm,
28    OpenVINO,
29    CoreML,
30}
31
32#[derive(Debug, Clone)]
33pub struct ModelMetadata {
34    pub input_names: Vec<String>,
35    pub output_names: Vec<String>,
36    pub input_shapes: Vec<Vec<i64>>,
37    pub output_shapes: Vec<Vec<i64>>,
38    pub input_dtypes: Vec<Option<TensorDtype>>,
39    pub output_dtypes: Vec<Option<TensorDtype>>,
40}
41
42pub struct OnnxBackend {
43    session: Arc<RwLock<Option<Session>>>,
44    bucket_sessions: Arc<RwLock<BucketSessionState>>,
45    model_path: Arc<RwLock<Option<PathBuf>>>,
46    provider: ExecutionProvider,
47    optimization_level: u8,
48    device_id: i32,
49    memory_pattern: bool,
50    disable_cpu_mem_arena: bool,
51    max_bucket_sessions: usize,
52    bucket_dim_granularity: usize,
53    bucket_max_dims: usize,
54    peak_concurrency_hint: Option<u32>,
55    metrics: Arc<RwLock<EngineMetrics>>,
56    metadata: Arc<RwLock<Option<ModelMetadata>>>,
57    warmed_up: Arc<AtomicBool>,
58}
59
60#[derive(Default)]
61struct BucketSessionState {
62    primary_bucket_key: Option<String>,
63    sessions: HashMap<String, Session>,
64    lru: VecDeque<String>,
65}
66
67const ORT_MEMORY_PATTERN_ENV: &str = "KAPSL_ORT_MEMORY_PATTERN";
68const ORT_DISABLE_CPU_MEM_ARENA_ENV: &str = "KAPSL_ORT_DISABLE_CPU_MEM_ARENA";
69const ORT_SESSION_BUCKETS_ENV: &str = "KAPSL_ORT_SESSION_BUCKETS";
70const ORT_BUCKET_DIM_GRANULARITY_ENV: &str = "KAPSL_ORT_BUCKET_DIM_GRANULARITY";
71const ORT_BUCKET_MAX_DIMS_ENV: &str = "KAPSL_ORT_BUCKET_MAX_DIMS";
72const MODEL_PEAK_CONCURRENCY_ENV: &str = "KAPSL_MODEL_PEAK_CONCURRENCY";
73const ORT_SESSION_BUCKETS_MAX: usize = 64;
74
75fn read_env_flag(name: &str, default: bool) -> bool {
76    std::env::var(name)
77        .ok()
78        .and_then(|value| match value.trim().to_ascii_lowercase().as_str() {
79            "1" | "true" | "yes" | "on" => Some(true),
80            "0" | "false" | "no" | "off" => Some(false),
81            _ => None,
82        })
83        .unwrap_or(default)
84}
85
86fn read_env_usize(name: &str) -> Option<usize> {
87    std::env::var(name)
88        .ok()
89        .and_then(|value| value.trim().parse::<usize>().ok())
90}
91
92fn read_env_u32(name: &str) -> Option<u32> {
93    std::env::var(name)
94        .ok()
95        .and_then(|value| value.trim().parse::<u32>().ok())
96        .filter(|value| *value > 0)
97}
98
99fn expose_sensitive_ids_in_logs() -> bool {
100    static CACHE: OnceLock<bool> = OnceLock::new();
101    *CACHE.get_or_init(|| {
102        std::env::var("KAPSL_LOG_SENSITIVE_IDS")
103            .or_else(|_| std::env::var("KAPSL_LOG_SENSITIVE_IDS"))
104            .ok()
105            .map(|value| {
106                matches!(
107                    value.trim().to_ascii_lowercase().as_str(),
108                    "1" | "true" | "yes" | "on"
109                )
110            })
111            .unwrap_or(false)
112    })
113}
114
115fn redact_session_id_for_log(session_id: &str) -> String {
116    if expose_sensitive_ids_in_logs() || session_id.is_empty() {
117        return session_id.to_string();
118    }
119    let prefix: String = session_id.chars().take(4).collect();
120    format!("{}...[redacted]", prefix)
121}
122
123fn map_ort_dtype(dtype: TensorElementType) -> Option<TensorDtype> {
124    match dtype {
125        TensorElementType::Float32 => Some(TensorDtype::Float32),
126        TensorElementType::Float64 => Some(TensorDtype::Float64),
127        TensorElementType::Float16 => Some(TensorDtype::Float16),
128        TensorElementType::Int32 => Some(TensorDtype::Int32),
129        TensorElementType::Int64 => Some(TensorDtype::Int64),
130        TensorElementType::Uint8 => Some(TensorDtype::Uint8),
131        _ => None,
132    }
133}
134
135// TODO: Review if manual unsafe impl is necessary - Session should already be Send + Sync
136// TODO: Add documentation explaining thread safety guarantees
137// Session is Send + Sync, so OnnxBackend is Send + Sync
138unsafe impl Send for OnnxBackend {}
139unsafe impl Sync for OnnxBackend {}
140
141#[derive(Debug)]
142pub struct OnnxBackendBuilder {
143    provider: ExecutionProvider,
144    optimization_level: GraphOptimizationLevel,
145    device_id: i32,
146    memory_pattern: bool,
147    disable_cpu_mem_arena: bool,
148    max_bucket_sessions: usize,
149    bucket_dim_granularity: usize,
150    bucket_max_dims: usize,
151    peak_concurrency_hint: Option<u32>,
152}
153
154impl OnnxBackendBuilder {
155    pub fn new() -> Self {
156        let memory_pattern = read_env_flag(ORT_MEMORY_PATTERN_ENV, true);
157        let disable_cpu_mem_arena = read_env_flag(ORT_DISABLE_CPU_MEM_ARENA_ENV, false);
158        let max_bucket_sessions = read_env_usize(ORT_SESSION_BUCKETS_ENV)
159            .unwrap_or(4)
160            .clamp(1, ORT_SESSION_BUCKETS_MAX);
161        let bucket_dim_granularity = read_env_usize(ORT_BUCKET_DIM_GRANULARITY_ENV)
162            .unwrap_or(64)
163            .max(1);
164        let bucket_max_dims = read_env_usize(ORT_BUCKET_MAX_DIMS_ENV).unwrap_or(4).max(1);
165        let peak_concurrency_hint = read_env_u32(MODEL_PEAK_CONCURRENCY_ENV);
166        Self {
167            provider: ExecutionProvider::CPU,
168            optimization_level: GraphOptimizationLevel::Level3,
169            device_id: 0,
170            memory_pattern,
171            disable_cpu_mem_arena,
172            max_bucket_sessions,
173            bucket_dim_granularity,
174            bucket_max_dims,
175            peak_concurrency_hint,
176        }
177    }
178
179    pub fn with_provider(mut self, provider: ExecutionProvider) -> Self {
180        self.provider = provider;
181        self
182    }
183
184    pub fn with_optimization_level(mut self, opt_level: GraphOptimizationLevel) -> Self {
185        self.optimization_level = opt_level;
186        self
187    }
188
189    pub fn with_device_id(mut self, device_id: i32) -> Result<Self, String> {
190        if device_id < 0 {
191            return Err("Device ID must be non-negative".to_string());
192        }
193        self.device_id = device_id;
194        Ok(self)
195    }
196
197    pub fn with_memory_pattern(mut self, enabled: bool) -> Self {
198        self.memory_pattern = enabled;
199        self
200    }
201
202    pub fn with_disable_cpu_mem_arena(mut self, disabled: bool) -> Self {
203        self.disable_cpu_mem_arena = disabled;
204        self
205    }
206
207    pub fn with_max_bucket_sessions(mut self, max_bucket_sessions: usize) -> Self {
208        self.max_bucket_sessions = max_bucket_sessions.clamp(1, ORT_SESSION_BUCKETS_MAX);
209        self
210    }
211
212    pub fn with_bucket_dim_granularity(mut self, bucket_dim_granularity: usize) -> Self {
213        self.bucket_dim_granularity = bucket_dim_granularity.max(1);
214        self
215    }
216
217    pub fn with_bucket_max_dims(mut self, bucket_max_dims: usize) -> Self {
218        self.bucket_max_dims = bucket_max_dims.max(1);
219        self
220    }
221
222    pub fn with_peak_concurrency_hint(mut self, peak_concurrency_hint: u32) -> Self {
223        self.peak_concurrency_hint = Some(peak_concurrency_hint.max(1));
224        self
225    }
226
227    pub fn build(self) -> OnnxBackend {
228        let level_value = match self.optimization_level {
229            GraphOptimizationLevel::Disable => 0,
230            GraphOptimizationLevel::Level1 => 1,
231            GraphOptimizationLevel::Level2 => 2,
232            GraphOptimizationLevel::Level3 => 3,
233            GraphOptimizationLevel::All => 4,
234        };
235        OnnxBackend {
236            session: Arc::new(RwLock::new(None)),
237            bucket_sessions: Arc::new(RwLock::new(BucketSessionState::default())),
238            model_path: Arc::new(RwLock::new(None)),
239            provider: self.provider,
240            optimization_level: level_value,
241            device_id: self.device_id,
242            memory_pattern: self.memory_pattern,
243            disable_cpu_mem_arena: self.disable_cpu_mem_arena,
244            max_bucket_sessions: self.max_bucket_sessions,
245            bucket_dim_granularity: self.bucket_dim_granularity,
246            bucket_max_dims: self.bucket_max_dims,
247            peak_concurrency_hint: self.peak_concurrency_hint,
248            metrics: Arc::new(RwLock::new(EngineMetrics::default())),
249            metadata: Arc::new(RwLock::new(None)),
250            warmed_up: Arc::new(AtomicBool::new(false)),
251        }
252    }
253}
254
255impl Default for OnnxBackendBuilder {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261impl OnnxBackend {
262    /// Convert stored u8 optimization level to GraphOptimizationLevel
263    fn get_opt_level(&self) -> GraphOptimizationLevel {
264        match self.optimization_level {
265            0 => GraphOptimizationLevel::Disable,
266            1 => GraphOptimizationLevel::Level1,
267            2 => GraphOptimizationLevel::Level2,
268            3 => GraphOptimizationLevel::Level3,
269            _ => GraphOptimizationLevel::All,
270        }
271    }
272
273    pub fn builder() -> OnnxBackendBuilder {
274        OnnxBackendBuilder::new()
275    }
276
277    pub fn new_cpu() -> Self {
278        Self::builder().build()
279    }
280
281    pub fn new_cpu_with_optimization(opt_level: GraphOptimizationLevel) -> Self {
282        Self::builder().with_optimization_level(opt_level).build()
283    }
284
285    pub fn new_cuda(device_id: i32) -> Result<Self, String> {
286        Self::new_cuda_with_optimization(GraphOptimizationLevel::Level3, device_id)
287    }
288
289    pub fn new_cuda_with_optimization(
290        opt_level: GraphOptimizationLevel,
291        device_id: i32,
292    ) -> Result<Self, String> {
293        Ok(Self::builder()
294            .with_provider(ExecutionProvider::CUDA)
295            .with_optimization_level(opt_level)
296            .with_device_id(device_id)?
297            .build())
298    }
299
300    pub fn new_tensorrt(device_id: i32) -> Result<Self, String> {
301        Self::new_tensorrt_with_optimization(GraphOptimizationLevel::Level3, device_id)
302    }
303
304    pub fn new_tensorrt_with_optimization(
305        opt_level: GraphOptimizationLevel,
306        device_id: i32,
307    ) -> Result<Self, String> {
308        Ok(Self::builder()
309            .with_provider(ExecutionProvider::TensorRT)
310            .with_optimization_level(opt_level)
311            .with_device_id(device_id)?
312            .build())
313    }
314
315    pub fn new_directml(device_id: i32) -> Result<Self, String> {
316        Self::new_directml_with_optimization(GraphOptimizationLevel::Level3, device_id)
317    }
318
319    pub fn new_directml_with_optimization(
320        opt_level: GraphOptimizationLevel,
321        device_id: i32,
322    ) -> Result<Self, String> {
323        Ok(Self::builder()
324            .with_provider(ExecutionProvider::DirectML)
325            .with_optimization_level(opt_level)
326            .with_device_id(device_id)?
327            .build())
328    }
329
330    pub fn new_rocm(device_id: i32) -> Result<Self, String> {
331        Self::new_rocm_with_optimization(GraphOptimizationLevel::Level3, device_id)
332    }
333
334    pub fn new_rocm_with_optimization(
335        opt_level: GraphOptimizationLevel,
336        device_id: i32,
337    ) -> Result<Self, String> {
338        Ok(Self::builder()
339            .with_provider(ExecutionProvider::ROCm)
340            .with_optimization_level(opt_level)
341            .with_device_id(device_id)?
342            .build())
343    }
344
345    pub fn new_openvino_with_optimiation(
346        opt_level: GraphOptimizationLevel,
347        device_id: i32,
348    ) -> Result<Self, String> {
349        Ok(Self::builder()
350            .with_provider(ExecutionProvider::OpenVINO)
351            .with_optimization_level(opt_level)
352            .with_device_id(device_id)?
353            .build())
354    }
355
356    pub fn new_coreml_with_optimiation(
357        opt_level: GraphOptimizationLevel,
358        device_id: i32,
359    ) -> Result<Self, String> {
360        Ok(Self::builder()
361            .with_provider(ExecutionProvider::CoreML)
362            .with_optimization_level(opt_level)
363            .with_device_id(device_id)?
364            .build())
365    }
366
367    fn bucket_key_for_request(&self, request: &InferenceRequest) -> Option<String> {
368        if self.max_bucket_sessions <= 1 {
369            return None;
370        }
371
372        let mut key = format!(
373            "{}:r{}",
374            request.input.dtype.as_str(),
375            request.input.shape.len()
376        );
377        for (index, dim) in request
378            .input
379            .shape
380            .iter()
381            .take(self.bucket_max_dims)
382            .copied()
383            .enumerate()
384        {
385            let rounded = if dim <= 0 {
386                -1
387            } else if index == 0 {
388                dim
389            } else {
390                let granularity = self.bucket_dim_granularity as i64;
391                ((dim + granularity - 1) / granularity) * granularity
392            };
393            key.push(':');
394            key.push_str(&rounded.to_string());
395        }
396        if request.input.shape.len() > self.bucket_max_dims {
397            key.push_str(":*");
398        }
399        Some(key)
400    }
401
402    fn touch_bucket_lru(state: &mut BucketSessionState, bucket_key: &str) {
403        if let Some(pos) = state.lru.iter().position(|existing| existing == bucket_key) {
404            state.lru.remove(pos);
405        }
406        state.lru.push_back(bucket_key.to_string());
407    }
408
409    fn get_or_create_bucket_session<'a>(
410        &self,
411        state: &'a mut BucketSessionState,
412        bucket_key: &str,
413    ) -> Result<&'a mut Session, EngineError> {
414        if !state.sessions.contains_key(bucket_key) {
415            let secondary_capacity = self.max_bucket_sessions.saturating_sub(1).max(1);
416            while state.sessions.len() >= secondary_capacity {
417                let Some(evict_key) = state.lru.pop_front() else {
418                    break;
419                };
420                state.sessions.remove(&evict_key);
421            }
422
423            let model_path = self
424                .model_path
425                .read()
426                .map_err(|_| EngineError::Backend {
427                    message: "Lock poisoned".to_string(),
428                    source: None,
429                })?
430                .clone()
431                .ok_or(EngineError::ModelNotLoaded)?;
432            let session = self.create_session(&model_path, self.get_opt_level())?;
433            state.sessions.insert(bucket_key.to_string(), session);
434        }
435
436        Self::touch_bucket_lru(state, bucket_key);
437        state
438            .sessions
439            .get_mut(bucket_key)
440            .ok_or(EngineError::ModelNotLoaded)
441    }
442
443    fn create_session(
444        &self,
445        model_path: &Path,
446        opt_level: GraphOptimizationLevel,
447    ) -> Result<Session, EngineError> {
448        // Common builder setup
449        let mut builder = Session::builder()
450            .map_err(|e| EngineError::ModelLoadError {
451                path: model_path.to_string_lossy().into_owned(),
452                source: Box::new(std::io::Error::other(e.to_string())),
453            })?
454            .with_optimization_level(opt_level)
455            .map_err(|e| EngineError::ModelLoadError {
456                path: model_path.to_string_lossy().into_owned(),
457                source: Box::new(std::io::Error::other(e.to_string())),
458            })?
459            .with_memory_pattern(self.memory_pattern)
460            .map_err(|e| EngineError::ModelLoadError {
461                path: model_path.to_string_lossy().into_owned(),
462                source: Box::new(std::io::Error::other(e.to_string())),
463            })?;
464        if self.disable_cpu_mem_arena {
465            builder = builder
466                .with_config_entry("session.disable_cpu_mem_arena", "1")
467                .map_err(|e| EngineError::ModelLoadError {
468                    path: model_path.to_string_lossy().into_owned(),
469                    source: Box::new(std::io::Error::other(e.to_string())),
470                })?;
471        }
472
473        // Configure execution providers based on the selected backend
474        let builder = match self.provider {
475            ExecutionProvider::CUDA => {
476                if !ort::execution_providers::CUDAExecutionProvider::default()
477                    .is_available()
478                    .unwrap_or(false)
479                {
480                    return Err(EngineError::Backend {
481                        message:
482                            "CUDA execution provider is not available. Please check your CUDA installation."
483                                .to_string(),
484                        source: None,
485                    });
486                }
487                builder
488                    .with_execution_providers([
489                        ort::execution_providers::CUDAExecutionProvider::default()
490                            .with_device_id(self.device_id)
491                            .build(),
492                    ])
493                    .map_err(|e| EngineError::ModelLoadError {
494                        path: model_path.to_string_lossy().into_owned(),
495                        source: Box::new(std::io::Error::other(e.to_string())),
496                    })?
497            }
498            ExecutionProvider::TensorRT => {
499                if !ort::execution_providers::TensorRTExecutionProvider::default()
500                    .is_available()
501                    .unwrap_or(false)
502                {
503                    return Err(EngineError::Backend {
504                        message: "TensorRT execution provider is not available.".to_string(),
505                        source: None,
506                    });
507                }
508                builder
509                    .with_execution_providers([
510                        ort::execution_providers::TensorRTExecutionProvider::default()
511                            .with_device_id(self.device_id)
512                            .build(),
513                        // Fallback to CUDA if TensorRT has issues with some nodes
514                        ort::execution_providers::CUDAExecutionProvider::default()
515                            .with_device_id(self.device_id)
516                            .build(),
517                    ])
518                    .map_err(|e| EngineError::ModelLoadError {
519                        path: model_path.to_string_lossy().into_owned(),
520                        source: Box::new(std::io::Error::other(e.to_string())),
521                    })?
522            }
523            ExecutionProvider::DirectML => {
524                #[cfg(target_os = "windows")]
525                {
526                    if !ort::execution_providers::DirectMLExecutionProvider::default()
527                        .is_available()
528                        .unwrap_or(false)
529                    {
530                        return Err(EngineError::Backend {
531                            message: "DirectML execution provider is not available.".to_string(),
532                            source: None,
533                        });
534                    }
535                    builder
536                        .with_execution_providers([
537                            ort::execution_providers::DirectMLExecutionProvider::default()
538                                .with_device_id(self.device_id)
539                                .build(),
540                        ])
541                        .map_err(|e| EngineError::ModelLoadError {
542                            path: model_path.to_string_lossy().into_owned(),
543                            source: Box::new(std::io::Error::other(e.to_string())),
544                        })?
545                }
546                #[cfg(not(target_os = "windows"))]
547                {
548                    return Err(EngineError::Backend {
549                        message: "DirectML execution provider is only supported on Windows."
550                            .to_string(),
551                        source: None,
552                    });
553                }
554            }
555            ExecutionProvider::ROCm => {
556                if !ort::execution_providers::ROCmExecutionProvider::default()
557                    .is_available()
558                    .unwrap_or(false)
559                {
560                    return Err(EngineError::Backend {
561                        message: "ROCm execution provider is not available.".to_string(),
562                        source: None,
563                    });
564                }
565                builder
566                    .with_execution_providers([
567                        ort::execution_providers::ROCmExecutionProvider::default()
568                            .with_device_id(self.device_id)
569                            .build(),
570                    ])
571                    .map_err(|e| EngineError::ModelLoadError {
572                        path: model_path.to_string_lossy().into_owned(),
573                        source: Box::new(std::io::Error::other(e.to_string())),
574                    })?
575            }
576            ExecutionProvider::OpenVINO => {
577                if !ort::execution_providers::OpenVINOExecutionProvider::default()
578                    .is_available()
579                    .unwrap_or(false)
580                {
581                    return Err(EngineError::Backend {
582                        message: "OpenVINO execution provider is not available.".to_string(),
583                        source: None,
584                    });
585                }
586                builder
587                    .with_execution_providers([
588                        ort::execution_providers::OpenVINOExecutionProvider::default().build(),
589                    ])
590                    .map_err(|e| EngineError::ModelLoadError {
591                        path: model_path.to_string_lossy().into_owned(),
592                        source: Box::new(std::io::Error::other(e.to_string())),
593                    })?
594            }
595            ExecutionProvider::CoreML => {
596                if !ort::execution_providers::CoreMLExecutionProvider::default()
597                    .is_available()
598                    .unwrap_or(false)
599                {
600                    return Err(EngineError::Backend {
601                        message: "CoreML execution provider is not available.".to_string(),
602                        source: None,
603                    });
604                }
605                builder
606                    .with_execution_providers([
607                        ort::execution_providers::CoreMLExecutionProvider::default().build(),
608                        // Keep CPU registered as a fallback for nodes CoreML cannot handle. Without
609                        // this, ONNXRuntime will hard-fail on models that partially compile/run
610                        // under CoreML (ex: plan-building failures for certain ops).
611                        ort::execution_providers::CPUExecutionProvider::default().build(),
612                    ])
613                    .map_err(|e| EngineError::ModelLoadError {
614                        path: model_path.to_string_lossy().into_owned(),
615                        source: Box::new(std::io::Error::other(e.to_string())),
616                    })?
617            }
618            ExecutionProvider::CPU => {
619                // CPU is always available, nothing special to add
620                builder
621            }
622        };
623
624        // Finalize session from file
625        builder
626            .commit_from_file(model_path)
627            .map_err(|e| EngineError::ModelLoadError {
628                path: model_path.to_string_lossy().into_owned(),
629                source: Box::new(std::io::Error::other(e.to_string())),
630            })
631    }
632}
633
634/// PreparedInput is returned by the input validation helper. It contains the
635/// typed vector of elements parsed from raw bytes.
636///
637/// Note: kept at module scope so tests can access it.
638#[derive(Debug)]
639enum PreparedInput {
640    F32(Vec<f32>),
641    F64(Vec<f64>),
642    F16(Vec<f16>),
643    I32(Vec<i32>),
644    I64(Vec<i64>),
645    U8(Vec<u8>),
646}
647
648fn validate_byte_len(
649    input: &BinaryTensorPacket,
650    num_elements: usize,
651    elem_size: usize,
652    dtype_label: &str,
653) -> Result<(), EngineError> {
654    let expected =
655        num_elements
656            .checked_mul(elem_size)
657            .ok_or_else(|| EngineError::InvalidInput {
658                message: "Data size overflow".to_string(),
659                source: None,
660            })?;
661    if input.data.len() != expected {
662        return Err(EngineError::InvalidInput {
663            message: format!(
664                "Data length mismatch: expected {} bytes ({} {} values) but got {} bytes",
665                expected,
666                num_elements,
667                dtype_label,
668                input.data.len()
669            ),
670            source: None,
671        });
672    }
673    Ok(())
674}
675
676fn parse_ne_f32(bytes: &[u8], num_elements: usize) -> Vec<f32> {
677    if let Some(values) = try_aligned_copy::<f32>(bytes) {
678        return values;
679    }
680    let mut values = Vec::with_capacity(num_elements);
681    for chunk in bytes.chunks_exact(4) {
682        values.push(f32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
683    }
684    values
685}
686
687fn parse_ne_f64(bytes: &[u8], num_elements: usize) -> Vec<f64> {
688    if let Some(values) = try_aligned_copy::<f64>(bytes) {
689        return values;
690    }
691    let mut values = Vec::with_capacity(num_elements);
692    for chunk in bytes.chunks_exact(8) {
693        values.push(f64::from_ne_bytes([
694            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
695        ]));
696    }
697    values
698}
699
700fn parse_ne_f16(bytes: &[u8], num_elements: usize) -> Vec<f16> {
701    if let Some(values) = try_aligned_copy::<u16>(bytes) {
702        let mut out = Vec::with_capacity(num_elements);
703        out.extend(values.into_iter().map(f16::from_bits));
704        return out;
705    }
706    let mut values = Vec::with_capacity(num_elements);
707    for chunk in bytes.chunks_exact(2) {
708        values.push(f16::from_bits(u16::from_ne_bytes([chunk[0], chunk[1]])));
709    }
710    values
711}
712
713fn parse_ne_i32(bytes: &[u8], num_elements: usize) -> Vec<i32> {
714    if let Some(values) = try_aligned_copy::<i32>(bytes) {
715        return values;
716    }
717    let mut values = Vec::with_capacity(num_elements);
718    for chunk in bytes.chunks_exact(4) {
719        values.push(i32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
720    }
721    values
722}
723
724fn parse_ne_i64(bytes: &[u8], num_elements: usize) -> Vec<i64> {
725    if let Some(values) = try_aligned_copy::<i64>(bytes) {
726        return values;
727    }
728    let mut values = Vec::with_capacity(num_elements);
729    for chunk in bytes.chunks_exact(8) {
730        values.push(i64::from_ne_bytes([
731            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
732        ]));
733    }
734    values
735}
736
737fn find_additional_input_by_name<'a>(
738    additional_inputs: &'a [kapsl_engine_api::NamedTensor],
739    name: &str,
740) -> Option<&'a BinaryTensorPacket> {
741    additional_inputs
742        .iter()
743        .find(|entry| entry.name == name)
744        .map(|entry| &entry.tensor)
745}
746
747fn ensure_unique_additional_input_names(
748    additional_inputs: &[kapsl_engine_api::NamedTensor],
749) -> Result<(), EngineError> {
750    for i in 0..additional_inputs.len() {
751        for j in (i + 1)..additional_inputs.len() {
752            if additional_inputs[i].name == additional_inputs[j].name {
753                return Err(EngineError::InvalidInput {
754                    message: format!(
755                        "Duplicate additional input name: {}",
756                        additional_inputs[i].name
757                    ),
758                    source: None,
759                });
760            }
761        }
762    }
763    Ok(())
764}
765
766fn try_aligned_copy<T: Copy>(bytes: &[u8]) -> Option<Vec<T>> {
767    // SAFETY: The concrete call sites only use plain numeric POD types.
768    let (prefix, aligned, suffix) = unsafe { bytes.align_to::<T>() };
769    if prefix.is_empty() && suffix.is_empty() {
770        Some(aligned.to_vec())
771    } else {
772        None
773    }
774}
775
776/// Validate an incoming BinaryTensorPacket and return the shape and a typed
777/// vector if valid. This performs:
778///  - dtype support checking
779///  - shape -> element count computation
780///  - buffer length validation (must equal element_count * dtype_size)
781///  - safe byte->value conversion
782fn validate_and_prepare_input(
783    input: &kapsl_engine_api::BinaryTensorPacket,
784) -> Result<(Vec<i64>, PreparedInput), EngineError> {
785    // Compute number of elements from shape; treat empty shape as scalar (1 element)
786    let num_elements: usize = if input.shape.is_empty() {
787        1
788    } else {
789        // If any dimension is <= 0, reject as invalid shape
790        let mut prod: usize = 1;
791        for &d in &input.shape {
792            if d <= 0 {
793                return Err(EngineError::InvalidInput {
794                    message: format!("Invalid shape dimension: {}", d),
795                    source: None,
796                });
797            }
798            prod = prod
799                .checked_mul(d as usize)
800                .ok_or_else(|| EngineError::InvalidInput {
801                    message: "Shape multiplication overflow".to_string(),
802                    source: None,
803                })?;
804        }
805        prod
806    };
807
808    // Determine element size and branch by dtype
809    match input.dtype {
810        TensorDtype::Float32 => {
811            validate_byte_len(input, num_elements, 4, "float32")?;
812            let values = parse_ne_f32(&input.data, num_elements);
813            Ok((input.shape.clone(), PreparedInput::F32(values)))
814        }
815        TensorDtype::Float64 => {
816            validate_byte_len(input, num_elements, 8, "float64")?;
817            let values = parse_ne_f64(&input.data, num_elements);
818            Ok((input.shape.clone(), PreparedInput::F64(values)))
819        }
820        TensorDtype::Float16 => {
821            validate_byte_len(input, num_elements, 2, "float16")?;
822            let values = parse_ne_f16(&input.data, num_elements);
823            Ok((input.shape.clone(), PreparedInput::F16(values)))
824        }
825        TensorDtype::Int32 => {
826            validate_byte_len(input, num_elements, 4, "int32")?;
827            let values = parse_ne_i32(&input.data, num_elements);
828            Ok((input.shape.clone(), PreparedInput::I32(values)))
829        }
830        TensorDtype::Int64 => {
831            validate_byte_len(input, num_elements, 8, "int64")?;
832            let values = parse_ne_i64(&input.data, num_elements);
833            Ok((input.shape.clone(), PreparedInput::I64(values)))
834        }
835        TensorDtype::Uint8 => {
836            validate_byte_len(input, num_elements, 1, "uint8")?;
837            Ok((input.shape.clone(), PreparedInput::U8(input.data.clone())))
838        }
839        other => Err(EngineError::InvalidInput {
840            message: format!(
841                "Unsupported dtype: {}. Supported: float32, float64, float16, int32, int64, uint8",
842                other
843            ),
844            source: None,
845        }),
846    }
847}
848
849fn tensor_packet_to_session_input(
850    input: &BinaryTensorPacket,
851) -> Result<(Vec<usize>, SessionInputValue<'_>), EngineError> {
852    let (shape_i64, prepared) = validate_and_prepare_input(input)?;
853    let shape_usize = get_shape_usize(&shape_i64);
854
855    let value: SessionInputValue = match prepared {
856        PreparedInput::F32(v) => Value::from_array((shape_usize.clone(), v)).map(|v| v.into()),
857        PreparedInput::F64(v) => Value::from_array((shape_usize.clone(), v)).map(|v| v.into()),
858        PreparedInput::F16(v) => Value::from_array((shape_usize.clone(), v)).map(|v| v.into()),
859        PreparedInput::I32(v) => Value::from_array((shape_usize.clone(), v)).map(|v| v.into()),
860        PreparedInput::I64(v) => Value::from_array((shape_usize.clone(), v)).map(|v| v.into()),
861        PreparedInput::U8(v) => Value::from_array((shape_usize.clone(), v)).map(|v| v.into()),
862    }
863    .map_err(|e| EngineError::InferenceError {
864        reason: "Failed to create input tensor".to_string(),
865        source: Some(Box::new(e)),
866    })?;
867
868    Ok((shape_usize, value))
869}
870
871fn run_inference_with_session(
872    session: &mut Session,
873    request: &InferenceRequest,
874    metadata: &ModelMetadata,
875    shape_usize: Vec<usize>,
876    main_input_tensor: SessionInputValue<'_>,
877) -> Result<BinaryTensorPacket, EngineError> {
878    // We assume the first input maps to the provided input packet
879    if metadata.input_names.is_empty() {
880        return Err(EngineError::InferenceError {
881            reason: "Model has no inputs defined".to_string(),
882            source: None,
883        });
884    }
885
886    let outputs = if metadata.input_names.len() == 1 && request.additional_inputs.is_empty() {
887        session.run([main_input_tensor]).map_err(|e| {
888            log::error!("ONNX Runtime inference error: {:?}", e);
889            EngineError::InferenceError {
890                reason: format!("Inference failed: {}", e),
891                source: Some(Box::new(e)),
892            }
893        })?
894    } else {
895        // Prepare named input map only when required by multi-input models.
896        let mut inputs: Vec<(Cow<'_, str>, SessionInputValue)> =
897            Vec::with_capacity(metadata.input_names.len());
898        inputs.push((
899            Cow::Borrowed(metadata.input_names[0].as_str()),
900            main_input_tensor,
901        ));
902        ensure_unique_additional_input_names(&request.additional_inputs)?;
903
904        // Get batch size from main input (assume dim 0 is batch)
905        let batch_size = if !shape_usize.is_empty() {
906            shape_usize[0]
907        } else {
908            1
909        };
910        let seq_len = if shape_usize.len() > 1 {
911            shape_usize[1]
912        } else {
913            1
914        };
915
916        // Fill other inputs
917        // Workaround: ORT Value::from_array rejects 0-dimension (e.g. past_seq_len=0).
918        // Try to use ndarray which might handle it better, or maybe it was just a vector check.
919        // We set workaround length to 0 (correct behavior).
920        let workaround_past_len = 0;
921
922        for (i, name) in metadata.input_names.iter().enumerate().skip(1) {
923            let shape_def = &metadata.input_shapes[i];
924            if let Some(named_input) =
925                find_additional_input_by_name(&request.additional_inputs, name)
926            {
927                let (_, value) = tensor_packet_to_session_input(named_input)?;
928                inputs.push((Cow::Borrowed(name.as_str()), value));
929                continue;
930            }
931
932            if name.contains("attention_mask") {
933                // Attention mask: [batch, total_seq_len] -> 1s for active, 0 for past workaround
934                // Total len = seq_len + workaround_past_len
935                let total_len = seq_len + workaround_past_len;
936                // Attention mask: [batch, total_seq_len] -> 1s for active
937                let mask_shape = vec![batch_size as i64, total_len as i64];
938
939                let mut mask_data = Vec::with_capacity(batch_size * total_len);
940                for _ in 0..batch_size {
941                    // Mask out the dummy past
942                    mask_data.extend(std::iter::repeat_n(0i64, workaround_past_len));
943                    // Active sequence
944                    mask_data.extend(std::iter::repeat_n(1i64, seq_len));
945                }
946
947                log::debug!(
948                    "Creating attention_mask tensor for {} with shape {:?}",
949                    name,
950                    mask_shape
951                );
952
953                let mask_tensor = Value::from_array((get_shape_usize(&mask_shape), mask_data))
954                    .map_err(|e| EngineError::InferenceError {
955                        reason: format!("Failed to create attention_mask tensor for {}", name),
956                        source: Some(Box::new(e)),
957                    })?;
958                inputs.push((Cow::Borrowed(name.as_str()), mask_tensor.into()));
959            } else if name.contains("position_ids") {
960                let pos_shape = vec![batch_size as i64, seq_len as i64];
961                let mut pos_data = Vec::with_capacity(batch_size * seq_len);
962                // Position IDs should likely start after past? Or 0?
963                // If we have dummy past at 0..1, then real tokens start at 0?
964                // Actually if past is masked, it's effectively like it doesn't exist.
965                // So we keep pos ids 0-based.
966                for _ in 0..batch_size {
967                    for s in 0..seq_len {
968                        pos_data.push(s as i64);
969                    }
970                }
971                let pos_tensor = Value::from_array((get_shape_usize(&pos_shape), pos_data))
972                    .map_err(|e| EngineError::InferenceError {
973                        reason: format!("Failed to create position_ids tensor for {}", name),
974                        source: Some(Box::new(e)),
975                    })?;
976                inputs.push((Cow::Borrowed(name.as_str()), pos_tensor.into()));
977            } else if name.starts_with("past_key_values") {
978                let mut new_shape = Vec::new();
979                new_shape.push(batch_size); // dim 0
980
981                if shape_def.len() == 4 {
982                    let dim1 = if shape_def[1] > 0 {
983                        shape_def[1] as usize
984                    } else {
985                        1
986                    };
987                    new_shape.push(dim1);
988
989                    new_shape.push(workaround_past_len); // dim 2: 0
990
991                    let dim3 = if shape_def[3] > 0 {
992                        shape_def[3] as usize
993                    } else {
994                        64
995                    };
996                    new_shape.push(dim3);
997
998                    log::debug!("Creating KV tensor for {} with shape {:?}", name, new_shape);
999
1000                    let count: usize = new_shape.iter().product();
1001                    let empty_data: Vec<f16> = vec![f16::ZERO; count];
1002
1003                    // Use ndarray to construct possibly 0-dim tensor
1004                    let kv_array = ArrayD::from_shape_vec(new_shape, empty_data).map_err(|e| {
1005                        EngineError::InferenceError {
1006                            reason: format!("Failed to create ndarray for {}: {:?}", name, e),
1007                            source: Some(Box::new(e)),
1008                        }
1009                    })?;
1010
1011                    let kv_tensor =
1012                        Value::from_array(kv_array).map_err(|e| EngineError::InferenceError {
1013                            reason: format!(
1014                                "Failed to create empty KV tensor for {}: {:?}",
1015                                name, e
1016                            ),
1017                            source: Some(Box::new(e)),
1018                        })?;
1019                    inputs.push((Cow::Borrowed(name.as_str()), kv_tensor.into()));
1020                } else {
1021                    log::warn!(
1022                        "Skipping input {} due to unknown shape pattern {:?}",
1023                        name,
1024                        shape_def
1025                    );
1026                }
1027            } else {
1028                log::warn!(
1029                    "Skipping input {} as it is not recognized as auto-fillable",
1030                    name
1031                );
1032            }
1033        }
1034
1035        session.run(inputs).map_err(|e| {
1036            log::error!("ONNX Runtime inference error: {:?}", e);
1037            EngineError::InferenceError {
1038                reason: format!("Inference failed: {}", e),
1039                source: Some(Box::new(e)),
1040            }
1041        })?
1042    };
1043
1044    // For LLMs, we often get multiple outputs (logits + KV cache).
1045    // We currently only ignore the KV cache return values and just use the first output (logits).
1046    if outputs.len() > 1 {
1047        log::debug!(
1048            "Backend received {} outputs, using only the first one (logits)",
1049            outputs.len()
1050        );
1051    }
1052
1053    // Handle output - try f32 first, otherwise return an error.
1054    let output_value = &outputs[0];
1055    let output_packet = if let Ok((shape_ref, data)) = output_value.try_extract_tensor::<f32>() {
1056        BinaryTensorPacket {
1057            shape: shape_ref.to_vec(),
1058            dtype: TensorDtype::Float32,
1059            data: data.iter().flat_map(|&x| x.to_ne_bytes()).collect(),
1060        }
1061    } else if let Ok((shape_ref, data)) = output_value.try_extract_tensor::<f64>() {
1062        BinaryTensorPacket {
1063            shape: shape_ref.to_vec(),
1064            dtype: TensorDtype::Float64,
1065            data: data.iter().flat_map(|&x| x.to_ne_bytes()).collect(),
1066        }
1067    } else if let Ok((shape_ref, data)) = output_value.try_extract_tensor::<f16>() {
1068        BinaryTensorPacket {
1069            shape: shape_ref.to_vec(),
1070            dtype: TensorDtype::Float16,
1071            data: data
1072                .iter()
1073                .flat_map(|x| x.to_bits().to_ne_bytes())
1074                .collect(),
1075        }
1076    } else if let Ok((shape_ref, data)) = output_value.try_extract_tensor::<i32>() {
1077        BinaryTensorPacket {
1078            shape: shape_ref.to_vec(),
1079            dtype: TensorDtype::Int32,
1080            data: data.iter().flat_map(|&x| x.to_ne_bytes()).collect(),
1081        }
1082    } else if let Ok((shape_ref, data)) = output_value.try_extract_tensor::<i64>() {
1083        BinaryTensorPacket {
1084            shape: shape_ref.to_vec(),
1085            dtype: TensorDtype::Int64,
1086            data: data.iter().flat_map(|&x| x.to_ne_bytes()).collect(),
1087        }
1088    } else if let Ok((shape_ref, data)) = output_value.try_extract_tensor::<u8>() {
1089        BinaryTensorPacket {
1090            shape: shape_ref.to_vec(),
1091            dtype: TensorDtype::Uint8,
1092            data: data.to_vec(),
1093        }
1094    } else {
1095        return Err(EngineError::InferenceError {
1096            reason: "Failed to extract output tensor. Supported output dtypes: float32, float64, float16, int32, int64, uint8"
1097                .to_string(),
1098            source: None,
1099        });
1100    };
1101
1102    Ok(output_packet)
1103}
1104
1105#[async_trait]
1106impl Engine for OnnxBackend {
1107    // TODO: Extract session creation to a helper method to reduce duplication
1108    // TODO: Add better error messages with execution provider fallback information
1109    // TODO: Add capability checking before attempting to use hardware accelerators
1110    // TODO: Consider async loading for large models
1111    // TODO: Add progress callback for large model loads
1112    async fn load(&mut self, model_path: &Path) -> Result<(), EngineError> {
1113        let opt_level = self.get_opt_level();
1114        log::info!(
1115            "Loading ONNX model with optimization level: {:?} on provider {:?}",
1116            opt_level,
1117            self.provider
1118        );
1119        log::info!(
1120            "ORT memory config: memory_pattern={} disable_cpu_mem_arena={} session_buckets={} bucket_dim_granularity={} bucket_max_dims={} peak_concurrency_hint={:?}",
1121            self.memory_pattern,
1122            self.disable_cpu_mem_arena,
1123            self.max_bucket_sessions,
1124            self.bucket_dim_granularity,
1125            self.bucket_max_dims,
1126            self.peak_concurrency_hint
1127        );
1128
1129        let session = self.create_session(model_path, opt_level)?;
1130
1131        log::info!("Model Inputs:");
1132        for (i, input) in session.inputs().iter().enumerate() {
1133            log::info!("  Input {}: {} ({:?})", i, input.name(), input.dtype());
1134        }
1135
1136        // Extract metadata
1137        let input_names: Vec<String> = session
1138            .inputs()
1139            .iter()
1140            .map(|i| i.name().to_string())
1141            .collect();
1142        let output_names: Vec<String> = session
1143            .outputs()
1144            .iter()
1145            .map(|o| o.name().to_string())
1146            .collect();
1147
1148        let mut input_shapes = Vec::new();
1149        let mut input_dtypes = Vec::new();
1150        for input in session.inputs() {
1151            let (shape, dtype) = match input.dtype() {
1152                ort::value::ValueType::Tensor { ty, shape, .. } => {
1153                    (shape.iter().copied().collect(), map_ort_dtype(*ty))
1154                }
1155                _ => (vec![], None),
1156            };
1157            input_shapes.push(shape);
1158            input_dtypes.push(dtype);
1159        }
1160
1161        let mut output_shapes = Vec::new();
1162        let mut output_dtypes = Vec::new();
1163        for output in session.outputs() {
1164            let (shape, dtype) = match output.dtype() {
1165                ort::value::ValueType::Tensor { ty, shape, .. } => {
1166                    (shape.iter().copied().collect(), map_ort_dtype(*ty))
1167                }
1168                _ => (vec![], None),
1169            };
1170            output_shapes.push(shape);
1171            output_dtypes.push(dtype);
1172        }
1173
1174        let metadata = ModelMetadata {
1175            input_names,
1176            output_names,
1177            input_shapes,
1178            output_shapes,
1179            input_dtypes,
1180            output_dtypes,
1181        };
1182
1183        // Store metadata
1184        if let Ok(mut meta_guard) = self.metadata.write() {
1185            *meta_guard = Some(metadata);
1186        }
1187
1188        let mut session_guard = self.session.write().map_err(|_| EngineError::Backend {
1189            message: "Lock poisoned".to_string(),
1190            source: None,
1191        })?;
1192        *session_guard = Some(session);
1193        drop(session_guard);
1194
1195        if let Ok(mut model_path_guard) = self.model_path.write() {
1196            *model_path_guard = Some(model_path.to_path_buf());
1197        }
1198        if let Ok(mut bucket_guard) = self.bucket_sessions.write() {
1199            bucket_guard.primary_bucket_key = None;
1200            bucket_guard.sessions.clear();
1201            bucket_guard.lru.clear();
1202        }
1203
1204        // Reset warmup state
1205        self.warmed_up.store(false, Ordering::SeqCst);
1206
1207        Ok(())
1208    }
1209
1210    fn infer(&self, request: &InferenceRequest) -> Result<BinaryTensorPacket, EngineError> {
1211        let start_time = Instant::now();
1212        if let Some(session_id) = &request.session_id {
1213            log::debug!(
1214                "Processing request for session: {}",
1215                redact_session_id_for_log(session_id)
1216            );
1217        }
1218
1219        let metadata = self.metadata.read().map_err(|_| EngineError::Backend {
1220            message: "Lock poisoned".to_string(),
1221            source: None,
1222        })?;
1223        let metadata = metadata
1224            .as_ref()
1225            .cloned()
1226            .ok_or(EngineError::ModelNotLoaded)?;
1227
1228        // Validate and convert primary input.
1229        let (shape_usize, main_input_tensor) = tensor_packet_to_session_input(&request.input)?;
1230        let mut prepared_input = Some((shape_usize, main_input_tensor));
1231
1232        let output_packet = if let Some(bucket_key) = self.bucket_key_for_request(request) {
1233            let use_primary = {
1234                let mut bucket_guard =
1235                    self.bucket_sessions
1236                        .write()
1237                        .map_err(|_| EngineError::Backend {
1238                            message: "Lock poisoned".to_string(),
1239                            source: None,
1240                        })?;
1241                let primary_key = bucket_guard
1242                    .primary_bucket_key
1243                    .get_or_insert_with(|| bucket_key.clone());
1244                *primary_key == bucket_key
1245            };
1246
1247            if use_primary {
1248                let mut session_guard = self.session.write().map_err(|_| EngineError::Backend {
1249                    message: "Lock poisoned".to_string(),
1250                    source: None,
1251                })?;
1252                let session = session_guard.as_mut().ok_or(EngineError::ModelNotLoaded)?;
1253                let (shape_usize, main_input_tensor) = prepared_input
1254                    .take()
1255                    .ok_or_else(|| EngineError::backend("input already consumed".to_string()))?;
1256                run_inference_with_session(
1257                    session,
1258                    request,
1259                    &metadata,
1260                    shape_usize,
1261                    main_input_tensor,
1262                )?
1263            } else {
1264                let mut bucket_guard =
1265                    self.bucket_sessions
1266                        .write()
1267                        .map_err(|_| EngineError::Backend {
1268                            message: "Lock poisoned".to_string(),
1269                            source: None,
1270                        })?;
1271                let session = self.get_or_create_bucket_session(&mut bucket_guard, &bucket_key)?;
1272                let (shape_usize, main_input_tensor) = prepared_input
1273                    .take()
1274                    .ok_or_else(|| EngineError::backend("input already consumed".to_string()))?;
1275                run_inference_with_session(
1276                    session,
1277                    request,
1278                    &metadata,
1279                    shape_usize,
1280                    main_input_tensor,
1281                )?
1282            }
1283        } else {
1284            let mut session_guard = self.session.write().map_err(|_| EngineError::Backend {
1285                message: "Lock poisoned".to_string(),
1286                source: None,
1287            })?;
1288            let session = session_guard.as_mut().ok_or(EngineError::ModelNotLoaded)?;
1289            let (shape_usize, main_input_tensor) = prepared_input
1290                .take()
1291                .ok_or_else(|| EngineError::backend("input already consumed".to_string()))?;
1292            run_inference_with_session(session, request, &metadata, shape_usize, main_input_tensor)?
1293        };
1294
1295        // Update metrics
1296        let duration = start_time.elapsed().as_secs_f64();
1297        if let Ok(mut metrics) = self.metrics.write() {
1298            metrics.inference_time = duration;
1299            // We can't easily get exact memory usage per inference from ONNX Runtime easily here without allocator hooks,
1300            // so we leave it as is or update if we had a way.
1301        }
1302
1303        // Mark as warmed up
1304        self.warmed_up.store(true, Ordering::SeqCst);
1305
1306        Ok(output_packet)
1307    }
1308
1309    // TODO: Implement proper streaming for LLM models with token-by-token generation
1310    // TODO: This is a placeholder - real streaming should yield tokens as they're generated
1311    fn infer_stream(
1312        &self,
1313        request: &InferenceRequest,
1314    ) -> std::pin::Pin<
1315        Box<dyn futures::stream::Stream<Item = Result<BinaryTensorPacket, EngineError>> + Send>,
1316    > {
1317        // Call infer immediately to avoid lifetime issues
1318        let result = self.infer(request);
1319        // Wrap the single result in a stream using futures::stream::once
1320        Box::pin(futures::stream::once(async move { result }))
1321    }
1322
1323    // TODO: Add proper cleanup (free GPU memory, release resources)
1324    // TODO: Log unload operations for debugging
1325    fn unload(&mut self) {
1326        if let Ok(mut session_guard) = self.session.write() {
1327            *session_guard = None;
1328        }
1329        if let Ok(mut bucket_guard) = self.bucket_sessions.write() {
1330            bucket_guard.primary_bucket_key = None;
1331            bucket_guard.sessions.clear();
1332            bucket_guard.lru.clear();
1333        }
1334        if let Ok(mut model_path_guard) = self.model_path.write() {
1335            *model_path_guard = None;
1336        }
1337        if let Ok(mut meta_guard) = self.metadata.write() {
1338            *meta_guard = None;
1339        }
1340    }
1341
1342    fn metrics(&self) -> kapsl_engine_api::EngineMetrics {
1343        if let Ok(metrics) = self.metrics.read() {
1344            metrics.clone()
1345        } else {
1346            kapsl_engine_api::EngineMetrics::default()
1347        }
1348    }
1349
1350    fn health_check(&self) -> Result<(), EngineError> {
1351        // Check if session is loaded and lock is not poisoned
1352        let session_guard = self.session.read().map_err(|_| EngineError::Backend {
1353            message: "Session lock poisoned".to_string(),
1354            source: None,
1355        })?;
1356        let bucket_guard = self
1357            .bucket_sessions
1358            .read()
1359            .map_err(|_| EngineError::Backend {
1360                message: "Session cache lock poisoned".to_string(),
1361                source: None,
1362            })?;
1363
1364        if session_guard.is_some() || !bucket_guard.sessions.is_empty() {
1365            Ok(())
1366        } else {
1367            Err(EngineError::ModelNotLoaded)
1368        }
1369    }
1370
1371    fn model_info(&self) -> Option<EngineModelInfo> {
1372        let metadata_guard = self.metadata.read().ok()?;
1373        let metadata = metadata_guard.as_ref()?;
1374        Some(EngineModelInfo {
1375            input_names: metadata.input_names.clone(),
1376            output_names: metadata.output_names.clone(),
1377            input_shapes: metadata.input_shapes.clone(),
1378            output_shapes: metadata.output_shapes.clone(),
1379            input_dtypes: metadata
1380                .input_dtypes
1381                .iter()
1382                .map(|dtype| {
1383                    dtype
1384                        .as_ref()
1385                        .map(TensorDtype::as_str)
1386                        .unwrap_or("unknown")
1387                        .to_string()
1388                })
1389                .collect(),
1390            output_dtypes: metadata
1391                .output_dtypes
1392                .iter()
1393                .map(|dtype| {
1394                    dtype
1395                        .as_ref()
1396                        .map(TensorDtype::as_str)
1397                        .unwrap_or("unknown")
1398                        .to_string()
1399                })
1400                .collect(),
1401            framework: Some("onnx".to_string()),
1402            model_version: None,
1403            peak_concurrency: self.peak_concurrency_hint,
1404        })
1405    }
1406}
1407
1408fn get_shape_usize(shape: &[i64]) -> Vec<usize> {
1409    shape.iter().map(|&v| v as usize).collect()
1410}
1411
1412#[path = "onnx_tests.rs"]
1413mod onnx_tests;