use crate::error::{CvError, CvResult};
use crate::ml::tensor::Tensor;
use ort::session::builder::{GraphOptimizationLevel, SessionBuilder};
use ort::session::Session as OrtSession;
use std::path::Path;
use std::sync::Arc;
#[cfg(feature = "cuda")]
use ort::execution_providers::CUDAExecutionProvider;
#[cfg(target_os = "macos")]
use ort::execution_providers::CoreMLExecutionProvider;
#[cfg(target_os = "windows")]
use ort::execution_providers::DirectMLExecutionProvider;
#[cfg(feature = "rocm")]
use ort::execution_providers::ROCmExecutionProvider;
#[cfg(feature = "tensorrt")]
use ort::execution_providers::TensorRTExecutionProvider;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DeviceType {
#[default]
Cpu,
Cuda,
Rocm,
TensorRt,
DirectMl,
CoreMl,
}
#[derive(Clone)]
pub struct OnnxRuntime {
device: DeviceType,
}
impl OnnxRuntime {
pub fn new() -> CvResult<Self> {
Self::with_device(DeviceType::Cpu)
}
pub fn with_device(device: DeviceType) -> CvResult<Self> {
Ok(Self { device })
}
pub fn load_model(&self, path: impl AsRef<Path>) -> CvResult<Session> {
let path = path.as_ref();
let session = self.build_session(path)?;
Ok(Session {
inner: Arc::new(std::sync::Mutex::new(session)),
})
}
pub fn load_model_from_bytes(&self, bytes: &[u8]) -> CvResult<Session> {
let session = self.build_session_from_bytes(bytes)?;
Ok(Session {
inner: Arc::new(std::sync::Mutex::new(session)),
})
}
#[must_use]
pub const fn device(&self) -> DeviceType {
self.device
}
fn build_session(&self, path: &Path) -> CvResult<OrtSession> {
let mut builder = OrtSession::builder()
.map_err(|e| CvError::onnx_runtime(format!("Failed to create session builder: {e}")))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| CvError::onnx_runtime(format!("Failed to set optimization level: {e}")))?;
builder = self.configure_execution_provider(builder)?;
builder
.commit_from_file(path)
.map_err(|e| CvError::model_load(format!("Failed to load model from file: {e}")))
}
fn build_session_from_bytes(&self, bytes: &[u8]) -> CvResult<OrtSession> {
let mut builder = OrtSession::builder()
.map_err(|e| CvError::onnx_runtime(format!("Failed to create session builder: {e}")))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| CvError::onnx_runtime(format!("Failed to set optimization level: {e}")))?;
builder = self.configure_execution_provider(builder)?;
builder
.commit_from_memory(bytes)
.map_err(|e| CvError::model_load(format!("Failed to load model from bytes: {e}")))
}
#[allow(unused_variables)]
fn configure_execution_provider(&self, builder: SessionBuilder) -> CvResult<SessionBuilder> {
let builder = match self.device {
DeviceType::Cpu => builder,
#[cfg(feature = "cuda")]
DeviceType::Cuda => builder
.with_execution_providers([CUDAExecutionProvider::default().build()])
.map_err(|e| {
CvError::onnx_runtime(format!("Failed to configure CUDA provider: {e}"))
})?,
#[cfg(not(feature = "cuda"))]
DeviceType::Cuda => {
return Err(CvError::onnx_runtime(
"CUDA support not compiled in".to_owned(),
))
}
#[cfg(feature = "rocm")]
DeviceType::Rocm => builder
.with_execution_providers([ROCmExecutionProvider::default().build()])
.map_err(|e| {
CvError::onnx_runtime(format!("Failed to configure ROCm provider: {e}"))
})?,
#[cfg(not(feature = "rocm"))]
DeviceType::Rocm => {
return Err(CvError::onnx_runtime(
"ROCm support not compiled in".to_owned(),
))
}
#[cfg(feature = "tensorrt")]
DeviceType::TensorRt => builder
.with_execution_providers([TensorRTExecutionProvider::default().build()])
.map_err(|e| {
CvError::onnx_runtime(format!("Failed to configure TensorRT provider: {e}"))
})?,
#[cfg(not(feature = "tensorrt"))]
DeviceType::TensorRt => {
return Err(CvError::onnx_runtime(
"TensorRT support not compiled in".to_owned(),
))
}
#[cfg(target_os = "windows")]
DeviceType::DirectMl => builder
.with_execution_providers([DirectMLExecutionProvider::default().build()])
.map_err(|e| {
CvError::onnx_runtime(format!("Failed to configure DirectML provider: {e}"))
})?,
#[cfg(not(target_os = "windows"))]
DeviceType::DirectMl => {
return Err(CvError::onnx_runtime(
"DirectML is only available on Windows".to_owned(),
))
}
#[cfg(target_os = "macos")]
DeviceType::CoreMl => builder
.with_execution_providers([CoreMLExecutionProvider::default().build()])
.map_err(|e| {
CvError::onnx_runtime(format!("Failed to configure CoreML provider: {e}"))
})?,
#[cfg(not(target_os = "macos"))]
DeviceType::CoreMl => {
return Err(CvError::onnx_runtime(
"CoreML is only available on macOS".to_owned(),
))
}
};
Ok(builder)
}
}
impl Default for OnnxRuntime {
fn default() -> Self {
Self {
device: DeviceType::Cpu,
}
}
}
#[derive(Clone)]
pub struct Session {
inner: Arc<std::sync::Mutex<OrtSession>>,
}
impl Session {
pub fn run(&self, inputs: &[Tensor]) -> CvResult<Vec<Tensor>> {
let ort_inputs: Vec<ort::value::DynValue> = inputs
.iter()
.map(super::tensor::Tensor::to_ort_value)
.collect::<CvResult<Vec<_>>>()?;
let mut session = self
.inner
.lock()
.map_err(|e| CvError::onnx_runtime(format!("Session lock error: {e}")))?;
let input_names: Vec<String> = session
.inputs()
.iter()
.map(|i| i.name().to_string())
.collect();
if ort_inputs.len() > input_names.len() {
return Err(CvError::onnx_runtime(format!(
"Too many inputs: got {}, model expects {}",
ort_inputs.len(),
input_names.len()
)));
}
let named_inputs: Vec<(String, ort::value::DynValue)> =
input_names.into_iter().zip(ort_inputs).collect();
let outputs = session
.run(named_inputs)
.map_err(|e| CvError::onnx_runtime(format!("Inference failed: {e}")))?;
outputs
.values()
.map(|value| Tensor::from_ort_value(&value))
.collect()
}
#[must_use]
pub fn input_names(&self) -> Vec<String> {
self.inner.lock().map_or_else(
|_| Vec::new(),
|s| s.inputs().iter().map(|i| i.name().to_string()).collect(),
)
}
#[must_use]
pub fn output_names(&self) -> Vec<String> {
self.inner.lock().map_or_else(
|_| Vec::new(),
|s| s.outputs().iter().map(|o| o.name().to_string()).collect(),
)
}
#[must_use]
pub fn input_count(&self) -> usize {
self.inner.lock().map_or(0, |s| s.inputs().len())
}
#[must_use]
pub fn output_count(&self) -> usize {
self.inner.lock().map_or(0, |s| s.outputs().len())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_type_default() {
assert_eq!(DeviceType::default(), DeviceType::Cpu);
}
#[test]
fn test_onnx_runtime_new() {
let runtime = OnnxRuntime::new();
assert!(runtime.is_ok());
assert_eq!(
runtime.expect("value should be valid").device(),
DeviceType::Cpu
);
}
#[test]
fn test_onnx_runtime_with_device() {
let runtime = OnnxRuntime::with_device(DeviceType::Cuda);
assert!(runtime.is_ok());
assert_eq!(
runtime.expect("value should be valid").device(),
DeviceType::Cuda
);
}
#[test]
fn test_onnx_runtime_default() {
let runtime = OnnxRuntime::default();
assert_eq!(runtime.device(), DeviceType::Cpu);
}
}