Skip to main content

ronn_api/
batch.rs

1//! Batch processing for high-throughput inference
2//!
3//! Provides static and dynamic batching to maximize GPU/CPU utilization
4//! and achieve 3-10x throughput improvements.
5
6use crate::InferenceSession;
7use crate::error::{Error, Result};
8use ronn_core::tensor::Tensor;
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::{RwLock, mpsc};
13use tokio::time::timeout;
14
15/// Batch processing strategy
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum BatchStrategy {
18    /// Fixed batch size - waits until batch is full
19    Static {
20        /// Target batch size
21        batch_size: usize,
22    },
23    /// Dynamic batching - fills batch up to max size or timeout
24    Dynamic {
25        /// Maximum batch size
26        max_batch_size: usize,
27        /// Maximum wait time before processing partial batch
28        timeout_ms: u64,
29    },
30}
31
32impl Default for BatchStrategy {
33    fn default() -> Self {
34        Self::Dynamic {
35            max_batch_size: 32,
36            timeout_ms: 10,
37        }
38    }
39}
40
41/// Configuration for batch processor
42#[derive(Debug, Clone)]
43pub struct BatchConfig {
44    /// Batching strategy
45    pub strategy: BatchStrategy,
46    /// Queue capacity for incoming requests
47    pub queue_capacity: usize,
48    /// Number of worker threads
49    pub num_workers: usize,
50}
51
52impl Default for BatchConfig {
53    fn default() -> Self {
54        Self {
55            strategy: BatchStrategy::default(),
56            queue_capacity: 1024,
57            num_workers: 1,
58        }
59    }
60}
61
62/// A single inference request
63pub struct BatchRequest {
64    /// Input tensors for this request
65    pub inputs: HashMap<String, Tensor>,
66    /// Channel to send result back
67    response_tx: tokio::sync::oneshot::Sender<Result<HashMap<String, Tensor>>>,
68}
69
70impl BatchRequest {
71    /// Create a new batch request
72    pub fn new(
73        inputs: HashMap<String, Tensor>,
74        response_tx: tokio::sync::oneshot::Sender<Result<HashMap<String, Tensor>>>,
75    ) -> Self {
76        Self {
77            inputs,
78            response_tx,
79        }
80    }
81
82    /// Send the response back to the caller
83    fn send_response(self, result: Result<HashMap<String, Tensor>>) {
84        let _ = self.response_tx.send(result);
85    }
86}
87
88/// Batch processor for high-throughput inference
89///
90/// Automatically batches incoming requests according to the configured strategy,
91/// executes them in a single forward pass, and returns individual results.
92///
93/// # Performance
94///
95/// - Static batching: 3-5x throughput for stable workloads
96/// - Dynamic batching: 5-10x throughput with variable request rates
97/// - Optimal for GPU inference where batch processing is highly efficient
98///
99/// # Example
100///
101/// ```ignore
102/// use ronn_api::{Model, SessionOptions, BatchProcessor, BatchConfig, BatchStrategy};
103/// use std::collections::HashMap;
104///
105/// async fn example() -> Result<(), Box<dyn std::error::Error>> {
106///     let model = Model::load("model.onnx")?;
107///     let session = model.create_session(SessionOptions::default())?;
108///
109///     let config = BatchConfig {
110///         strategy: BatchStrategy::Dynamic {
111///             max_batch_size: 32,
112///             timeout_ms: 10,
113///         },
114///         ..Default::default()
115///     };
116///
117///     let processor = BatchProcessor::new(session, config);
118///
119///     // Submit requests - they will be automatically batched
120///     let inputs = HashMap::new(); // Add your inputs here
121///     let output = processor.process(inputs).await?;
122///     Ok(())
123/// }
124/// ```
125pub struct BatchProcessor {
126    /// Request queue
127    request_tx: mpsc::Sender<BatchRequest>,
128    /// Worker handle
129    _worker_handle: tokio::task::JoinHandle<()>,
130    /// Configuration
131    config: BatchConfig,
132}
133
134impl BatchProcessor {
135    /// Create a new batch processor
136    pub fn new(session: InferenceSession, config: BatchConfig) -> Self {
137        let (request_tx, request_rx) = mpsc::channel(config.queue_capacity);
138
139        let worker_config = config.clone();
140        let worker_handle = tokio::spawn(async move {
141            Self::worker_loop(session, request_rx, worker_config).await;
142        });
143
144        Self {
145            request_tx,
146            _worker_handle: worker_handle,
147            config,
148        }
149    }
150
151    /// Submit a request for batch processing
152    ///
153    /// # Arguments
154    ///
155    /// * `inputs` - Input tensors for inference
156    ///
157    /// # Returns
158    ///
159    /// Future that resolves to the inference outputs
160    pub async fn process(
161        &self,
162        inputs: HashMap<String, Tensor>,
163    ) -> Result<HashMap<String, Tensor>> {
164        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
165
166        let request = BatchRequest::new(inputs, response_tx);
167
168        self.request_tx
169            .send(request)
170            .await
171            .map_err(|_| Error::InferenceError("Batch processor channel closed".to_string()))?;
172
173        response_rx
174            .await
175            .map_err(|_| Error::InferenceError("Response channel closed".to_string()))?
176    }
177
178    /// Main worker loop - collects requests and processes batches
179    async fn worker_loop(
180        session: InferenceSession,
181        mut request_rx: mpsc::Receiver<BatchRequest>,
182        config: BatchConfig,
183    ) {
184        let session = Arc::new(RwLock::new(session));
185
186        loop {
187            match config.strategy {
188                BatchStrategy::Static { batch_size } => {
189                    let batch = Self::collect_static_batch(&mut request_rx, batch_size).await;
190                    if batch.is_empty() {
191                        break; // Channel closed
192                    }
193                    Self::process_batch(session.clone(), batch).await;
194                }
195                BatchStrategy::Dynamic {
196                    max_batch_size,
197                    timeout_ms,
198                } => {
199                    let batch =
200                        Self::collect_dynamic_batch(&mut request_rx, max_batch_size, timeout_ms)
201                            .await;
202                    if batch.is_empty() {
203                        break; // Channel closed
204                    }
205                    Self::process_batch(session.clone(), batch).await;
206                }
207            }
208        }
209    }
210
211    /// Collect a static batch - waits until batch_size requests are available
212    async fn collect_static_batch(
213        request_rx: &mut mpsc::Receiver<BatchRequest>,
214        batch_size: usize,
215    ) -> Vec<BatchRequest> {
216        let mut batch = Vec::with_capacity(batch_size);
217
218        for _ in 0..batch_size {
219            match request_rx.recv().await {
220                Some(request) => batch.push(request),
221                None => break, // Channel closed
222            }
223        }
224
225        batch
226    }
227
228    /// Collect a dynamic batch - fills up to max_batch_size or until timeout
229    async fn collect_dynamic_batch(
230        request_rx: &mut mpsc::Receiver<BatchRequest>,
231        max_batch_size: usize,
232        timeout_ms: u64,
233    ) -> Vec<BatchRequest> {
234        let mut batch = Vec::with_capacity(max_batch_size);
235        let deadline = Duration::from_millis(timeout_ms);
236
237        // Get first request (blocking)
238        match request_rx.recv().await {
239            Some(request) => batch.push(request),
240            None => return batch, // Channel closed
241        }
242
243        // Collect additional requests until timeout or batch full
244        let start = Instant::now();
245        while batch.len() < max_batch_size {
246            let remaining = deadline.saturating_sub(start.elapsed());
247            if remaining.is_zero() {
248                break;
249            }
250
251            match timeout(remaining, request_rx.recv()).await {
252                Ok(Some(request)) => batch.push(request),
253                Ok(None) => break, // Channel closed
254                Err(_) => break,   // Timeout
255            }
256        }
257
258        batch
259    }
260
261    /// Process a batch of requests
262    async fn process_batch(session: Arc<RwLock<InferenceSession>>, batch: Vec<BatchRequest>) {
263        if batch.is_empty() {
264            return;
265        }
266
267        // Combine inputs into batched tensors
268        let batch_size = batch.len();
269        let combined_inputs = match Self::combine_inputs(&batch) {
270            Ok(inputs) => inputs,
271            Err(e) => {
272                // Send error to all requests
273                let err_msg = format!("{}", e);
274                for request in batch {
275                    request.send_response(Err(Error::InferenceError(err_msg.clone())));
276                }
277                return;
278            }
279        };
280
281        // Convert HashMap<String, Tensor> to HashMap<&str, Tensor>
282        let inputs_ref: HashMap<&str, Tensor> = combined_inputs
283            .iter()
284            .map(|(k, v)| (k.as_str(), v.clone()))
285            .collect();
286
287        // Run inference on batched inputs
288        let session = session.read().await;
289        let combined_outputs = match session.run(inputs_ref) {
290            Ok(outputs) => outputs,
291            Err(e) => {
292                // Send error to all requests
293                let err_msg = format!("{}", e);
294                for request in batch {
295                    request.send_response(Err(Error::InferenceError(err_msg.clone())));
296                }
297                return;
298            }
299        };
300
301        // Split outputs and send back to individual requests
302        match Self::split_outputs(combined_outputs, batch_size) {
303            Ok(individual_outputs) => {
304                for (request, outputs) in batch.into_iter().zip(individual_outputs) {
305                    request.send_response(Ok(outputs));
306                }
307            }
308            Err(e) => {
309                // Send error to all requests
310                let err_msg = format!("{}", e);
311                for request in batch {
312                    request.send_response(Err(Error::InferenceError(err_msg.clone())));
313                }
314            }
315        }
316    }
317
318    /// Combine multiple requests' inputs into batched tensors
319    fn combine_inputs(batch: &[BatchRequest]) -> Result<HashMap<String, Tensor>> {
320        if batch.is_empty() {
321            return Ok(HashMap::new());
322        }
323
324        // Get all input names from first request
325        let input_names: Vec<_> = batch[0].inputs.keys().cloned().collect();
326
327        let mut combined = HashMap::new();
328
329        for name in input_names {
330            // Collect all tensors for this input
331            let tensors: std::result::Result<Vec<_>, Error> = batch
332                .iter()
333                .map(|req| {
334                    req.inputs.get(&name).ok_or_else(|| {
335                        Error::InvalidInput(format!("Missing input tensor: {}", name))
336                    })
337                })
338                .collect();
339            let tensors = tensors?;
340
341            // Stack tensors along batch dimension (dim 0)
342            let batched = Tensor::stack(&tensors, 0)
343                .map_err(|e| Error::InferenceError(format!("Failed to stack tensors: {}", e)))?;
344            combined.insert(name, batched);
345        }
346
347        Ok(combined)
348    }
349
350    /// Split batched outputs into individual results
351    fn split_outputs(
352        combined: HashMap<String, Tensor>,
353        batch_size: usize,
354    ) -> Result<Vec<HashMap<String, Tensor>>> {
355        let mut results = vec![HashMap::new(); batch_size];
356
357        for (name, batched_tensor) in combined {
358            // Split along batch dimension (dim 0)
359            let individual_tensors = batched_tensor
360                .split(batch_size, 0)
361                .map_err(|e| Error::InferenceError(format!("Failed to split tensors: {}", e)))?;
362
363            for (i, tensor) in individual_tensors.into_iter().enumerate() {
364                results[i].insert(name.clone(), tensor);
365            }
366        }
367
368        Ok(results)
369    }
370
371    /// Get the current configuration
372    pub fn config(&self) -> &BatchConfig {
373        &self.config
374    }
375}
376
377/// Statistics about batch processing
378#[derive(Debug, Clone, Default)]
379pub struct BatchStats {
380    /// Total number of batches processed
381    pub total_batches: u64,
382    /// Total number of individual requests processed
383    pub total_requests: u64,
384    /// Average batch size
385    pub avg_batch_size: f64,
386    /// Maximum batch size seen
387    pub max_batch_size: usize,
388    /// Minimum batch size seen
389    pub min_batch_size: usize,
390    /// Total processing time
391    pub total_processing_time_ms: f64,
392    /// Average processing time per batch
393    pub avg_batch_time_ms: f64,
394}
395
396impl BatchStats {
397    /// Calculate throughput (requests per second)
398    pub fn throughput(&self) -> f64 {
399        if self.total_processing_time_ms == 0.0 {
400            0.0
401        } else {
402            (self.total_requests as f64 * 1000.0) / self.total_processing_time_ms
403        }
404    }
405
406    /// Calculate batch utilization (actual vs max batch size)
407    pub fn utilization(&self, max_batch_size: usize) -> f64 {
408        if max_batch_size == 0 {
409            0.0
410        } else {
411            self.avg_batch_size / max_batch_size as f64
412        }
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn test_batch_config_default() {
422        let config = BatchConfig::default();
423        assert_eq!(config.queue_capacity, 1024);
424        assert_eq!(config.num_workers, 1);
425        match config.strategy {
426            BatchStrategy::Dynamic {
427                max_batch_size,
428                timeout_ms,
429            } => {
430                assert_eq!(max_batch_size, 32);
431                assert_eq!(timeout_ms, 10);
432            }
433            _ => panic!("Expected dynamic strategy"),
434        }
435    }
436
437    #[test]
438    fn test_batch_strategy_static() {
439        let strategy = BatchStrategy::Static { batch_size: 16 };
440        match strategy {
441            BatchStrategy::Static { batch_size } => {
442                assert_eq!(batch_size, 16);
443            }
444            _ => panic!("Expected static strategy"),
445        }
446    }
447
448    #[test]
449    fn test_batch_stats_throughput() {
450        let stats = BatchStats {
451            total_requests: 1000,
452            total_processing_time_ms: 1000.0,
453            ..Default::default()
454        };
455        assert_eq!(stats.throughput(), 1000.0); // 1000 req/s
456    }
457
458    #[test]
459    fn test_batch_stats_utilization() {
460        let stats = BatchStats {
461            avg_batch_size: 16.0,
462            ..Default::default()
463        };
464        assert_eq!(stats.utilization(32), 0.5); // 50% utilization
465    }
466}