1use std::sync::Arc;
2
3use anyhow::Result;
4use futures::stream::{self, StreamExt, TryStreamExt};
5use indicatif::{ProgressBar, ProgressStyle};
6use thiserror::Error;
7use tokio::sync::Semaphore;
8use tracing::{debug, info, warn};
9
10#[derive(Error, Debug)]
11pub enum BulkError {
12 #[error("Multiple tasks failed: {count} failures")]
13 MultipleFailed { count: usize },
14
15 #[error("Semaphore acquire error: {0}")]
16 SemaphoreError(#[from] tokio::sync::AcquireError),
17}
18
19#[derive(Clone)]
20pub struct BulkConfig {
21 pub concurrency: usize,
22 pub dry_run: bool,
23 pub show_progress: bool,
24 pub fail_fast: bool,
25}
26
27impl Default for BulkConfig {
28 fn default() -> Self {
29 Self {
30 concurrency: 4,
31 dry_run: false,
32 show_progress: true,
33 fail_fast: false,
34 }
35 }
36}
37
38#[derive(Debug)]
39pub struct BulkResult<T> {
40 pub successful: Vec<T>,
41 pub failed: Vec<(usize, anyhow::Error)>,
42}
43
44impl<T> BulkResult<T> {
45 pub fn is_complete_success(&self) -> bool {
46 self.failed.is_empty()
47 }
48
49 pub fn success_count(&self) -> usize {
50 self.successful.len()
51 }
52
53 pub fn failure_count(&self) -> usize {
54 self.failed.len()
55 }
56}
57
58pub struct BulkExecutor {
60 concurrency: usize,
61 dry_run: bool,
62 show_progress: bool,
63 fail_fast: bool,
64}
65
66impl BulkExecutor {
67 pub fn new(concurrency: usize, dry_run: bool) -> Self {
68 Self {
69 concurrency: concurrency.max(1),
70 dry_run,
71 show_progress: true,
72 fail_fast: false,
73 }
74 }
75
76 pub fn from_config(config: BulkConfig) -> Self {
77 Self {
78 concurrency: config.concurrency.max(1),
79 dry_run: config.dry_run,
80 show_progress: config.show_progress,
81 fail_fast: config.fail_fast,
82 }
83 }
84
85 pub fn with_progress(mut self, show_progress: bool) -> Self {
86 self.show_progress = show_progress;
87 self
88 }
89
90 pub fn with_fail_fast(mut self, fail_fast: bool) -> Self {
91 self.fail_fast = fail_fast;
92 self
93 }
94
95 pub async fn run<T, Fut, F>(&self, items: Vec<T>, job: F) -> Result<()>
96 where
97 T: Send + Sync + std::fmt::Debug + 'static,
98 F: Fn(T) -> Fut + Send + Sync + 'static,
99 Fut: std::future::Future<Output = Result<()>> + Send,
100 {
101 if items.is_empty() {
102 debug!("No items to process");
103 return Ok(());
104 }
105
106 let total = items.len();
107 info!(
108 total,
109 concurrency = self.concurrency,
110 "Starting bulk execution"
111 );
112
113 let semaphore = Arc::new(Semaphore::new(self.concurrency));
114 let job = Arc::new(job);
115 let progress = self.create_progress_bar(total);
116 let dry_run = self.dry_run;
117
118 let results = stream::iter(items.into_iter().enumerate().map(|(idx, item)| {
119 let job = Arc::clone(&job);
120 let semaphore = Arc::clone(&semaphore);
121 let progress = progress.clone();
122 async move {
123 let _permit = semaphore.acquire().await?;
124 if dry_run {
125 info!(?item, "Dry run: skipping execution");
126 progress.inc(1);
127 return Ok(());
128 }
129 debug!(index = idx, "Processing item");
130 match job(item).await {
131 Ok(()) => {
132 progress.inc(1);
133 Ok(())
134 }
135 Err(e) => {
136 warn!(index = idx, error = %e, "Task failed");
137 progress.inc(1);
138 Err(e)
139 }
140 }
141 }
142 }))
143 .buffer_unordered(self.concurrency);
144
145 if self.fail_fast {
146 results.try_collect::<Vec<_>>().await?;
147 } else {
148 let all_results: Vec<Result<()>> = results.collect().await;
149 let failures: Vec<_> = all_results.into_iter().filter_map(|r| r.err()).collect();
150
151 if !failures.is_empty() {
152 warn!(failure_count = failures.len(), "Some tasks failed");
153 progress.finish_with_message(format!("Completed with {} failures", failures.len()));
154 return Err(BulkError::MultipleFailed {
155 count: failures.len(),
156 }
157 .into());
158 }
159 }
160
161 progress.finish_with_message("All tasks completed successfully");
162 info!(total, "Bulk execution completed");
163 Ok(())
164 }
165
166 pub async fn execute_with_results<T, R, Fut, F>(
167 &self,
168 items: Vec<T>,
169 job: F,
170 ) -> Result<BulkResult<R>>
171 where
172 T: Send + Sync + std::fmt::Debug + 'static,
173 R: Send + 'static,
174 F: Fn(T) -> Fut + Send + Sync + 'static,
175 Fut: std::future::Future<Output = Result<R>> + Send,
176 {
177 if items.is_empty() {
178 debug!("No items to process");
179 return Ok(BulkResult {
180 successful: vec![],
181 failed: vec![],
182 });
183 }
184
185 let total = items.len();
186 info!(
187 total,
188 concurrency = self.concurrency,
189 "Starting bulk execution with results"
190 );
191
192 let semaphore = Arc::new(Semaphore::new(self.concurrency));
193 let job = Arc::new(job);
194 let progress = self.create_progress_bar(total);
195 let dry_run = self.dry_run;
196
197 let results: Vec<(usize, Result<R>)> =
198 stream::iter(items.into_iter().enumerate().map(|(idx, item)| {
199 let job = Arc::clone(&job);
200 let semaphore = Arc::clone(&semaphore);
201 let progress = progress.clone();
202 async move {
203 let _permit = semaphore.acquire().await?;
204 if dry_run {
205 info!(?item, "Dry run: skipping execution");
206 progress.inc(1);
207 return Ok::<(usize, Result<R>), anyhow::Error>((
208 idx,
209 Err(anyhow::anyhow!("Dry run")),
210 ));
211 }
212 debug!(index = idx, "Processing item");
213 let result = job(item).await;
214 progress.inc(1);
215 Ok((idx, result))
216 }
217 }))
218 .buffer_unordered(self.concurrency)
219 .try_collect()
220 .await?;
221
222 let mut successful = Vec::new();
223 let mut failed = Vec::new();
224
225 for (idx, result) in results {
226 match result {
227 Ok(value) => successful.push(value),
228 Err(error) => failed.push((idx, error)),
229 }
230 }
231
232 if !failed.is_empty() {
233 warn!(
234 success_count = successful.len(),
235 failure_count = failed.len(),
236 "Some tasks failed"
237 );
238 progress.finish_with_message(format!(
239 "Completed: {} succeeded, {} failed",
240 successful.len(),
241 failed.len()
242 ));
243 } else {
244 progress.finish_with_message("All tasks completed successfully");
245 }
246
247 info!(
248 success = successful.len(),
249 failures = failed.len(),
250 "Bulk execution completed"
251 );
252
253 Ok(BulkResult { successful, failed })
254 }
255
256 fn create_progress_bar(&self, total: usize) -> ProgressBar {
257 let progress = if self.show_progress {
258 ProgressBar::new(total as u64)
259 } else {
260 ProgressBar::hidden()
261 };
262
263 progress.set_style(
264 ProgressStyle::with_template(
265 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} {msg}",
266 )
267 .unwrap()
268 .progress_chars("#>-")
269 .tick_chars("⠁⠂⠄⡀⢀⠠⠐⠈ "),
270 );
271
272 progress
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use std::sync::atomic::{AtomicUsize, Ordering};
280
281 #[test]
282 fn test_new_executor() {
283 let executor = BulkExecutor::new(5, false);
284 assert_eq!(executor.concurrency, 5);
285 assert!(!executor.dry_run);
286 }
287
288 #[test]
289 fn test_new_executor_zero_concurrency() {
290 let executor = BulkExecutor::new(0, false);
291 assert_eq!(executor.concurrency, 1);
292 }
293
294 #[test]
295 fn test_new_executor_dry_run() {
296 let executor = BulkExecutor::new(3, true);
297 assert_eq!(executor.concurrency, 3);
298 assert!(executor.dry_run);
299 }
300
301 #[tokio::test]
302 async fn test_run_empty_items() {
303 let executor = BulkExecutor::new(2, false);
304 let items: Vec<i32> = vec![];
305
306 let result = executor.run(items, |_item| async { Ok(()) }).await;
307 assert!(result.is_ok());
308 }
309
310 #[tokio::test]
311 async fn test_run_single_item() {
312 let executor = BulkExecutor::new(1, false);
313 let counter = Arc::new(AtomicUsize::new(0));
314 let counter_clone = Arc::clone(&counter);
315
316 let items = vec![1];
317 let result = executor
318 .run(items, move |_item| {
319 let counter = Arc::clone(&counter_clone);
320 async move {
321 counter.fetch_add(1, Ordering::SeqCst);
322 Ok(())
323 }
324 })
325 .await;
326
327 assert!(result.is_ok());
328 assert_eq!(counter.load(Ordering::SeqCst), 1);
329 }
330
331 #[tokio::test]
332 async fn test_run_multiple_items() {
333 let executor = BulkExecutor::new(3, false);
334 let counter = Arc::new(AtomicUsize::new(0));
335 let counter_clone = Arc::clone(&counter);
336
337 let items = vec![1, 2, 3, 4, 5];
338 let result = executor
339 .run(items, move |_item| {
340 let counter = Arc::clone(&counter_clone);
341 async move {
342 counter.fetch_add(1, Ordering::SeqCst);
343 Ok(())
344 }
345 })
346 .await;
347
348 assert!(result.is_ok());
349 assert_eq!(counter.load(Ordering::SeqCst), 5);
350 }
351
352 #[tokio::test]
353 async fn test_dry_run_skips_execution() {
354 let executor = BulkExecutor::new(2, true);
355 let counter = Arc::new(AtomicUsize::new(0));
356 let counter_clone = Arc::clone(&counter);
357
358 let items = vec![1, 2, 3];
359 let result = executor
360 .run(items, move |_item| {
361 let counter = Arc::clone(&counter_clone);
362 async move {
363 counter.fetch_add(1, Ordering::SeqCst);
364 Ok(())
365 }
366 })
367 .await;
368
369 assert!(result.is_ok());
370 assert_eq!(counter.load(Ordering::SeqCst), 0);
371 }
372
373 #[tokio::test]
374 async fn test_run_with_error() {
375 let executor = BulkExecutor::new(2, false).with_fail_fast(true);
376 let items = vec![1, 2, 3];
377
378 let result = executor
379 .run(items, |item| async move {
380 if item == 2 {
381 anyhow::bail!("Test error on item 2");
382 }
383 Ok(())
384 })
385 .await;
386
387 assert!(result.is_err());
388 assert!(result
389 .unwrap_err()
390 .to_string()
391 .contains("Test error on item 2"));
392 }
393
394 #[tokio::test]
395 async fn test_run_with_multiple_errors() {
396 let executor = BulkExecutor::new(2, false);
397 let items = vec![1, 2, 3, 4];
398
399 let result = executor
400 .run(items, |item| async move {
401 if item == 2 || item == 4 {
402 anyhow::bail!("Test error on item {}", item);
403 }
404 Ok(())
405 })
406 .await;
407
408 assert!(result.is_err());
409 let err_msg = result.unwrap_err().to_string();
410 assert!(err_msg.contains("Multiple tasks failed") || err_msg.contains("2 failures"));
411 }
412
413 #[tokio::test]
414 async fn test_concurrency_limit() {
415 use std::time::Duration;
416 use tokio::time::sleep;
417
418 let executor = BulkExecutor::new(2, false);
419 let active_count = Arc::new(AtomicUsize::new(0));
420 let max_concurrent = Arc::new(AtomicUsize::new(0));
421 let active_clone = Arc::clone(&active_count);
422 let max_clone = Arc::clone(&max_concurrent);
423
424 let items = vec![1, 2, 3, 4, 5];
425 let result = executor
426 .run(items, move |_item| {
427 let active = Arc::clone(&active_clone);
428 let max = Arc::clone(&max_clone);
429 async move {
430 let current = active.fetch_add(1, Ordering::SeqCst) + 1;
431 max.fetch_max(current, Ordering::SeqCst);
432 sleep(Duration::from_millis(10)).await;
433 active.fetch_sub(1, Ordering::SeqCst);
434 Ok(())
435 }
436 })
437 .await;
438
439 assert!(result.is_ok());
440 assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
441 }
442}