use super::execution_provider::ExecutionProviderKind;
use super::session::{ONNXSession, SessionOptions};
use crate::ir::Envelope;
use crate::runtime_adapter::tensor_utils::{envelope_to_tensors, tensors_to_envelope};
use crate::runtime_adapter::{
AdapterError, AdapterResult, ModelMetadata, RuntimeAdapter, RuntimeAdapterExt,
};
use std::collections::HashMap;
use std::path::Path;
pub struct OnnxRuntimeAdapter {
models: HashMap<String, ModelMetadata>,
sessions: HashMap<String, ONNXSession>,
current_model: Option<String>,
execution_provider: ExecutionProviderKind,
}
impl OnnxRuntimeAdapter {
pub fn new() -> Self {
Self {
models: HashMap::new(),
sessions: HashMap::new(),
current_model: None,
execution_provider: ExecutionProviderKind::Cpu,
}
}
pub fn with_execution_provider(execution_provider: ExecutionProviderKind) -> Self {
Self {
models: HashMap::new(),
sessions: HashMap::new(),
current_model: None,
execution_provider,
}
}
pub fn execution_provider(&self) -> &ExecutionProviderKind {
&self.execution_provider
}
pub fn set_execution_provider(&mut self, provider: ExecutionProviderKind) {
self.execution_provider = provider;
}
pub fn with_auto_selection(hints: &super::execution_provider::ModelHints) -> Self {
let provider = super::execution_provider::select_optimal_provider(hints);
Self::with_execution_provider(provider)
}
#[deprecated(
since = "0.0.24",
note = "Use with_auto_selection() with ModelHints instead"
)]
#[allow(dead_code)]
pub fn select_optimal_provider() -> ExecutionProviderKind {
#[cfg(all(feature = "ort-coreml", any(target_os = "macos", target_os = "ios")))]
{
use super::execution_provider::CoreMLConfig;
ExecutionProviderKind::CoreML(CoreMLConfig::with_neural_engine())
}
#[cfg(not(all(feature = "ort-coreml", any(target_os = "macos", target_os = "ios"))))]
{
ExecutionProviderKind::Cpu
}
}
fn validate_model_file(&self, model_path: &str) -> AdapterResult<()> {
let path = Path::new(model_path);
if !path.exists() {
return Err(AdapterError::ModelNotFound(format!(
"Model file not found: {}",
model_path
)));
}
if !path.is_file() {
return Err(AdapterError::InvalidInput(format!(
"Path is not a file: {}",
model_path
)));
}
if let Some(ext) = path.extension() {
if ext != "onnx" && ext != "ONNX" {
}
}
Ok(())
}
fn extract_model_id(&self, path: &str) -> String {
Path::new(path)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string()
}
fn real_inference(&self, session: &ONNXSession, input: &Envelope) -> AdapterResult<Envelope> {
let input_shapes: Vec<Vec<i64>> = session.input_shapes().to_vec();
let input_names: Vec<String> = session.input_names().to_vec();
let input_tensors = envelope_to_tensors(input, &input_shapes, &input_names)?;
let output_tensors = session.run(input_tensors).map_err(|e| {
AdapterError::InferenceFailed(format!("ONNX Runtime inference failed: {e}"))
})?;
if let Some(resolved) = session.resolved_providers() {
crate::tracing::add_metadata("execution_provider", resolved.primary);
}
eprintln!("🔵 DEBUG: Raw ONNX Output Tensors");
eprintln!(" Number of outputs: {}", output_tensors.len());
for (name, tensor) in &output_tensors {
eprintln!(
" Output '{}': shape {:?}, size {}",
name,
tensor.shape(),
tensor.len()
);
}
let output_names: Vec<String> = session.output_names().to_vec();
let output = tensors_to_envelope(&output_tensors, &output_names)?;
Ok(output)
}
pub fn get_session(&self, model_path: &str) -> AdapterResult<&ONNXSession> {
let model_id = self.extract_model_id(model_path);
self.sessions.get(&model_id).ok_or_else(|| {
AdapterError::ModelNotLoaded(format!(
"Session for model '{}' (path: {}) not found",
model_id, model_path
))
})
}
}
impl Default for OnnxRuntimeAdapter {
fn default() -> Self {
Self::new()
}
}
impl RuntimeAdapter for OnnxRuntimeAdapter {
fn name(&self) -> &str {
"onnx"
}
fn supported_formats(&self) -> Vec<&'static str> {
vec!["onnx", "onnx.gz"]
}
fn warmup(&mut self) -> AdapterResult<()> {
use ndarray::{ArrayD, IxDyn};
let model_id = self.current_model.as_ref().ok_or_else(|| {
AdapterError::ModelNotLoaded("No model loaded. Call load_model() first.".to_string())
})?;
let session = self.sessions.get(model_id).ok_or_else(|| {
AdapterError::ModelNotLoaded(format!("Session for model '{}' not found", model_id))
})?;
let input_shapes = session.input_shapes();
let input_names = session.input_names();
let mut dummy_inputs: HashMap<String, ArrayD<f32>> = HashMap::new();
for (i, name) in input_names.iter().enumerate() {
if let Some(shape) = input_shapes.get(i) {
let resolved_shape: Vec<usize> = shape
.iter()
.map(|&d| if d < 0 { 1 } else { d as usize })
.collect();
let dummy_tensor = ArrayD::<f32>::zeros(IxDyn(&resolved_shape));
dummy_inputs.insert(name.clone(), dummy_tensor);
}
}
let _ = session.run(dummy_inputs);
log::info!(
"Warmed up model '{}' with {} provider",
model_id,
self.execution_provider
);
Ok(())
}
fn load_model(&mut self, path: &str) -> AdapterResult<()> {
self.validate_model_file(path)?;
let model_id = self.extract_model_id(path);
if self.models.contains_key(&model_id) {
log::warn!("Model '{}' is already loaded, skipping reload", model_id);
return Ok(());
}
let session = ONNXSession::build(
path,
self.execution_provider,
SessionOptions {
capture_resolved_ep: true,
},
)?;
log::info!(
"Loaded model '{}' with {} execution provider",
model_id,
self.execution_provider
);
let input_shapes = session.input_shapes();
let output_shapes = session.output_shapes();
let input_names = session.input_names();
let output_names = session.output_names();
let mut input_schema = HashMap::new();
for (i, name) in input_names.iter().enumerate() {
if let Some(shape) = input_shapes.get(i) {
input_schema.insert(name.clone(), shape.iter().map(|&s| s as u64).collect());
}
}
let mut output_schema = HashMap::new();
for (i, name) in output_names.iter().enumerate() {
if let Some(shape) = output_shapes.get(i) {
output_schema.insert(name.clone(), shape.iter().map(|&s| s as u64).collect());
}
}
let metadata = ModelMetadata {
model_id: model_id.clone(),
version: "1.0.0".to_string(), runtime_type: "onnx".to_string(),
model_path: path.to_string(),
input_schema,
output_schema,
};
self.sessions.insert(model_id.clone(), session);
self.models.insert(model_id.clone(), metadata);
self.current_model = Some(model_id);
Ok(())
}
fn execute(&self, input: &Envelope) -> AdapterResult<Envelope> {
let model_id = self.current_model.as_ref().ok_or_else(|| {
AdapterError::ModelNotLoaded("No model loaded. Call load_model() first.".to_string())
})?;
let session = self.sessions.get(model_id).ok_or_else(|| {
AdapterError::ModelNotLoaded(format!("Session for model '{}' not found", model_id))
})?;
self.real_inference(session, input)
}
}
impl RuntimeAdapterExt for OnnxRuntimeAdapter {
fn is_loaded(&self, model_id: &str) -> bool {
self.models.contains_key(model_id)
}
fn get_metadata(&self, model_id: &str) -> AdapterResult<&ModelMetadata> {
self.models.get(model_id).ok_or_else(|| {
AdapterError::ModelNotLoaded(format!("Model '{}' is not loaded", model_id))
})
}
fn infer(&self, model_id: &str, input: &Envelope) -> AdapterResult<Envelope> {
if !self.is_loaded(model_id) {
return Err(AdapterError::ModelNotLoaded(format!(
"Model '{}' is not loaded. Call load_model() first.",
model_id
)));
}
let session = self.sessions.get(model_id).ok_or_else(|| {
AdapterError::ModelNotLoaded(format!("Session for model '{}' not found", model_id))
})?;
self.real_inference(session, input)
}
fn unload_model(&mut self, model_id: &str) -> AdapterResult<()> {
if !self.models.contains_key(model_id) {
return Err(AdapterError::ModelNotLoaded(format!(
"Model '{}' is not loaded",
model_id
)));
}
self.sessions.remove(model_id);
self.models.remove(model_id);
if self.current_model.as_ref() == Some(&model_id.to_string()) {
self.current_model = None;
}
Ok(())
}
fn list_loaded_models(&self) -> Vec<String> {
self.models.keys().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::EnvelopeKind;
use std::fs;
use tempfile::TempDir;
fn create_mock_onnx_file() -> (TempDir, String) {
let temp_dir = TempDir::new().unwrap();
let model_path = temp_dir.path().join("test_model.onnx");
fs::write(&model_path, b"fake onnx model data").unwrap();
(temp_dir, model_path.to_string_lossy().to_string())
}
#[test]
fn test_create_adapter() {
let adapter = OnnxRuntimeAdapter::new();
assert!(adapter.list_loaded_models().is_empty());
assert_eq!(*adapter.execution_provider(), ExecutionProviderKind::Cpu);
}
#[test]
fn test_adapter_with_execution_provider() {
let adapter = OnnxRuntimeAdapter::with_execution_provider(ExecutionProviderKind::Cpu);
assert_eq!(*adapter.execution_provider(), ExecutionProviderKind::Cpu);
}
#[test]
fn test_adapter_name() {
let adapter = OnnxRuntimeAdapter::new();
assert_eq!(adapter.name(), "onnx");
}
#[test]
fn test_supported_formats() {
let adapter = OnnxRuntimeAdapter::new();
let formats = adapter.supported_formats();
assert!(formats.contains(&"onnx"));
assert!(formats.contains(&"onnx.gz"));
}
#[test]
fn test_load_model_not_found() {
let mut adapter = OnnxRuntimeAdapter::new();
let result = adapter.load_model("/nonexistent/model.onnx");
assert!(matches!(result, Err(AdapterError::ModelNotFound(_))));
}
#[test]
fn test_execute_no_model_loaded() {
let adapter = OnnxRuntimeAdapter::new();
let input = Envelope::new(EnvelopeKind::Text("test".to_string()));
let result = adapter.execute(&input);
assert!(matches!(result, Err(AdapterError::ModelNotLoaded(_))));
}
#[test]
fn test_infer_model_not_loaded() {
let adapter = OnnxRuntimeAdapter::new();
let input = Envelope::new(EnvelopeKind::Text("test".to_string()));
let result = adapter.infer("nonexistent-model", &input);
assert!(matches!(result, Err(AdapterError::ModelNotLoaded(_))));
}
#[test]
#[allow(deprecated)]
fn test_select_optimal_provider() {
let _provider = OnnxRuntimeAdapter::select_optimal_provider();
#[cfg(not(all(feature = "ort-coreml", any(target_os = "macos", target_os = "ios"))))]
assert_eq!(_provider, ExecutionProviderKind::Cpu);
}
#[test]
fn test_warmup_no_model_loaded() {
let mut adapter = OnnxRuntimeAdapter::new();
let result = adapter.warmup();
assert!(matches!(result, Err(AdapterError::ModelNotLoaded(_))));
}
}