ronn_api/model.rs
1use crate::error::{Error, Result};
2use crate::session::{SessionBuilder, SessionOptions};
3use ronn_onnx::{LoadedModel, ModelLoader};
4use std::path::Path;
5use std::sync::Arc;
6use tracing::info;
7
8/// Represents a loaded ML model
9pub struct Model {
10 inner: Arc<LoadedModel>,
11}
12
13impl Model {
14 /// Load a model from an ONNX file
15 ///
16 /// # Example
17 /// ```no_run
18 /// use ronn_api::Model;
19 ///
20 /// let model = Model::load("model.onnx")?;
21 /// # Ok::<(), ronn_api::Error>(())
22 /// ```
23 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
24 info!("Loading model from: {:?}", path.as_ref());
25 let loaded = ModelLoader::load_from_file(path)?;
26 Ok(Self {
27 inner: Arc::new(loaded),
28 })
29 }
30
31 /// Load a model from bytes
32 ///
33 /// # Example
34 /// ```no_run
35 /// use ronn_api::Model;
36 ///
37 /// let bytes = std::fs::read("model.onnx")?;
38 /// let model = Model::from_bytes(&bytes)?;
39 /// # Ok::<(), ronn_api::Error>(())
40 /// ```
41 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
42 info!("Loading model from bytes ({} bytes)", bytes.len());
43 let loaded = ModelLoader::load_from_bytes(bytes)?;
44 Ok(Self {
45 inner: Arc::new(loaded),
46 })
47 }
48
49 /// Create an inference session with default options
50 ///
51 /// # Example
52 /// ```no_run
53 /// use ronn_api::Model;
54 ///
55 /// let model = Model::load("model.onnx")?;
56 /// let session = model.create_session_default()?;
57 /// # Ok::<(), ronn_api::Error>(())
58 /// ```
59 pub fn create_session_default(&self) -> Result<crate::session::InferenceSession> {
60 self.create_session(SessionOptions::default())
61 }
62
63 /// Create an inference session with custom options
64 ///
65 /// # Example
66 /// ```no_run
67 /// use ronn_api::{Model, SessionOptions, OptimizationLevel};
68 /// use ronn_providers::ProviderType;
69 ///
70 /// let model = Model::load("model.onnx")?;
71 /// let options = SessionOptions::new()
72 /// .with_optimization_level(OptimizationLevel::O3)
73 /// .with_provider(ProviderType::GPU);
74 /// let session = model.create_session(options)?;
75 /// # Ok::<(), ronn_api::Error>(())
76 /// ```
77 pub fn create_session(
78 &self,
79 options: SessionOptions,
80 ) -> Result<crate::session::InferenceSession> {
81 SessionBuilder::new(self.inner.clone(), options).build()
82 }
83
84 /// Get model metadata
85 pub fn producer_name(&self) -> Option<&str> {
86 self.inner.producer_name.as_deref()
87 }
88
89 /// Get IR version
90 pub fn ir_version(&self) -> i64 {
91 self.inner.ir_version
92 }
93
94 /// Get input names
95 pub fn input_names(&self) -> Vec<&str> {
96 self.inner
97 .inputs()
98 .iter()
99 .map(|i| i.name.as_str())
100 .collect()
101 }
102
103 /// Get output names
104 pub fn output_names(&self) -> Vec<&str> {
105 self.inner
106 .outputs()
107 .iter()
108 .map(|o| o.name.as_str())
109 .collect()
110 }
111}