chie_core/
batch.rs

1//! Batch processing utilities for parallel operations.
2//!
3//! This module provides utilities for efficiently processing multiple operations
4//! in parallel with configurable concurrency limits and error handling.
5//!
6//! # Features
7//!
8//! - Parallel task execution with configurable concurrency
9//! - Error collection and reporting
10//! - Progress tracking
11//! - Automatic retry for failed operations
12//! - Rate limiting support
13//!
14//! # Example
15//!
16//! ```
17//! use chie_core::batch::{BatchProcessor, BatchConfig};
18//!
19//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
20//! let config = BatchConfig::default().with_max_concurrent(10);
21//! let processor = BatchProcessor::new(config);
22//!
23//! let tasks = vec![1, 2, 3, 4, 5];
24//! let results = processor.process_all(tasks, |num| async move {
25//!     Ok::<_, String>(num * 2)
26//! }).await;
27//!
28//! println!("Successful: {}, Failed: {}", results.successful, results.failed);
29//! # Ok(())
30//! # }
31//! ```
32
33use std::future::Future;
34use std::sync::Arc;
35use std::time::Duration;
36use thiserror::Error;
37use tokio::sync::Semaphore;
38
39/// Batch processing error types.
40#[derive(Debug, Error)]
41pub enum BatchError {
42    /// Operation timeout.
43    #[error("Operation timed out")]
44    Timeout,
45
46    /// Too many failures.
47    #[error("Too many failures: {0}/{1}")]
48    TooManyFailures(usize, usize),
49
50    /// Custom error.
51    #[error("Batch error: {0}")]
52    Custom(String),
53}
54
55/// Configuration for batch processing.
56#[derive(Debug, Clone)]
57pub struct BatchConfig {
58    /// Maximum concurrent operations.
59    pub max_concurrent: usize,
60
61    /// Timeout per operation.
62    pub operation_timeout: Duration,
63
64    /// Maximum number of retries per operation.
65    pub max_retries: u32,
66
67    /// Delay between retries.
68    pub retry_delay: Duration,
69
70    /// Maximum allowed failures before aborting.
71    pub max_failures: Option<usize>,
72
73    /// Enable progress tracking.
74    pub track_progress: bool,
75}
76
77impl Default for BatchConfig {
78    fn default() -> Self {
79        Self {
80            max_concurrent: 50,
81            operation_timeout: Duration::from_secs(30),
82            max_retries: 2,
83            retry_delay: Duration::from_millis(100),
84            max_failures: None,
85            track_progress: true,
86        }
87    }
88}
89
90impl BatchConfig {
91    /// Create a new batch configuration.
92    #[must_use]
93    #[inline]
94    pub fn new() -> Self {
95        Self::default()
96    }
97
98    /// Set maximum concurrent operations.
99    #[must_use]
100    #[inline]
101    pub fn with_max_concurrent(mut self, max: usize) -> Self {
102        self.max_concurrent = max;
103        self
104    }
105
106    /// Set operation timeout.
107    #[must_use]
108    #[inline]
109    pub fn with_timeout(mut self, timeout: Duration) -> Self {
110        self.operation_timeout = timeout;
111        self
112    }
113
114    /// Set maximum retries.
115    #[must_use]
116    #[inline]
117    pub fn with_max_retries(mut self, retries: u32) -> Self {
118        self.max_retries = retries;
119        self
120    }
121
122    /// Set maximum failures.
123    #[must_use]
124    #[inline]
125    pub fn with_max_failures(mut self, max_failures: usize) -> Self {
126        self.max_failures = Some(max_failures);
127        self
128    }
129}
130
131/// Result of batch processing.
132#[derive(Debug, Clone)]
133pub struct BatchResult<T, E> {
134    /// Successful results.
135    pub results: Vec<T>,
136
137    /// Failed operations with errors.
138    pub errors: Vec<E>,
139
140    /// Total operations attempted.
141    pub total: usize,
142
143    /// Successful operations.
144    pub successful: usize,
145
146    /// Failed operations.
147    pub failed: usize,
148
149    /// Total time taken.
150    pub duration: Duration,
151}
152
153impl<T, E> BatchResult<T, E> {
154    /// Get success rate (0.0 to 1.0).
155    #[must_use]
156    #[inline]
157    pub fn success_rate(&self) -> f64 {
158        if self.total == 0 {
159            0.0
160        } else {
161            self.successful as f64 / self.total as f64
162        }
163    }
164
165    /// Check if all operations succeeded.
166    #[must_use]
167    #[inline]
168    pub const fn is_complete_success(&self) -> bool {
169        self.failed == 0
170    }
171
172    /// Check if any operations failed.
173    #[must_use]
174    #[inline]
175    pub const fn has_failures(&self) -> bool {
176        self.failed > 0
177    }
178}
179
180/// Batch processor for parallel operations.
181pub struct BatchProcessor {
182    config: BatchConfig,
183    semaphore: Arc<Semaphore>,
184}
185
186impl BatchProcessor {
187    /// Create a new batch processor.
188    #[must_use]
189    #[inline]
190    pub fn new(config: BatchConfig) -> Self {
191        let semaphore = Arc::new(Semaphore::new(config.max_concurrent));
192        Self { config, semaphore }
193    }
194
195    /// Process all items with the given async function.
196    pub async fn process_all<T, R, E, F, Fut>(&self, items: Vec<T>, f: F) -> BatchResult<R, E>
197    where
198        T: Send + 'static,
199        R: Send + 'static,
200        E: Send + 'static,
201        F: Fn(T) -> Fut + Send + Sync + 'static,
202        Fut: Future<Output = Result<R, E>> + Send,
203    {
204        let start = std::time::Instant::now();
205        let total = items.len();
206        let f = Arc::new(f);
207
208        let mut handles = Vec::new();
209
210        for item in items {
211            let semaphore = self.semaphore.clone();
212            let f = f.clone();
213            let timeout = self.config.operation_timeout;
214
215            let handle = tokio::spawn(async move {
216                let _permit = semaphore.acquire().await.unwrap();
217
218                // Execute with timeout
219                match tokio::time::timeout(timeout, f(item)).await {
220                    Ok(Ok(value)) => Some(Ok(value)),
221                    Ok(Err(e)) => Some(Err(e)),
222                    Err(_) => None, // Timeout
223                }
224            });
225
226            handles.push(handle);
227        }
228
229        let mut results = Vec::new();
230        let mut errors = Vec::new();
231
232        for handle in handles {
233            match handle.await {
234                Ok(Some(Ok(value))) => results.push(value),
235                Ok(Some(Err(e))) => errors.push(e),
236                Ok(None) => {
237                    // Timeout occurred
238                }
239                Err(_) => {
240                    // Task panicked or was cancelled
241                }
242            }
243        }
244
245        let successful = results.len();
246        let failed = errors.len();
247        let duration = start.elapsed();
248
249        BatchResult {
250            results,
251            errors,
252            total,
253            successful,
254            failed,
255            duration,
256        }
257    }
258
259    /// Process all items and collect only successful results.
260    pub async fn process_all_ok<T, R, E, F, Fut>(&self, items: Vec<T>, f: F) -> Vec<R>
261    where
262        T: Send + 'static,
263        R: Send + 'static,
264        E: Send + 'static,
265        F: Fn(T) -> Fut + Send + Sync + 'static,
266        Fut: Future<Output = Result<R, E>> + Send,
267    {
268        let result = self.process_all(items, f).await;
269        result.results
270    }
271
272    /// Get the configuration.
273    #[must_use]
274    #[inline]
275    pub const fn config(&self) -> &BatchConfig {
276        &self.config
277    }
278}
279
280/// Batch iterator for processing items in chunks.
281pub struct BatchIterator<I> {
282    iter: I,
283    batch_size: usize,
284}
285
286impl<I: Iterator> BatchIterator<I> {
287    /// Create a new batch iterator.
288    #[must_use]
289    #[inline]
290    pub fn new(iter: I, batch_size: usize) -> Self {
291        Self { iter, batch_size }
292    }
293}
294
295impl<I: Iterator> Iterator for BatchIterator<I> {
296    type Item = Vec<I::Item>;
297
298    fn next(&mut self) -> Option<Self::Item> {
299        let mut batch = Vec::with_capacity(self.batch_size);
300        for _ in 0..self.batch_size {
301            match self.iter.next() {
302                Some(item) => batch.push(item),
303                None => break,
304            }
305        }
306
307        if batch.is_empty() { None } else { Some(batch) }
308    }
309}
310
311/// Extension trait for creating batch iterators.
312pub trait BatchIteratorExt: Iterator + Sized {
313    /// Create batches of specified size.
314    fn batches(self, size: usize) -> BatchIterator<Self> {
315        BatchIterator::new(self, size)
316    }
317}
318
319impl<I: Iterator> BatchIteratorExt for I {}
320
321/// Process items in parallel with a simple function.
322pub async fn parallel_map<T, R, E, F, Fut>(
323    items: Vec<T>,
324    max_concurrent: usize,
325    f: F,
326) -> BatchResult<R, E>
327where
328    T: Send + 'static,
329    R: Send + 'static,
330    E: Send + 'static,
331    F: Fn(T) -> Fut + Send + Sync + 'static,
332    Fut: Future<Output = Result<R, E>> + Send,
333{
334    let config = BatchConfig::default().with_max_concurrent(max_concurrent);
335    let processor = BatchProcessor::new(config);
336    processor.process_all(items, f).await
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[tokio::test]
344    async fn test_batch_config_default() {
345        let config = BatchConfig::default();
346        assert_eq!(config.max_concurrent, 50);
347        assert_eq!(config.max_retries, 2);
348    }
349
350    #[tokio::test]
351    async fn test_batch_config_builder() {
352        let config = BatchConfig::new()
353            .with_max_concurrent(10)
354            .with_max_retries(5)
355            .with_timeout(Duration::from_secs(60));
356
357        assert_eq!(config.max_concurrent, 10);
358        assert_eq!(config.max_retries, 5);
359        assert_eq!(config.operation_timeout, Duration::from_secs(60));
360    }
361
362    #[tokio::test]
363    async fn test_batch_processor_basic() {
364        let config = BatchConfig::default();
365        let processor = BatchProcessor::new(config);
366
367        let items = vec![1, 2, 3, 4, 5];
368        let result = processor
369            .process_all(items, |x| async move { Ok::<_, String>(x * 2) })
370            .await;
371
372        assert_eq!(result.successful, 5);
373        assert_eq!(result.failed, 0);
374        assert_eq!(result.results.len(), 5);
375        assert!(result.is_complete_success());
376    }
377
378    #[tokio::test]
379    async fn test_batch_processor_with_failures() {
380        let config = BatchConfig::default();
381        let processor = BatchProcessor::new(config);
382
383        let items = vec![1, 2, 3, 4, 5];
384        let result = processor
385            .process_all(items, |x| async move {
386                if x % 2 == 0 {
387                    Err(format!("Error: {}", x))
388                } else {
389                    Ok(x * 2)
390                }
391            })
392            .await;
393
394        assert_eq!(result.successful, 3); // 1, 3, 5
395        assert_eq!(result.failed, 2); // 2, 4
396        assert!(result.has_failures());
397        assert!(!result.is_complete_success());
398    }
399
400    #[tokio::test]
401    async fn test_batch_result_success_rate() {
402        let config = BatchConfig::default();
403        let processor = BatchProcessor::new(config);
404
405        let items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
406        let result = processor
407            .process_all(items, |x| async move {
408                if x <= 7 { Ok(x) } else { Err("error") }
409            })
410            .await;
411
412        assert_eq!(result.total, 10);
413        assert_eq!(result.successful, 7);
414        assert_eq!(result.failed, 3);
415        assert_eq!(result.success_rate(), 0.7);
416    }
417
418    #[tokio::test]
419    async fn test_batch_processor_ok_only() {
420        let config = BatchConfig::default();
421        let processor = BatchProcessor::new(config);
422
423        let items = vec![1, 2, 3, 4, 5];
424        let results = processor
425            .process_all_ok(items, |x| async move {
426                if x % 2 == 0 { Err("error") } else { Ok(x * 2) }
427            })
428            .await;
429
430        assert_eq!(results.len(), 3); // Only 1, 3, 5 succeed
431        assert_eq!(results, vec![2, 6, 10]);
432    }
433
434    #[tokio::test]
435    async fn test_batch_iterator() {
436        let items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
437        let batches: Vec<_> = items.into_iter().batches(3).collect();
438
439        assert_eq!(batches.len(), 4);
440        assert_eq!(batches[0], vec![1, 2, 3]);
441        assert_eq!(batches[1], vec![4, 5, 6]);
442        assert_eq!(batches[2], vec![7, 8, 9]);
443        assert_eq!(batches[3], vec![10]);
444    }
445
446    #[tokio::test]
447    async fn test_parallel_map() {
448        let items = vec![1, 2, 3, 4, 5];
449        let result = parallel_map(items, 10, |x| async move { Ok::<_, String>(x * 2) }).await;
450
451        assert_eq!(result.successful, 5);
452        assert_eq!(result.failed, 0);
453    }
454
455    #[tokio::test]
456    async fn test_concurrent_limit() {
457        use std::sync::Arc;
458        use std::sync::atomic::{AtomicUsize, Ordering};
459
460        let concurrent = Arc::new(AtomicUsize::new(0));
461        let max_seen = Arc::new(AtomicUsize::new(0));
462
463        let config = BatchConfig::default().with_max_concurrent(5);
464        let processor = BatchProcessor::new(config);
465
466        let items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
467
468        let concurrent_clone = concurrent.clone();
469        let max_seen_clone = max_seen.clone();
470
471        let _result = processor
472            .process_all(items, move |_x| {
473                let concurrent = concurrent_clone.clone();
474                let max_seen = max_seen_clone.clone();
475                async move {
476                    let current = concurrent.fetch_add(1, Ordering::SeqCst) + 1;
477                    max_seen.fetch_max(current, Ordering::SeqCst);
478
479                    tokio::time::sleep(Duration::from_millis(10)).await;
480
481                    concurrent.fetch_sub(1, Ordering::SeqCst);
482                    Ok::<_, String>(())
483                }
484            })
485            .await;
486
487        let max = max_seen.load(Ordering::SeqCst);
488        assert!(max <= 5, "Max concurrent was {}, expected <= 5", max);
489    }
490}