1use crate::error::{Error, Result};
2use dashmap::DashMap;
3use ronn_core::ModelGraph;
4use ronn_core::tensor::Tensor;
5use ronn_graph::{OptimizationLevel, Optimizer};
6use ronn_onnx::LoadedModel;
7use ronn_providers::{ProviderRegistry, ProviderType};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::{debug, info};
11
12#[derive(Debug, Clone)]
14pub struct SessionOptions {
15 optimization_level: OptimizationLevel,
16 provider_type: ProviderType,
17 num_threads: Option<usize>,
18 enable_profiling: bool,
19}
20
21impl SessionOptions {
22 pub fn new() -> Self {
24 Self::default()
25 }
26
27 pub fn with_optimization_level(mut self, level: OptimizationLevel) -> Self {
29 self.optimization_level = level;
30 self
31 }
32
33 pub fn with_provider(mut self, provider: ProviderType) -> Self {
35 self.provider_type = provider;
36 self
37 }
38
39 pub fn with_num_threads(mut self, num_threads: usize) -> Self {
41 self.num_threads = Some(num_threads);
42 self
43 }
44
45 pub fn with_profiling(mut self, enable: bool) -> Self {
47 self.enable_profiling = enable;
48 self
49 }
50
51 pub fn optimization_level(&self) -> OptimizationLevel {
53 self.optimization_level
54 }
55
56 pub fn provider_type(&self) -> ProviderType {
58 self.provider_type
59 }
60}
61
62impl Default for SessionOptions {
63 fn default() -> Self {
64 Self {
65 optimization_level: OptimizationLevel::O2,
66 provider_type: ProviderType::CPU,
67 num_threads: None,
68 enable_profiling: false,
69 }
70 }
71}
72
73pub struct SessionBuilder {
75 model: Arc<LoadedModel>,
76 options: SessionOptions,
77}
78
79impl SessionBuilder {
80 pub fn new(model: Arc<LoadedModel>, options: SessionOptions) -> Self {
82 Self { model, options }
83 }
84
85 pub fn build(self) -> Result<InferenceSession> {
87 info!(
88 "Building inference session with options: {:?}",
89 self.options
90 );
91
92 let mut graph = self.model.graph().clone();
94
95 let optimizer = Optimizer::new(self.options.optimization_level);
97 let stats = optimizer.optimize(&mut graph)?;
98 info!(
99 "Optimization completed: {} changes in {} iterations",
100 stats.total_changes(),
101 stats.iterations
102 );
103
104 let provider_registry = ronn_providers::create_provider_system().map_err(|e| {
106 Error::ProviderError(format!("Failed to create provider system: {}", e))
107 })?;
108
109 let provider = provider_registry
111 .get_provider(self.options.provider_type)
112 .ok_or_else(|| {
113 Error::ProviderError(format!(
114 "Provider {:?} not available",
115 self.options.provider_type
116 ))
117 })?;
118
119 info!("Using execution provider: {:?}", provider.provider_id());
120
121 let provider_type = self.options.provider_type;
122
123 Ok(InferenceSession {
124 model: self.model,
125 graph,
126 options: self.options,
127 provider_registry,
128 provider_type,
129 value_cache: Arc::new(DashMap::new()),
130 })
131 }
132}
133
134pub struct InferenceSession {
136 model: Arc<LoadedModel>,
137 graph: ModelGraph,
138 options: SessionOptions,
139 provider_registry: ProviderRegistry,
140 provider_type: ProviderType,
141 value_cache: Arc<DashMap<String, Tensor>>,
142}
143
144impl InferenceSession {
145 pub fn run(&self, inputs: HashMap<&str, Tensor>) -> Result<HashMap<String, Tensor>> {
169 debug!("Running inference with {} inputs", inputs.len());
170
171 self.validate_inputs(&inputs)?;
173
174 for (name, tensor) in self.model.initializers() {
176 self.value_cache.insert(name.clone(), tensor.clone());
177 }
178
179 for (name, tensor) in inputs {
181 self.value_cache.insert(name.to_string(), tensor);
182 }
183
184 self.execute_graph()?;
186
187 let mut outputs = HashMap::new();
189 for output_info in self.model.outputs() {
190 if let Some(tensor) = self.value_cache.get(&output_info.name) {
191 outputs.insert(output_info.name.clone(), tensor.clone());
192 } else {
193 return Err(Error::InferenceError(format!(
194 "Output tensor not found: {}",
195 output_info.name
196 )));
197 }
198 }
199
200 debug!("Inference completed with {} outputs", outputs.len());
201 Ok(outputs)
202 }
203
204 pub async fn run_async(
206 &self,
207 inputs: HashMap<&str, Tensor>,
208 ) -> Result<HashMap<String, Tensor>> {
209 tokio::task::spawn_blocking(move || {
212 Err(Error::InferenceError(
214 "Async inference not yet implemented".to_string(),
215 ))
216 })
217 .await
218 .map_err(|e| Error::InferenceError(format!("Async execution failed: {}", e)))?
219 }
220
221 pub fn run_batch(
223 &self,
224 batch: Vec<HashMap<&str, Tensor>>,
225 ) -> Result<Vec<HashMap<String, Tensor>>> {
226 batch.into_iter().map(|inputs| self.run(inputs)).collect()
227 }
228
229 fn validate_inputs(&self, inputs: &HashMap<&str, Tensor>) -> Result<()> {
230 for input_info in self.model.inputs() {
231 if !inputs.contains_key(input_info.name.as_str()) {
232 return Err(Error::InvalidInput(format!(
233 "Missing required input: {}",
234 input_info.name
235 )));
236 }
237 }
238 Ok(())
239 }
240
241 fn execute_graph(&self) -> Result<()> {
242 for node in self.graph.nodes() {
244 debug!("Executing node: {} ({})", node.id, node.op_type);
245
246 let input_tensors: Vec<Tensor> = node
248 .inputs
249 .iter()
250 .filter_map(|name| self.value_cache.get(name).map(|t| t.clone()))
251 .collect();
252
253 let op_registry = ronn_onnx::OperatorRegistry::new();
255 let op = op_registry.get(&node.op_type).map_err(|e| {
256 Error::InferenceError(format!("Operator {} not supported: {}", node.op_type, e))
257 })?;
258
259 let input_refs: Vec<&Tensor> = input_tensors.iter().collect();
261 let outputs = op
262 .execute(&input_refs, &node.attributes)
263 .map_err(|e| Error::InferenceError(format!("Operator execution failed: {}", e)))?;
264
265 for (i, tensor) in outputs.into_iter().enumerate() {
267 if i < node.outputs.len() {
268 self.value_cache.insert(node.outputs[i].clone(), tensor);
269 }
270 }
271 }
272
273 Ok(())
274 }
275
276 pub fn options(&self) -> &SessionOptions {
278 &self.options
279 }
280
281 pub fn graph(&self) -> &ModelGraph {
283 &self.graph
284 }
285}