Skip to main content

ronn_api/
session.rs

1use crate::error::{Error, Result};
2use dashmap::DashMap;
3use ronn_core::ModelGraph;
4use ronn_core::tensor::Tensor;
5use ronn_graph::{OptimizationLevel, Optimizer};
6use ronn_onnx::LoadedModel;
7use ronn_providers::{ProviderRegistry, ProviderType};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::{debug, info};
11
12/// Options for configuring an inference session
13#[derive(Debug, Clone)]
14pub struct SessionOptions {
15    optimization_level: OptimizationLevel,
16    provider_type: ProviderType,
17    num_threads: Option<usize>,
18    enable_profiling: bool,
19}
20
21impl SessionOptions {
22    /// Create new session options with defaults
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    /// Set optimization level
28    pub fn with_optimization_level(mut self, level: OptimizationLevel) -> Self {
29        self.optimization_level = level;
30        self
31    }
32
33    /// Set execution provider
34    pub fn with_provider(mut self, provider: ProviderType) -> Self {
35        self.provider_type = provider;
36        self
37    }
38
39    /// Set number of threads for CPU execution
40    pub fn with_num_threads(mut self, num_threads: usize) -> Self {
41        self.num_threads = Some(num_threads);
42        self
43    }
44
45    /// Enable profiling
46    pub fn with_profiling(mut self, enable: bool) -> Self {
47        self.enable_profiling = enable;
48        self
49    }
50
51    /// Get optimization level
52    pub fn optimization_level(&self) -> OptimizationLevel {
53        self.optimization_level
54    }
55
56    /// Get provider type
57    pub fn provider_type(&self) -> ProviderType {
58        self.provider_type
59    }
60}
61
62impl Default for SessionOptions {
63    fn default() -> Self {
64        Self {
65            optimization_level: OptimizationLevel::O2,
66            provider_type: ProviderType::CPU,
67            num_threads: None,
68            enable_profiling: false,
69        }
70    }
71}
72
73/// Builder for creating inference sessions
74pub struct SessionBuilder {
75    model: Arc<LoadedModel>,
76    options: SessionOptions,
77}
78
79impl SessionBuilder {
80    /// Create a new session builder
81    pub fn new(model: Arc<LoadedModel>, options: SessionOptions) -> Self {
82        Self { model, options }
83    }
84
85    /// Build the inference session
86    pub fn build(self) -> Result<InferenceSession> {
87        info!(
88            "Building inference session with options: {:?}",
89            self.options
90        );
91
92        // Clone the model graph for optimization
93        let mut graph = self.model.graph().clone();
94
95        // Apply optimizations
96        let optimizer = Optimizer::new(self.options.optimization_level);
97        let stats = optimizer.optimize(&mut graph)?;
98        info!(
99            "Optimization completed: {} changes in {} iterations",
100            stats.total_changes(),
101            stats.iterations
102        );
103
104        // Initialize provider registry with available providers
105        let provider_registry = ronn_providers::create_provider_system().map_err(|e| {
106            Error::ProviderError(format!("Failed to create provider system: {}", e))
107        })?;
108
109        // Get the requested provider
110        let provider = provider_registry
111            .get_provider(self.options.provider_type)
112            .ok_or_else(|| {
113                Error::ProviderError(format!(
114                    "Provider {:?} not available",
115                    self.options.provider_type
116                ))
117            })?;
118
119        info!("Using execution provider: {:?}", provider.provider_id());
120
121        let provider_type = self.options.provider_type;
122
123        Ok(InferenceSession {
124            model: self.model,
125            graph,
126            options: self.options,
127            provider_registry,
128            provider_type,
129            value_cache: Arc::new(DashMap::new()),
130        })
131    }
132}
133
134/// An inference session for running a model
135pub struct InferenceSession {
136    model: Arc<LoadedModel>,
137    graph: ModelGraph,
138    options: SessionOptions,
139    provider_registry: ProviderRegistry,
140    provider_type: ProviderType,
141    value_cache: Arc<DashMap<String, Tensor>>,
142}
143
144impl InferenceSession {
145    /// Run inference synchronously
146    ///
147    /// # Example
148    /// ```no_run
149    /// use ronn_api::{Model, Tensor};
150    /// use ronn_core::{DataType, TensorLayout};
151    /// use std::collections::HashMap;
152    ///
153    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
154    /// let model = Model::load("model.onnx")?;
155    /// let session = model.create_session_default()?;
156    ///
157    /// let mut inputs = HashMap::new();
158    /// inputs.insert("input", Tensor::zeros(
159    ///     vec![1, 3, 224, 224],
160    ///     DataType::F32,
161    ///     TensorLayout::RowMajor
162    /// )?);
163    ///
164    /// let outputs = session.run(inputs)?;
165    /// # Ok(())
166    /// # }
167    /// ```
168    pub fn run(&self, inputs: HashMap<&str, Tensor>) -> Result<HashMap<String, Tensor>> {
169        debug!("Running inference with {} inputs", inputs.len());
170
171        // Validate inputs
172        self.validate_inputs(&inputs)?;
173
174        // Load initializers into cache
175        for (name, tensor) in self.model.initializers() {
176            self.value_cache.insert(name.clone(), tensor.clone());
177        }
178
179        // Load input tensors into cache
180        for (name, tensor) in inputs {
181            self.value_cache.insert(name.to_string(), tensor);
182        }
183
184        // Execute the graph
185        self.execute_graph()?;
186
187        // Collect outputs
188        let mut outputs = HashMap::new();
189        for output_info in self.model.outputs() {
190            if let Some(tensor) = self.value_cache.get(&output_info.name) {
191                outputs.insert(output_info.name.clone(), tensor.clone());
192            } else {
193                return Err(Error::InferenceError(format!(
194                    "Output tensor not found: {}",
195                    output_info.name
196                )));
197            }
198        }
199
200        debug!("Inference completed with {} outputs", outputs.len());
201        Ok(outputs)
202    }
203
204    /// Run inference asynchronously
205    pub async fn run_async(
206        &self,
207        inputs: HashMap<&str, Tensor>,
208    ) -> Result<HashMap<String, Tensor>> {
209        // For now, just wrap synchronous execution
210        // Full async implementation would use tokio::spawn_blocking
211        tokio::task::spawn_blocking(move || {
212            // This is a simplified version - full implementation would handle the move properly
213            Err(Error::InferenceError(
214                "Async inference not yet implemented".to_string(),
215            ))
216        })
217        .await
218        .map_err(|e| Error::InferenceError(format!("Async execution failed: {}", e)))?
219    }
220
221    /// Run inference on a batch of inputs
222    pub fn run_batch(
223        &self,
224        batch: Vec<HashMap<&str, Tensor>>,
225    ) -> Result<Vec<HashMap<String, Tensor>>> {
226        batch.into_iter().map(|inputs| self.run(inputs)).collect()
227    }
228
229    fn validate_inputs(&self, inputs: &HashMap<&str, Tensor>) -> Result<()> {
230        for input_info in self.model.inputs() {
231            if !inputs.contains_key(input_info.name.as_str()) {
232                return Err(Error::InvalidInput(format!(
233                    "Missing required input: {}",
234                    input_info.name
235                )));
236            }
237        }
238        Ok(())
239    }
240
241    fn execute_graph(&self) -> Result<()> {
242        // Execute nodes in topological order
243        for node in self.graph.nodes() {
244            debug!("Executing node: {} ({})", node.id, node.op_type);
245
246            // Collect input tensors
247            let input_tensors: Vec<Tensor> = node
248                .inputs
249                .iter()
250                .filter_map(|name| self.value_cache.get(name).map(|t| t.clone()))
251                .collect();
252
253            // Get the operator implementation
254            let op_registry = ronn_onnx::OperatorRegistry::new();
255            let op = op_registry.get(&node.op_type).map_err(|e| {
256                Error::InferenceError(format!("Operator {} not supported: {}", node.op_type, e))
257            })?;
258
259            // Execute the operator
260            let input_refs: Vec<&Tensor> = input_tensors.iter().collect();
261            let outputs = op
262                .execute(&input_refs, &node.attributes)
263                .map_err(|e| Error::InferenceError(format!("Operator execution failed: {}", e)))?;
264
265            // Store output tensors
266            for (i, tensor) in outputs.into_iter().enumerate() {
267                if i < node.outputs.len() {
268                    self.value_cache.insert(node.outputs[i].clone(), tensor);
269                }
270            }
271        }
272
273        Ok(())
274    }
275
276    /// Get session options
277    pub fn options(&self) -> &SessionOptions {
278        &self.options
279    }
280
281    /// Get the model graph
282    pub fn graph(&self) -> &ModelGraph {
283        &self.graph
284    }
285}