1use crate::InferenceSession;
7use crate::error::{Error, Result};
8use ronn_core::tensor::Tensor;
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::{RwLock, mpsc};
13use tokio::time::timeout;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum BatchStrategy {
18 Static {
20 batch_size: usize,
22 },
23 Dynamic {
25 max_batch_size: usize,
27 timeout_ms: u64,
29 },
30}
31
32impl Default for BatchStrategy {
33 fn default() -> Self {
34 Self::Dynamic {
35 max_batch_size: 32,
36 timeout_ms: 10,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct BatchConfig {
44 pub strategy: BatchStrategy,
46 pub queue_capacity: usize,
48 pub num_workers: usize,
50}
51
52impl Default for BatchConfig {
53 fn default() -> Self {
54 Self {
55 strategy: BatchStrategy::default(),
56 queue_capacity: 1024,
57 num_workers: 1,
58 }
59 }
60}
61
62pub struct BatchRequest {
64 pub inputs: HashMap<String, Tensor>,
66 response_tx: tokio::sync::oneshot::Sender<Result<HashMap<String, Tensor>>>,
68}
69
70impl BatchRequest {
71 pub fn new(
73 inputs: HashMap<String, Tensor>,
74 response_tx: tokio::sync::oneshot::Sender<Result<HashMap<String, Tensor>>>,
75 ) -> Self {
76 Self {
77 inputs,
78 response_tx,
79 }
80 }
81
82 fn send_response(self, result: Result<HashMap<String, Tensor>>) {
84 let _ = self.response_tx.send(result);
85 }
86}
87
88pub struct BatchProcessor {
126 request_tx: mpsc::Sender<BatchRequest>,
128 _worker_handle: tokio::task::JoinHandle<()>,
130 config: BatchConfig,
132}
133
134impl BatchProcessor {
135 pub fn new(session: InferenceSession, config: BatchConfig) -> Self {
137 let (request_tx, request_rx) = mpsc::channel(config.queue_capacity);
138
139 let worker_config = config.clone();
140 let worker_handle = tokio::spawn(async move {
141 Self::worker_loop(session, request_rx, worker_config).await;
142 });
143
144 Self {
145 request_tx,
146 _worker_handle: worker_handle,
147 config,
148 }
149 }
150
151 pub async fn process(
161 &self,
162 inputs: HashMap<String, Tensor>,
163 ) -> Result<HashMap<String, Tensor>> {
164 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
165
166 let request = BatchRequest::new(inputs, response_tx);
167
168 self.request_tx
169 .send(request)
170 .await
171 .map_err(|_| Error::InferenceError("Batch processor channel closed".to_string()))?;
172
173 response_rx
174 .await
175 .map_err(|_| Error::InferenceError("Response channel closed".to_string()))?
176 }
177
178 async fn worker_loop(
180 session: InferenceSession,
181 mut request_rx: mpsc::Receiver<BatchRequest>,
182 config: BatchConfig,
183 ) {
184 let session = Arc::new(RwLock::new(session));
185
186 loop {
187 match config.strategy {
188 BatchStrategy::Static { batch_size } => {
189 let batch = Self::collect_static_batch(&mut request_rx, batch_size).await;
190 if batch.is_empty() {
191 break; }
193 Self::process_batch(session.clone(), batch).await;
194 }
195 BatchStrategy::Dynamic {
196 max_batch_size,
197 timeout_ms,
198 } => {
199 let batch =
200 Self::collect_dynamic_batch(&mut request_rx, max_batch_size, timeout_ms)
201 .await;
202 if batch.is_empty() {
203 break; }
205 Self::process_batch(session.clone(), batch).await;
206 }
207 }
208 }
209 }
210
211 async fn collect_static_batch(
213 request_rx: &mut mpsc::Receiver<BatchRequest>,
214 batch_size: usize,
215 ) -> Vec<BatchRequest> {
216 let mut batch = Vec::with_capacity(batch_size);
217
218 for _ in 0..batch_size {
219 match request_rx.recv().await {
220 Some(request) => batch.push(request),
221 None => break, }
223 }
224
225 batch
226 }
227
228 async fn collect_dynamic_batch(
230 request_rx: &mut mpsc::Receiver<BatchRequest>,
231 max_batch_size: usize,
232 timeout_ms: u64,
233 ) -> Vec<BatchRequest> {
234 let mut batch = Vec::with_capacity(max_batch_size);
235 let deadline = Duration::from_millis(timeout_ms);
236
237 match request_rx.recv().await {
239 Some(request) => batch.push(request),
240 None => return batch, }
242
243 let start = Instant::now();
245 while batch.len() < max_batch_size {
246 let remaining = deadline.saturating_sub(start.elapsed());
247 if remaining.is_zero() {
248 break;
249 }
250
251 match timeout(remaining, request_rx.recv()).await {
252 Ok(Some(request)) => batch.push(request),
253 Ok(None) => break, Err(_) => break, }
256 }
257
258 batch
259 }
260
261 async fn process_batch(session: Arc<RwLock<InferenceSession>>, batch: Vec<BatchRequest>) {
263 if batch.is_empty() {
264 return;
265 }
266
267 let batch_size = batch.len();
269 let combined_inputs = match Self::combine_inputs(&batch) {
270 Ok(inputs) => inputs,
271 Err(e) => {
272 let err_msg = format!("{}", e);
274 for request in batch {
275 request.send_response(Err(Error::InferenceError(err_msg.clone())));
276 }
277 return;
278 }
279 };
280
281 let inputs_ref: HashMap<&str, Tensor> = combined_inputs
283 .iter()
284 .map(|(k, v)| (k.as_str(), v.clone()))
285 .collect();
286
287 let session = session.read().await;
289 let combined_outputs = match session.run(inputs_ref) {
290 Ok(outputs) => outputs,
291 Err(e) => {
292 let err_msg = format!("{}", e);
294 for request in batch {
295 request.send_response(Err(Error::InferenceError(err_msg.clone())));
296 }
297 return;
298 }
299 };
300
301 match Self::split_outputs(combined_outputs, batch_size) {
303 Ok(individual_outputs) => {
304 for (request, outputs) in batch.into_iter().zip(individual_outputs) {
305 request.send_response(Ok(outputs));
306 }
307 }
308 Err(e) => {
309 let err_msg = format!("{}", e);
311 for request in batch {
312 request.send_response(Err(Error::InferenceError(err_msg.clone())));
313 }
314 }
315 }
316 }
317
318 fn combine_inputs(batch: &[BatchRequest]) -> Result<HashMap<String, Tensor>> {
320 if batch.is_empty() {
321 return Ok(HashMap::new());
322 }
323
324 let input_names: Vec<_> = batch[0].inputs.keys().cloned().collect();
326
327 let mut combined = HashMap::new();
328
329 for name in input_names {
330 let tensors: std::result::Result<Vec<_>, Error> = batch
332 .iter()
333 .map(|req| {
334 req.inputs.get(&name).ok_or_else(|| {
335 Error::InvalidInput(format!("Missing input tensor: {}", name))
336 })
337 })
338 .collect();
339 let tensors = tensors?;
340
341 let batched = Tensor::stack(&tensors, 0)
343 .map_err(|e| Error::InferenceError(format!("Failed to stack tensors: {}", e)))?;
344 combined.insert(name, batched);
345 }
346
347 Ok(combined)
348 }
349
350 fn split_outputs(
352 combined: HashMap<String, Tensor>,
353 batch_size: usize,
354 ) -> Result<Vec<HashMap<String, Tensor>>> {
355 let mut results = vec![HashMap::new(); batch_size];
356
357 for (name, batched_tensor) in combined {
358 let individual_tensors = batched_tensor
360 .split(batch_size, 0)
361 .map_err(|e| Error::InferenceError(format!("Failed to split tensors: {}", e)))?;
362
363 for (i, tensor) in individual_tensors.into_iter().enumerate() {
364 results[i].insert(name.clone(), tensor);
365 }
366 }
367
368 Ok(results)
369 }
370
371 pub fn config(&self) -> &BatchConfig {
373 &self.config
374 }
375}
376
377#[derive(Debug, Clone, Default)]
379pub struct BatchStats {
380 pub total_batches: u64,
382 pub total_requests: u64,
384 pub avg_batch_size: f64,
386 pub max_batch_size: usize,
388 pub min_batch_size: usize,
390 pub total_processing_time_ms: f64,
392 pub avg_batch_time_ms: f64,
394}
395
396impl BatchStats {
397 pub fn throughput(&self) -> f64 {
399 if self.total_processing_time_ms == 0.0 {
400 0.0
401 } else {
402 (self.total_requests as f64 * 1000.0) / self.total_processing_time_ms
403 }
404 }
405
406 pub fn utilization(&self, max_batch_size: usize) -> f64 {
408 if max_batch_size == 0 {
409 0.0
410 } else {
411 self.avg_batch_size / max_batch_size as f64
412 }
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419
420 #[test]
421 fn test_batch_config_default() {
422 let config = BatchConfig::default();
423 assert_eq!(config.queue_capacity, 1024);
424 assert_eq!(config.num_workers, 1);
425 match config.strategy {
426 BatchStrategy::Dynamic {
427 max_batch_size,
428 timeout_ms,
429 } => {
430 assert_eq!(max_batch_size, 32);
431 assert_eq!(timeout_ms, 10);
432 }
433 _ => panic!("Expected dynamic strategy"),
434 }
435 }
436
437 #[test]
438 fn test_batch_strategy_static() {
439 let strategy = BatchStrategy::Static { batch_size: 16 };
440 match strategy {
441 BatchStrategy::Static { batch_size } => {
442 assert_eq!(batch_size, 16);
443 }
444 _ => panic!("Expected static strategy"),
445 }
446 }
447
448 #[test]
449 fn test_batch_stats_throughput() {
450 let stats = BatchStats {
451 total_requests: 1000,
452 total_processing_time_ms: 1000.0,
453 ..Default::default()
454 };
455 assert_eq!(stats.throughput(), 1000.0); }
457
458 #[test]
459 fn test_batch_stats_utilization() {
460 let stats = BatchStats {
461 avg_batch_size: 16.0,
462 ..Default::default()
463 };
464 assert_eq!(stats.utilization(32), 0.5); }
466}