Skip to main content

oximedia_ml/
model.rs

1//! Pure-Rust ONNX model wrapper.
2//!
3//! [`OnnxModel`] is a thin façade over `oxionnx::Session` that exposes a
4//! stable, Pure-Rust-only surface. The intent is that the rest of
5//! oximedia-ml (and its downstream pipelines) never imports `oxionnx`
6//! symbols directly — everything goes through [`OnnxModel`],
7//! [`TensorSpec`], and [`ModelInfo`].
8//!
9//! When the `onnx` feature is disabled, [`OnnxModel`] still exists but
10//! its constructor returns [`crate::error::MlError::FeatureDisabled`].
11//! Pipelines can then decide whether to fall back to a non-ML heuristic
12//! or propagate the error upward.
13//!
14//! ## Single-input convenience
15//!
16//! Many classifier / detector models have a single input and one or more
17//! float outputs. For these,
18//! [`OnnxModel::run_single`][crate::OnnxModel::run_single] skips the
19//! `HashMap<&str, Tensor>` boilerplate and returns flat `Vec<f32>`
20//! buffers keyed on output name:
21//!
22//! ```no_run
23//! # #[cfg(feature = "onnx")]
24//! # fn demo() -> oximedia_ml::MlResult<()> {
25//! use oximedia_ml::{DeviceType, OnnxModel};
26//!
27//! let model = OnnxModel::load("scene.onnx", DeviceType::auto())?;
28//! let outputs = model.run_single(
29//!     "input",
30//!     vec![0.0_f32; 1 * 3 * 224 * 224],
31//!     vec![1, 3, 224, 224],
32//! )?;
33//! // `outputs` is `HashMap<String, Vec<f32>>`.
34//! # let _ = outputs;
35//! # Ok(())
36//! # }
37//! ```
38//!
39//! ## Metadata
40//!
41//! [`ModelInfo`] exposes input/output [`TensorSpec`]s (name, dtype,
42//! shape with dynamic dims as `None`), the ONNX producer name, and the
43//! opset version. Callers can use [`TensorSpec::dynamic_rank`] to decide
44//! whether dynamic shape plumbing is needed.
45
46use std::path::{Path, PathBuf};
47
48use crate::device::DeviceType;
49use crate::error::MlResult;
50
51/// Canonical scalar dtype advertised by a model input or output.
52///
53/// Mirrors a pragmatic subset of the ONNX dtype list. Internally
54/// `oxionnx` stores tensor data as `f32`, so anything non-`F32` signals
55/// a cast at the boundary (performed inside the pipeline layer).
56#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
57pub enum TensorDType {
58    /// 32-bit IEEE float.
59    F32,
60    /// 16-bit IEEE float.
61    F16,
62    /// 64-bit IEEE float.
63    F64,
64    /// Signed 8-bit integer.
65    I8,
66    /// Signed 16-bit integer.
67    I16,
68    /// Signed 32-bit integer.
69    I32,
70    /// Signed 64-bit integer.
71    I64,
72    /// Unsigned 8-bit integer.
73    U8,
74    /// Unsigned 16-bit integer.
75    U16,
76    /// Unsigned 32-bit integer.
77    U32,
78    /// Unsigned 64-bit integer.
79    U64,
80    /// Boolean.
81    Bool,
82}
83
84impl TensorDType {
85    /// Short canonical name matching ONNX nomenclature.
86    #[must_use]
87    pub fn name(self) -> &'static str {
88        match self {
89            Self::F32 => "f32",
90            Self::F16 => "f16",
91            Self::F64 => "f64",
92            Self::I8 => "i8",
93            Self::I16 => "i16",
94            Self::I32 => "i32",
95            Self::I64 => "i64",
96            Self::U8 => "u8",
97            Self::U16 => "u16",
98            Self::U32 => "u32",
99            Self::U64 => "u64",
100            Self::Bool => "bool",
101        }
102    }
103}
104
105/// Describes a single model input or output tensor.
106#[derive(Clone, Debug)]
107pub struct TensorSpec {
108    /// Tensor name as declared in the ONNX graph.
109    pub name: String,
110    /// Scalar dtype.
111    pub dtype: TensorDType,
112    /// Shape with dynamic (None) dimensions expressed as `None`.
113    /// Static dimensions are positive integers; `i64` is used to match
114    /// the ONNX specification convention.
115    pub shape: Vec<Option<i64>>,
116}
117
118impl TensorSpec {
119    /// Create a new [`TensorSpec`].
120    #[must_use]
121    pub fn new(name: impl Into<String>, dtype: TensorDType, shape: Vec<Option<i64>>) -> Self {
122        Self {
123            name: name.into(),
124            dtype,
125            shape,
126        }
127    }
128
129    /// Number of dynamic dimensions (those reported as `None`).
130    #[must_use]
131    pub fn dynamic_rank(&self) -> usize {
132        self.shape.iter().filter(|d| d.is_none()).count()
133    }
134}
135
136/// Static metadata describing a loaded ONNX model.
137///
138/// Returned by [`OnnxModel::info`][crate::OnnxModel::info]; inspect
139/// [`Self::inputs`] / [`Self::outputs`] to validate the expected tensor
140/// contract before running inference.
141#[derive(Clone, Debug, Default)]
142pub struct ModelInfo {
143    /// Source path of the model.
144    pub path: PathBuf,
145    /// Model input tensor specifications.
146    pub inputs: Vec<TensorSpec>,
147    /// Model output tensor specifications.
148    pub outputs: Vec<TensorSpec>,
149    /// Producer name as declared in the ONNX file.
150    pub producer: Option<String>,
151    /// Opset version, if reported by the backend.
152    pub opset_version: Option<i64>,
153}
154
155#[cfg(feature = "onnx")]
156mod imp {
157    use super::{ModelInfo, TensorDType, TensorSpec};
158    use crate::device::DeviceType;
159    use crate::error::{MlError, MlResult};
160    use oxionnx::execution_providers::ProviderKind;
161    use oxionnx::graph::TensorInfo;
162    use oxionnx::DType;
163    use oxionnx::{OptLevel, Session, SessionBuilder, Tensor};
164    use std::collections::HashMap;
165    use std::path::{Path, PathBuf};
166    use std::sync::Mutex;
167
168    /// Pure-Rust ONNX model handle.
169    pub struct OnnxModel {
170        session: Mutex<Session>,
171        info: ModelInfo,
172        device: DeviceType,
173    }
174
175    impl std::fmt::Debug for OnnxModel {
176        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177            f.debug_struct("OnnxModel")
178                .field("device", &self.device)
179                .field("info", &self.info)
180                .finish()
181        }
182    }
183
184    impl OnnxModel {
185        /// Load an ONNX model from disk.
186        ///
187        /// The session is configured with an ordered [`ProviderKind`] list
188        /// derived from `device` (see `device_to_providers`).  CPU is
189        /// always the implicit terminal fallback, so the session can never
190        /// fail solely because a GPU backend is unavailable at runtime.
191        ///
192        /// # Errors
193        ///
194        /// * [`MlError::DeviceUnavailable`] if `device` is not available in
195        ///   this build / runtime.
196        /// * [`MlError::ModelLoad`] if the file cannot be parsed by the
197        ///   OxiONNX runtime.
198        pub fn load(path: impl AsRef<Path>, device: DeviceType) -> MlResult<Self> {
199            let path_ref = path.as_ref();
200            if !device.is_available() {
201                return Err(MlError::DeviceUnavailable(device.name().to_string()));
202            }
203
204            let providers = device_to_providers(device);
205            let session = SessionBuilder::new()
206                .with_optimization_level(OptLevel::All)
207                .with_provider_kinds(providers)
208                .load(path_ref)
209                .map_err(|e| MlError::ModelLoad {
210                    path: PathBuf::from(path_ref),
211                    reason: format!("{e:?}"),
212                })?;
213
214            let info = extract_info(&session, path_ref);
215
216            Ok(Self {
217                session: Mutex::new(session),
218                info,
219                device,
220            })
221        }
222
223        /// Load an ONNX model from an in-memory byte buffer.
224        ///
225        /// `virtual_path` is a synthetic identifier used for
226        /// [`ModelInfo::path`] and cache keying; it does not need to
227        /// refer to a real file on disk.
228        ///
229        /// The session is configured with an ordered [`ProviderKind`] list
230        /// derived from `device` (see `device_to_providers`).
231        ///
232        /// # Errors
233        ///
234        /// * [`MlError::DeviceUnavailable`] if `device` is not available.
235        /// * [`MlError::ModelLoad`] if `bytes` is not a valid ONNX payload.
236        pub fn load_from_bytes(
237            bytes: &[u8],
238            device: DeviceType,
239            virtual_path: impl Into<PathBuf>,
240        ) -> MlResult<Self> {
241            if !device.is_available() {
242                return Err(MlError::DeviceUnavailable(device.name().to_string()));
243            }
244            let path = virtual_path.into();
245            let providers = device_to_providers(device);
246            let session = SessionBuilder::new()
247                .with_optimization_level(OptLevel::All)
248                .with_provider_kinds(providers)
249                .load_from_bytes(bytes)
250                .map_err(|e| MlError::ModelLoad {
251                    path: path.clone(),
252                    reason: format!("{e:?}"),
253                })?;
254
255            let mut info = extract_info(&session, &path);
256            info.path = path;
257
258            Ok(Self {
259                session: Mutex::new(session),
260                info,
261                device,
262            })
263        }
264
265        /// Execute a forward pass.
266        ///
267        /// Inputs are keyed by input tensor name. Outputs are returned
268        /// as a map from output name to tensor.
269        ///
270        /// # Errors
271        ///
272        /// * [`MlError::Pipeline`] with stage `"onnx"` if the session
273        ///   mutex is poisoned.
274        /// * [`MlError::OnnxRuntime`] if inference fails inside OxiONNX.
275        pub fn run(&self, inputs: &HashMap<&str, Tensor>) -> MlResult<HashMap<String, Tensor>> {
276            let guard = self
277                .session
278                .lock()
279                .map_err(|_| MlError::pipeline("onnx", "session mutex poisoned"))?;
280            guard
281                .run(inputs)
282                .map_err(|e| MlError::OnnxRuntime(format!("{e:?}")))
283        }
284
285        /// Execute a forward pass for a single-input model, returning the
286        /// raw `f32` buffer of every output tensor.
287        ///
288        /// This is a convenience adapter for the common case where the
289        /// caller has a single `Vec<f32>` payload and a shape, and does not
290        /// want to import `oxionnx::Tensor` into its own crate.  Pure-Rust
291        /// downstream crates (e.g. `oximedia-recommend`) can build their
292        /// embedding pipelines on top of `OnnxModel` without touching the
293        /// `oxionnx` symbol surface at all, preserving the encapsulation
294        /// documented in the module header.
295        ///
296        /// # Errors
297        ///
298        /// * [`MlError::OnnxRuntime`] if the underlying `Session::run` fails.
299        /// * [`MlError::pipeline`] if the session mutex is poisoned.
300        pub fn run_single(
301            &self,
302            input_name: &str,
303            data: Vec<f32>,
304            shape: Vec<usize>,
305        ) -> MlResult<HashMap<String, Vec<f32>>> {
306            let expected = shape.iter().product::<usize>();
307            if data.len() != expected {
308                return Err(MlError::pipeline(
309                    "onnx",
310                    format!(
311                        "run_single: data len {} does not match shape product {}",
312                        data.len(),
313                        expected,
314                    ),
315                ));
316            }
317            let tensor = Tensor { data, shape };
318            let mut inputs: HashMap<&str, Tensor> = HashMap::with_capacity(1);
319            inputs.insert(input_name, tensor);
320            let outputs = self.run(&inputs)?;
321            Ok(outputs
322                .into_iter()
323                .map(|(name, t)| (name, t.data))
324                .collect())
325        }
326
327        /// Return the loaded model metadata.
328        #[must_use]
329        pub fn info(&self) -> &ModelInfo {
330            &self.info
331        }
332
333        /// Return the device this model was loaded onto.
334        #[must_use]
335        pub fn device(&self) -> DeviceType {
336            self.device
337        }
338    }
339
340    /// Map a [`DeviceType`] to an ordered list of [`ProviderKind`] backends.
341    ///
342    /// The returned list is passed to
343    /// [`SessionBuilder::with_provider_kinds`][oxionnx::SessionBuilder::with_provider_kinds]
344    /// so that the oxionnx dispatch loop tries the requested backend first
345    /// and falls back to CPU when it is unavailable or returns `None` for a
346    /// specific operator.
347    ///
348    /// | DeviceType  | Provider priority list              |
349    /// |-------------|--------------------------------------|
350    /// | Cpu         | `[Cpu]`                             |
351    /// | Cuda        | `[Cuda, Cpu]` (feature `cuda`)      |
352    /// | WebGpu      | `[Gpu, Cpu]`  (feature `webgpu`)    |
353    /// | DirectMl    | `[DirectMl, Cpu]` (feature `directml`) |
354    /// | CoreMl      | `[Cpu]` (CoreML not yet wired)      |
355    fn device_to_providers(device: DeviceType) -> Vec<ProviderKind> {
356        match device {
357            DeviceType::Cpu | DeviceType::CoreMl => vec![ProviderKind::Cpu],
358            DeviceType::Cuda => {
359                // cuda feature is required for ProviderKind::Cuda to exist.
360                #[cfg(feature = "cuda")]
361                {
362                    vec![ProviderKind::Cuda, ProviderKind::Cpu]
363                }
364                #[cfg(not(feature = "cuda"))]
365                {
366                    vec![ProviderKind::Cpu]
367                }
368            }
369            DeviceType::WebGpu => {
370                // webgpu feature is required for ProviderKind::Gpu to exist.
371                #[cfg(feature = "webgpu")]
372                {
373                    vec![ProviderKind::Gpu, ProviderKind::Cpu]
374                }
375                #[cfg(not(feature = "webgpu"))]
376                {
377                    vec![ProviderKind::Cpu]
378                }
379            }
380            DeviceType::DirectMl => {
381                // directml feature is required for ProviderKind::DirectMl to exist.
382                #[cfg(feature = "directml")]
383                {
384                    vec![ProviderKind::DirectMl, ProviderKind::Cpu]
385                }
386                #[cfg(not(feature = "directml"))]
387                {
388                    vec![ProviderKind::Cpu]
389                }
390            }
391        }
392    }
393
394    fn extract_info(session: &Session, path: &Path) -> ModelInfo {
395        let inputs = session
396            .input_info()
397            .iter()
398            .map(tensor_info_to_spec)
399            .collect();
400        let outputs = session
401            .output_info()
402            .iter()
403            .map(tensor_info_to_spec)
404            .collect();
405
406        let meta = session.metadata();
407        let producer = meta.producer_name.clone();
408        let opset_version = if meta.ir_version == 0 {
409            None
410        } else {
411            Some(meta.ir_version)
412        };
413
414        ModelInfo {
415            path: PathBuf::from(path),
416            inputs,
417            outputs,
418            producer: if producer.is_empty() {
419                None
420            } else {
421                Some(producer)
422            },
423            opset_version,
424        }
425    }
426
427    fn tensor_info_to_spec(info: &TensorInfo) -> TensorSpec {
428        TensorSpec {
429            name: info.name.clone(),
430            dtype: dtype_to_public(info.dtype),
431            shape: info.shape.iter().map(|d| d.map(|v| v as i64)).collect(),
432        }
433    }
434
435    fn dtype_to_public(dtype: DType) -> TensorDType {
436        match dtype {
437            DType::F32 => TensorDType::F32,
438            DType::F16 | DType::BF16 => TensorDType::F16,
439            DType::F64 => TensorDType::F64,
440            DType::I8 => TensorDType::I8,
441            DType::I16 => TensorDType::I16,
442            DType::I32 => TensorDType::I32,
443            DType::I64 => TensorDType::I64,
444            DType::U8 => TensorDType::U8,
445            DType::U16 => TensorDType::U16,
446            DType::U32 => TensorDType::U32,
447            DType::U64 => TensorDType::U64,
448            DType::Bool => TensorDType::Bool,
449        }
450    }
451}
452
453#[cfg(not(feature = "onnx"))]
454mod imp {
455    use super::ModelInfo;
456    use crate::device::DeviceType;
457    use crate::error::{MlError, MlResult};
458    use std::collections::HashMap;
459    use std::path::{Path, PathBuf};
460
461    /// Stub ONNX model used when the `onnx` feature is disabled.
462    ///
463    /// All constructors return [`MlError::FeatureDisabled`] so that
464    /// downstream code can degrade gracefully at runtime without special
465    /// `cfg` handling of its own.
466    #[derive(Debug)]
467    pub struct OnnxModel {
468        _priv: (),
469    }
470
471    impl OnnxModel {
472        /// Always fails with [`MlError::FeatureDisabled`] in this build.
473        pub fn load(_path: impl AsRef<Path>, _device: DeviceType) -> MlResult<Self> {
474            Err(MlError::FeatureDisabled("onnx"))
475        }
476
477        /// Always fails with [`MlError::FeatureDisabled`] in this build.
478        pub fn load_from_bytes(
479            _bytes: &[u8],
480            _device: DeviceType,
481            _virtual_path: impl Into<PathBuf>,
482        ) -> MlResult<Self> {
483            Err(MlError::FeatureDisabled("onnx"))
484        }
485
486        /// Always fails with [`MlError::FeatureDisabled`] in this build.
487        pub fn run_single(
488            &self,
489            _input_name: &str,
490            _data: Vec<f32>,
491            _shape: Vec<usize>,
492        ) -> MlResult<HashMap<String, Vec<f32>>> {
493            Err(MlError::FeatureDisabled("onnx"))
494        }
495
496        /// Always returns an empty [`ModelInfo`].
497        #[must_use]
498        pub fn info(&self) -> &ModelInfo {
499            // Unreachable in practice — `load` never succeeds — but provides
500            // a safe default so pattern matching compiles.
501            static EMPTY: std::sync::OnceLock<ModelInfo> = std::sync::OnceLock::new();
502            EMPTY.get_or_init(ModelInfo::default)
503        }
504
505        /// Device this stub was requested with (always CPU).
506        #[must_use]
507        pub fn device(&self) -> DeviceType {
508            DeviceType::Cpu
509        }
510    }
511}
512
513pub use imp::OnnxModel;
514
515/// Convenience wrapper: load a model with the auto-selected device.
516///
517/// Equivalent to
518/// `OnnxModel::load(path, DeviceType::auto())`. Use when you do not
519/// need to pin a specific backend.
520///
521/// # Errors
522///
523/// Propagates any error from
524/// [`OnnxModel::load`][crate::OnnxModel::load].
525pub fn load_auto(path: impl AsRef<Path>) -> MlResult<OnnxModel> {
526    OnnxModel::load(path, DeviceType::auto())
527}
528
529/// Stable path-based hint used by [`crate::ModelCache`] as a key.
530///
531/// Canonicalises `path` when possible; falls back to the raw path if
532/// canonicalisation fails (e.g. the target does not exist yet).
533#[must_use]
534pub fn canonical_path(path: &Path) -> PathBuf {
535    path.canonicalize().unwrap_or_else(|_| PathBuf::from(path))
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    #[cfg(not(feature = "onnx"))]
542    use crate::error::MlError;
543
544    #[test]
545    fn tensor_spec_dynamic_rank_counts_nones() {
546        let spec = TensorSpec::new(
547            "x",
548            TensorDType::F32,
549            vec![None, Some(3), Some(224), Some(224)],
550        );
551        assert_eq!(spec.dynamic_rank(), 1);
552    }
553
554    #[test]
555    fn dtype_names_are_canonical() {
556        assert_eq!(TensorDType::F32.name(), "f32");
557        assert_eq!(TensorDType::I64.name(), "i64");
558        assert_eq!(TensorDType::Bool.name(), "bool");
559    }
560
561    #[cfg(not(feature = "onnx"))]
562    #[test]
563    fn load_without_onnx_feature_reports_feature_disabled() {
564        let err =
565            OnnxModel::load("does-not-matter.onnx", DeviceType::Cpu).expect_err("expected failure");
566        matches!(err, MlError::FeatureDisabled("onnx"));
567    }
568}