Skip to main content

ronn_api/
model.rs

1use crate::error::{Error, Result};
2use crate::session::{SessionBuilder, SessionOptions};
3use ronn_onnx::{LoadedModel, ModelLoader};
4use std::path::Path;
5use std::sync::Arc;
6use tracing::info;
7
8/// Represents a loaded ML model
9pub struct Model {
10    inner: Arc<LoadedModel>,
11}
12
13impl Model {
14    /// Load a model from an ONNX file
15    ///
16    /// # Example
17    /// ```no_run
18    /// use ronn_api::Model;
19    ///
20    /// let model = Model::load("model.onnx")?;
21    /// # Ok::<(), ronn_api::Error>(())
22    /// ```
23    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
24        info!("Loading model from: {:?}", path.as_ref());
25        let loaded = ModelLoader::load_from_file(path)?;
26        Ok(Self {
27            inner: Arc::new(loaded),
28        })
29    }
30
31    /// Load a model from bytes
32    ///
33    /// # Example
34    /// ```no_run
35    /// use ronn_api::Model;
36    ///
37    /// let bytes = std::fs::read("model.onnx")?;
38    /// let model = Model::from_bytes(&bytes)?;
39    /// # Ok::<(), ronn_api::Error>(())
40    /// ```
41    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
42        info!("Loading model from bytes ({} bytes)", bytes.len());
43        let loaded = ModelLoader::load_from_bytes(bytes)?;
44        Ok(Self {
45            inner: Arc::new(loaded),
46        })
47    }
48
49    /// Create an inference session with default options
50    ///
51    /// # Example
52    /// ```no_run
53    /// use ronn_api::Model;
54    ///
55    /// let model = Model::load("model.onnx")?;
56    /// let session = model.create_session_default()?;
57    /// # Ok::<(), ronn_api::Error>(())
58    /// ```
59    pub fn create_session_default(&self) -> Result<crate::session::InferenceSession> {
60        self.create_session(SessionOptions::default())
61    }
62
63    /// Create an inference session with custom options
64    ///
65    /// # Example
66    /// ```no_run
67    /// use ronn_api::{Model, SessionOptions, OptimizationLevel};
68    /// use ronn_providers::ProviderType;
69    ///
70    /// let model = Model::load("model.onnx")?;
71    /// let options = SessionOptions::new()
72    ///     .with_optimization_level(OptimizationLevel::O3)
73    ///     .with_provider(ProviderType::GPU);
74    /// let session = model.create_session(options)?;
75    /// # Ok::<(), ronn_api::Error>(())
76    /// ```
77    pub fn create_session(
78        &self,
79        options: SessionOptions,
80    ) -> Result<crate::session::InferenceSession> {
81        SessionBuilder::new(self.inner.clone(), options).build()
82    }
83
84    /// Get model metadata
85    pub fn producer_name(&self) -> Option<&str> {
86        self.inner.producer_name.as_deref()
87    }
88
89    /// Get IR version
90    pub fn ir_version(&self) -> i64 {
91        self.inner.ir_version
92    }
93
94    /// Get input names
95    pub fn input_names(&self) -> Vec<&str> {
96        self.inner
97            .inputs()
98            .iter()
99            .map(|i| i.name.as_str())
100            .collect()
101    }
102
103    /// Get output names
104    pub fn output_names(&self) -> Vec<&str> {
105        self.inner
106            .outputs()
107            .iter()
108            .map(|o| o.name.as_str())
109            .collect()
110    }
111}