Skip to main content

ronn_api/
async_session.rs

1//! Async inference API for non-blocking, high-throughput inference.
2//!
3//! This module provides async/await support for RONN inference sessions,
4//! enabling efficient concurrent request handling for production workloads.
5
6use crate::error::{Error, Result};
7use crate::{InferenceSession, Model, SessionOptions};
8use ronn_core::tensor::Tensor;
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::sync::RwLock;
12
13/// Async wrapper for inference sessions.
14///
15/// Provides non-blocking inference operations using Tokio runtime.
16/// Ideal for web services, concurrent request handling, and high-throughput scenarios.
17///
18/// # Example
19///
20/// ```ignore
21/// use ronn_api::AsyncSession;
22/// use std::collections::HashMap;
23///
24/// #[tokio::main]
25/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
26///     let session = AsyncSession::from_file("model.onnx").await?;
27///
28///     let inputs = HashMap::new(); // Add your inputs here
29///     let outputs = session.run(inputs).await?;
30///     Ok(())
31/// }
32/// ```
33pub struct AsyncSession {
34    inner: Arc<RwLock<InferenceSession>>,
35}
36
37impl AsyncSession {
38    /// Create a new async session from an ONNX model file.
39    ///
40    /// # Arguments
41    ///
42    /// * `path` - Path to the ONNX model file
43    ///
44    /// # Example
45    ///
46    /// ```no_run
47    /// # use ronn_api::AsyncSession;
48    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
49    /// let session = AsyncSession::from_file("model.onnx").await?;
50    /// # Ok(())
51    /// # }
52    /// ```
53    pub async fn from_file(path: impl AsRef<std::path::Path> + Send + 'static) -> Result<Self> {
54        let session = tokio::task::spawn_blocking(move || {
55            let model = Model::load(path)?;
56            model.create_session_default()
57        })
58        .await
59        .map_err(|e| Error::InferenceError(format!("Task join error: {}", e)))??;
60
61        Ok(Self {
62            inner: Arc::new(RwLock::new(session)),
63        })
64    }
65
66    /// Create a new async session with custom options.
67    ///
68    /// # Arguments
69    ///
70    /// * `path` - Path to the ONNX model file
71    /// * `options` - Session configuration options
72    ///
73    /// # Example
74    ///
75    /// ```no_run
76    /// # use ronn_api::{AsyncSession, SessionOptions, OptimizationLevel};
77    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
78    /// let options = SessionOptions::new()
79    ///     .with_optimization_level(OptimizationLevel::O3);
80    ///
81    /// let session = AsyncSession::with_options("model.onnx", options).await?;
82    /// # Ok(())
83    /// # }
84    /// ```
85    pub async fn with_options(
86        path: impl AsRef<std::path::Path> + Send + 'static,
87        options: SessionOptions,
88    ) -> Result<Self> {
89        let session = tokio::task::spawn_blocking(move || {
90            let model = Model::load(path)?;
91            model.create_session(options)
92        })
93        .await
94        .map_err(|e| Error::InferenceError(format!("Task join error: {}", e)))??;
95
96        Ok(Self {
97            inner: Arc::new(RwLock::new(session)),
98        })
99    }
100
101    /// Run async inference on the provided inputs.
102    ///
103    /// This method is non-blocking and returns a future that resolves to the outputs.
104    ///
105    /// # Arguments
106    ///
107    /// * `inputs` - Map of input names to tensors
108    ///
109    /// # Returns
110    ///
111    /// A future that resolves to a map of output names to tensors.
112    ///
113    /// # Example
114    ///
115    /// ```no_run
116    /// # use ronn_api::AsyncSession;
117    /// # use std::collections::HashMap;
118    /// # async fn example(session: AsyncSession, inputs: HashMap<String, ronn_core::tensor::Tensor>) -> Result<(), Box<dyn std::error::Error>> {
119    /// let outputs = session.run(inputs).await?;
120    /// # Ok(())
121    /// # }
122    /// ```
123    pub async fn run(&self, inputs: HashMap<String, Tensor>) -> Result<HashMap<String, Tensor>> {
124        // Clone the inner Arc so we can move it into the blocking task
125        let session_arc = Arc::clone(&self.inner);
126
127        // Run inference in blocking thread pool to avoid blocking tokio runtime
128        tokio::task::spawn_blocking(move || {
129            let session = session_arc
130                .read()
131                .map_err(|e| Error::InferenceError(format!("Lock poisoned: {}", e)))?;
132            // Convert HashMap<String, Tensor> to HashMap<&str, Tensor>
133            let inputs_ref: HashMap<&str, Tensor> = inputs
134                .iter()
135                .map(|(k, v)| (k.as_str(), v.clone()))
136                .collect();
137            session.run(inputs_ref)
138        })
139        .await
140        .map_err(|e| Error::InferenceError(format!("Task join error: {}", e)))?
141    }
142
143    /// Run inference with read lock (allows concurrent reads if session supports it).
144    ///
145    /// # Note
146    ///
147    /// This is the same as `run()` for now, but could be optimized for
148    /// concurrent inference in the future.
149    pub async fn run_concurrent(
150        &self,
151        inputs: HashMap<String, Tensor>,
152    ) -> Result<HashMap<String, Tensor>> {
153        self.run(inputs).await
154    }
155
156    /// Clone the async session for sharing across tasks.
157    ///
158    /// This is cheap (only clones the Arc), allowing multiple tasks
159    /// to share the same session.
160    pub fn clone_handle(&self) -> Self {
161        Self {
162            inner: Arc::clone(&self.inner),
163        }
164    }
165}
166
167impl Clone for AsyncSession {
168    fn clone(&self) -> Self {
169        self.clone_handle()
170    }
171}
172
173/// Batch processor for async inference.
174///
175/// Collects multiple requests and processes them in batches for improved throughput.
176pub struct AsyncBatchProcessor {
177    session: AsyncSession,
178    max_batch_size: usize,
179    timeout_ms: u64,
180}
181
182impl AsyncBatchProcessor {
183    /// Create a new batch processor.
184    ///
185    /// # Arguments
186    ///
187    /// * `session` - The async session to use for inference
188    /// * `max_batch_size` - Maximum number of requests to batch together
189    /// * `timeout_ms` - Maximum time to wait for batch to fill (milliseconds)
190    ///
191    /// # Example
192    ///
193    /// ```no_run
194    /// # use ronn_api::{AsyncSession, AsyncBatchProcessor};
195    /// # async fn example(session: AsyncSession) {
196    /// let processor = AsyncBatchProcessor::new(
197    ///     session,
198    ///     32,    // batch up to 32 requests
199    ///     10,    // wait max 10ms for batch to fill
200    /// );
201    /// # }
202    /// ```
203    pub fn new(session: AsyncSession, max_batch_size: usize, timeout_ms: u64) -> Self {
204        Self {
205            session,
206            max_batch_size,
207            timeout_ms,
208        }
209    }
210
211    /// Submit a request for batched inference.
212    ///
213    /// The request will be batched with other concurrent requests up to
214    /// `max_batch_size` or until `timeout_ms` elapses.
215    ///
216    /// # Arguments
217    ///
218    /// * `inputs` - Input tensors for this request
219    ///
220    /// # Returns
221    ///
222    /// Output tensors for this specific request
223    pub async fn infer(&self, inputs: HashMap<String, Tensor>) -> Result<HashMap<String, Tensor>> {
224        // For now, just pass through to session
225        // In a full implementation, this would use a channel to collect
226        // requests and batch them together
227        self.session.run(inputs).await
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[tokio::test]
236    async fn test_async_session_creation() {
237        // This would require a real model file
238        // Placeholder for when we have test models
239    }
240
241    #[tokio::test]
242    async fn test_concurrent_inference() {
243        // Test multiple concurrent inference requests
244        // Placeholder for when we have test models
245    }
246
247    #[tokio::test]
248    async fn test_batch_processor() {
249        // Test batch processing
250        // Placeholder for when we have test models
251    }
252}