use std::path::{Path, PathBuf};
use crate::device::DeviceType;
use crate::error::MlResult;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum TensorDType {
F32,
F16,
F64,
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64,
Bool,
}
impl TensorDType {
#[must_use]
pub fn name(self) -> &'static str {
match self {
Self::F32 => "f32",
Self::F16 => "f16",
Self::F64 => "f64",
Self::I8 => "i8",
Self::I16 => "i16",
Self::I32 => "i32",
Self::I64 => "i64",
Self::U8 => "u8",
Self::U16 => "u16",
Self::U32 => "u32",
Self::U64 => "u64",
Self::Bool => "bool",
}
}
}
#[derive(Clone, Debug)]
pub struct TensorSpec {
pub name: String,
pub dtype: TensorDType,
pub shape: Vec<Option<i64>>,
}
impl TensorSpec {
#[must_use]
pub fn new(name: impl Into<String>, dtype: TensorDType, shape: Vec<Option<i64>>) -> Self {
Self {
name: name.into(),
dtype,
shape,
}
}
#[must_use]
pub fn dynamic_rank(&self) -> usize {
self.shape.iter().filter(|d| d.is_none()).count()
}
}
#[derive(Clone, Debug, Default)]
pub struct ModelInfo {
pub path: PathBuf,
pub inputs: Vec<TensorSpec>,
pub outputs: Vec<TensorSpec>,
pub producer: Option<String>,
pub opset_version: Option<i64>,
}
#[cfg(feature = "onnx")]
mod imp {
use super::{ModelInfo, TensorDType, TensorSpec};
use crate::device::DeviceType;
use crate::error::{MlError, MlResult};
use oxionnx::graph::TensorInfo;
use oxionnx::DType;
use oxionnx::{OptLevel, Session, SessionBuilder, Tensor};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
pub struct OnnxModel {
session: Mutex<Session>,
info: ModelInfo,
device: DeviceType,
}
impl std::fmt::Debug for OnnxModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OnnxModel")
.field("device", &self.device)
.field("info", &self.info)
.finish()
}
}
impl OnnxModel {
pub fn load(path: impl AsRef<Path>, device: DeviceType) -> MlResult<Self> {
let path_ref = path.as_ref();
if !device.is_available() {
return Err(MlError::DeviceUnavailable(device.name().to_string()));
}
let session = SessionBuilder::new()
.with_optimization_level(OptLevel::All)
.load(path_ref)
.map_err(|e| MlError::ModelLoad {
path: PathBuf::from(path_ref),
reason: format!("{e:?}"),
})?;
let info = extract_info(&session, path_ref);
Ok(Self {
session: Mutex::new(session),
info,
device,
})
}
pub fn load_from_bytes(
bytes: &[u8],
device: DeviceType,
virtual_path: impl Into<PathBuf>,
) -> MlResult<Self> {
if !device.is_available() {
return Err(MlError::DeviceUnavailable(device.name().to_string()));
}
let path = virtual_path.into();
let session = SessionBuilder::new()
.with_optimization_level(OptLevel::All)
.load_from_bytes(bytes)
.map_err(|e| MlError::ModelLoad {
path: path.clone(),
reason: format!("{e:?}"),
})?;
let mut info = extract_info(&session, &path);
info.path = path;
Ok(Self {
session: Mutex::new(session),
info,
device,
})
}
pub fn run(&self, inputs: &HashMap<&str, Tensor>) -> MlResult<HashMap<String, Tensor>> {
let guard = self
.session
.lock()
.map_err(|_| MlError::pipeline("onnx", "session mutex poisoned"))?;
guard
.run(inputs)
.map_err(|e| MlError::OnnxRuntime(format!("{e:?}")))
}
pub fn run_single(
&self,
input_name: &str,
data: Vec<f32>,
shape: Vec<usize>,
) -> MlResult<HashMap<String, Vec<f32>>> {
let expected = shape.iter().product::<usize>();
if data.len() != expected {
return Err(MlError::pipeline(
"onnx",
format!(
"run_single: data len {} does not match shape product {}",
data.len(),
expected,
),
));
}
let tensor = Tensor { data, shape };
let mut inputs: HashMap<&str, Tensor> = HashMap::with_capacity(1);
inputs.insert(input_name, tensor);
let outputs = self.run(&inputs)?;
Ok(outputs
.into_iter()
.map(|(name, t)| (name, t.data))
.collect())
}
#[must_use]
pub fn info(&self) -> &ModelInfo {
&self.info
}
#[must_use]
pub fn device(&self) -> DeviceType {
self.device
}
}
fn extract_info(session: &Session, path: &Path) -> ModelInfo {
let inputs = session
.input_info()
.iter()
.map(tensor_info_to_spec)
.collect();
let outputs = session
.output_info()
.iter()
.map(tensor_info_to_spec)
.collect();
let meta = session.metadata();
let producer = meta.producer_name.clone();
let opset_version = if meta.ir_version == 0 {
None
} else {
Some(meta.ir_version)
};
ModelInfo {
path: PathBuf::from(path),
inputs,
outputs,
producer: if producer.is_empty() {
None
} else {
Some(producer)
},
opset_version,
}
}
fn tensor_info_to_spec(info: &TensorInfo) -> TensorSpec {
TensorSpec {
name: info.name.clone(),
dtype: dtype_to_public(info.dtype),
shape: info.shape.iter().map(|d| d.map(|v| v as i64)).collect(),
}
}
fn dtype_to_public(dtype: DType) -> TensorDType {
match dtype {
DType::F32 => TensorDType::F32,
DType::F16 | DType::BF16 => TensorDType::F16,
DType::F64 => TensorDType::F64,
DType::I8 => TensorDType::I8,
DType::I16 => TensorDType::I16,
DType::I32 => TensorDType::I32,
DType::I64 => TensorDType::I64,
DType::U8 => TensorDType::U8,
DType::U16 => TensorDType::U16,
DType::U32 => TensorDType::U32,
DType::U64 => TensorDType::U64,
DType::Bool => TensorDType::Bool,
}
}
}
#[cfg(not(feature = "onnx"))]
mod imp {
use super::ModelInfo;
use crate::device::DeviceType;
use crate::error::{MlError, MlResult};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
#[derive(Debug)]
pub struct OnnxModel {
_priv: (),
}
impl OnnxModel {
pub fn load(_path: impl AsRef<Path>, _device: DeviceType) -> MlResult<Self> {
Err(MlError::FeatureDisabled("onnx"))
}
pub fn load_from_bytes(
_bytes: &[u8],
_device: DeviceType,
_virtual_path: impl Into<PathBuf>,
) -> MlResult<Self> {
Err(MlError::FeatureDisabled("onnx"))
}
pub fn run_single(
&self,
_input_name: &str,
_data: Vec<f32>,
_shape: Vec<usize>,
) -> MlResult<HashMap<String, Vec<f32>>> {
Err(MlError::FeatureDisabled("onnx"))
}
#[must_use]
pub fn info(&self) -> &ModelInfo {
static EMPTY: std::sync::OnceLock<ModelInfo> = std::sync::OnceLock::new();
EMPTY.get_or_init(ModelInfo::default)
}
#[must_use]
pub fn device(&self) -> DeviceType {
DeviceType::Cpu
}
}
}
pub use imp::OnnxModel;
pub fn load_auto(path: impl AsRef<Path>) -> MlResult<OnnxModel> {
OnnxModel::load(path, DeviceType::auto())
}
#[must_use]
pub fn canonical_path(path: &Path) -> PathBuf {
path.canonicalize().unwrap_or_else(|_| PathBuf::from(path))
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(not(feature = "onnx"))]
use crate::error::MlError;
#[test]
fn tensor_spec_dynamic_rank_counts_nones() {
let spec = TensorSpec::new(
"x",
TensorDType::F32,
vec![None, Some(3), Some(224), Some(224)],
);
assert_eq!(spec.dynamic_rank(), 1);
}
#[test]
fn dtype_names_are_canonical() {
assert_eq!(TensorDType::F32.name(), "f32");
assert_eq!(TensorDType::I64.name(), "i64");
assert_eq!(TensorDType::Bool.name(), "bool");
}
#[cfg(not(feature = "onnx"))]
#[test]
fn load_without_onnx_feature_reports_feature_disabled() {
let err =
OnnxModel::load("does-not-matter.onnx", DeviceType::Cpu).expect_err("expected failure");
matches!(err, MlError::FeatureDisabled("onnx"));
}
}