forge_orchestration/inference/
batch.rs

1//! Request batching for AI/ML inference
2//!
3//! Provides dynamic batching to improve throughput for inference workloads.
4
5use std::collections::VecDeque;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use parking_lot::Mutex;
9use tokio::sync::{oneshot, Notify};
10use tracing::{debug, info};
11
12/// Configuration for batch processing
13#[derive(Debug, Clone)]
14pub struct BatchConfig {
15    /// Maximum batch size
16    pub max_batch_size: usize,
17    /// Maximum wait time before processing a partial batch
18    pub max_wait_ms: u64,
19    /// Minimum batch size to trigger immediate processing
20    pub min_batch_size: usize,
21    /// Enable dynamic batch sizing based on load
22    pub dynamic_sizing: bool,
23}
24
25impl Default for BatchConfig {
26    fn default() -> Self {
27        Self {
28            max_batch_size: 32,
29            max_wait_ms: 50,
30            min_batch_size: 1,
31            dynamic_sizing: true,
32        }
33    }
34}
35
36impl BatchConfig {
37    /// Create a new batch config
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Set maximum batch size
43    pub fn max_size(mut self, size: usize) -> Self {
44        self.max_batch_size = size.max(1);
45        self
46    }
47
48    /// Set maximum wait time in milliseconds
49    pub fn max_wait(mut self, ms: u64) -> Self {
50        self.max_wait_ms = ms;
51        self
52    }
53
54    /// Set minimum batch size for immediate processing
55    pub fn min_size(mut self, size: usize) -> Self {
56        self.min_batch_size = size.max(1);
57        self
58    }
59}
60
61/// A request in the batch queue
62pub struct BatchRequest<T> {
63    /// Request payload
64    pub payload: T,
65    /// Response channel
66    response_tx: oneshot::Sender<BatchResult<T>>,
67    /// Request arrival time
68    arrived_at: Instant,
69}
70
71/// Result of a batched request
72#[derive(Debug)]
73pub struct BatchResult<T> {
74    /// Response payload
75    pub payload: T,
76    /// Batch size this request was processed with
77    pub batch_size: usize,
78    /// Time spent waiting in queue
79    pub queue_time_ms: u64,
80    /// Processing time
81    pub process_time_ms: u64,
82}
83
84/// Batch processor for inference requests
85pub struct BatchProcessor<T: Send + 'static> {
86    config: BatchConfig,
87    queue: Arc<Mutex<VecDeque<BatchRequest<T>>>>,
88    notify: Arc<Notify>,
89    stats: Arc<Mutex<BatchStats>>,
90}
91
92/// Statistics for batch processing
93#[derive(Debug, Default, Clone)]
94pub struct BatchStats {
95    /// Total requests processed
96    pub total_requests: u64,
97    /// Total batches processed
98    pub total_batches: u64,
99    /// Average batch size
100    pub avg_batch_size: f64,
101    /// Average queue time in ms
102    pub avg_queue_time_ms: f64,
103    /// Average processing time in ms
104    pub avg_process_time_ms: f64,
105}
106
107impl<T: Send + 'static> BatchProcessor<T> {
108    /// Create a new batch processor
109    pub fn new(config: BatchConfig) -> Self {
110        Self {
111            config,
112            queue: Arc::new(Mutex::new(VecDeque::new())),
113            notify: Arc::new(Notify::new()),
114            stats: Arc::new(Mutex::new(BatchStats::default())),
115        }
116    }
117
118    /// Submit a request and wait for the result
119    pub async fn submit(&self, payload: T) -> Result<BatchResult<T>, BatchError> {
120        let (tx, rx) = oneshot::channel();
121        
122        let request = BatchRequest {
123            payload,
124            response_tx: tx,
125            arrived_at: Instant::now(),
126        };
127
128        {
129            let mut queue = self.queue.lock();
130            queue.push_back(request);
131            
132            // Notify if we've reached max batch size
133            if queue.len() >= self.config.max_batch_size {
134                self.notify.notify_one();
135            }
136        }
137
138        // Also notify to start the timer
139        self.notify.notify_one();
140
141        rx.await.map_err(|_| BatchError::Cancelled)
142    }
143
144    /// Get current queue length
145    pub fn queue_len(&self) -> usize {
146        self.queue.lock().len()
147    }
148
149    /// Get batch statistics
150    pub fn stats(&self) -> BatchStats {
151        self.stats.lock().clone()
152    }
153
154    /// Collect a batch of requests (up to max_batch_size)
155    pub fn collect_batch(&self) -> Vec<(T, oneshot::Sender<BatchResult<T>>, Instant)> {
156        let mut queue = self.queue.lock();
157        let batch_size = queue.len().min(self.config.max_batch_size);
158        
159        let mut batch = Vec::with_capacity(batch_size);
160        for _ in 0..batch_size {
161            if let Some(req) = queue.pop_front() {
162                batch.push((req.payload, req.response_tx, req.arrived_at));
163            }
164        }
165        
166        batch
167    }
168
169    /// Complete a batch with results
170    pub fn complete_batch(&self, results: Vec<(T, oneshot::Sender<BatchResult<T>>, Instant, T)>) {
171        let batch_size = results.len();
172        let process_start = Instant::now();
173        
174        let mut total_queue_time = 0u64;
175        
176        for (_, tx, arrived_at, response) in results {
177            let queue_time = arrived_at.elapsed().as_millis() as u64;
178            total_queue_time += queue_time;
179            
180            let result = BatchResult {
181                payload: response,
182                batch_size,
183                queue_time_ms: queue_time,
184                process_time_ms: process_start.elapsed().as_millis() as u64,
185            };
186            
187            let _ = tx.send(result);
188        }
189
190        // Update stats
191        let mut stats = self.stats.lock();
192        stats.total_requests += batch_size as u64;
193        stats.total_batches += 1;
194        
195        let n = stats.total_batches as f64;
196        stats.avg_batch_size = stats.avg_batch_size * (n - 1.0) / n + batch_size as f64 / n;
197        stats.avg_queue_time_ms = stats.avg_queue_time_ms * (n - 1.0) / n 
198            + (total_queue_time as f64 / batch_size as f64) / n;
199        
200        debug!(batch_size = batch_size, "Batch completed");
201    }
202
203    /// Wait for a batch to be ready
204    pub async fn wait_for_batch(&self) -> bool {
205        let timeout = Duration::from_millis(self.config.max_wait_ms);
206        
207        tokio::select! {
208            _ = self.notify.notified() => {
209                // Check if we have enough for a batch
210                self.queue.lock().len() >= self.config.min_batch_size
211            }
212            _ = tokio::time::sleep(timeout) => {
213                // Timeout - process whatever we have
214                !self.queue.lock().is_empty()
215            }
216        }
217    }
218}
219
220impl<T: Send + 'static> Clone for BatchProcessor<T> {
221    fn clone(&self) -> Self {
222        Self {
223            config: self.config.clone(),
224            queue: self.queue.clone(),
225            notify: self.notify.clone(),
226            stats: self.stats.clone(),
227        }
228    }
229}
230
231/// Batch processing error
232#[derive(Debug, thiserror::Error)]
233pub enum BatchError {
234    /// Request was cancelled
235    #[error("Request cancelled")]
236    Cancelled,
237    /// Queue is full
238    #[error("Queue full")]
239    QueueFull,
240    /// Processing failed
241    #[error("Processing failed: {0}")]
242    ProcessingFailed(String),
243}
244
245/// Simple batch collector for manual batch processing
246pub struct BatchCollector<T> {
247    items: Vec<T>,
248    max_size: usize,
249    created_at: Instant,
250    max_wait: Duration,
251}
252
253impl<T> BatchCollector<T> {
254    /// Create a new batch collector
255    pub fn new(max_size: usize, max_wait_ms: u64) -> Self {
256        Self {
257            items: Vec::with_capacity(max_size),
258            max_size,
259            created_at: Instant::now(),
260            max_wait: Duration::from_millis(max_wait_ms),
261        }
262    }
263
264    /// Add an item to the batch
265    pub fn add(&mut self, item: T) -> bool {
266        if self.items.len() < self.max_size {
267            self.items.push(item);
268            true
269        } else {
270            false
271        }
272    }
273
274    /// Check if batch is ready (full or timeout)
275    pub fn is_ready(&self) -> bool {
276        self.items.len() >= self.max_size || self.created_at.elapsed() >= self.max_wait
277    }
278
279    /// Check if batch is full
280    pub fn is_full(&self) -> bool {
281        self.items.len() >= self.max_size
282    }
283
284    /// Get current batch size
285    pub fn len(&self) -> usize {
286        self.items.len()
287    }
288
289    /// Check if batch is empty
290    pub fn is_empty(&self) -> bool {
291        self.items.is_empty()
292    }
293
294    /// Take the collected items
295    pub fn take(self) -> Vec<T> {
296        self.items
297    }
298
299    /// Time since batch was created
300    pub fn age(&self) -> Duration {
301        self.created_at.elapsed()
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_batch_collector() {
311        let mut collector: BatchCollector<i32> = BatchCollector::new(3, 100);
312        
313        assert!(collector.add(1));
314        assert!(collector.add(2));
315        assert!(!collector.is_full());
316        assert!(collector.add(3));
317        assert!(collector.is_full());
318        assert!(!collector.add(4)); // Should fail, batch is full
319        
320        let items = collector.take();
321        assert_eq!(items, vec![1, 2, 3]);
322    }
323
324    #[tokio::test]
325    async fn test_batch_processor_stats() {
326        let processor: BatchProcessor<String> = BatchProcessor::new(BatchConfig::default());
327        
328        let stats = processor.stats();
329        assert_eq!(stats.total_requests, 0);
330        assert_eq!(stats.total_batches, 0);
331    }
332}