forge_orchestration/inference/
batch.rs1use std::collections::VecDeque;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use parking_lot::Mutex;
9use tokio::sync::{oneshot, Notify};
10use tracing::{debug, info};
11
12#[derive(Debug, Clone)]
14pub struct BatchConfig {
15 pub max_batch_size: usize,
17 pub max_wait_ms: u64,
19 pub min_batch_size: usize,
21 pub dynamic_sizing: bool,
23}
24
25impl Default for BatchConfig {
26 fn default() -> Self {
27 Self {
28 max_batch_size: 32,
29 max_wait_ms: 50,
30 min_batch_size: 1,
31 dynamic_sizing: true,
32 }
33 }
34}
35
36impl BatchConfig {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn max_size(mut self, size: usize) -> Self {
44 self.max_batch_size = size.max(1);
45 self
46 }
47
48 pub fn max_wait(mut self, ms: u64) -> Self {
50 self.max_wait_ms = ms;
51 self
52 }
53
54 pub fn min_size(mut self, size: usize) -> Self {
56 self.min_batch_size = size.max(1);
57 self
58 }
59}
60
61pub struct BatchRequest<T> {
63 pub payload: T,
65 response_tx: oneshot::Sender<BatchResult<T>>,
67 arrived_at: Instant,
69}
70
71#[derive(Debug)]
73pub struct BatchResult<T> {
74 pub payload: T,
76 pub batch_size: usize,
78 pub queue_time_ms: u64,
80 pub process_time_ms: u64,
82}
83
84pub struct BatchProcessor<T: Send + 'static> {
86 config: BatchConfig,
87 queue: Arc<Mutex<VecDeque<BatchRequest<T>>>>,
88 notify: Arc<Notify>,
89 stats: Arc<Mutex<BatchStats>>,
90}
91
92#[derive(Debug, Default, Clone)]
94pub struct BatchStats {
95 pub total_requests: u64,
97 pub total_batches: u64,
99 pub avg_batch_size: f64,
101 pub avg_queue_time_ms: f64,
103 pub avg_process_time_ms: f64,
105}
106
107impl<T: Send + 'static> BatchProcessor<T> {
108 pub fn new(config: BatchConfig) -> Self {
110 Self {
111 config,
112 queue: Arc::new(Mutex::new(VecDeque::new())),
113 notify: Arc::new(Notify::new()),
114 stats: Arc::new(Mutex::new(BatchStats::default())),
115 }
116 }
117
118 pub async fn submit(&self, payload: T) -> Result<BatchResult<T>, BatchError> {
120 let (tx, rx) = oneshot::channel();
121
122 let request = BatchRequest {
123 payload,
124 response_tx: tx,
125 arrived_at: Instant::now(),
126 };
127
128 {
129 let mut queue = self.queue.lock();
130 queue.push_back(request);
131
132 if queue.len() >= self.config.max_batch_size {
134 self.notify.notify_one();
135 }
136 }
137
138 self.notify.notify_one();
140
141 rx.await.map_err(|_| BatchError::Cancelled)
142 }
143
144 pub fn queue_len(&self) -> usize {
146 self.queue.lock().len()
147 }
148
149 pub fn stats(&self) -> BatchStats {
151 self.stats.lock().clone()
152 }
153
154 pub fn collect_batch(&self) -> Vec<(T, oneshot::Sender<BatchResult<T>>, Instant)> {
156 let mut queue = self.queue.lock();
157 let batch_size = queue.len().min(self.config.max_batch_size);
158
159 let mut batch = Vec::with_capacity(batch_size);
160 for _ in 0..batch_size {
161 if let Some(req) = queue.pop_front() {
162 batch.push((req.payload, req.response_tx, req.arrived_at));
163 }
164 }
165
166 batch
167 }
168
169 pub fn complete_batch(&self, results: Vec<(T, oneshot::Sender<BatchResult<T>>, Instant, T)>) {
171 let batch_size = results.len();
172 let process_start = Instant::now();
173
174 let mut total_queue_time = 0u64;
175
176 for (_, tx, arrived_at, response) in results {
177 let queue_time = arrived_at.elapsed().as_millis() as u64;
178 total_queue_time += queue_time;
179
180 let result = BatchResult {
181 payload: response,
182 batch_size,
183 queue_time_ms: queue_time,
184 process_time_ms: process_start.elapsed().as_millis() as u64,
185 };
186
187 let _ = tx.send(result);
188 }
189
190 let mut stats = self.stats.lock();
192 stats.total_requests += batch_size as u64;
193 stats.total_batches += 1;
194
195 let n = stats.total_batches as f64;
196 stats.avg_batch_size = stats.avg_batch_size * (n - 1.0) / n + batch_size as f64 / n;
197 stats.avg_queue_time_ms = stats.avg_queue_time_ms * (n - 1.0) / n
198 + (total_queue_time as f64 / batch_size as f64) / n;
199
200 debug!(batch_size = batch_size, "Batch completed");
201 }
202
203 pub async fn wait_for_batch(&self) -> bool {
205 let timeout = Duration::from_millis(self.config.max_wait_ms);
206
207 tokio::select! {
208 _ = self.notify.notified() => {
209 self.queue.lock().len() >= self.config.min_batch_size
211 }
212 _ = tokio::time::sleep(timeout) => {
213 !self.queue.lock().is_empty()
215 }
216 }
217 }
218}
219
220impl<T: Send + 'static> Clone for BatchProcessor<T> {
221 fn clone(&self) -> Self {
222 Self {
223 config: self.config.clone(),
224 queue: self.queue.clone(),
225 notify: self.notify.clone(),
226 stats: self.stats.clone(),
227 }
228 }
229}
230
231#[derive(Debug, thiserror::Error)]
233pub enum BatchError {
234 #[error("Request cancelled")]
236 Cancelled,
237 #[error("Queue full")]
239 QueueFull,
240 #[error("Processing failed: {0}")]
242 ProcessingFailed(String),
243}
244
245pub struct BatchCollector<T> {
247 items: Vec<T>,
248 max_size: usize,
249 created_at: Instant,
250 max_wait: Duration,
251}
252
253impl<T> BatchCollector<T> {
254 pub fn new(max_size: usize, max_wait_ms: u64) -> Self {
256 Self {
257 items: Vec::with_capacity(max_size),
258 max_size,
259 created_at: Instant::now(),
260 max_wait: Duration::from_millis(max_wait_ms),
261 }
262 }
263
264 pub fn add(&mut self, item: T) -> bool {
266 if self.items.len() < self.max_size {
267 self.items.push(item);
268 true
269 } else {
270 false
271 }
272 }
273
274 pub fn is_ready(&self) -> bool {
276 self.items.len() >= self.max_size || self.created_at.elapsed() >= self.max_wait
277 }
278
279 pub fn is_full(&self) -> bool {
281 self.items.len() >= self.max_size
282 }
283
284 pub fn len(&self) -> usize {
286 self.items.len()
287 }
288
289 pub fn is_empty(&self) -> bool {
291 self.items.is_empty()
292 }
293
294 pub fn take(self) -> Vec<T> {
296 self.items
297 }
298
299 pub fn age(&self) -> Duration {
301 self.created_at.elapsed()
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_batch_collector() {
311 let mut collector: BatchCollector<i32> = BatchCollector::new(3, 100);
312
313 assert!(collector.add(1));
314 assert!(collector.add(2));
315 assert!(!collector.is_full());
316 assert!(collector.add(3));
317 assert!(collector.is_full());
318 assert!(!collector.add(4)); let items = collector.take();
321 assert_eq!(items, vec![1, 2, 3]);
322 }
323
324 #[tokio::test]
325 async fn test_batch_processor_stats() {
326 let processor: BatchProcessor<String> = BatchProcessor::new(BatchConfig::default());
327
328 let stats = processor.stats();
329 assert_eq!(stats.total_requests, 0);
330 assert_eq!(stats.total_batches, 0);
331 }
332}