oximedia_ml/model.rs
1//! Pure-Rust ONNX model wrapper.
2//!
3//! [`OnnxModel`] is a thin façade over `oxionnx::Session` that exposes a
4//! stable, Pure-Rust-only surface. The intent is that the rest of
5//! oximedia-ml (and its downstream pipelines) never imports `oxionnx`
6//! symbols directly — everything goes through [`OnnxModel`],
7//! [`TensorSpec`], and [`ModelInfo`].
8//!
9//! When the `onnx` feature is disabled, [`OnnxModel`] still exists but
10//! its constructor returns [`crate::error::MlError::FeatureDisabled`].
11//! Pipelines can then decide whether to fall back to a non-ML heuristic
12//! or propagate the error upward.
13//!
14//! ## Single-input convenience
15//!
16//! Many classifier / detector models have a single input and one or more
17//! float outputs. For these,
18//! [`OnnxModel::run_single`][crate::OnnxModel::run_single] skips the
19//! `HashMap<&str, Tensor>` boilerplate and returns flat `Vec<f32>`
20//! buffers keyed on output name:
21//!
22//! ```no_run
23//! # #[cfg(feature = "onnx")]
24//! # fn demo() -> oximedia_ml::MlResult<()> {
25//! use oximedia_ml::{DeviceType, OnnxModel};
26//!
27//! let model = OnnxModel::load("scene.onnx", DeviceType::auto())?;
28//! let outputs = model.run_single(
29//! "input",
30//! vec![0.0_f32; 1 * 3 * 224 * 224],
31//! vec![1, 3, 224, 224],
32//! )?;
33//! // `outputs` is `HashMap<String, Vec<f32>>`.
34//! # let _ = outputs;
35//! # Ok(())
36//! # }
37//! ```
38//!
39//! ## Metadata
40//!
41//! [`ModelInfo`] exposes input/output [`TensorSpec`]s (name, dtype,
42//! shape with dynamic dims as `None`), the ONNX producer name, and the
43//! opset version. Callers can use [`TensorSpec::dynamic_rank`] to decide
44//! whether dynamic shape plumbing is needed.
45
46use std::path::{Path, PathBuf};
47
48use crate::device::DeviceType;
49use crate::error::MlResult;
50
51/// Canonical scalar dtype advertised by a model input or output.
52///
53/// Mirrors a pragmatic subset of the ONNX dtype list. Internally
54/// `oxionnx` stores tensor data as `f32`, so anything non-`F32` signals
55/// a cast at the boundary (performed inside the pipeline layer).
56#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
57pub enum TensorDType {
58 /// 32-bit IEEE float.
59 F32,
60 /// 16-bit IEEE float.
61 F16,
62 /// 64-bit IEEE float.
63 F64,
64 /// Signed 8-bit integer.
65 I8,
66 /// Signed 16-bit integer.
67 I16,
68 /// Signed 32-bit integer.
69 I32,
70 /// Signed 64-bit integer.
71 I64,
72 /// Unsigned 8-bit integer.
73 U8,
74 /// Unsigned 16-bit integer.
75 U16,
76 /// Unsigned 32-bit integer.
77 U32,
78 /// Unsigned 64-bit integer.
79 U64,
80 /// Boolean.
81 Bool,
82}
83
84impl TensorDType {
85 /// Short canonical name matching ONNX nomenclature.
86 #[must_use]
87 pub fn name(self) -> &'static str {
88 match self {
89 Self::F32 => "f32",
90 Self::F16 => "f16",
91 Self::F64 => "f64",
92 Self::I8 => "i8",
93 Self::I16 => "i16",
94 Self::I32 => "i32",
95 Self::I64 => "i64",
96 Self::U8 => "u8",
97 Self::U16 => "u16",
98 Self::U32 => "u32",
99 Self::U64 => "u64",
100 Self::Bool => "bool",
101 }
102 }
103}
104
105/// Describes a single model input or output tensor.
106#[derive(Clone, Debug)]
107pub struct TensorSpec {
108 /// Tensor name as declared in the ONNX graph.
109 pub name: String,
110 /// Scalar dtype.
111 pub dtype: TensorDType,
112 /// Shape with dynamic (None) dimensions expressed as `None`.
113 /// Static dimensions are positive integers; `i64` is used to match
114 /// the ONNX specification convention.
115 pub shape: Vec<Option<i64>>,
116}
117
118impl TensorSpec {
119 /// Create a new [`TensorSpec`].
120 #[must_use]
121 pub fn new(name: impl Into<String>, dtype: TensorDType, shape: Vec<Option<i64>>) -> Self {
122 Self {
123 name: name.into(),
124 dtype,
125 shape,
126 }
127 }
128
129 /// Number of dynamic dimensions (those reported as `None`).
130 #[must_use]
131 pub fn dynamic_rank(&self) -> usize {
132 self.shape.iter().filter(|d| d.is_none()).count()
133 }
134}
135
136/// Static metadata describing a loaded ONNX model.
137///
138/// Returned by [`OnnxModel::info`][crate::OnnxModel::info]; inspect
139/// [`Self::inputs`] / [`Self::outputs`] to validate the expected tensor
140/// contract before running inference.
141#[derive(Clone, Debug, Default)]
142pub struct ModelInfo {
143 /// Source path of the model.
144 pub path: PathBuf,
145 /// Model input tensor specifications.
146 pub inputs: Vec<TensorSpec>,
147 /// Model output tensor specifications.
148 pub outputs: Vec<TensorSpec>,
149 /// Producer name as declared in the ONNX file.
150 pub producer: Option<String>,
151 /// Opset version, if reported by the backend.
152 pub opset_version: Option<i64>,
153}
154
155#[cfg(feature = "onnx")]
156mod imp {
157 use super::{ModelInfo, TensorDType, TensorSpec};
158 use crate::device::DeviceType;
159 use crate::error::{MlError, MlResult};
160 use oxionnx::execution_providers::ProviderKind;
161 use oxionnx::graph::TensorInfo;
162 use oxionnx::DType;
163 use oxionnx::{OptLevel, Session, SessionBuilder, Tensor};
164 use std::collections::HashMap;
165 use std::path::{Path, PathBuf};
166 use std::sync::Mutex;
167
168 /// Pure-Rust ONNX model handle.
169 pub struct OnnxModel {
170 session: Mutex<Session>,
171 info: ModelInfo,
172 device: DeviceType,
173 }
174
175 impl std::fmt::Debug for OnnxModel {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 f.debug_struct("OnnxModel")
178 .field("device", &self.device)
179 .field("info", &self.info)
180 .finish()
181 }
182 }
183
184 impl OnnxModel {
185 /// Load an ONNX model from disk.
186 ///
187 /// The session is configured with an ordered [`ProviderKind`] list
188 /// derived from `device` (see `device_to_providers`). CPU is
189 /// always the implicit terminal fallback, so the session can never
190 /// fail solely because a GPU backend is unavailable at runtime.
191 ///
192 /// # Errors
193 ///
194 /// * [`MlError::DeviceUnavailable`] if `device` is not available in
195 /// this build / runtime.
196 /// * [`MlError::ModelLoad`] if the file cannot be parsed by the
197 /// OxiONNX runtime.
198 pub fn load(path: impl AsRef<Path>, device: DeviceType) -> MlResult<Self> {
199 let path_ref = path.as_ref();
200 if !device.is_available() {
201 return Err(MlError::DeviceUnavailable(device.name().to_string()));
202 }
203
204 let providers = device_to_providers(device);
205 let session = SessionBuilder::new()
206 .with_optimization_level(OptLevel::All)
207 .with_provider_kinds(providers)
208 .load(path_ref)
209 .map_err(|e| MlError::ModelLoad {
210 path: PathBuf::from(path_ref),
211 reason: format!("{e:?}"),
212 })?;
213
214 let info = extract_info(&session, path_ref);
215
216 Ok(Self {
217 session: Mutex::new(session),
218 info,
219 device,
220 })
221 }
222
223 /// Load an ONNX model from an in-memory byte buffer.
224 ///
225 /// `virtual_path` is a synthetic identifier used for
226 /// [`ModelInfo::path`] and cache keying; it does not need to
227 /// refer to a real file on disk.
228 ///
229 /// The session is configured with an ordered [`ProviderKind`] list
230 /// derived from `device` (see `device_to_providers`).
231 ///
232 /// # Errors
233 ///
234 /// * [`MlError::DeviceUnavailable`] if `device` is not available.
235 /// * [`MlError::ModelLoad`] if `bytes` is not a valid ONNX payload.
236 pub fn load_from_bytes(
237 bytes: &[u8],
238 device: DeviceType,
239 virtual_path: impl Into<PathBuf>,
240 ) -> MlResult<Self> {
241 if !device.is_available() {
242 return Err(MlError::DeviceUnavailable(device.name().to_string()));
243 }
244 let path = virtual_path.into();
245 let providers = device_to_providers(device);
246 let session = SessionBuilder::new()
247 .with_optimization_level(OptLevel::All)
248 .with_provider_kinds(providers)
249 .load_from_bytes(bytes)
250 .map_err(|e| MlError::ModelLoad {
251 path: path.clone(),
252 reason: format!("{e:?}"),
253 })?;
254
255 let mut info = extract_info(&session, &path);
256 info.path = path;
257
258 Ok(Self {
259 session: Mutex::new(session),
260 info,
261 device,
262 })
263 }
264
265 /// Execute a forward pass.
266 ///
267 /// Inputs are keyed by input tensor name. Outputs are returned
268 /// as a map from output name to tensor.
269 ///
270 /// # Errors
271 ///
272 /// * [`MlError::Pipeline`] with stage `"onnx"` if the session
273 /// mutex is poisoned.
274 /// * [`MlError::OnnxRuntime`] if inference fails inside OxiONNX.
275 pub fn run(&self, inputs: &HashMap<&str, Tensor>) -> MlResult<HashMap<String, Tensor>> {
276 let guard = self
277 .session
278 .lock()
279 .map_err(|_| MlError::pipeline("onnx", "session mutex poisoned"))?;
280 guard
281 .run(inputs)
282 .map_err(|e| MlError::OnnxRuntime(format!("{e:?}")))
283 }
284
285 /// Execute a forward pass for a single-input model, returning the
286 /// raw `f32` buffer of every output tensor.
287 ///
288 /// This is a convenience adapter for the common case where the
289 /// caller has a single `Vec<f32>` payload and a shape, and does not
290 /// want to import `oxionnx::Tensor` into its own crate. Pure-Rust
291 /// downstream crates (e.g. `oximedia-recommend`) can build their
292 /// embedding pipelines on top of `OnnxModel` without touching the
293 /// `oxionnx` symbol surface at all, preserving the encapsulation
294 /// documented in the module header.
295 ///
296 /// # Errors
297 ///
298 /// * [`MlError::OnnxRuntime`] if the underlying `Session::run` fails.
299 /// * [`MlError::pipeline`] if the session mutex is poisoned.
300 pub fn run_single(
301 &self,
302 input_name: &str,
303 data: Vec<f32>,
304 shape: Vec<usize>,
305 ) -> MlResult<HashMap<String, Vec<f32>>> {
306 let expected = shape.iter().product::<usize>();
307 if data.len() != expected {
308 return Err(MlError::pipeline(
309 "onnx",
310 format!(
311 "run_single: data len {} does not match shape product {}",
312 data.len(),
313 expected,
314 ),
315 ));
316 }
317 let tensor = Tensor { data, shape };
318 let mut inputs: HashMap<&str, Tensor> = HashMap::with_capacity(1);
319 inputs.insert(input_name, tensor);
320 let outputs = self.run(&inputs)?;
321 Ok(outputs
322 .into_iter()
323 .map(|(name, t)| (name, t.data))
324 .collect())
325 }
326
327 /// Return the loaded model metadata.
328 #[must_use]
329 pub fn info(&self) -> &ModelInfo {
330 &self.info
331 }
332
333 /// Return the device this model was loaded onto.
334 #[must_use]
335 pub fn device(&self) -> DeviceType {
336 self.device
337 }
338 }
339
340 /// Map a [`DeviceType`] to an ordered list of [`ProviderKind`] backends.
341 ///
342 /// The returned list is passed to
343 /// [`SessionBuilder::with_provider_kinds`][oxionnx::SessionBuilder::with_provider_kinds]
344 /// so that the oxionnx dispatch loop tries the requested backend first
345 /// and falls back to CPU when it is unavailable or returns `None` for a
346 /// specific operator.
347 ///
348 /// | DeviceType | Provider priority list |
349 /// |-------------|--------------------------------------|
350 /// | Cpu | `[Cpu]` |
351 /// | Cuda | `[Cuda, Cpu]` (feature `cuda`) |
352 /// | WebGpu | `[Gpu, Cpu]` (feature `webgpu`) |
353 /// | DirectMl | `[DirectMl, Cpu]` (feature `directml`) |
354 /// | CoreMl | `[Cpu]` (CoreML not yet wired) |
355 fn device_to_providers(device: DeviceType) -> Vec<ProviderKind> {
356 match device {
357 DeviceType::Cpu | DeviceType::CoreMl => vec![ProviderKind::Cpu],
358 DeviceType::Cuda => {
359 // cuda feature is required for ProviderKind::Cuda to exist.
360 #[cfg(feature = "cuda")]
361 {
362 vec![ProviderKind::Cuda, ProviderKind::Cpu]
363 }
364 #[cfg(not(feature = "cuda"))]
365 {
366 vec![ProviderKind::Cpu]
367 }
368 }
369 DeviceType::WebGpu => {
370 // webgpu feature is required for ProviderKind::Gpu to exist.
371 #[cfg(feature = "webgpu")]
372 {
373 vec![ProviderKind::Gpu, ProviderKind::Cpu]
374 }
375 #[cfg(not(feature = "webgpu"))]
376 {
377 vec![ProviderKind::Cpu]
378 }
379 }
380 DeviceType::DirectMl => {
381 // directml feature is required for ProviderKind::DirectMl to exist.
382 #[cfg(feature = "directml")]
383 {
384 vec![ProviderKind::DirectMl, ProviderKind::Cpu]
385 }
386 #[cfg(not(feature = "directml"))]
387 {
388 vec![ProviderKind::Cpu]
389 }
390 }
391 }
392 }
393
394 fn extract_info(session: &Session, path: &Path) -> ModelInfo {
395 let inputs = session
396 .input_info()
397 .iter()
398 .map(tensor_info_to_spec)
399 .collect();
400 let outputs = session
401 .output_info()
402 .iter()
403 .map(tensor_info_to_spec)
404 .collect();
405
406 let meta = session.metadata();
407 let producer = meta.producer_name.clone();
408 let opset_version = if meta.ir_version == 0 {
409 None
410 } else {
411 Some(meta.ir_version)
412 };
413
414 ModelInfo {
415 path: PathBuf::from(path),
416 inputs,
417 outputs,
418 producer: if producer.is_empty() {
419 None
420 } else {
421 Some(producer)
422 },
423 opset_version,
424 }
425 }
426
427 fn tensor_info_to_spec(info: &TensorInfo) -> TensorSpec {
428 TensorSpec {
429 name: info.name.clone(),
430 dtype: dtype_to_public(info.dtype),
431 shape: info.shape.iter().map(|d| d.map(|v| v as i64)).collect(),
432 }
433 }
434
435 fn dtype_to_public(dtype: DType) -> TensorDType {
436 match dtype {
437 DType::F32 => TensorDType::F32,
438 DType::F16 | DType::BF16 => TensorDType::F16,
439 DType::F64 => TensorDType::F64,
440 DType::I8 => TensorDType::I8,
441 DType::I16 => TensorDType::I16,
442 DType::I32 => TensorDType::I32,
443 DType::I64 => TensorDType::I64,
444 DType::U8 => TensorDType::U8,
445 DType::U16 => TensorDType::U16,
446 DType::U32 => TensorDType::U32,
447 DType::U64 => TensorDType::U64,
448 DType::Bool => TensorDType::Bool,
449 }
450 }
451}
452
453#[cfg(not(feature = "onnx"))]
454mod imp {
455 use super::ModelInfo;
456 use crate::device::DeviceType;
457 use crate::error::{MlError, MlResult};
458 use std::collections::HashMap;
459 use std::path::{Path, PathBuf};
460
461 /// Stub ONNX model used when the `onnx` feature is disabled.
462 ///
463 /// All constructors return [`MlError::FeatureDisabled`] so that
464 /// downstream code can degrade gracefully at runtime without special
465 /// `cfg` handling of its own.
466 #[derive(Debug)]
467 pub struct OnnxModel {
468 _priv: (),
469 }
470
471 impl OnnxModel {
472 /// Always fails with [`MlError::FeatureDisabled`] in this build.
473 pub fn load(_path: impl AsRef<Path>, _device: DeviceType) -> MlResult<Self> {
474 Err(MlError::FeatureDisabled("onnx"))
475 }
476
477 /// Always fails with [`MlError::FeatureDisabled`] in this build.
478 pub fn load_from_bytes(
479 _bytes: &[u8],
480 _device: DeviceType,
481 _virtual_path: impl Into<PathBuf>,
482 ) -> MlResult<Self> {
483 Err(MlError::FeatureDisabled("onnx"))
484 }
485
486 /// Always fails with [`MlError::FeatureDisabled`] in this build.
487 pub fn run_single(
488 &self,
489 _input_name: &str,
490 _data: Vec<f32>,
491 _shape: Vec<usize>,
492 ) -> MlResult<HashMap<String, Vec<f32>>> {
493 Err(MlError::FeatureDisabled("onnx"))
494 }
495
496 /// Always returns an empty [`ModelInfo`].
497 #[must_use]
498 pub fn info(&self) -> &ModelInfo {
499 // Unreachable in practice — `load` never succeeds — but provides
500 // a safe default so pattern matching compiles.
501 static EMPTY: std::sync::OnceLock<ModelInfo> = std::sync::OnceLock::new();
502 EMPTY.get_or_init(ModelInfo::default)
503 }
504
505 /// Device this stub was requested with (always CPU).
506 #[must_use]
507 pub fn device(&self) -> DeviceType {
508 DeviceType::Cpu
509 }
510 }
511}
512
513pub use imp::OnnxModel;
514
515/// Convenience wrapper: load a model with the auto-selected device.
516///
517/// Equivalent to
518/// `OnnxModel::load(path, DeviceType::auto())`. Use when you do not
519/// need to pin a specific backend.
520///
521/// # Errors
522///
523/// Propagates any error from
524/// [`OnnxModel::load`][crate::OnnxModel::load].
525pub fn load_auto(path: impl AsRef<Path>) -> MlResult<OnnxModel> {
526 OnnxModel::load(path, DeviceType::auto())
527}
528
529/// Stable path-based hint used by [`crate::ModelCache`] as a key.
530///
531/// Canonicalises `path` when possible; falls back to the raw path if
532/// canonicalisation fails (e.g. the target does not exist yet).
533#[must_use]
534pub fn canonical_path(path: &Path) -> PathBuf {
535 path.canonicalize().unwrap_or_else(|_| PathBuf::from(path))
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 #[cfg(not(feature = "onnx"))]
542 use crate::error::MlError;
543
544 #[test]
545 fn tensor_spec_dynamic_rank_counts_nones() {
546 let spec = TensorSpec::new(
547 "x",
548 TensorDType::F32,
549 vec![None, Some(3), Some(224), Some(224)],
550 );
551 assert_eq!(spec.dynamic_rank(), 1);
552 }
553
554 #[test]
555 fn dtype_names_are_canonical() {
556 assert_eq!(TensorDType::F32.name(), "f32");
557 assert_eq!(TensorDType::I64.name(), "i64");
558 assert_eq!(TensorDType::Bool.name(), "bool");
559 }
560
561 #[cfg(not(feature = "onnx"))]
562 #[test]
563 fn load_without_onnx_feature_reports_feature_disabled() {
564 let err =
565 OnnxModel::load("does-not-matter.onnx", DeviceType::Cpu).expect_err("expected failure");
566 matches!(err, MlError::FeatureDisabled("onnx"));
567 }
568}