use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use oxionnx::{ModelInfo, NodeProfile, OptLevel, Session, Tensor};
use crate::error::VoirsError;
#[derive(Debug, Clone)]
pub struct OnnxSessionConfig {
pub opt_level: OptLevel,
pub enable_profiling: bool,
pub enable_memory_pool: bool,
}
impl Default for OnnxSessionConfig {
fn default() -> Self {
Self {
opt_level: OptLevel::All,
enable_profiling: false,
enable_memory_pool: false,
}
}
}
pub struct OnnxSession {
session: Arc<RwLock<Session>>,
model_path: Option<PathBuf>,
config: OnnxSessionConfig,
}
impl OnnxSession {
pub fn from_file(path: &Path, config: &OnnxSessionConfig) -> Result<Self, VoirsError> {
let mut builder = Session::builder()
.with_optimization_level(config.opt_level)
.with_memory_pool(config.enable_memory_pool);
if config.enable_profiling {
builder = builder.with_profiling();
}
let session = builder.load(path).map_err(|e| {
VoirsError::ModelLoadError(format!(
"Failed to load ONNX model '{}': {}",
path.display(),
e
))
})?;
Ok(Self {
session: Arc::new(RwLock::new(session)),
model_path: Some(path.to_path_buf()),
config: config.clone(),
})
}
pub fn from_bytes(bytes: &[u8], config: &OnnxSessionConfig) -> Result<Self, VoirsError> {
let mut builder = Session::builder()
.with_optimization_level(config.opt_level)
.with_memory_pool(config.enable_memory_pool);
if config.enable_profiling {
builder = builder.with_profiling();
}
let session = builder.load_from_bytes(bytes).map_err(|e| {
VoirsError::ModelLoadError(format!("Failed to load ONNX model from bytes: {}", e))
})?;
Ok(Self {
session: Arc::new(RwLock::new(session)),
model_path: None,
config: config.clone(),
})
}
pub fn run(
&self,
inputs: &HashMap<&str, Tensor>,
) -> Result<HashMap<String, Tensor>, VoirsError> {
let session = self.session.read().map_err(|e| {
VoirsError::InferenceError(format!("Failed to acquire session lock: {}", e))
})?;
session
.run(inputs)
.map_err(|e| VoirsError::InferenceError(format!("ONNX inference failed: {}", e)))
}
pub fn run_one(
&self,
name: &str,
input: Tensor,
) -> Result<HashMap<String, Tensor>, VoirsError> {
let session = self.session.read().map_err(|e| {
VoirsError::InferenceError(format!("Failed to acquire session lock: {}", e))
})?;
session
.run_one(name, input)
.map_err(|e| VoirsError::InferenceError(format!("ONNX inference failed: {}", e)))
}
pub fn input_names(&self) -> Result<Vec<String>, VoirsError> {
let session = self.session.read().map_err(|e| {
VoirsError::InferenceError(format!("Failed to acquire session lock: {}", e))
})?;
Ok(session.input_names().to_vec())
}
pub fn output_names(&self) -> Result<Vec<String>, VoirsError> {
let session = self.session.read().map_err(|e| {
VoirsError::InferenceError(format!("Failed to acquire session lock: {}", e))
})?;
Ok(session.output_names().to_vec())
}
pub fn model_info(&self) -> Result<ModelInfo, VoirsError> {
let session = self.session.read().map_err(|e| {
VoirsError::InferenceError(format!("Failed to acquire session lock: {}", e))
})?;
Ok(session.model_info())
}
pub fn profiling_results(&self) -> Result<Option<Vec<NodeProfile>>, VoirsError> {
let session = self.session.read().map_err(|e| {
VoirsError::InferenceError(format!("Failed to acquire session lock: {}", e))
})?;
Ok(session.profiling_results())
}
pub fn profiling_summary(&self) -> Result<Option<ProfilingSummary>, VoirsError> {
let profiles = self.profiling_results()?;
Ok(profiles.map(|p| ProfilingSummary::from_node_profiles(&p)))
}
pub fn export_dot(&self) -> Result<String, VoirsError> {
let session = self.session.read().map_err(|e| {
VoirsError::InferenceError(format!("Failed to acquire session lock: {}", e))
})?;
Ok(session.export_dot())
}
pub fn estimated_memory_bytes(&self) -> Result<Option<usize>, VoirsError> {
let session = self.session.read().map_err(|e| {
VoirsError::InferenceError(format!("Failed to acquire session lock: {}", e))
})?;
Ok(session.estimated_memory_bytes())
}
pub fn model_path(&self) -> Option<&Path> {
self.model_path.as_deref()
}
pub fn config(&self) -> &OnnxSessionConfig {
&self.config
}
pub fn inner(&self) -> Arc<RwLock<Session>> {
Arc::clone(&self.session)
}
pub fn weights(&self) -> Result<Vec<(String, Vec<usize>)>, VoirsError> {
let session = self.session.read().map_err(|e| {
VoirsError::InferenceError(format!("Failed to acquire session lock: {}", e))
})?;
Ok(session
.weights()
.iter()
.map(|(k, v)| (k.clone(), v.shape.clone()))
.collect())
}
}
impl Clone for OnnxSession {
fn clone(&self) -> Self {
Self {
session: Arc::clone(&self.session),
model_path: self.model_path.clone(),
config: self.config.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct ProfilingSummary {
pub total_duration: Duration,
pub top_nodes: Vec<(String, Duration)>,
pub op_type_durations: HashMap<String, Duration>,
pub bottleneck_node: Option<String>,
}
impl ProfilingSummary {
pub fn from_node_profiles(profiles: &[NodeProfile]) -> Self {
let total_duration: Duration = profiles.iter().map(|p| p.duration).sum();
let mut op_type_durations: HashMap<String, Duration> = HashMap::new();
for profile in profiles {
let entry = op_type_durations
.entry(profile.op_type.clone())
.or_insert(Duration::ZERO);
*entry += profile.duration;
}
let mut top_nodes: Vec<(String, Duration)> = profiles
.iter()
.map(|p| (p.node_name.clone(), p.duration))
.collect();
top_nodes.sort_by_key(|b| std::cmp::Reverse(b.1));
let bottleneck_node = top_nodes.first().map(|(name, _)| name.clone());
top_nodes.truncate(20);
Self {
total_duration,
top_nodes,
op_type_durations,
bottleneck_node,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_onnx_session_config_default() {
let config = OnnxSessionConfig::default();
assert!(!config.enable_profiling);
assert!(!config.enable_memory_pool);
}
#[test]
fn test_profiling_summary_empty() {
let summary = ProfilingSummary::from_node_profiles(&[]);
assert_eq!(summary.total_duration, Duration::ZERO);
assert!(summary.top_nodes.is_empty());
assert!(summary.op_type_durations.is_empty());
assert!(summary.bottleneck_node.is_none());
}
#[test]
fn test_profiling_summary_from_profiles() {
let profiles = vec![
NodeProfile {
node_name: "matmul_1".to_string(),
op_type: "MatMul".to_string(),
duration: Duration::from_millis(10),
output_shapes: vec![vec![1, 512]],
},
NodeProfile {
node_name: "relu_1".to_string(),
op_type: "Relu".to_string(),
duration: Duration::from_millis(2),
output_shapes: vec![vec![1, 512]],
},
NodeProfile {
node_name: "matmul_2".to_string(),
op_type: "MatMul".to_string(),
duration: Duration::from_millis(8),
output_shapes: vec![vec![1, 256]],
},
];
let summary = ProfilingSummary::from_node_profiles(&profiles);
assert_eq!(summary.total_duration, Duration::from_millis(20));
assert_eq!(summary.bottleneck_node.as_deref(), Some("matmul_1"));
assert_eq!(summary.top_nodes.len(), 3);
assert_eq!(summary.top_nodes[0].0, "matmul_1");
assert_eq!(
*summary
.op_type_durations
.get("MatMul")
.expect("MatMul should exist"),
Duration::from_millis(18)
);
}
}