use crate::error::{Error, Result};
use crate::session::{SessionBuilder, SessionOptions};
use ronn_onnx::{LoadedModel, ModelLoader};
use std::path::Path;
use std::sync::Arc;
use tracing::info;
pub struct Model {
inner: Arc<LoadedModel>,
}
impl Model {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
info!("Loading model from: {:?}", path.as_ref());
let loaded = ModelLoader::load_from_file(path)?;
Ok(Self {
inner: Arc::new(loaded),
})
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
info!("Loading model from bytes ({} bytes)", bytes.len());
let loaded = ModelLoader::load_from_bytes(bytes)?;
Ok(Self {
inner: Arc::new(loaded),
})
}
pub fn create_session_default(&self) -> Result<crate::session::InferenceSession> {
self.create_session(SessionOptions::default())
}
pub fn create_session(
&self,
options: SessionOptions,
) -> Result<crate::session::InferenceSession> {
SessionBuilder::new(self.inner.clone(), options).build()
}
pub fn producer_name(&self) -> Option<&str> {
self.inner.producer_name.as_deref()
}
pub fn ir_version(&self) -> i64 {
self.inner.ir_version
}
pub fn input_names(&self) -> Vec<&str> {
self.inner
.inputs()
.iter()
.map(|i| i.name.as_str())
.collect()
}
pub fn output_names(&self) -> Vec<&str> {
self.inner
.outputs()
.iter()
.map(|o| o.name.as_str())
.collect()
}
}