use crate::error::{CvError, CvResult};
use crate::ml::tensor::Tensor;
use std::path::Path;
use std::sync::Arc;
#[cfg(feature = "onnx")]
use oxionnx::Session as OxiSession;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DeviceType {
#[default]
Cpu,
Cuda,
Rocm,
TensorRt,
DirectMl,
CoreMl,
WebGpu,
}
#[derive(Clone)]
pub struct OnnxRuntime {
device: DeviceType,
}
impl OnnxRuntime {
pub fn new() -> CvResult<Self> {
Self::with_device(DeviceType::Cpu)
}
pub fn with_device(device: DeviceType) -> CvResult<Self> {
Ok(Self { device })
}
#[cfg(feature = "onnx")]
pub fn load_model(&self, path: impl AsRef<Path>) -> CvResult<Session> {
let path = path.as_ref();
let session = self.build_session(path)?;
Ok(Session {
inner: Arc::new(std::sync::Mutex::new(session)),
})
}
#[cfg(feature = "onnx")]
pub fn load_model_from_bytes(&self, bytes: &[u8]) -> CvResult<Session> {
let session = self.build_session_from_bytes(bytes)?;
Ok(Session {
inner: Arc::new(std::sync::Mutex::new(session)),
})
}
#[must_use]
pub const fn device(&self) -> DeviceType {
self.device
}
#[cfg(feature = "onnx")]
fn build_session(&self, path: &Path) -> CvResult<OxiSession> {
self.check_device_support()?;
oxionnx::Session::builder()
.with_optimization_level(oxionnx::OptLevel::All)
.load(path)
.map_err(|e| CvError::model_load(format!("Failed to load model from file: {e}")))
}
#[cfg(feature = "onnx")]
fn build_session_from_bytes(&self, bytes: &[u8]) -> CvResult<OxiSession> {
self.check_device_support()?;
oxionnx::Session::builder()
.with_optimization_level(oxionnx::OptLevel::All)
.load_from_bytes(bytes)
.map_err(|e| CvError::model_load(format!("Failed to load model from bytes: {e}")))
}
#[cfg(feature = "onnx")]
fn check_device_support(&self) -> CvResult<()> {
match self.device {
DeviceType::Cpu => Ok(()),
#[cfg(feature = "cuda")]
DeviceType::Cuda => Ok(()),
#[cfg(not(feature = "cuda"))]
DeviceType::Cuda => Err(CvError::onnx_runtime(
"CUDA support requires the 'cuda' feature".to_owned(),
)),
#[cfg(feature = "rocm")]
DeviceType::Rocm => {
Ok(())
}
#[cfg(not(feature = "rocm"))]
DeviceType::Rocm => Err(CvError::onnx_runtime(
"ROCm support requires the 'rocm' feature (currently CPU-only fallback; \
no oxionnx ROCm execution provider is available yet)"
.to_owned(),
)),
#[cfg(feature = "tensorrt")]
DeviceType::TensorRt => Ok(()),
#[cfg(not(feature = "tensorrt"))]
DeviceType::TensorRt => Err(CvError::onnx_runtime(
"TensorRT support requires the 'tensorrt' feature".to_owned(),
)),
#[cfg(feature = "directml")]
DeviceType::DirectMl => Ok(()),
#[cfg(not(feature = "directml"))]
DeviceType::DirectMl => Err(CvError::onnx_runtime(
"DirectML support requires the 'directml' feature".to_owned(),
)),
#[cfg(target_os = "macos")]
DeviceType::CoreMl => Ok(()),
#[cfg(not(target_os = "macos"))]
DeviceType::CoreMl => Err(CvError::onnx_runtime(
"CoreML is only available on macOS".to_owned(),
)),
#[cfg(feature = "webgpu")]
DeviceType::WebGpu => Ok(()),
#[cfg(not(feature = "webgpu"))]
DeviceType::WebGpu => Err(CvError::onnx_runtime(
"WebGPU support requires the 'webgpu' feature".to_owned(),
)),
}
}
}
impl Default for OnnxRuntime {
fn default() -> Self {
Self {
device: DeviceType::Cpu,
}
}
}
#[cfg(feature = "onnx")]
#[derive(Clone)]
pub struct Session {
inner: Arc<std::sync::Mutex<OxiSession>>,
}
#[cfg(feature = "onnx")]
impl Session {
pub fn run(&self, inputs: &[Tensor]) -> CvResult<Vec<Tensor>> {
let oxi_inputs: Vec<oxionnx::Tensor> = inputs
.iter()
.map(super::tensor::Tensor::to_oxionnx_tensor)
.collect::<CvResult<Vec<_>>>()?;
let session = self
.inner
.lock()
.map_err(|e| CvError::onnx_runtime(format!("Session lock error: {e}")))?;
let input_names: Vec<String> = session.input_names().to_vec();
if oxi_inputs.len() > input_names.len() {
return Err(CvError::onnx_runtime(format!(
"Too many inputs: got {}, model expects {}",
oxi_inputs.len(),
input_names.len()
)));
}
let mut named_inputs = std::collections::HashMap::new();
for (name, tensor) in input_names.iter().zip(oxi_inputs) {
named_inputs.insert(name.as_str(), tensor);
}
let outputs = session
.run(&named_inputs)
.map_err(|e| CvError::onnx_runtime(format!("Inference failed: {e}")))?;
let output_names = session.output_names().to_vec();
output_names
.iter()
.filter_map(|name| outputs.get(name))
.map(super::tensor::Tensor::from_oxionnx_tensor)
.collect()
}
#[must_use]
pub fn input_names(&self) -> Vec<String> {
self.inner
.lock()
.map_or_else(|_| Vec::new(), |s| s.input_names().to_vec())
}
#[must_use]
pub fn output_names(&self) -> Vec<String> {
self.inner
.lock()
.map_or_else(|_| Vec::new(), |s| s.output_names().to_vec())
}
#[must_use]
pub fn input_count(&self) -> usize {
self.inner.lock().map_or(0, |s| s.input_names().len())
}
#[must_use]
pub fn output_count(&self) -> usize {
self.inner.lock().map_or(0, |s| s.output_names().len())
}
}
#[cfg(not(feature = "onnx"))]
#[derive(Clone)]
pub struct Session {
_private: (),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_type_default() {
assert_eq!(DeviceType::default(), DeviceType::Cpu);
}
#[test]
fn test_onnx_runtime_new() {
let runtime = OnnxRuntime::new();
assert!(runtime.is_ok());
assert_eq!(
runtime.expect("value should be valid").device(),
DeviceType::Cpu
);
}
#[test]
fn test_onnx_runtime_with_device() {
let runtime = OnnxRuntime::with_device(DeviceType::Cuda);
assert!(runtime.is_ok());
assert_eq!(
runtime.expect("value should be valid").device(),
DeviceType::Cuda
);
#[cfg(not(feature = "webgpu"))]
{
let result = OnnxRuntime::with_device(DeviceType::WebGpu);
assert!(result.is_err(), "WebGPU should require 'webgpu' feature");
}
#[cfg(not(feature = "directml"))]
{
let result = OnnxRuntime::with_device(DeviceType::DirectMl);
assert!(
result.is_err(),
"DirectML should require 'directml' feature"
);
}
}
#[test]
fn test_onnx_runtime_default() {
let runtime = OnnxRuntime::default();
assert_eq!(runtime.device(), DeviceType::Cpu);
}
}