Skip to main content

litellm_rs/core/batch/
async_batch.rs

1//! Async Batch Completion - Concurrent Request Processing
2//!
3//! This module provides high-performance concurrent batch processing for
4//! chat completions, similar to Python LiteLLM's `abatch_completion()`.
5
6use crate::utils::error::gateway_error::GatewayError;
7use futures::stream::{self, StreamExt};
8use std::time::Duration;
9
10/// Configuration for async batch processing
11#[derive(Debug, Clone)]
12pub struct AsyncBatchConfig {
13    /// Maximum concurrent requests (default: 10)
14    pub concurrency: usize,
15    /// Timeout per individual request (default: 60s)
16    pub timeout: Duration,
17    /// Continue processing on individual failures (default: true)
18    pub continue_on_error: bool,
19    /// Retry failed requests (default: 1)
20    pub max_retries: u32,
21    /// Delay between retries (default: 1s)
22    pub retry_delay: Duration,
23}
24
25impl Default for AsyncBatchConfig {
26    fn default() -> Self {
27        Self {
28            concurrency: 10,
29            timeout: Duration::from_secs(60),
30            continue_on_error: true,
31            max_retries: 1,
32            retry_delay: Duration::from_secs(1),
33        }
34    }
35}
36
37impl AsyncBatchConfig {
38    /// Create a new config
39    pub fn new() -> Self {
40        Self::default()
41    }
42
43    /// Set concurrency limit
44    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
45        self.concurrency = concurrency.max(1);
46        self
47    }
48
49    /// Set timeout per request
50    pub fn with_timeout(mut self, timeout: Duration) -> Self {
51        self.timeout = timeout;
52        self
53    }
54
55    /// Set whether to continue on individual errors
56    pub fn with_continue_on_error(mut self, continue_on_error: bool) -> Self {
57        self.continue_on_error = continue_on_error;
58        self
59    }
60
61    /// Set max retries
62    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
63        self.max_retries = max_retries;
64        self
65    }
66}
67
68/// Result of an individual request in a batch
69#[derive(Debug, Clone)]
70pub struct AsyncBatchItemResult<T> {
71    /// Index of the request in the original batch
72    pub index: usize,
73    /// The result (Ok or Err)
74    pub result: std::result::Result<T, AsyncBatchError>,
75    /// Time taken for this request
76    pub duration: Duration,
77    /// Number of retries attempted
78    pub retries: u32,
79}
80
81/// Error for async batch operations
82#[derive(Debug, Clone)]
83pub struct AsyncBatchError {
84    /// Error message
85    pub message: String,
86    /// Error code (if available)
87    pub code: Option<String>,
88    /// Whether this error is retryable
89    pub retryable: bool,
90}
91
92impl std::fmt::Display for AsyncBatchError {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        write!(f, "{}", self.message)
95    }
96}
97
98impl std::error::Error for AsyncBatchError {}
99
100impl From<GatewayError> for AsyncBatchError {
101    fn from(err: GatewayError) -> Self {
102        let retryable = matches!(
103            &err,
104            GatewayError::Timeout(_) | GatewayError::Network(_) | GatewayError::RateLimit { .. }
105        );
106
107        Self {
108            message: err.to_string(),
109            code: None,
110            retryable,
111        }
112    }
113}
114
115/// Summary of batch execution
116#[derive(Debug, Clone)]
117pub struct AsyncBatchSummary {
118    /// Total requests processed
119    pub total: usize,
120    /// Successful requests
121    pub succeeded: usize,
122    /// Failed requests
123    pub failed: usize,
124    /// Total time for batch processing
125    pub total_duration: Duration,
126    /// Average time per request
127    pub avg_duration: Duration,
128}
129
130/// Async batch executor for concurrent request processing
131pub struct AsyncBatchExecutor {
132    config: AsyncBatchConfig,
133}
134
135impl AsyncBatchExecutor {
136    /// Create a new batch executor
137    pub fn new(config: AsyncBatchConfig) -> Self {
138        Self { config }
139    }
140
141    /// Execute a batch of async operations concurrently
142    ///
143    /// # Arguments
144    /// * `items` - Iterator of items to process
145    /// * `operation` - Async function to execute for each item
146    ///
147    /// # Returns
148    /// Vector of results in the same order as input items
149    ///
150    /// # Example
151    /// ```rust,no_run
152    /// # use litellm_rs::core::batch::{AsyncBatchExecutor, AsyncBatchConfig};
153    /// # use std::time::Duration;
154    /// # async fn example() {
155    /// let executor = AsyncBatchExecutor::new(
156    ///     AsyncBatchConfig::new()
157    ///         .with_concurrency(5)
158    ///         .with_timeout(Duration::from_secs(30))
159    /// );
160    ///
161    /// let requests: Vec<String> = vec!["req1".into(), "req2".into(), "req3".into()];
162    /// let results = executor.execute(requests, |req| async move {
163    ///     // Process each request
164    ///     Ok::<_, litellm_rs::GatewayError>(format!("processed: {}", req))
165    /// }).await;
166    /// # }
167    /// ```
168    pub async fn execute<T, R, F, Fut>(
169        &self,
170        items: impl IntoIterator<Item = T>,
171        operation: F,
172    ) -> Vec<AsyncBatchItemResult<R>>
173    where
174        T: Send + 'static,
175        R: Send + 'static,
176        F: Fn(T) -> Fut + Send + Sync + Clone + 'static,
177        Fut: std::future::Future<Output = std::result::Result<R, GatewayError>> + Send,
178    {
179        let items_with_index: Vec<(usize, T)> = items.into_iter().enumerate().collect();
180        let config = self.config.clone();
181
182        let results: Vec<AsyncBatchItemResult<R>> = stream::iter(items_with_index)
183            .map(|(index, item)| {
184                let op = operation.clone();
185                let cfg = config.clone();
186
187                async move {
188                    let start = std::time::Instant::now();
189                    let retries = 0u32;
190
191                    let result = tokio::time::timeout(cfg.timeout, op(item))
192                        .await
193                        .map_err(|_| {
194                            GatewayError::Timeout(format!(
195                                "Request {} timed out after {:?}",
196                                index, cfg.timeout
197                            ))
198                        })
199                        .and_then(|r| r);
200
201                    match result {
202                        Ok(value) => AsyncBatchItemResult {
203                            index,
204                            result: Ok(value),
205                            duration: start.elapsed(),
206                            retries,
207                        },
208                        Err(e) => {
209                            let batch_err = AsyncBatchError::from(e);
210                            // Note: Can't retry because item is consumed
211                            // In a real implementation, we'd clone the item
212                            AsyncBatchItemResult {
213                                index,
214                                result: Err(batch_err),
215                                duration: start.elapsed(),
216                                retries,
217                            }
218                        }
219                    }
220                }
221            })
222            .buffer_unordered(config.concurrency)
223            .collect()
224            .await;
225
226        // Sort by index to maintain original order
227        let mut sorted_results = results;
228        sorted_results.sort_by_key(|r| r.index);
229        sorted_results
230    }
231
232    /// Execute with summary statistics
233    pub async fn execute_with_summary<T, R, F, Fut>(
234        &self,
235        items: impl IntoIterator<Item = T>,
236        operation: F,
237    ) -> (Vec<AsyncBatchItemResult<R>>, AsyncBatchSummary)
238    where
239        T: Send + 'static,
240        R: Send + 'static,
241        F: Fn(T) -> Fut + Send + Sync + Clone + 'static,
242        Fut: std::future::Future<Output = std::result::Result<R, GatewayError>> + Send,
243    {
244        let start = std::time::Instant::now();
245        let results = self.execute(items, operation).await;
246        let total_duration = start.elapsed();
247
248        let total = results.len();
249        let succeeded = results.iter().filter(|r| r.result.is_ok()).count();
250        let failed = total - succeeded;
251        let avg_duration = if total > 0 {
252            Duration::from_nanos((total_duration.as_nanos() / total as u128) as u64)
253        } else {
254            Duration::ZERO
255        };
256
257        let summary = AsyncBatchSummary {
258            total,
259            succeeded,
260            failed,
261            total_duration,
262            avg_duration,
263        };
264
265        (results, summary)
266    }
267
268    /// Get current configuration
269    pub fn config(&self) -> &AsyncBatchConfig {
270        &self.config
271    }
272}
273
274impl Default for AsyncBatchExecutor {
275    fn default() -> Self {
276        Self::new(AsyncBatchConfig::default())
277    }
278}
279
280/// Convenience function for batch completion without creating an executor
281pub async fn batch_execute<T, R, F, Fut>(
282    items: impl IntoIterator<Item = T>,
283    operation: F,
284    config: Option<AsyncBatchConfig>,
285) -> Vec<AsyncBatchItemResult<R>>
286where
287    T: Send + 'static,
288    R: Send + 'static,
289    F: Fn(T) -> Fut + Send + Sync + Clone + 'static,
290    Fut: std::future::Future<Output = std::result::Result<R, GatewayError>> + Send,
291{
292    let executor = AsyncBatchExecutor::new(config.unwrap_or_default());
293    executor.execute(items, operation).await
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    // ==================== AsyncBatchConfig Tests ====================
301
302    #[test]
303    fn test_async_batch_config_default() {
304        let config = AsyncBatchConfig::default();
305
306        assert_eq!(config.concurrency, 10);
307        assert_eq!(config.timeout, Duration::from_secs(60));
308        assert!(config.continue_on_error);
309        assert_eq!(config.max_retries, 1);
310        assert_eq!(config.retry_delay, Duration::from_secs(1));
311    }
312
313    #[test]
314    fn test_async_batch_config_new() {
315        let config = AsyncBatchConfig::new();
316
317        assert_eq!(config.concurrency, 10);
318        assert_eq!(config.timeout, Duration::from_secs(60));
319    }
320
321    #[test]
322    fn test_async_batch_config_with_concurrency() {
323        let config = AsyncBatchConfig::new().with_concurrency(5);
324
325        assert_eq!(config.concurrency, 5);
326    }
327
328    #[test]
329    fn test_async_batch_config_with_concurrency_minimum() {
330        let config = AsyncBatchConfig::new().with_concurrency(0);
331
332        // Should be at least 1
333        assert_eq!(config.concurrency, 1);
334    }
335
336    #[test]
337    fn test_async_batch_config_with_timeout() {
338        let config = AsyncBatchConfig::new().with_timeout(Duration::from_secs(30));
339
340        assert_eq!(config.timeout, Duration::from_secs(30));
341    }
342
343    #[test]
344    fn test_async_batch_config_with_continue_on_error() {
345        let config = AsyncBatchConfig::new().with_continue_on_error(false);
346
347        assert!(!config.continue_on_error);
348    }
349
350    #[test]
351    fn test_async_batch_config_with_max_retries() {
352        let config = AsyncBatchConfig::new().with_max_retries(3);
353
354        assert_eq!(config.max_retries, 3);
355    }
356
357    #[test]
358    fn test_async_batch_config_builder_chain() {
359        let config = AsyncBatchConfig::new()
360            .with_concurrency(20)
361            .with_timeout(Duration::from_secs(120))
362            .with_continue_on_error(false)
363            .with_max_retries(5);
364
365        assert_eq!(config.concurrency, 20);
366        assert_eq!(config.timeout, Duration::from_secs(120));
367        assert!(!config.continue_on_error);
368        assert_eq!(config.max_retries, 5);
369    }
370
371    #[test]
372    fn test_async_batch_config_clone() {
373        let config = AsyncBatchConfig::new().with_concurrency(15);
374        let cloned = config.clone();
375
376        assert_eq!(config.concurrency, cloned.concurrency);
377        assert_eq!(config.timeout, cloned.timeout);
378    }
379
380    #[test]
381    fn test_async_batch_config_debug() {
382        let config = AsyncBatchConfig::new();
383        let debug_str = format!("{:?}", config);
384
385        assert!(debug_str.contains("AsyncBatchConfig"));
386        assert!(debug_str.contains("concurrency"));
387    }
388
389    // ==================== AsyncBatchError Tests ====================
390
391    #[test]
392    fn test_async_batch_error_display() {
393        let error = AsyncBatchError {
394            message: "Test error".to_string(),
395            code: None,
396            retryable: false,
397        };
398
399        assert_eq!(format!("{}", error), "Test error");
400    }
401
402    #[test]
403    fn test_async_batch_error_with_code() {
404        let error = AsyncBatchError {
405            message: "API error".to_string(),
406            code: Some("E001".to_string()),
407            retryable: true,
408        };
409
410        assert_eq!(error.code, Some("E001".to_string()));
411        assert!(error.retryable);
412    }
413
414    #[test]
415    fn test_async_batch_error_clone() {
416        let error = AsyncBatchError {
417            message: "Clone test".to_string(),
418            code: Some("E002".to_string()),
419            retryable: false,
420        };
421
422        let cloned = error.clone();
423        assert_eq!(error.message, cloned.message);
424        assert_eq!(error.code, cloned.code);
425        assert_eq!(error.retryable, cloned.retryable);
426    }
427
428    #[test]
429    fn test_async_batch_error_debug() {
430        let error = AsyncBatchError {
431            message: "Debug test".to_string(),
432            code: None,
433            retryable: false,
434        };
435
436        let debug_str = format!("{:?}", error);
437        assert!(debug_str.contains("AsyncBatchError"));
438        assert!(debug_str.contains("Debug test"));
439    }
440
441    #[test]
442    fn test_async_batch_error_from_gateway_error_timeout() {
443        let gateway_error = GatewayError::Timeout("Request timed out".to_string());
444        let batch_error: AsyncBatchError = gateway_error.into();
445
446        assert!(batch_error.retryable);
447        assert!(batch_error.message.contains("timed out"));
448    }
449
450    #[test]
451    fn test_async_batch_error_from_gateway_error_network() {
452        let gateway_error = GatewayError::Network("Connection failed".to_string());
453        let batch_error: AsyncBatchError = gateway_error.into();
454
455        assert!(batch_error.retryable);
456    }
457
458    #[test]
459    fn test_async_batch_error_from_gateway_error_rate_limit() {
460        let gateway_error = GatewayError::RateLimit {
461            message: "Rate limit exceeded".to_string(),
462            retry_after: None,
463            rpm_limit: None,
464            tpm_limit: None,
465        };
466        let batch_error: AsyncBatchError = gateway_error.into();
467
468        assert!(batch_error.retryable);
469    }
470
471    // ==================== AsyncBatchItemResult Tests ====================
472
473    #[test]
474    fn test_async_batch_item_result_success() {
475        let result: AsyncBatchItemResult<String> = AsyncBatchItemResult {
476            index: 0,
477            result: Ok("Success".to_string()),
478            duration: Duration::from_millis(100),
479            retries: 0,
480        };
481
482        assert_eq!(result.index, 0);
483        assert!(result.result.is_ok());
484        assert_eq!(result.retries, 0);
485    }
486
487    #[test]
488    fn test_async_batch_item_result_failure() {
489        let error = AsyncBatchError {
490            message: "Failed".to_string(),
491            code: None,
492            retryable: false,
493        };
494
495        let result: AsyncBatchItemResult<String> = AsyncBatchItemResult {
496            index: 1,
497            result: Err(error),
498            duration: Duration::from_millis(50),
499            retries: 2,
500        };
501
502        assert_eq!(result.index, 1);
503        assert!(result.result.is_err());
504        assert_eq!(result.retries, 2);
505    }
506
507    #[test]
508    fn test_async_batch_item_result_clone() {
509        let result: AsyncBatchItemResult<i32> = AsyncBatchItemResult {
510            index: 5,
511            result: Ok(42),
512            duration: Duration::from_millis(200),
513            retries: 1,
514        };
515
516        let cloned = result.clone();
517        assert_eq!(result.index, cloned.index);
518        assert_eq!(result.duration, cloned.duration);
519        assert_eq!(result.retries, cloned.retries);
520    }
521
522    // ==================== AsyncBatchSummary Tests ====================
523
524    #[test]
525    fn test_async_batch_summary_creation() {
526        let summary = AsyncBatchSummary {
527            total: 10,
528            succeeded: 8,
529            failed: 2,
530            total_duration: Duration::from_secs(5),
531            avg_duration: Duration::from_millis(500),
532        };
533
534        assert_eq!(summary.total, 10);
535        assert_eq!(summary.succeeded, 8);
536        assert_eq!(summary.failed, 2);
537    }
538
539    #[test]
540    fn test_async_batch_summary_clone() {
541        let summary = AsyncBatchSummary {
542            total: 5,
543            succeeded: 5,
544            failed: 0,
545            total_duration: Duration::from_secs(2),
546            avg_duration: Duration::from_millis(400),
547        };
548
549        let cloned = summary.clone();
550        assert_eq!(summary.total, cloned.total);
551        assert_eq!(summary.succeeded, cloned.succeeded);
552        assert_eq!(summary.total_duration, cloned.total_duration);
553    }
554
555    #[test]
556    fn test_async_batch_summary_debug() {
557        let summary = AsyncBatchSummary {
558            total: 3,
559            succeeded: 2,
560            failed: 1,
561            total_duration: Duration::from_secs(1),
562            avg_duration: Duration::from_millis(333),
563        };
564
565        let debug_str = format!("{:?}", summary);
566        assert!(debug_str.contains("AsyncBatchSummary"));
567    }
568
569    // ==================== AsyncBatchExecutor Tests ====================
570
571    #[test]
572    fn test_async_batch_executor_new() {
573        let config = AsyncBatchConfig::new().with_concurrency(5);
574        let executor = AsyncBatchExecutor::new(config);
575
576        assert_eq!(executor.config().concurrency, 5);
577    }
578
579    #[test]
580    fn test_async_batch_executor_default() {
581        let executor = AsyncBatchExecutor::default();
582
583        assert_eq!(executor.config().concurrency, 10);
584        assert_eq!(executor.config().timeout, Duration::from_secs(60));
585    }
586
587    #[test]
588    fn test_async_batch_executor_config() {
589        let config = AsyncBatchConfig::new()
590            .with_concurrency(15)
591            .with_timeout(Duration::from_secs(90));
592        let executor = AsyncBatchExecutor::new(config);
593
594        let retrieved_config = executor.config();
595        assert_eq!(retrieved_config.concurrency, 15);
596        assert_eq!(retrieved_config.timeout, Duration::from_secs(90));
597    }
598
599    #[tokio::test]
600    async fn test_async_batch_executor_execute_empty() {
601        let executor = AsyncBatchExecutor::default();
602        let items: Vec<i32> = vec![];
603
604        let results = executor
605            .execute(items, |x| async move { Ok::<_, GatewayError>(x * 2) })
606            .await;
607
608        assert!(results.is_empty());
609    }
610
611    #[tokio::test]
612    async fn test_async_batch_executor_execute_single() {
613        let executor = AsyncBatchExecutor::default();
614        let items = vec![5];
615
616        let results = executor
617            .execute(items, |x| async move { Ok::<_, GatewayError>(x * 2) })
618            .await;
619
620        assert_eq!(results.len(), 1);
621        assert_eq!(results[0].index, 0);
622        assert_eq!(results[0].result.as_ref().unwrap(), &10);
623    }
624
625    #[tokio::test]
626    async fn test_async_batch_executor_execute_multiple() {
627        let executor = AsyncBatchExecutor::new(AsyncBatchConfig::new().with_concurrency(3));
628        let items = vec![1, 2, 3, 4, 5];
629
630        let results = executor
631            .execute(items, |x| async move { Ok::<_, GatewayError>(x * 10) })
632            .await;
633
634        assert_eq!(results.len(), 5);
635        // Results should be sorted by index
636        for (i, result) in results.iter().enumerate() {
637            assert_eq!(result.index, i);
638            assert_eq!(result.result.as_ref().unwrap(), &((i + 1) as i32 * 10));
639        }
640    }
641
642    #[tokio::test]
643    async fn test_async_batch_executor_maintains_order() {
644        let executor = AsyncBatchExecutor::new(AsyncBatchConfig::new().with_concurrency(10));
645        let items: Vec<i32> = (0..20).collect();
646
647        let results = executor
648            .execute(items, |x| async move { Ok::<_, GatewayError>(x) })
649            .await;
650
651        // Verify results are in original order
652        for (i, result) in results.iter().enumerate() {
653            assert_eq!(result.index, i);
654        }
655    }
656
657    #[tokio::test]
658    async fn test_async_batch_executor_with_summary_empty() {
659        let executor = AsyncBatchExecutor::default();
660        let items: Vec<i32> = vec![];
661
662        let (results, summary) = executor
663            .execute_with_summary(items, |x| async move { Ok::<_, GatewayError>(x) })
664            .await;
665
666        assert!(results.is_empty());
667        assert_eq!(summary.total, 0);
668        assert_eq!(summary.succeeded, 0);
669        assert_eq!(summary.failed, 0);
670    }
671
672    #[tokio::test]
673    async fn test_async_batch_executor_with_summary_success() {
674        let executor = AsyncBatchExecutor::default();
675        let items = vec![1, 2, 3];
676
677        let (results, summary) = executor
678            .execute_with_summary(items, |x| async move { Ok::<_, GatewayError>(x * 2) })
679            .await;
680
681        assert_eq!(results.len(), 3);
682        assert_eq!(summary.total, 3);
683        assert_eq!(summary.succeeded, 3);
684        assert_eq!(summary.failed, 0);
685    }
686
687    #[tokio::test]
688    async fn test_async_batch_executor_with_summary_mixed() {
689        let executor = AsyncBatchExecutor::default();
690        let items = vec![1, 2, 3, 4, 5];
691
692        let (results, summary) = executor
693            .execute_with_summary(items, |x| async move {
694                if x % 2 == 0 {
695                    Err(GatewayError::Internal("Even number".to_string()))
696                } else {
697                    Ok::<_, GatewayError>(x)
698                }
699            })
700            .await;
701
702        assert_eq!(results.len(), 5);
703        assert_eq!(summary.total, 5);
704        assert_eq!(summary.succeeded, 3); // 1, 3, 5
705        assert_eq!(summary.failed, 2); // 2, 4
706    }
707
708    // ==================== batch_execute Function Tests ====================
709
710    #[tokio::test]
711    async fn test_batch_execute_with_default_config() {
712        let items = vec![1, 2, 3];
713
714        let results =
715            batch_execute(items, |x| async move { Ok::<_, GatewayError>(x + 1) }, None).await;
716
717        assert_eq!(results.len(), 3);
718        assert!(results.iter().all(|r| r.result.is_ok()));
719    }
720
721    #[tokio::test]
722    async fn test_batch_execute_with_custom_config() {
723        let config = AsyncBatchConfig::new().with_concurrency(2);
724        let items = vec![10, 20, 30];
725
726        let results = batch_execute(
727            items,
728            |x| async move { Ok::<_, GatewayError>(x / 10) },
729            Some(config),
730        )
731        .await;
732
733        assert_eq!(results.len(), 3);
734        assert_eq!(results[0].result.as_ref().unwrap(), &1);
735        assert_eq!(results[1].result.as_ref().unwrap(), &2);
736        assert_eq!(results[2].result.as_ref().unwrap(), &3);
737    }
738
739    // ==================== Timeout Tests ====================
740
741    #[tokio::test]
742    async fn test_async_batch_executor_timeout() {
743        let executor = AsyncBatchExecutor::new(
744            AsyncBatchConfig::new().with_timeout(Duration::from_millis(50)),
745        );
746        let items = vec![1];
747
748        let results = executor
749            .execute(items, |_x| async move {
750                tokio::time::sleep(Duration::from_millis(200)).await;
751                Ok::<_, GatewayError>(42)
752            })
753            .await;
754
755        assert_eq!(results.len(), 1);
756        assert!(results[0].result.is_err());
757    }
758}