use crate::error::{Error, Result};
use dashmap::DashMap;
use ronn_core::ModelGraph;
use ronn_core::tensor::Tensor;
use ronn_graph::{OptimizationLevel, Optimizer};
use ronn_onnx::LoadedModel;
use ronn_providers::{ProviderRegistry, ProviderType};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct SessionOptions {
optimization_level: OptimizationLevel,
provider_type: ProviderType,
num_threads: Option<usize>,
enable_profiling: bool,
}
impl SessionOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_optimization_level(mut self, level: OptimizationLevel) -> Self {
self.optimization_level = level;
self
}
pub fn with_provider(mut self, provider: ProviderType) -> Self {
self.provider_type = provider;
self
}
pub fn with_num_threads(mut self, num_threads: usize) -> Self {
self.num_threads = Some(num_threads);
self
}
pub fn with_profiling(mut self, enable: bool) -> Self {
self.enable_profiling = enable;
self
}
pub fn optimization_level(&self) -> OptimizationLevel {
self.optimization_level
}
pub fn provider_type(&self) -> ProviderType {
self.provider_type
}
}
impl Default for SessionOptions {
fn default() -> Self {
Self {
optimization_level: OptimizationLevel::O2,
provider_type: ProviderType::CPU,
num_threads: None,
enable_profiling: false,
}
}
}
pub struct SessionBuilder {
model: Arc<LoadedModel>,
options: SessionOptions,
}
impl SessionBuilder {
pub fn new(model: Arc<LoadedModel>, options: SessionOptions) -> Self {
Self { model, options }
}
pub fn build(self) -> Result<InferenceSession> {
info!(
"Building inference session with options: {:?}",
self.options
);
let mut graph = self.model.graph().clone();
let optimizer = Optimizer::new(self.options.optimization_level);
let stats = optimizer.optimize(&mut graph)?;
info!(
"Optimization completed: {} changes in {} iterations",
stats.total_changes(),
stats.iterations
);
let provider_registry = ronn_providers::create_provider_system().map_err(|e| {
Error::ProviderError(format!("Failed to create provider system: {}", e))
})?;
let provider = provider_registry
.get_provider(self.options.provider_type)
.ok_or_else(|| {
Error::ProviderError(format!(
"Provider {:?} not available",
self.options.provider_type
))
})?;
info!("Using execution provider: {:?}", provider.provider_id());
let provider_type = self.options.provider_type;
Ok(InferenceSession {
model: self.model,
graph,
options: self.options,
provider_registry,
provider_type,
value_cache: Arc::new(DashMap::new()),
})
}
}
pub struct InferenceSession {
model: Arc<LoadedModel>,
graph: ModelGraph,
options: SessionOptions,
provider_registry: ProviderRegistry,
provider_type: ProviderType,
value_cache: Arc<DashMap<String, Tensor>>,
}
impl InferenceSession {
pub fn run(&self, inputs: HashMap<&str, Tensor>) -> Result<HashMap<String, Tensor>> {
debug!("Running inference with {} inputs", inputs.len());
self.validate_inputs(&inputs)?;
for (name, tensor) in self.model.initializers() {
self.value_cache.insert(name.clone(), tensor.clone());
}
for (name, tensor) in inputs {
self.value_cache.insert(name.to_string(), tensor);
}
self.execute_graph()?;
let mut outputs = HashMap::new();
for output_info in self.model.outputs() {
if let Some(tensor) = self.value_cache.get(&output_info.name) {
outputs.insert(output_info.name.clone(), tensor.clone());
} else {
return Err(Error::InferenceError(format!(
"Output tensor not found: {}",
output_info.name
)));
}
}
debug!("Inference completed with {} outputs", outputs.len());
Ok(outputs)
}
pub async fn run_async(
&self,
inputs: HashMap<&str, Tensor>,
) -> Result<HashMap<String, Tensor>> {
tokio::task::spawn_blocking(move || {
Err(Error::InferenceError(
"Async inference not yet implemented".to_string(),
))
})
.await
.map_err(|e| Error::InferenceError(format!("Async execution failed: {}", e)))?
}
pub fn run_batch(
&self,
batch: Vec<HashMap<&str, Tensor>>,
) -> Result<Vec<HashMap<String, Tensor>>> {
batch.into_iter().map(|inputs| self.run(inputs)).collect()
}
fn validate_inputs(&self, inputs: &HashMap<&str, Tensor>) -> Result<()> {
for input_info in self.model.inputs() {
if !inputs.contains_key(input_info.name.as_str()) {
return Err(Error::InvalidInput(format!(
"Missing required input: {}",
input_info.name
)));
}
}
Ok(())
}
fn execute_graph(&self) -> Result<()> {
for node in self.graph.nodes() {
debug!("Executing node: {} ({})", node.id, node.op_type);
let input_tensors: Vec<Tensor> = node
.inputs
.iter()
.filter_map(|name| self.value_cache.get(name).map(|t| t.clone()))
.collect();
let op_registry = ronn_onnx::OperatorRegistry::new();
let op = op_registry.get(&node.op_type).map_err(|e| {
Error::InferenceError(format!("Operator {} not supported: {}", node.op_type, e))
})?;
let input_refs: Vec<&Tensor> = input_tensors.iter().collect();
let outputs = op
.execute(&input_refs, &node.attributes)
.map_err(|e| Error::InferenceError(format!("Operator execution failed: {}", e)))?;
for (i, tensor) in outputs.into_iter().enumerate() {
if i < node.outputs.len() {
self.value_cache.insert(node.outputs[i].clone(), tensor);
}
}
}
Ok(())
}
pub fn options(&self) -> &SessionOptions {
&self.options
}
pub fn graph(&self) -> &ModelGraph {
&self.graph
}
}