oxify_connect_llm/
batch.rs

1//! Batch request processing for efficient LLM operations
2//!
3//! This module provides batching capabilities to process multiple LLM requests efficiently.
4//! It collects requests over a time window and processes them together, reducing overhead.
5//!
6//! # Example
7//!
8//! ```rust
9//! use oxify_connect_llm::{BatchConfig, BatchProvider, OpenAIProvider, LlmProvider, LlmRequest};
10//!
11//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
12//! let provider = OpenAIProvider::new("test-key".to_string(), "gpt-4".to_string());
13//! let config = BatchConfig {
14//!     max_batch_size: 10,
15//!     max_wait_ms: 100,
16//! };
17//! let batch_provider = BatchProvider::new(provider, config);
18//!
19//! // Multiple concurrent requests will be batched automatically
20//! let request = LlmRequest {
21//!     prompt: "Hello".to_string(),
22//!     system_prompt: None,
23//!     temperature: None,
24//!     max_tokens: None,
25//!     tools: vec![],
26//!     images: vec![],
27//! };
28//! // let response = batch_provider.complete(request).await?;
29//! # Ok(())
30//! # }
31//! ```
32
33use crate::{
34    EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmProvider, LlmRequest, LlmResponse,
35    Result,
36};
37use async_trait::async_trait;
38use std::sync::Arc;
39use tokio::sync::{mpsc, oneshot, Mutex};
40use tokio::time::Duration;
41
42/// Configuration for batch processing
43#[derive(Debug, Clone)]
44pub struct BatchConfig {
45    /// Maximum number of requests to batch together
46    pub max_batch_size: usize,
47    /// Maximum time to wait for more requests (in milliseconds)
48    pub max_wait_ms: u64,
49}
50
51impl Default for BatchConfig {
52    fn default() -> Self {
53        Self {
54            max_batch_size: 10,
55            max_wait_ms: 100,
56        }
57    }
58}
59
60/// Statistics about batch processing
61#[derive(Debug, Clone, Default)]
62pub struct BatchStats {
63    /// Total number of batches processed
64    pub batches_processed: usize,
65    /// Total number of individual requests processed
66    pub total_requests: usize,
67    /// Average batch size
68    pub avg_batch_size: f64,
69    /// Number of timeouts (batch sent due to max_wait_ms)
70    pub timeout_batches: usize,
71    /// Number of full batches (batch sent due to max_batch_size)
72    pub full_batches: usize,
73}
74
75impl BatchStats {
76    fn update(&mut self, batch_size: usize, is_timeout: bool) {
77        self.batches_processed += 1;
78        self.total_requests += batch_size;
79        self.avg_batch_size = self.total_requests as f64 / self.batches_processed as f64;
80        if is_timeout {
81            self.timeout_batches += 1;
82        } else {
83            self.full_batches += 1;
84        }
85    }
86}
87
88struct BatchRequest {
89    request: LlmRequest,
90    response_tx: oneshot::Sender<Result<LlmResponse>>,
91}
92
93struct BatchWorker<P> {
94    provider: Arc<P>,
95    config: BatchConfig,
96    stats: Arc<Mutex<BatchStats>>,
97    rx: mpsc::UnboundedReceiver<BatchRequest>,
98}
99
100impl<P: LlmProvider + 'static> BatchWorker<P> {
101    async fn run(mut self) {
102        let mut pending_requests: Vec<BatchRequest> = Vec::new();
103
104        loop {
105            // Wait for first request or process pending batch
106            if pending_requests.is_empty() {
107                match self.rx.recv().await {
108                    Some(batch_req) => pending_requests.push(batch_req),
109                    None => break, // Channel closed
110                }
111            }
112
113            // Collect more requests up to max_batch_size or max_wait_ms
114            let start = tokio::time::Instant::now();
115            let max_wait = Duration::from_millis(self.config.max_wait_ms);
116
117            while pending_requests.len() < self.config.max_batch_size {
118                let remaining = max_wait.saturating_sub(start.elapsed());
119                if remaining.is_zero() {
120                    break;
121                }
122
123                match tokio::time::timeout(remaining, self.rx.recv()).await {
124                    Ok(Some(batch_req)) => pending_requests.push(batch_req),
125                    Ok(None) => break, // Channel closed
126                    Err(_) => break,   // Timeout - process current batch
127                }
128            }
129
130            // Process batch
131            if !pending_requests.is_empty() {
132                let batch_size = pending_requests.len();
133                let is_timeout = batch_size < self.config.max_batch_size;
134
135                // Update stats
136                {
137                    let mut stats = self.stats.lock().await;
138                    stats.update(batch_size, is_timeout);
139                }
140
141                // Process each request (in parallel for efficiency)
142                let provider = Arc::clone(&self.provider);
143                let requests = std::mem::take(&mut pending_requests);
144
145                tokio::spawn(async move {
146                    for batch_req in requests {
147                        let provider = Arc::clone(&provider);
148                        let request = batch_req.request;
149                        let response_tx = batch_req.response_tx;
150
151                        tokio::spawn(async move {
152                            let result = provider.complete(request).await;
153                            let _ = response_tx.send(result);
154                        });
155                    }
156                });
157            }
158        }
159    }
160}
161
162/// Batch provider that wraps any LLM provider with batching capabilities
163pub struct BatchProvider<P> {
164    tx: mpsc::UnboundedSender<BatchRequest>,
165    stats: Arc<Mutex<BatchStats>>,
166    _phantom: std::marker::PhantomData<P>,
167}
168
169impl<P: LlmProvider + 'static> BatchProvider<P> {
170    /// Create a new batch provider
171    pub fn new(provider: P, config: BatchConfig) -> Self {
172        let (tx, rx) = mpsc::unbounded_channel();
173        let stats = Arc::new(Mutex::new(BatchStats::default()));
174
175        let worker = BatchWorker {
176            provider: Arc::new(provider),
177            config,
178            stats: Arc::clone(&stats),
179            rx,
180        };
181
182        tokio::spawn(worker.run());
183
184        Self {
185            tx,
186            stats,
187            _phantom: std::marker::PhantomData,
188        }
189    }
190
191    /// Get current batch processing statistics
192    pub async fn stats(&self) -> BatchStats {
193        self.stats.lock().await.clone()
194    }
195}
196
197#[async_trait]
198impl<P: LlmProvider + 'static> LlmProvider for BatchProvider<P> {
199    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
200        let (response_tx, response_rx) = oneshot::channel();
201
202        let batch_req = BatchRequest {
203            request,
204            response_tx,
205        };
206
207        self.tx
208            .send(batch_req)
209            .map_err(|_| crate::LlmError::Other("Batch worker has stopped".to_string()))?;
210
211        response_rx
212            .await
213            .map_err(|_| crate::LlmError::Other("Response channel closed".to_string()))?
214    }
215}
216
217// Embedding batch support
218struct EmbeddingBatchRequest {
219    request: EmbeddingRequest,
220    response_tx: oneshot::Sender<Result<EmbeddingResponse>>,
221}
222
223struct EmbeddingBatchWorker<P> {
224    provider: Arc<P>,
225    config: BatchConfig,
226    stats: Arc<Mutex<BatchStats>>,
227    rx: mpsc::UnboundedReceiver<EmbeddingBatchRequest>,
228}
229
230impl<P: EmbeddingProvider + 'static> EmbeddingBatchWorker<P> {
231    async fn run(mut self) {
232        let mut pending_requests: Vec<EmbeddingBatchRequest> = Vec::new();
233
234        loop {
235            if pending_requests.is_empty() {
236                match self.rx.recv().await {
237                    Some(batch_req) => pending_requests.push(batch_req),
238                    None => break,
239                }
240            }
241
242            let start = tokio::time::Instant::now();
243            let max_wait = Duration::from_millis(self.config.max_wait_ms);
244
245            while pending_requests.len() < self.config.max_batch_size {
246                let remaining = max_wait.saturating_sub(start.elapsed());
247                if remaining.is_zero() {
248                    break;
249                }
250
251                match tokio::time::timeout(remaining, self.rx.recv()).await {
252                    Ok(Some(batch_req)) => pending_requests.push(batch_req),
253                    Ok(None) => break,
254                    Err(_) => break,
255                }
256            }
257
258            if !pending_requests.is_empty() {
259                let batch_size = pending_requests.len();
260                let is_timeout = batch_size < self.config.max_batch_size;
261
262                {
263                    let mut stats = self.stats.lock().await;
264                    stats.update(batch_size, is_timeout);
265                }
266
267                let provider = Arc::clone(&self.provider);
268                let requests = std::mem::take(&mut pending_requests);
269
270                tokio::spawn(async move {
271                    for batch_req in requests {
272                        let provider = Arc::clone(&provider);
273                        let request = batch_req.request;
274                        let response_tx = batch_req.response_tx;
275
276                        tokio::spawn(async move {
277                            let result = provider.embed(request).await;
278                            let _ = response_tx.send(result);
279                        });
280                    }
281                });
282            }
283        }
284    }
285}
286
287/// Batch provider for embeddings
288pub struct EmbeddingBatchProvider<P> {
289    tx: mpsc::UnboundedSender<EmbeddingBatchRequest>,
290    stats: Arc<Mutex<BatchStats>>,
291    _phantom: std::marker::PhantomData<P>,
292}
293
294impl<P: EmbeddingProvider + 'static> EmbeddingBatchProvider<P> {
295    /// Create a new embedding batch provider
296    pub fn new(provider: P, config: BatchConfig) -> Self {
297        let (tx, rx) = mpsc::unbounded_channel();
298        let stats = Arc::new(Mutex::new(BatchStats::default()));
299
300        let worker = EmbeddingBatchWorker {
301            provider: Arc::new(provider),
302            config,
303            stats: Arc::clone(&stats),
304            rx,
305        };
306
307        tokio::spawn(worker.run());
308
309        Self {
310            tx,
311            stats,
312            _phantom: std::marker::PhantomData,
313        }
314    }
315
316    /// Get current batch processing statistics
317    pub async fn stats(&self) -> BatchStats {
318        self.stats.lock().await.clone()
319    }
320}
321
322#[async_trait]
323impl<P: EmbeddingProvider + 'static> EmbeddingProvider for EmbeddingBatchProvider<P> {
324    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
325        let (response_tx, response_rx) = oneshot::channel();
326
327        let batch_req = EmbeddingBatchRequest {
328            request,
329            response_tx,
330        };
331
332        self.tx
333            .send(batch_req)
334            .map_err(|_| crate::LlmError::Other("Batch worker has stopped".to_string()))?;
335
336        response_rx
337            .await
338            .map_err(|_| crate::LlmError::Other("Response channel closed".to_string()))?
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::{LlmResponse, Usage};
346    use tokio::time::sleep;
347
348    struct MockProvider {
349        delay_ms: u64,
350    }
351
352    #[async_trait]
353    impl LlmProvider for MockProvider {
354        async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
355            if self.delay_ms > 0 {
356                sleep(Duration::from_millis(self.delay_ms)).await;
357            }
358            Ok(LlmResponse {
359                content: format!("Response to: {}", request.prompt),
360                model: "mock-model".to_string(),
361                usage: Some(Usage {
362                    prompt_tokens: 10,
363                    completion_tokens: 20,
364                    total_tokens: 30,
365                }),
366                tool_calls: vec![],
367            })
368        }
369    }
370
371    #[tokio::test]
372    async fn test_batch_config_default() {
373        let config = BatchConfig::default();
374        assert_eq!(config.max_batch_size, 10);
375        assert_eq!(config.max_wait_ms, 100);
376    }
377
378    #[tokio::test]
379    async fn test_batch_provider_single_request() {
380        let provider = MockProvider { delay_ms: 0 };
381        let config = BatchConfig {
382            max_batch_size: 5,
383            max_wait_ms: 50,
384        };
385        let batch_provider = BatchProvider::new(provider, config);
386
387        let request = LlmRequest {
388            prompt: "Hello".to_string(),
389            system_prompt: None,
390            temperature: None,
391            max_tokens: None,
392            tools: vec![],
393            images: vec![],
394        };
395
396        let response = batch_provider.complete(request).await.unwrap();
397        assert_eq!(response.content, "Response to: Hello");
398        assert_eq!(response.model, "mock-model");
399
400        // Wait a bit for batch processing
401        sleep(Duration::from_millis(100)).await;
402
403        let stats = batch_provider.stats().await;
404        assert_eq!(stats.total_requests, 1);
405        assert_eq!(stats.batches_processed, 1);
406    }
407
408    #[tokio::test]
409    async fn test_batch_provider_multiple_requests() {
410        let provider = MockProvider { delay_ms: 10 };
411        let config = BatchConfig {
412            max_batch_size: 3,
413            max_wait_ms: 200,
414        };
415        let batch_provider = Arc::new(BatchProvider::new(provider, config));
416
417        let mut handles = vec![];
418
419        // Send 5 requests concurrently
420        for i in 0..5 {
421            let bp = Arc::clone(&batch_provider);
422            let handle = tokio::spawn(async move {
423                let request = LlmRequest {
424                    prompt: format!("Request {}", i),
425                    system_prompt: None,
426                    temperature: None,
427                    max_tokens: None,
428                    tools: vec![],
429                    images: vec![],
430                };
431                bp.complete(request).await
432            });
433            handles.push(handle);
434        }
435
436        // Wait for all requests to complete
437        for handle in handles {
438            let result = handle.await.unwrap();
439            assert!(result.is_ok());
440        }
441
442        // Wait for batch processing to settle
443        sleep(Duration::from_millis(300)).await;
444
445        let stats = batch_provider.stats().await;
446        assert_eq!(stats.total_requests, 5);
447        // Should have at least 2 batches (3 + 2)
448        assert!(stats.batches_processed >= 2);
449    }
450
451    #[tokio::test]
452    async fn test_batch_stats_calculation() {
453        let provider = MockProvider { delay_ms: 0 };
454        let config = BatchConfig {
455            max_batch_size: 2,
456            max_wait_ms: 50,
457        };
458        let batch_provider = Arc::new(BatchProvider::new(provider, config));
459
460        // Send 4 requests (should create 2 batches of size 2)
461        let mut handles = vec![];
462        for i in 0..4 {
463            let bp = Arc::clone(&batch_provider);
464            let handle = tokio::spawn(async move {
465                let request = LlmRequest {
466                    prompt: format!("Request {}", i),
467                    system_prompt: None,
468                    temperature: None,
469                    max_tokens: None,
470                    tools: vec![],
471                    images: vec![],
472                };
473                bp.complete(request).await
474            });
475            handles.push(handle);
476            // Small delay to ensure batching
477            if i == 1 {
478                sleep(Duration::from_millis(10)).await;
479            }
480        }
481
482        for handle in handles {
483            let _ = handle.await.unwrap();
484        }
485
486        sleep(Duration::from_millis(200)).await;
487
488        let stats = batch_provider.stats().await;
489        assert_eq!(stats.total_requests, 4);
490        assert!(stats.avg_batch_size > 0.0);
491    }
492
493    #[tokio::test]
494    async fn test_batch_timeout_trigger() {
495        let provider = MockProvider { delay_ms: 0 };
496        let config = BatchConfig {
497            max_batch_size: 10, // Large batch size
498            max_wait_ms: 50,    // Short wait time
499        };
500        let batch_provider = BatchProvider::new(provider, config);
501
502        // Send single request - should be processed after timeout
503        let request = LlmRequest {
504            prompt: "Single request".to_string(),
505            system_prompt: None,
506            temperature: None,
507            max_tokens: None,
508            tools: vec![],
509            images: vec![],
510        };
511
512        let response = batch_provider.complete(request).await.unwrap();
513        assert_eq!(response.content, "Response to: Single request");
514
515        sleep(Duration::from_millis(100)).await;
516
517        let stats = batch_provider.stats().await;
518        assert_eq!(stats.timeout_batches, 1);
519        assert_eq!(stats.full_batches, 0);
520    }
521}