1use crate::{
34 EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmProvider, LlmRequest, LlmResponse,
35 Result,
36};
37use async_trait::async_trait;
38use std::sync::Arc;
39use tokio::sync::{mpsc, oneshot, Mutex};
40use tokio::time::Duration;
41
42#[derive(Debug, Clone)]
44pub struct BatchConfig {
45 pub max_batch_size: usize,
47 pub max_wait_ms: u64,
49}
50
51impl Default for BatchConfig {
52 fn default() -> Self {
53 Self {
54 max_batch_size: 10,
55 max_wait_ms: 100,
56 }
57 }
58}
59
60#[derive(Debug, Clone, Default)]
62pub struct BatchStats {
63 pub batches_processed: usize,
65 pub total_requests: usize,
67 pub avg_batch_size: f64,
69 pub timeout_batches: usize,
71 pub full_batches: usize,
73}
74
75impl BatchStats {
76 fn update(&mut self, batch_size: usize, is_timeout: bool) {
77 self.batches_processed += 1;
78 self.total_requests += batch_size;
79 self.avg_batch_size = self.total_requests as f64 / self.batches_processed as f64;
80 if is_timeout {
81 self.timeout_batches += 1;
82 } else {
83 self.full_batches += 1;
84 }
85 }
86}
87
88struct BatchRequest {
89 request: LlmRequest,
90 response_tx: oneshot::Sender<Result<LlmResponse>>,
91}
92
93struct BatchWorker<P> {
94 provider: Arc<P>,
95 config: BatchConfig,
96 stats: Arc<Mutex<BatchStats>>,
97 rx: mpsc::UnboundedReceiver<BatchRequest>,
98}
99
100impl<P: LlmProvider + 'static> BatchWorker<P> {
101 async fn run(mut self) {
102 let mut pending_requests: Vec<BatchRequest> = Vec::new();
103
104 loop {
105 if pending_requests.is_empty() {
107 match self.rx.recv().await {
108 Some(batch_req) => pending_requests.push(batch_req),
109 None => break, }
111 }
112
113 let start = tokio::time::Instant::now();
115 let max_wait = Duration::from_millis(self.config.max_wait_ms);
116
117 while pending_requests.len() < self.config.max_batch_size {
118 let remaining = max_wait.saturating_sub(start.elapsed());
119 if remaining.is_zero() {
120 break;
121 }
122
123 match tokio::time::timeout(remaining, self.rx.recv()).await {
124 Ok(Some(batch_req)) => pending_requests.push(batch_req),
125 Ok(None) => break, Err(_) => break, }
128 }
129
130 if !pending_requests.is_empty() {
132 let batch_size = pending_requests.len();
133 let is_timeout = batch_size < self.config.max_batch_size;
134
135 {
137 let mut stats = self.stats.lock().await;
138 stats.update(batch_size, is_timeout);
139 }
140
141 let provider = Arc::clone(&self.provider);
143 let requests = std::mem::take(&mut pending_requests);
144
145 tokio::spawn(async move {
146 for batch_req in requests {
147 let provider = Arc::clone(&provider);
148 let request = batch_req.request;
149 let response_tx = batch_req.response_tx;
150
151 tokio::spawn(async move {
152 let result = provider.complete(request).await;
153 let _ = response_tx.send(result);
154 });
155 }
156 });
157 }
158 }
159 }
160}
161
162pub struct BatchProvider<P> {
164 tx: mpsc::UnboundedSender<BatchRequest>,
165 stats: Arc<Mutex<BatchStats>>,
166 _phantom: std::marker::PhantomData<P>,
167}
168
169impl<P: LlmProvider + 'static> BatchProvider<P> {
170 pub fn new(provider: P, config: BatchConfig) -> Self {
172 let (tx, rx) = mpsc::unbounded_channel();
173 let stats = Arc::new(Mutex::new(BatchStats::default()));
174
175 let worker = BatchWorker {
176 provider: Arc::new(provider),
177 config,
178 stats: Arc::clone(&stats),
179 rx,
180 };
181
182 tokio::spawn(worker.run());
183
184 Self {
185 tx,
186 stats,
187 _phantom: std::marker::PhantomData,
188 }
189 }
190
191 pub async fn stats(&self) -> BatchStats {
193 self.stats.lock().await.clone()
194 }
195}
196
197#[async_trait]
198impl<P: LlmProvider + 'static> LlmProvider for BatchProvider<P> {
199 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
200 let (response_tx, response_rx) = oneshot::channel();
201
202 let batch_req = BatchRequest {
203 request,
204 response_tx,
205 };
206
207 self.tx
208 .send(batch_req)
209 .map_err(|_| crate::LlmError::Other("Batch worker has stopped".to_string()))?;
210
211 response_rx
212 .await
213 .map_err(|_| crate::LlmError::Other("Response channel closed".to_string()))?
214 }
215}
216
217struct EmbeddingBatchRequest {
219 request: EmbeddingRequest,
220 response_tx: oneshot::Sender<Result<EmbeddingResponse>>,
221}
222
223struct EmbeddingBatchWorker<P> {
224 provider: Arc<P>,
225 config: BatchConfig,
226 stats: Arc<Mutex<BatchStats>>,
227 rx: mpsc::UnboundedReceiver<EmbeddingBatchRequest>,
228}
229
230impl<P: EmbeddingProvider + 'static> EmbeddingBatchWorker<P> {
231 async fn run(mut self) {
232 let mut pending_requests: Vec<EmbeddingBatchRequest> = Vec::new();
233
234 loop {
235 if pending_requests.is_empty() {
236 match self.rx.recv().await {
237 Some(batch_req) => pending_requests.push(batch_req),
238 None => break,
239 }
240 }
241
242 let start = tokio::time::Instant::now();
243 let max_wait = Duration::from_millis(self.config.max_wait_ms);
244
245 while pending_requests.len() < self.config.max_batch_size {
246 let remaining = max_wait.saturating_sub(start.elapsed());
247 if remaining.is_zero() {
248 break;
249 }
250
251 match tokio::time::timeout(remaining, self.rx.recv()).await {
252 Ok(Some(batch_req)) => pending_requests.push(batch_req),
253 Ok(None) => break,
254 Err(_) => break,
255 }
256 }
257
258 if !pending_requests.is_empty() {
259 let batch_size = pending_requests.len();
260 let is_timeout = batch_size < self.config.max_batch_size;
261
262 {
263 let mut stats = self.stats.lock().await;
264 stats.update(batch_size, is_timeout);
265 }
266
267 let provider = Arc::clone(&self.provider);
268 let requests = std::mem::take(&mut pending_requests);
269
270 tokio::spawn(async move {
271 for batch_req in requests {
272 let provider = Arc::clone(&provider);
273 let request = batch_req.request;
274 let response_tx = batch_req.response_tx;
275
276 tokio::spawn(async move {
277 let result = provider.embed(request).await;
278 let _ = response_tx.send(result);
279 });
280 }
281 });
282 }
283 }
284 }
285}
286
287pub struct EmbeddingBatchProvider<P> {
289 tx: mpsc::UnboundedSender<EmbeddingBatchRequest>,
290 stats: Arc<Mutex<BatchStats>>,
291 _phantom: std::marker::PhantomData<P>,
292}
293
294impl<P: EmbeddingProvider + 'static> EmbeddingBatchProvider<P> {
295 pub fn new(provider: P, config: BatchConfig) -> Self {
297 let (tx, rx) = mpsc::unbounded_channel();
298 let stats = Arc::new(Mutex::new(BatchStats::default()));
299
300 let worker = EmbeddingBatchWorker {
301 provider: Arc::new(provider),
302 config,
303 stats: Arc::clone(&stats),
304 rx,
305 };
306
307 tokio::spawn(worker.run());
308
309 Self {
310 tx,
311 stats,
312 _phantom: std::marker::PhantomData,
313 }
314 }
315
316 pub async fn stats(&self) -> BatchStats {
318 self.stats.lock().await.clone()
319 }
320}
321
322#[async_trait]
323impl<P: EmbeddingProvider + 'static> EmbeddingProvider for EmbeddingBatchProvider<P> {
324 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
325 let (response_tx, response_rx) = oneshot::channel();
326
327 let batch_req = EmbeddingBatchRequest {
328 request,
329 response_tx,
330 };
331
332 self.tx
333 .send(batch_req)
334 .map_err(|_| crate::LlmError::Other("Batch worker has stopped".to_string()))?;
335
336 response_rx
337 .await
338 .map_err(|_| crate::LlmError::Other("Response channel closed".to_string()))?
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::{LlmResponse, Usage};
346 use tokio::time::sleep;
347
348 struct MockProvider {
349 delay_ms: u64,
350 }
351
352 #[async_trait]
353 impl LlmProvider for MockProvider {
354 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
355 if self.delay_ms > 0 {
356 sleep(Duration::from_millis(self.delay_ms)).await;
357 }
358 Ok(LlmResponse {
359 content: format!("Response to: {}", request.prompt),
360 model: "mock-model".to_string(),
361 usage: Some(Usage {
362 prompt_tokens: 10,
363 completion_tokens: 20,
364 total_tokens: 30,
365 }),
366 tool_calls: vec![],
367 })
368 }
369 }
370
371 #[tokio::test]
372 async fn test_batch_config_default() {
373 let config = BatchConfig::default();
374 assert_eq!(config.max_batch_size, 10);
375 assert_eq!(config.max_wait_ms, 100);
376 }
377
378 #[tokio::test]
379 async fn test_batch_provider_single_request() {
380 let provider = MockProvider { delay_ms: 0 };
381 let config = BatchConfig {
382 max_batch_size: 5,
383 max_wait_ms: 50,
384 };
385 let batch_provider = BatchProvider::new(provider, config);
386
387 let request = LlmRequest {
388 prompt: "Hello".to_string(),
389 system_prompt: None,
390 temperature: None,
391 max_tokens: None,
392 tools: vec![],
393 images: vec![],
394 };
395
396 let response = batch_provider.complete(request).await.unwrap();
397 assert_eq!(response.content, "Response to: Hello");
398 assert_eq!(response.model, "mock-model");
399
400 sleep(Duration::from_millis(100)).await;
402
403 let stats = batch_provider.stats().await;
404 assert_eq!(stats.total_requests, 1);
405 assert_eq!(stats.batches_processed, 1);
406 }
407
408 #[tokio::test]
409 async fn test_batch_provider_multiple_requests() {
410 let provider = MockProvider { delay_ms: 10 };
411 let config = BatchConfig {
412 max_batch_size: 3,
413 max_wait_ms: 200,
414 };
415 let batch_provider = Arc::new(BatchProvider::new(provider, config));
416
417 let mut handles = vec![];
418
419 for i in 0..5 {
421 let bp = Arc::clone(&batch_provider);
422 let handle = tokio::spawn(async move {
423 let request = LlmRequest {
424 prompt: format!("Request {}", i),
425 system_prompt: None,
426 temperature: None,
427 max_tokens: None,
428 tools: vec![],
429 images: vec![],
430 };
431 bp.complete(request).await
432 });
433 handles.push(handle);
434 }
435
436 for handle in handles {
438 let result = handle.await.unwrap();
439 assert!(result.is_ok());
440 }
441
442 sleep(Duration::from_millis(300)).await;
444
445 let stats = batch_provider.stats().await;
446 assert_eq!(stats.total_requests, 5);
447 assert!(stats.batches_processed >= 2);
449 }
450
451 #[tokio::test]
452 async fn test_batch_stats_calculation() {
453 let provider = MockProvider { delay_ms: 0 };
454 let config = BatchConfig {
455 max_batch_size: 2,
456 max_wait_ms: 50,
457 };
458 let batch_provider = Arc::new(BatchProvider::new(provider, config));
459
460 let mut handles = vec![];
462 for i in 0..4 {
463 let bp = Arc::clone(&batch_provider);
464 let handle = tokio::spawn(async move {
465 let request = LlmRequest {
466 prompt: format!("Request {}", i),
467 system_prompt: None,
468 temperature: None,
469 max_tokens: None,
470 tools: vec![],
471 images: vec![],
472 };
473 bp.complete(request).await
474 });
475 handles.push(handle);
476 if i == 1 {
478 sleep(Duration::from_millis(10)).await;
479 }
480 }
481
482 for handle in handles {
483 let _ = handle.await.unwrap();
484 }
485
486 sleep(Duration::from_millis(200)).await;
487
488 let stats = batch_provider.stats().await;
489 assert_eq!(stats.total_requests, 4);
490 assert!(stats.avg_batch_size > 0.0);
491 }
492
493 #[tokio::test]
494 async fn test_batch_timeout_trigger() {
495 let provider = MockProvider { delay_ms: 0 };
496 let config = BatchConfig {
497 max_batch_size: 10, max_wait_ms: 50, };
500 let batch_provider = BatchProvider::new(provider, config);
501
502 let request = LlmRequest {
504 prompt: "Single request".to_string(),
505 system_prompt: None,
506 temperature: None,
507 max_tokens: None,
508 tools: vec![],
509 images: vec![],
510 };
511
512 let response = batch_provider.complete(request).await.unwrap();
513 assert_eq!(response.content, "Response to: Single request");
514
515 sleep(Duration::from_millis(100)).await;
516
517 let stats = batch_provider.stats().await;
518 assert_eq!(stats.timeout_batches, 1);
519 assert_eq!(stats.full_batches, 0);
520 }
521}