1use crate::utils::error::gateway_error::GatewayError;
7use futures::stream::{self, StreamExt};
8use std::time::Duration;
9
10#[derive(Debug, Clone)]
12pub struct AsyncBatchConfig {
13 pub concurrency: usize,
15 pub timeout: Duration,
17 pub continue_on_error: bool,
19 pub max_retries: u32,
21 pub retry_delay: Duration,
23}
24
25impl Default for AsyncBatchConfig {
26 fn default() -> Self {
27 Self {
28 concurrency: 10,
29 timeout: Duration::from_secs(60),
30 continue_on_error: true,
31 max_retries: 1,
32 retry_delay: Duration::from_secs(1),
33 }
34 }
35}
36
37impl AsyncBatchConfig {
38 pub fn new() -> Self {
40 Self::default()
41 }
42
43 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
45 self.concurrency = concurrency.max(1);
46 self
47 }
48
49 pub fn with_timeout(mut self, timeout: Duration) -> Self {
51 self.timeout = timeout;
52 self
53 }
54
55 pub fn with_continue_on_error(mut self, continue_on_error: bool) -> Self {
57 self.continue_on_error = continue_on_error;
58 self
59 }
60
61 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
63 self.max_retries = max_retries;
64 self
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct AsyncBatchItemResult<T> {
71 pub index: usize,
73 pub result: std::result::Result<T, AsyncBatchError>,
75 pub duration: Duration,
77 pub retries: u32,
79}
80
81#[derive(Debug, Clone)]
83pub struct AsyncBatchError {
84 pub message: String,
86 pub code: Option<String>,
88 pub retryable: bool,
90}
91
92impl std::fmt::Display for AsyncBatchError {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 write!(f, "{}", self.message)
95 }
96}
97
98impl std::error::Error for AsyncBatchError {}
99
100impl From<GatewayError> for AsyncBatchError {
101 fn from(err: GatewayError) -> Self {
102 let retryable = matches!(
103 &err,
104 GatewayError::Timeout(_) | GatewayError::Network(_) | GatewayError::RateLimit { .. }
105 );
106
107 Self {
108 message: err.to_string(),
109 code: None,
110 retryable,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
117pub struct AsyncBatchSummary {
118 pub total: usize,
120 pub succeeded: usize,
122 pub failed: usize,
124 pub total_duration: Duration,
126 pub avg_duration: Duration,
128}
129
130pub struct AsyncBatchExecutor {
132 config: AsyncBatchConfig,
133}
134
135impl AsyncBatchExecutor {
136 pub fn new(config: AsyncBatchConfig) -> Self {
138 Self { config }
139 }
140
141 pub async fn execute<T, R, F, Fut>(
169 &self,
170 items: impl IntoIterator<Item = T>,
171 operation: F,
172 ) -> Vec<AsyncBatchItemResult<R>>
173 where
174 T: Send + 'static,
175 R: Send + 'static,
176 F: Fn(T) -> Fut + Send + Sync + Clone + 'static,
177 Fut: std::future::Future<Output = std::result::Result<R, GatewayError>> + Send,
178 {
179 let items_with_index: Vec<(usize, T)> = items.into_iter().enumerate().collect();
180 let config = self.config.clone();
181
182 let results: Vec<AsyncBatchItemResult<R>> = stream::iter(items_with_index)
183 .map(|(index, item)| {
184 let op = operation.clone();
185 let cfg = config.clone();
186
187 async move {
188 let start = std::time::Instant::now();
189 let retries = 0u32;
190
191 let result = tokio::time::timeout(cfg.timeout, op(item))
192 .await
193 .map_err(|_| {
194 GatewayError::Timeout(format!(
195 "Request {} timed out after {:?}",
196 index, cfg.timeout
197 ))
198 })
199 .and_then(|r| r);
200
201 match result {
202 Ok(value) => AsyncBatchItemResult {
203 index,
204 result: Ok(value),
205 duration: start.elapsed(),
206 retries,
207 },
208 Err(e) => {
209 let batch_err = AsyncBatchError::from(e);
210 AsyncBatchItemResult {
213 index,
214 result: Err(batch_err),
215 duration: start.elapsed(),
216 retries,
217 }
218 }
219 }
220 }
221 })
222 .buffer_unordered(config.concurrency)
223 .collect()
224 .await;
225
226 let mut sorted_results = results;
228 sorted_results.sort_by_key(|r| r.index);
229 sorted_results
230 }
231
232 pub async fn execute_with_summary<T, R, F, Fut>(
234 &self,
235 items: impl IntoIterator<Item = T>,
236 operation: F,
237 ) -> (Vec<AsyncBatchItemResult<R>>, AsyncBatchSummary)
238 where
239 T: Send + 'static,
240 R: Send + 'static,
241 F: Fn(T) -> Fut + Send + Sync + Clone + 'static,
242 Fut: std::future::Future<Output = std::result::Result<R, GatewayError>> + Send,
243 {
244 let start = std::time::Instant::now();
245 let results = self.execute(items, operation).await;
246 let total_duration = start.elapsed();
247
248 let total = results.len();
249 let succeeded = results.iter().filter(|r| r.result.is_ok()).count();
250 let failed = total - succeeded;
251 let avg_duration = if total > 0 {
252 Duration::from_nanos((total_duration.as_nanos() / total as u128) as u64)
253 } else {
254 Duration::ZERO
255 };
256
257 let summary = AsyncBatchSummary {
258 total,
259 succeeded,
260 failed,
261 total_duration,
262 avg_duration,
263 };
264
265 (results, summary)
266 }
267
268 pub fn config(&self) -> &AsyncBatchConfig {
270 &self.config
271 }
272}
273
274impl Default for AsyncBatchExecutor {
275 fn default() -> Self {
276 Self::new(AsyncBatchConfig::default())
277 }
278}
279
280pub async fn batch_execute<T, R, F, Fut>(
282 items: impl IntoIterator<Item = T>,
283 operation: F,
284 config: Option<AsyncBatchConfig>,
285) -> Vec<AsyncBatchItemResult<R>>
286where
287 T: Send + 'static,
288 R: Send + 'static,
289 F: Fn(T) -> Fut + Send + Sync + Clone + 'static,
290 Fut: std::future::Future<Output = std::result::Result<R, GatewayError>> + Send,
291{
292 let executor = AsyncBatchExecutor::new(config.unwrap_or_default());
293 executor.execute(items, operation).await
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
303 fn test_async_batch_config_default() {
304 let config = AsyncBatchConfig::default();
305
306 assert_eq!(config.concurrency, 10);
307 assert_eq!(config.timeout, Duration::from_secs(60));
308 assert!(config.continue_on_error);
309 assert_eq!(config.max_retries, 1);
310 assert_eq!(config.retry_delay, Duration::from_secs(1));
311 }
312
313 #[test]
314 fn test_async_batch_config_new() {
315 let config = AsyncBatchConfig::new();
316
317 assert_eq!(config.concurrency, 10);
318 assert_eq!(config.timeout, Duration::from_secs(60));
319 }
320
321 #[test]
322 fn test_async_batch_config_with_concurrency() {
323 let config = AsyncBatchConfig::new().with_concurrency(5);
324
325 assert_eq!(config.concurrency, 5);
326 }
327
328 #[test]
329 fn test_async_batch_config_with_concurrency_minimum() {
330 let config = AsyncBatchConfig::new().with_concurrency(0);
331
332 assert_eq!(config.concurrency, 1);
334 }
335
336 #[test]
337 fn test_async_batch_config_with_timeout() {
338 let config = AsyncBatchConfig::new().with_timeout(Duration::from_secs(30));
339
340 assert_eq!(config.timeout, Duration::from_secs(30));
341 }
342
343 #[test]
344 fn test_async_batch_config_with_continue_on_error() {
345 let config = AsyncBatchConfig::new().with_continue_on_error(false);
346
347 assert!(!config.continue_on_error);
348 }
349
350 #[test]
351 fn test_async_batch_config_with_max_retries() {
352 let config = AsyncBatchConfig::new().with_max_retries(3);
353
354 assert_eq!(config.max_retries, 3);
355 }
356
357 #[test]
358 fn test_async_batch_config_builder_chain() {
359 let config = AsyncBatchConfig::new()
360 .with_concurrency(20)
361 .with_timeout(Duration::from_secs(120))
362 .with_continue_on_error(false)
363 .with_max_retries(5);
364
365 assert_eq!(config.concurrency, 20);
366 assert_eq!(config.timeout, Duration::from_secs(120));
367 assert!(!config.continue_on_error);
368 assert_eq!(config.max_retries, 5);
369 }
370
371 #[test]
372 fn test_async_batch_config_clone() {
373 let config = AsyncBatchConfig::new().with_concurrency(15);
374 let cloned = config.clone();
375
376 assert_eq!(config.concurrency, cloned.concurrency);
377 assert_eq!(config.timeout, cloned.timeout);
378 }
379
380 #[test]
381 fn test_async_batch_config_debug() {
382 let config = AsyncBatchConfig::new();
383 let debug_str = format!("{:?}", config);
384
385 assert!(debug_str.contains("AsyncBatchConfig"));
386 assert!(debug_str.contains("concurrency"));
387 }
388
389 #[test]
392 fn test_async_batch_error_display() {
393 let error = AsyncBatchError {
394 message: "Test error".to_string(),
395 code: None,
396 retryable: false,
397 };
398
399 assert_eq!(format!("{}", error), "Test error");
400 }
401
402 #[test]
403 fn test_async_batch_error_with_code() {
404 let error = AsyncBatchError {
405 message: "API error".to_string(),
406 code: Some("E001".to_string()),
407 retryable: true,
408 };
409
410 assert_eq!(error.code, Some("E001".to_string()));
411 assert!(error.retryable);
412 }
413
414 #[test]
415 fn test_async_batch_error_clone() {
416 let error = AsyncBatchError {
417 message: "Clone test".to_string(),
418 code: Some("E002".to_string()),
419 retryable: false,
420 };
421
422 let cloned = error.clone();
423 assert_eq!(error.message, cloned.message);
424 assert_eq!(error.code, cloned.code);
425 assert_eq!(error.retryable, cloned.retryable);
426 }
427
428 #[test]
429 fn test_async_batch_error_debug() {
430 let error = AsyncBatchError {
431 message: "Debug test".to_string(),
432 code: None,
433 retryable: false,
434 };
435
436 let debug_str = format!("{:?}", error);
437 assert!(debug_str.contains("AsyncBatchError"));
438 assert!(debug_str.contains("Debug test"));
439 }
440
441 #[test]
442 fn test_async_batch_error_from_gateway_error_timeout() {
443 let gateway_error = GatewayError::Timeout("Request timed out".to_string());
444 let batch_error: AsyncBatchError = gateway_error.into();
445
446 assert!(batch_error.retryable);
447 assert!(batch_error.message.contains("timed out"));
448 }
449
450 #[test]
451 fn test_async_batch_error_from_gateway_error_network() {
452 let gateway_error = GatewayError::Network("Connection failed".to_string());
453 let batch_error: AsyncBatchError = gateway_error.into();
454
455 assert!(batch_error.retryable);
456 }
457
458 #[test]
459 fn test_async_batch_error_from_gateway_error_rate_limit() {
460 let gateway_error = GatewayError::RateLimit {
461 message: "Rate limit exceeded".to_string(),
462 retry_after: None,
463 rpm_limit: None,
464 tpm_limit: None,
465 };
466 let batch_error: AsyncBatchError = gateway_error.into();
467
468 assert!(batch_error.retryable);
469 }
470
471 #[test]
474 fn test_async_batch_item_result_success() {
475 let result: AsyncBatchItemResult<String> = AsyncBatchItemResult {
476 index: 0,
477 result: Ok("Success".to_string()),
478 duration: Duration::from_millis(100),
479 retries: 0,
480 };
481
482 assert_eq!(result.index, 0);
483 assert!(result.result.is_ok());
484 assert_eq!(result.retries, 0);
485 }
486
487 #[test]
488 fn test_async_batch_item_result_failure() {
489 let error = AsyncBatchError {
490 message: "Failed".to_string(),
491 code: None,
492 retryable: false,
493 };
494
495 let result: AsyncBatchItemResult<String> = AsyncBatchItemResult {
496 index: 1,
497 result: Err(error),
498 duration: Duration::from_millis(50),
499 retries: 2,
500 };
501
502 assert_eq!(result.index, 1);
503 assert!(result.result.is_err());
504 assert_eq!(result.retries, 2);
505 }
506
507 #[test]
508 fn test_async_batch_item_result_clone() {
509 let result: AsyncBatchItemResult<i32> = AsyncBatchItemResult {
510 index: 5,
511 result: Ok(42),
512 duration: Duration::from_millis(200),
513 retries: 1,
514 };
515
516 let cloned = result.clone();
517 assert_eq!(result.index, cloned.index);
518 assert_eq!(result.duration, cloned.duration);
519 assert_eq!(result.retries, cloned.retries);
520 }
521
522 #[test]
525 fn test_async_batch_summary_creation() {
526 let summary = AsyncBatchSummary {
527 total: 10,
528 succeeded: 8,
529 failed: 2,
530 total_duration: Duration::from_secs(5),
531 avg_duration: Duration::from_millis(500),
532 };
533
534 assert_eq!(summary.total, 10);
535 assert_eq!(summary.succeeded, 8);
536 assert_eq!(summary.failed, 2);
537 }
538
539 #[test]
540 fn test_async_batch_summary_clone() {
541 let summary = AsyncBatchSummary {
542 total: 5,
543 succeeded: 5,
544 failed: 0,
545 total_duration: Duration::from_secs(2),
546 avg_duration: Duration::from_millis(400),
547 };
548
549 let cloned = summary.clone();
550 assert_eq!(summary.total, cloned.total);
551 assert_eq!(summary.succeeded, cloned.succeeded);
552 assert_eq!(summary.total_duration, cloned.total_duration);
553 }
554
555 #[test]
556 fn test_async_batch_summary_debug() {
557 let summary = AsyncBatchSummary {
558 total: 3,
559 succeeded: 2,
560 failed: 1,
561 total_duration: Duration::from_secs(1),
562 avg_duration: Duration::from_millis(333),
563 };
564
565 let debug_str = format!("{:?}", summary);
566 assert!(debug_str.contains("AsyncBatchSummary"));
567 }
568
569 #[test]
572 fn test_async_batch_executor_new() {
573 let config = AsyncBatchConfig::new().with_concurrency(5);
574 let executor = AsyncBatchExecutor::new(config);
575
576 assert_eq!(executor.config().concurrency, 5);
577 }
578
579 #[test]
580 fn test_async_batch_executor_default() {
581 let executor = AsyncBatchExecutor::default();
582
583 assert_eq!(executor.config().concurrency, 10);
584 assert_eq!(executor.config().timeout, Duration::from_secs(60));
585 }
586
587 #[test]
588 fn test_async_batch_executor_config() {
589 let config = AsyncBatchConfig::new()
590 .with_concurrency(15)
591 .with_timeout(Duration::from_secs(90));
592 let executor = AsyncBatchExecutor::new(config);
593
594 let retrieved_config = executor.config();
595 assert_eq!(retrieved_config.concurrency, 15);
596 assert_eq!(retrieved_config.timeout, Duration::from_secs(90));
597 }
598
599 #[tokio::test]
600 async fn test_async_batch_executor_execute_empty() {
601 let executor = AsyncBatchExecutor::default();
602 let items: Vec<i32> = vec![];
603
604 let results = executor
605 .execute(items, |x| async move { Ok::<_, GatewayError>(x * 2) })
606 .await;
607
608 assert!(results.is_empty());
609 }
610
611 #[tokio::test]
612 async fn test_async_batch_executor_execute_single() {
613 let executor = AsyncBatchExecutor::default();
614 let items = vec![5];
615
616 let results = executor
617 .execute(items, |x| async move { Ok::<_, GatewayError>(x * 2) })
618 .await;
619
620 assert_eq!(results.len(), 1);
621 assert_eq!(results[0].index, 0);
622 assert_eq!(results[0].result.as_ref().unwrap(), &10);
623 }
624
625 #[tokio::test]
626 async fn test_async_batch_executor_execute_multiple() {
627 let executor = AsyncBatchExecutor::new(AsyncBatchConfig::new().with_concurrency(3));
628 let items = vec![1, 2, 3, 4, 5];
629
630 let results = executor
631 .execute(items, |x| async move { Ok::<_, GatewayError>(x * 10) })
632 .await;
633
634 assert_eq!(results.len(), 5);
635 for (i, result) in results.iter().enumerate() {
637 assert_eq!(result.index, i);
638 assert_eq!(result.result.as_ref().unwrap(), &((i + 1) as i32 * 10));
639 }
640 }
641
642 #[tokio::test]
643 async fn test_async_batch_executor_maintains_order() {
644 let executor = AsyncBatchExecutor::new(AsyncBatchConfig::new().with_concurrency(10));
645 let items: Vec<i32> = (0..20).collect();
646
647 let results = executor
648 .execute(items, |x| async move { Ok::<_, GatewayError>(x) })
649 .await;
650
651 for (i, result) in results.iter().enumerate() {
653 assert_eq!(result.index, i);
654 }
655 }
656
657 #[tokio::test]
658 async fn test_async_batch_executor_with_summary_empty() {
659 let executor = AsyncBatchExecutor::default();
660 let items: Vec<i32> = vec![];
661
662 let (results, summary) = executor
663 .execute_with_summary(items, |x| async move { Ok::<_, GatewayError>(x) })
664 .await;
665
666 assert!(results.is_empty());
667 assert_eq!(summary.total, 0);
668 assert_eq!(summary.succeeded, 0);
669 assert_eq!(summary.failed, 0);
670 }
671
672 #[tokio::test]
673 async fn test_async_batch_executor_with_summary_success() {
674 let executor = AsyncBatchExecutor::default();
675 let items = vec![1, 2, 3];
676
677 let (results, summary) = executor
678 .execute_with_summary(items, |x| async move { Ok::<_, GatewayError>(x * 2) })
679 .await;
680
681 assert_eq!(results.len(), 3);
682 assert_eq!(summary.total, 3);
683 assert_eq!(summary.succeeded, 3);
684 assert_eq!(summary.failed, 0);
685 }
686
687 #[tokio::test]
688 async fn test_async_batch_executor_with_summary_mixed() {
689 let executor = AsyncBatchExecutor::default();
690 let items = vec![1, 2, 3, 4, 5];
691
692 let (results, summary) = executor
693 .execute_with_summary(items, |x| async move {
694 if x % 2 == 0 {
695 Err(GatewayError::Internal("Even number".to_string()))
696 } else {
697 Ok::<_, GatewayError>(x)
698 }
699 })
700 .await;
701
702 assert_eq!(results.len(), 5);
703 assert_eq!(summary.total, 5);
704 assert_eq!(summary.succeeded, 3); assert_eq!(summary.failed, 2); }
707
708 #[tokio::test]
711 async fn test_batch_execute_with_default_config() {
712 let items = vec![1, 2, 3];
713
714 let results =
715 batch_execute(items, |x| async move { Ok::<_, GatewayError>(x + 1) }, None).await;
716
717 assert_eq!(results.len(), 3);
718 assert!(results.iter().all(|r| r.result.is_ok()));
719 }
720
721 #[tokio::test]
722 async fn test_batch_execute_with_custom_config() {
723 let config = AsyncBatchConfig::new().with_concurrency(2);
724 let items = vec![10, 20, 30];
725
726 let results = batch_execute(
727 items,
728 |x| async move { Ok::<_, GatewayError>(x / 10) },
729 Some(config),
730 )
731 .await;
732
733 assert_eq!(results.len(), 3);
734 assert_eq!(results[0].result.as_ref().unwrap(), &1);
735 assert_eq!(results[1].result.as_ref().unwrap(), &2);
736 assert_eq!(results[2].result.as_ref().unwrap(), &3);
737 }
738
739 #[tokio::test]
742 async fn test_async_batch_executor_timeout() {
743 let executor = AsyncBatchExecutor::new(
744 AsyncBatchConfig::new().with_timeout(Duration::from_millis(50)),
745 );
746 let items = vec![1];
747
748 let results = executor
749 .execute(items, |_x| async move {
750 tokio::time::sleep(Duration::from_millis(200)).await;
751 Ok::<_, GatewayError>(42)
752 })
753 .await;
754
755 assert_eq!(results.len(), 1);
756 assert!(results[0].result.is_err());
757 }
758}