ronn_api/async_session.rs
1//! Async inference API for non-blocking, high-throughput inference.
2//!
3//! This module provides async/await support for RONN inference sessions,
4//! enabling efficient concurrent request handling for production workloads.
5
6use crate::error::{Error, Result};
7use crate::{InferenceSession, Model, SessionOptions};
8use ronn_core::tensor::Tensor;
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::sync::RwLock;
12
13/// Async wrapper for inference sessions.
14///
15/// Provides non-blocking inference operations using Tokio runtime.
16/// Ideal for web services, concurrent request handling, and high-throughput scenarios.
17///
18/// # Example
19///
20/// ```ignore
21/// use ronn_api::AsyncSession;
22/// use std::collections::HashMap;
23///
24/// #[tokio::main]
25/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
26/// let session = AsyncSession::from_file("model.onnx").await?;
27///
28/// let inputs = HashMap::new(); // Add your inputs here
29/// let outputs = session.run(inputs).await?;
30/// Ok(())
31/// }
32/// ```
33pub struct AsyncSession {
34 inner: Arc<RwLock<InferenceSession>>,
35}
36
37impl AsyncSession {
38 /// Create a new async session from an ONNX model file.
39 ///
40 /// # Arguments
41 ///
42 /// * `path` - Path to the ONNX model file
43 ///
44 /// # Example
45 ///
46 /// ```no_run
47 /// # use ronn_api::AsyncSession;
48 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
49 /// let session = AsyncSession::from_file("model.onnx").await?;
50 /// # Ok(())
51 /// # }
52 /// ```
53 pub async fn from_file(path: impl AsRef<std::path::Path> + Send + 'static) -> Result<Self> {
54 let session = tokio::task::spawn_blocking(move || {
55 let model = Model::load(path)?;
56 model.create_session_default()
57 })
58 .await
59 .map_err(|e| Error::InferenceError(format!("Task join error: {}", e)))??;
60
61 Ok(Self {
62 inner: Arc::new(RwLock::new(session)),
63 })
64 }
65
66 /// Create a new async session with custom options.
67 ///
68 /// # Arguments
69 ///
70 /// * `path` - Path to the ONNX model file
71 /// * `options` - Session configuration options
72 ///
73 /// # Example
74 ///
75 /// ```no_run
76 /// # use ronn_api::{AsyncSession, SessionOptions, OptimizationLevel};
77 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
78 /// let options = SessionOptions::new()
79 /// .with_optimization_level(OptimizationLevel::O3);
80 ///
81 /// let session = AsyncSession::with_options("model.onnx", options).await?;
82 /// # Ok(())
83 /// # }
84 /// ```
85 pub async fn with_options(
86 path: impl AsRef<std::path::Path> + Send + 'static,
87 options: SessionOptions,
88 ) -> Result<Self> {
89 let session = tokio::task::spawn_blocking(move || {
90 let model = Model::load(path)?;
91 model.create_session(options)
92 })
93 .await
94 .map_err(|e| Error::InferenceError(format!("Task join error: {}", e)))??;
95
96 Ok(Self {
97 inner: Arc::new(RwLock::new(session)),
98 })
99 }
100
101 /// Run async inference on the provided inputs.
102 ///
103 /// This method is non-blocking and returns a future that resolves to the outputs.
104 ///
105 /// # Arguments
106 ///
107 /// * `inputs` - Map of input names to tensors
108 ///
109 /// # Returns
110 ///
111 /// A future that resolves to a map of output names to tensors.
112 ///
113 /// # Example
114 ///
115 /// ```no_run
116 /// # use ronn_api::AsyncSession;
117 /// # use std::collections::HashMap;
118 /// # async fn example(session: AsyncSession, inputs: HashMap<String, ronn_core::tensor::Tensor>) -> Result<(), Box<dyn std::error::Error>> {
119 /// let outputs = session.run(inputs).await?;
120 /// # Ok(())
121 /// # }
122 /// ```
123 pub async fn run(&self, inputs: HashMap<String, Tensor>) -> Result<HashMap<String, Tensor>> {
124 // Clone the inner Arc so we can move it into the blocking task
125 let session_arc = Arc::clone(&self.inner);
126
127 // Run inference in blocking thread pool to avoid blocking tokio runtime
128 tokio::task::spawn_blocking(move || {
129 let session = session_arc
130 .read()
131 .map_err(|e| Error::InferenceError(format!("Lock poisoned: {}", e)))?;
132 // Convert HashMap<String, Tensor> to HashMap<&str, Tensor>
133 let inputs_ref: HashMap<&str, Tensor> = inputs
134 .iter()
135 .map(|(k, v)| (k.as_str(), v.clone()))
136 .collect();
137 session.run(inputs_ref)
138 })
139 .await
140 .map_err(|e| Error::InferenceError(format!("Task join error: {}", e)))?
141 }
142
143 /// Run inference with read lock (allows concurrent reads if session supports it).
144 ///
145 /// # Note
146 ///
147 /// This is the same as `run()` for now, but could be optimized for
148 /// concurrent inference in the future.
149 pub async fn run_concurrent(
150 &self,
151 inputs: HashMap<String, Tensor>,
152 ) -> Result<HashMap<String, Tensor>> {
153 self.run(inputs).await
154 }
155
156 /// Clone the async session for sharing across tasks.
157 ///
158 /// This is cheap (only clones the Arc), allowing multiple tasks
159 /// to share the same session.
160 pub fn clone_handle(&self) -> Self {
161 Self {
162 inner: Arc::clone(&self.inner),
163 }
164 }
165}
166
167impl Clone for AsyncSession {
168 fn clone(&self) -> Self {
169 self.clone_handle()
170 }
171}
172
173/// Batch processor for async inference.
174///
175/// Collects multiple requests and processes them in batches for improved throughput.
176pub struct AsyncBatchProcessor {
177 session: AsyncSession,
178 max_batch_size: usize,
179 timeout_ms: u64,
180}
181
182impl AsyncBatchProcessor {
183 /// Create a new batch processor.
184 ///
185 /// # Arguments
186 ///
187 /// * `session` - The async session to use for inference
188 /// * `max_batch_size` - Maximum number of requests to batch together
189 /// * `timeout_ms` - Maximum time to wait for batch to fill (milliseconds)
190 ///
191 /// # Example
192 ///
193 /// ```no_run
194 /// # use ronn_api::{AsyncSession, AsyncBatchProcessor};
195 /// # async fn example(session: AsyncSession) {
196 /// let processor = AsyncBatchProcessor::new(
197 /// session,
198 /// 32, // batch up to 32 requests
199 /// 10, // wait max 10ms for batch to fill
200 /// );
201 /// # }
202 /// ```
203 pub fn new(session: AsyncSession, max_batch_size: usize, timeout_ms: u64) -> Self {
204 Self {
205 session,
206 max_batch_size,
207 timeout_ms,
208 }
209 }
210
211 /// Submit a request for batched inference.
212 ///
213 /// The request will be batched with other concurrent requests up to
214 /// `max_batch_size` or until `timeout_ms` elapses.
215 ///
216 /// # Arguments
217 ///
218 /// * `inputs` - Input tensors for this request
219 ///
220 /// # Returns
221 ///
222 /// Output tensors for this specific request
223 pub async fn infer(&self, inputs: HashMap<String, Tensor>) -> Result<HashMap<String, Tensor>> {
224 // For now, just pass through to session
225 // In a full implementation, this would use a channel to collect
226 // requests and batch them together
227 self.session.run(inputs).await
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[tokio::test]
236 async fn test_async_session_creation() {
237 // This would require a real model file
238 // Placeholder for when we have test models
239 }
240
241 #[tokio::test]
242 async fn test_concurrent_inference() {
243 // Test multiple concurrent inference requests
244 // Placeholder for when we have test models
245 }
246
247 #[tokio::test]
248 async fn test_batch_processor() {
249 // Test batch processing
250 // Placeholder for when we have test models
251 }
252}