use super::execution_provider::ExecutionProviderKind;
use super::profiling::{parse_profile_json, ResolvedExecutionProviders};
use crate::runtime_adapter::{AdapterError, AdapterResult};
use ndarray::{ArrayD, IxDyn};
use ort::session::{builder::GraphOptimizationLevel, Session};
use ort::tensor::TensorElementType;
use ort::value::Value;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use tempfile::TempDir;
type InputMetadata = (Vec<String>, Vec<Vec<i64>>, Vec<Option<TensorElementType>>);
#[non_exhaustive]
#[derive(Debug, Default, Clone, Copy)]
pub struct SessionOptions {
pub capture_resolved_ep: bool,
}
enum ResolvedEpState {
Disabled,
Pending {
_tempdir: TempDir,
},
Harvested(ResolvedExecutionProviders),
Failed(String),
}
pub struct ONNXSession {
session: Mutex<Session>,
input_names: Vec<String>,
output_names: Vec<String>,
input_shapes: Vec<Vec<i64>>,
output_shapes: Vec<Vec<i64>>,
input_dtypes: Vec<Option<TensorElementType>>,
execution_provider: ExecutionProviderKind,
resolved_state: Mutex<ResolvedEpState>,
}
impl ONNXSession {
pub fn build(
model_path: &str,
execution_provider: ExecutionProviderKind,
options: SessionOptions,
) -> AdapterResult<Self> {
let path = Path::new(model_path);
if !path.exists() {
return Err(AdapterError::ModelNotFound(format!(
"Model file not found: {}",
model_path
)));
}
let _ = ort::init().commit();
let mut builder = Session::builder()
.map_err(|e| {
AdapterError::RuntimeError(format!("Failed to create session builder: {}", e))
})?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| {
AdapterError::RuntimeError(format!("Failed to set optimization level: {}", e))
})?;
let resolved_state = if options.capture_resolved_ep {
let tempdir = tempfile::tempdir().map_err(|e| {
AdapterError::RuntimeError(format!(
"Failed to create profile tempdir for resolved-EP capture: {}",
e
))
})?;
let profile_prefix: PathBuf = tempdir.path().join("xybrid-profile");
builder = builder.with_profiling(&profile_prefix).map_err(|e| {
AdapterError::RuntimeError(format!(
"Failed to enable profiling for resolved-EP capture: {}",
e
))
})?;
ResolvedEpState::Pending { _tempdir: tempdir }
} else {
ResolvedEpState::Disabled
};
builder = Self::configure_execution_provider(builder, &execution_provider)?;
let session = builder
.commit_from_file(model_path)
.map_err(|e| AdapterError::RuntimeError(format!("Failed to load ONNX model: {}", e)))?;
let (input_names, input_shapes, input_dtypes) = Self::extract_input_metadata(&session)?;
let (output_names, output_shapes) = Self::extract_output_metadata(&session)?;
log::info!(
"Created ONNX session with {} execution provider for model: {} (capture_resolved_ep={})",
execution_provider,
model_path,
options.capture_resolved_ep,
);
Ok(Self {
session: Mutex::new(session),
input_names,
output_names,
input_shapes,
output_shapes,
input_dtypes,
execution_provider,
resolved_state: Mutex::new(resolved_state),
})
}
fn configure_execution_provider(
builder: ort::session::builder::SessionBuilder,
provider: &ExecutionProviderKind,
) -> AdapterResult<ort::session::builder::SessionBuilder> {
match provider {
ExecutionProviderKind::Cpu => {
Ok(builder)
}
#[cfg(feature = "ort-coreml")]
ExecutionProviderKind::CoreML(config) => {
use super::execution_provider::CoreMLComputeUnits;
use ort::ep;
let coreml_ep = {
let mut coreml = ep::CoreML::default();
coreml = coreml.with_subgraphs(config.use_subgraphs);
coreml = coreml.with_compute_units(match config.compute_units {
CoreMLComputeUnits::CpuOnly => ep::coreml::ComputeUnits::CPUOnly,
CoreMLComputeUnits::CpuAndGpu => ep::coreml::ComputeUnits::CPUAndGPU,
CoreMLComputeUnits::CpuAndNeuralEngine => {
ep::coreml::ComputeUnits::CPUAndNeuralEngine
}
CoreMLComputeUnits::All => ep::coreml::ComputeUnits::All,
});
coreml.build()
};
log::debug!("Configuring CoreML execution provider: {:?}", config);
builder.with_execution_providers([coreml_ep]).map_err(|e| {
AdapterError::RuntimeError(format!(
"Failed to configure CoreML execution provider: {}",
e
))
})
}
}
}
fn extract_input_metadata(session: &Session) -> AdapterResult<InputMetadata> {
let mut input_names = Vec::new();
let mut input_shapes = Vec::new();
let mut input_dtypes = Vec::new();
for input in session.inputs() {
input_names.push(input.name().to_string());
if let Some(shape) = input.dtype().tensor_shape() {
input_shapes.push(shape.iter().copied().collect());
} else {
input_shapes.push(vec![-1]);
}
input_dtypes.push(input.dtype().tensor_type());
}
if input_names.is_empty() {
input_names.push("input".to_string());
input_shapes.push(vec![1, 1, 16000]); input_dtypes.push(None);
}
Ok((input_names, input_shapes, input_dtypes))
}
fn extract_output_metadata(session: &Session) -> AdapterResult<(Vec<String>, Vec<Vec<i64>>)> {
let mut output_names = Vec::new();
let mut output_shapes = Vec::new();
for output in session.outputs() {
output_names.push(output.name().to_string());
output_shapes.push(vec![-1]); }
if output_names.is_empty() {
output_names.push("output".to_string());
output_shapes.push(vec![1, 512]); }
Ok((output_names, output_shapes))
}
pub fn run(
&self,
inputs: HashMap<String, ArrayD<f32>>,
) -> AdapterResult<HashMap<String, ArrayD<f32>>> {
let value_inputs: HashMap<String, Value> = inputs
.into_iter()
.map(|(k, v)| {
Ok((
k,
Value::from_array(v)
.map_err(|e| {
AdapterError::RuntimeError(format!("Failed to convert tensor: {}", e))
})?
.into(),
))
})
.collect::<AdapterResult<_>>()?;
self.run_with_values(value_inputs)
}
pub fn run_with_values(
&self,
inputs: HashMap<String, Value>,
) -> AdapterResult<HashMap<String, ArrayD<f32>>> {
use ort::session::SessionInputs;
let mut session_guard = self
.session
.lock()
.map_err(|e| AdapterError::RuntimeError(format!("Failed to lock session: {}", e)))?;
let ort_inputs: Vec<(
std::borrow::Cow<'_, str>,
ort::session::SessionInputValue<'_>,
)> = inputs
.into_iter()
.map(|(name, value)| (std::borrow::Cow::Owned(name), value.into()))
.collect();
let outputs = session_guard
.run(SessionInputs::from(ort_inputs))
.map_err(|e| {
AdapterError::InferenceFailed(format!("ONNX Runtime inference failed: {}", e))
})?;
let mut result = HashMap::new();
for output_name in &self.output_names {
let output_value = &outputs[output_name.as_str()];
let array_d = if let Ok(output_array) = output_value.try_extract_array::<f32>() {
let shape = output_array.shape();
let dims: Vec<usize> = shape.to_vec();
let owned_array = output_array.to_owned();
let data: Vec<f32> = owned_array.as_slice().unwrap().to_vec();
ArrayD::from_shape_vec(IxDyn(&dims), data).map_err(|e| {
AdapterError::RuntimeError(format!("Failed to convert output to ArrayD: {}", e))
})?
} else if let Ok(output_array) = output_value.try_extract_array::<i64>() {
let shape = output_array.shape();
let dims: Vec<usize> = shape.to_vec();
let owned_array = output_array.to_owned();
let data: Vec<f32> = owned_array
.as_slice()
.unwrap()
.iter()
.map(|&x| x as f32)
.collect();
ArrayD::from_shape_vec(IxDyn(&dims), data).map_err(|e| {
AdapterError::RuntimeError(format!("Failed to convert output to ArrayD: {}", e))
})?
} else {
return Err(AdapterError::RuntimeError(format!(
"Failed to extract output '{}': unsupported type (expected f32 or i64)",
output_name
)));
};
result.insert(output_name.clone(), array_d);
}
drop(outputs);
self.maybe_harvest_resolved_ep(&mut session_guard);
Ok(result)
}
pub fn input_names(&self) -> &[String] {
&self.input_names
}
pub fn output_names(&self) -> &[String] {
&self.output_names
}
pub fn input_shapes(&self) -> &[Vec<i64>] {
&self.input_shapes
}
pub fn input_dtypes(&self) -> &[Option<TensorElementType>] {
&self.input_dtypes
}
pub fn output_shapes(&self) -> &[Vec<i64>] {
&self.output_shapes
}
pub fn execution_provider(&self) -> &ExecutionProviderKind {
&self.execution_provider
}
pub fn resolved_providers(&self) -> Option<ResolvedExecutionProviders> {
let state = self.resolved_state.lock().ok()?;
match &*state {
ResolvedEpState::Harvested(summary) => Some(summary.clone()),
_ => None,
}
}
#[doc(hidden)]
pub fn resolved_state_debug(&self) -> String {
match self.resolved_state.lock() {
Ok(state) => match &*state {
ResolvedEpState::Disabled => "Disabled".into(),
ResolvedEpState::Pending { .. } => "Pending".into(),
ResolvedEpState::Harvested(s) => format!("Harvested({s:?})"),
ResolvedEpState::Failed(e) => format!("Failed({e})"),
},
Err(e) => format!("MutexPoisoned({e})"),
}
}
fn maybe_harvest_resolved_ep(&self, session_guard: &mut Session) {
let mut state = match self.resolved_state.lock() {
Ok(g) => g,
Err(e) => {
log::warn!("resolved-EP state mutex poisoned: {e}");
return;
}
};
if !matches!(*state, ResolvedEpState::Pending { .. }) {
return;
}
let next = match session_guard.end_profiling() {
Ok(profile_path) => {
let path = std::path::Path::new(&profile_path);
match parse_profile_json(path) {
Ok(summary) => {
log::debug!(
"Resolved EP for ONNX session: primary={}, breakdown={:?}",
summary.primary,
summary.breakdown
);
ResolvedEpState::Harvested(summary)
}
Err(parse_err) => {
log::warn!("Failed to parse ONNX profile {profile_path}: {parse_err}");
ResolvedEpState::Failed(parse_err.to_string())
}
}
}
Err(end_err) => {
log::warn!("Failed to end ONNX profiling: {end_err}");
ResolvedEpState::Failed(end_err.to_string())
}
};
*state = next;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::path::PathBuf;
use tempfile::TempDir;
#[test]
fn test_session_creation_fails_on_nonexistent_file() {
let result = ONNXSession::build(
"/nonexistent/model.onnx",
ExecutionProviderKind::Cpu,
SessionOptions::default(),
);
assert!(matches!(result, Err(AdapterError::ModelNotFound(_))));
}
#[test]
fn test_session_creation_with_mock_file() {
let temp_dir = TempDir::new().unwrap();
let model_path = temp_dir.path().join("test_model.onnx");
fs::write(&model_path, b"fake onnx data").unwrap();
let result = ONNXSession::build(
model_path.to_str().unwrap(),
ExecutionProviderKind::Cpu,
SessionOptions::default(),
);
match result {
Ok(_) => {
let session = result.unwrap();
assert!(!session.input_names().is_empty());
assert!(!session.output_names().is_empty());
}
Err(e) => {
println!("Expected error (invalid ONNX format): {:?}", e);
}
}
}
#[test]
fn test_mnist_model_loading() {
let possible_paths = vec![
PathBuf::from("test_models/mnist-12.onnx"),
PathBuf::from("../test_models/mnist-12.onnx"),
PathBuf::from("../../test_models/mnist-12.onnx"),
];
let model_path = possible_paths.iter().find(|p| p.exists()).cloned();
let model_path = match model_path {
Some(p) => p,
None => {
println!(
"MNIST model not found, skipping test. Tried: {:?}",
possible_paths
);
return;
}
};
let result = ONNXSession::build(
model_path.to_str().unwrap(),
ExecutionProviderKind::Cpu,
SessionOptions::default(),
);
assert!(
result.is_ok(),
"Failed to load MNIST model: {:?}",
result.err()
);
let session = result.unwrap();
let input_names = session.input_names();
let output_names = session.output_names();
println!("MNIST Input names: {:?}", input_names);
println!("MNIST Output names: {:?}", output_names);
println!("MNIST Input shapes: {:?}", session.input_shapes());
println!("MNIST Output shapes: {:?}", session.output_shapes());
assert!(!input_names.is_empty(), "Should have at least one input");
assert!(!output_names.is_empty(), "Should have at least one output");
assert_ne!(
input_names[0], "input",
"Should have real input name, not placeholder"
);
assert_ne!(
output_names[0], "output",
"Should have real output name, not placeholder"
);
}
#[test]
fn test_mnist_inference() {
let possible_paths = vec![
PathBuf::from("test_models/mnist-12.onnx"),
PathBuf::from("../test_models/mnist-12.onnx"),
PathBuf::from("../../test_models/mnist-12.onnx"),
];
let model_path = possible_paths.iter().find(|p| p.exists()).cloned();
let model_path = match model_path {
Some(p) => p,
None => {
println!(
"MNIST model not found, skipping test. Tried: {:?}",
possible_paths
);
return;
}
};
let session = ONNXSession::build(
model_path.to_str().unwrap(),
ExecutionProviderKind::Cpu,
SessionOptions::default(),
)
.expect("Failed to load MNIST model");
let input_names = session.input_names();
let input_name = &input_names[0];
let mut inputs = HashMap::new();
let input_tensor = ArrayD::<f32>::from_shape_vec(
IxDyn(&[1, 1, 28, 28]),
vec![0.0f32; 784], )
.unwrap();
inputs.insert(input_name.clone(), input_tensor);
let result = session.run(inputs);
assert!(result.is_ok(), "Inference failed: {:?}", result.err());
let outputs = result.unwrap();
assert!(!outputs.is_empty(), "Should have at least one output");
let output_names = session.output_names();
let output_name = &output_names[0];
assert!(
outputs.contains_key(output_name),
"Output should contain expected output name"
);
let output_tensor = outputs.get(output_name).unwrap();
println!("MNIST Output shape: {:?}", output_tensor.shape());
println!("MNIST Output size: {}", output_tensor.len());
assert_eq!(
output_tensor.shape(),
&[1, 10],
"MNIST should output shape [1, 10]"
);
assert_eq!(
output_tensor.len(),
10,
"MNIST output should have 10 elements"
);
}
#[test]
fn resolved_providers_returns_none_when_capture_disabled() {
let result = ONNXSession::build(
"/nonexistent/model.onnx",
ExecutionProviderKind::Cpu,
SessionOptions::default(),
);
assert!(matches!(result, Err(AdapterError::ModelNotFound(_))));
}
#[test]
fn resolved_providers_populates_after_first_inference() {
let possible_paths = [
PathBuf::from("test_models/mnist-12.onnx"),
PathBuf::from("../test_models/mnist-12.onnx"),
PathBuf::from("../../test_models/mnist-12.onnx"),
];
let model_path = match possible_paths.iter().find(|p| p.exists()) {
Some(p) => p.clone(),
None => {
eprintln!("MNIST model not found; skipping resolved-EP capture test.");
return;
}
};
let session = ONNXSession::build(
model_path.to_str().unwrap(),
ExecutionProviderKind::Cpu,
SessionOptions {
capture_resolved_ep: true,
},
)
.expect("Failed to load MNIST model with resolved-EP capture enabled");
assert!(
session.resolved_providers().is_none(),
"resolved_providers() should be None before the first inference"
);
let input_names = session.input_names();
let input_name = &input_names[0];
let mut inputs = HashMap::new();
let input_tensor =
ArrayD::<f32>::from_shape_vec(IxDyn(&[1, 1, 28, 28]), vec![0.0f32; 784]).unwrap();
inputs.insert(input_name.clone(), input_tensor);
session.run(inputs).expect("MNIST inference must succeed");
let summary = session.resolved_providers().unwrap_or_else(|| {
panic!(
"resolved_providers() must populate after the first inference; \
actual state: {}",
session.resolved_state_debug()
)
});
assert!(
!summary.primary.is_empty(),
"primary EP should be a non-empty string; got {:?}",
summary
);
assert!(
!summary.breakdown.is_empty(),
"breakdown should list at least one EP; got {:?}",
summary
);
let total_ops: usize = summary.breakdown.iter().map(|(_, n)| *n).sum();
assert!(
total_ops >= 1,
"breakdown should account for at least one Node event; got {:?}",
summary
);
if cfg!(not(feature = "ort-coreml")) {
assert_eq!(
summary.primary, "cpu",
"non-CoreML build should resolve to CPU; got {:?}",
summary
);
}
}
}