1use std::future::Future;
34use std::sync::Arc;
35use std::time::Duration;
36use thiserror::Error;
37use tokio::sync::Semaphore;
38
39#[derive(Debug, Error)]
41pub enum BatchError {
42 #[error("Operation timed out")]
44 Timeout,
45
46 #[error("Too many failures: {0}/{1}")]
48 TooManyFailures(usize, usize),
49
50 #[error("Batch error: {0}")]
52 Custom(String),
53}
54
55#[derive(Debug, Clone)]
57pub struct BatchConfig {
58 pub max_concurrent: usize,
60
61 pub operation_timeout: Duration,
63
64 pub max_retries: u32,
66
67 pub retry_delay: Duration,
69
70 pub max_failures: Option<usize>,
72
73 pub track_progress: bool,
75}
76
77impl Default for BatchConfig {
78 fn default() -> Self {
79 Self {
80 max_concurrent: 50,
81 operation_timeout: Duration::from_secs(30),
82 max_retries: 2,
83 retry_delay: Duration::from_millis(100),
84 max_failures: None,
85 track_progress: true,
86 }
87 }
88}
89
90impl BatchConfig {
91 #[must_use]
93 #[inline]
94 pub fn new() -> Self {
95 Self::default()
96 }
97
98 #[must_use]
100 #[inline]
101 pub fn with_max_concurrent(mut self, max: usize) -> Self {
102 self.max_concurrent = max;
103 self
104 }
105
106 #[must_use]
108 #[inline]
109 pub fn with_timeout(mut self, timeout: Duration) -> Self {
110 self.operation_timeout = timeout;
111 self
112 }
113
114 #[must_use]
116 #[inline]
117 pub fn with_max_retries(mut self, retries: u32) -> Self {
118 self.max_retries = retries;
119 self
120 }
121
122 #[must_use]
124 #[inline]
125 pub fn with_max_failures(mut self, max_failures: usize) -> Self {
126 self.max_failures = Some(max_failures);
127 self
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct BatchResult<T, E> {
134 pub results: Vec<T>,
136
137 pub errors: Vec<E>,
139
140 pub total: usize,
142
143 pub successful: usize,
145
146 pub failed: usize,
148
149 pub duration: Duration,
151}
152
153impl<T, E> BatchResult<T, E> {
154 #[must_use]
156 #[inline]
157 pub fn success_rate(&self) -> f64 {
158 if self.total == 0 {
159 0.0
160 } else {
161 self.successful as f64 / self.total as f64
162 }
163 }
164
165 #[must_use]
167 #[inline]
168 pub const fn is_complete_success(&self) -> bool {
169 self.failed == 0
170 }
171
172 #[must_use]
174 #[inline]
175 pub const fn has_failures(&self) -> bool {
176 self.failed > 0
177 }
178}
179
180pub struct BatchProcessor {
182 config: BatchConfig,
183 semaphore: Arc<Semaphore>,
184}
185
186impl BatchProcessor {
187 #[must_use]
189 #[inline]
190 pub fn new(config: BatchConfig) -> Self {
191 let semaphore = Arc::new(Semaphore::new(config.max_concurrent));
192 Self { config, semaphore }
193 }
194
195 pub async fn process_all<T, R, E, F, Fut>(&self, items: Vec<T>, f: F) -> BatchResult<R, E>
197 where
198 T: Send + 'static,
199 R: Send + 'static,
200 E: Send + 'static,
201 F: Fn(T) -> Fut + Send + Sync + 'static,
202 Fut: Future<Output = Result<R, E>> + Send,
203 {
204 let start = std::time::Instant::now();
205 let total = items.len();
206 let f = Arc::new(f);
207
208 let mut handles = Vec::new();
209
210 for item in items {
211 let semaphore = self.semaphore.clone();
212 let f = f.clone();
213 let timeout = self.config.operation_timeout;
214
215 let handle = tokio::spawn(async move {
216 let _permit = semaphore.acquire().await.unwrap();
217
218 match tokio::time::timeout(timeout, f(item)).await {
220 Ok(Ok(value)) => Some(Ok(value)),
221 Ok(Err(e)) => Some(Err(e)),
222 Err(_) => None, }
224 });
225
226 handles.push(handle);
227 }
228
229 let mut results = Vec::new();
230 let mut errors = Vec::new();
231
232 for handle in handles {
233 match handle.await {
234 Ok(Some(Ok(value))) => results.push(value),
235 Ok(Some(Err(e))) => errors.push(e),
236 Ok(None) => {
237 }
239 Err(_) => {
240 }
242 }
243 }
244
245 let successful = results.len();
246 let failed = errors.len();
247 let duration = start.elapsed();
248
249 BatchResult {
250 results,
251 errors,
252 total,
253 successful,
254 failed,
255 duration,
256 }
257 }
258
259 pub async fn process_all_ok<T, R, E, F, Fut>(&self, items: Vec<T>, f: F) -> Vec<R>
261 where
262 T: Send + 'static,
263 R: Send + 'static,
264 E: Send + 'static,
265 F: Fn(T) -> Fut + Send + Sync + 'static,
266 Fut: Future<Output = Result<R, E>> + Send,
267 {
268 let result = self.process_all(items, f).await;
269 result.results
270 }
271
272 #[must_use]
274 #[inline]
275 pub const fn config(&self) -> &BatchConfig {
276 &self.config
277 }
278}
279
280pub struct BatchIterator<I> {
282 iter: I,
283 batch_size: usize,
284}
285
286impl<I: Iterator> BatchIterator<I> {
287 #[must_use]
289 #[inline]
290 pub fn new(iter: I, batch_size: usize) -> Self {
291 Self { iter, batch_size }
292 }
293}
294
295impl<I: Iterator> Iterator for BatchIterator<I> {
296 type Item = Vec<I::Item>;
297
298 fn next(&mut self) -> Option<Self::Item> {
299 let mut batch = Vec::with_capacity(self.batch_size);
300 for _ in 0..self.batch_size {
301 match self.iter.next() {
302 Some(item) => batch.push(item),
303 None => break,
304 }
305 }
306
307 if batch.is_empty() { None } else { Some(batch) }
308 }
309}
310
311pub trait BatchIteratorExt: Iterator + Sized {
313 fn batches(self, size: usize) -> BatchIterator<Self> {
315 BatchIterator::new(self, size)
316 }
317}
318
319impl<I: Iterator> BatchIteratorExt for I {}
320
321pub async fn parallel_map<T, R, E, F, Fut>(
323 items: Vec<T>,
324 max_concurrent: usize,
325 f: F,
326) -> BatchResult<R, E>
327where
328 T: Send + 'static,
329 R: Send + 'static,
330 E: Send + 'static,
331 F: Fn(T) -> Fut + Send + Sync + 'static,
332 Fut: Future<Output = Result<R, E>> + Send,
333{
334 let config = BatchConfig::default().with_max_concurrent(max_concurrent);
335 let processor = BatchProcessor::new(config);
336 processor.process_all(items, f).await
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[tokio::test]
344 async fn test_batch_config_default() {
345 let config = BatchConfig::default();
346 assert_eq!(config.max_concurrent, 50);
347 assert_eq!(config.max_retries, 2);
348 }
349
350 #[tokio::test]
351 async fn test_batch_config_builder() {
352 let config = BatchConfig::new()
353 .with_max_concurrent(10)
354 .with_max_retries(5)
355 .with_timeout(Duration::from_secs(60));
356
357 assert_eq!(config.max_concurrent, 10);
358 assert_eq!(config.max_retries, 5);
359 assert_eq!(config.operation_timeout, Duration::from_secs(60));
360 }
361
362 #[tokio::test]
363 async fn test_batch_processor_basic() {
364 let config = BatchConfig::default();
365 let processor = BatchProcessor::new(config);
366
367 let items = vec![1, 2, 3, 4, 5];
368 let result = processor
369 .process_all(items, |x| async move { Ok::<_, String>(x * 2) })
370 .await;
371
372 assert_eq!(result.successful, 5);
373 assert_eq!(result.failed, 0);
374 assert_eq!(result.results.len(), 5);
375 assert!(result.is_complete_success());
376 }
377
378 #[tokio::test]
379 async fn test_batch_processor_with_failures() {
380 let config = BatchConfig::default();
381 let processor = BatchProcessor::new(config);
382
383 let items = vec![1, 2, 3, 4, 5];
384 let result = processor
385 .process_all(items, |x| async move {
386 if x % 2 == 0 {
387 Err(format!("Error: {}", x))
388 } else {
389 Ok(x * 2)
390 }
391 })
392 .await;
393
394 assert_eq!(result.successful, 3); assert_eq!(result.failed, 2); assert!(result.has_failures());
397 assert!(!result.is_complete_success());
398 }
399
400 #[tokio::test]
401 async fn test_batch_result_success_rate() {
402 let config = BatchConfig::default();
403 let processor = BatchProcessor::new(config);
404
405 let items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
406 let result = processor
407 .process_all(items, |x| async move {
408 if x <= 7 { Ok(x) } else { Err("error") }
409 })
410 .await;
411
412 assert_eq!(result.total, 10);
413 assert_eq!(result.successful, 7);
414 assert_eq!(result.failed, 3);
415 assert_eq!(result.success_rate(), 0.7);
416 }
417
418 #[tokio::test]
419 async fn test_batch_processor_ok_only() {
420 let config = BatchConfig::default();
421 let processor = BatchProcessor::new(config);
422
423 let items = vec![1, 2, 3, 4, 5];
424 let results = processor
425 .process_all_ok(items, |x| async move {
426 if x % 2 == 0 { Err("error") } else { Ok(x * 2) }
427 })
428 .await;
429
430 assert_eq!(results.len(), 3); assert_eq!(results, vec![2, 6, 10]);
432 }
433
434 #[tokio::test]
435 async fn test_batch_iterator() {
436 let items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
437 let batches: Vec<_> = items.into_iter().batches(3).collect();
438
439 assert_eq!(batches.len(), 4);
440 assert_eq!(batches[0], vec![1, 2, 3]);
441 assert_eq!(batches[1], vec![4, 5, 6]);
442 assert_eq!(batches[2], vec![7, 8, 9]);
443 assert_eq!(batches[3], vec![10]);
444 }
445
446 #[tokio::test]
447 async fn test_parallel_map() {
448 let items = vec![1, 2, 3, 4, 5];
449 let result = parallel_map(items, 10, |x| async move { Ok::<_, String>(x * 2) }).await;
450
451 assert_eq!(result.successful, 5);
452 assert_eq!(result.failed, 0);
453 }
454
455 #[tokio::test]
456 async fn test_concurrent_limit() {
457 use std::sync::Arc;
458 use std::sync::atomic::{AtomicUsize, Ordering};
459
460 let concurrent = Arc::new(AtomicUsize::new(0));
461 let max_seen = Arc::new(AtomicUsize::new(0));
462
463 let config = BatchConfig::default().with_max_concurrent(5);
464 let processor = BatchProcessor::new(config);
465
466 let items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
467
468 let concurrent_clone = concurrent.clone();
469 let max_seen_clone = max_seen.clone();
470
471 let _result = processor
472 .process_all(items, move |_x| {
473 let concurrent = concurrent_clone.clone();
474 let max_seen = max_seen_clone.clone();
475 async move {
476 let current = concurrent.fetch_add(1, Ordering::SeqCst) + 1;
477 max_seen.fetch_max(current, Ordering::SeqCst);
478
479 tokio::time::sleep(Duration::from_millis(10)).await;
480
481 concurrent.fetch_sub(1, Ordering::SeqCst);
482 Ok::<_, String>(())
483 }
484 })
485 .await;
486
487 let max = max_seen.load(Ordering::SeqCst);
488 assert!(max <= 5, "Max concurrent was {}, expected <= 5", max);
489 }
490}