1use std::sync::Arc;
2
3use anyhow::Result;
4use futures::stream::{self, StreamExt, TryStreamExt};
5use indicatif::{ProgressBar, ProgressStyle};
6use thiserror::Error;
7use tokio::sync::Semaphore;
8use tracing::{debug, info, warn};
9
10#[derive(Error, Debug)]
11pub enum BulkError {
12 #[error("Multiple tasks failed: {count} failures")]
13 MultipleFailed { count: usize },
14
15 #[error("Semaphore acquire error: {0}")]
16 SemaphoreError(#[from] tokio::sync::AcquireError),
17}
18
19#[derive(Clone)]
20pub struct BulkConfig {
21 pub concurrency: usize,
22 pub dry_run: bool,
23 pub show_progress: bool,
24 pub fail_fast: bool,
25}
26
27impl Default for BulkConfig {
28 fn default() -> Self {
29 Self {
30 concurrency: 4,
31 dry_run: false,
32 show_progress: true,
33 fail_fast: false,
34 }
35 }
36}
37
38#[derive(Debug)]
39pub struct BulkResult<T> {
40 pub successful: Vec<T>,
41 pub failed: Vec<(usize, anyhow::Error)>,
42}
43
44impl<T> BulkResult<T> {
45 pub fn is_complete_success(&self) -> bool {
46 self.failed.is_empty()
47 }
48
49 pub fn success_count(&self) -> usize {
50 self.successful.len()
51 }
52
53 pub fn failure_count(&self) -> usize {
54 self.failed.len()
55 }
56}
57
58pub struct BulkExecutor {
60 concurrency: usize,
61 dry_run: bool,
62 show_progress: bool,
63 fail_fast: bool,
64}
65
66impl BulkExecutor {
67 pub fn new(concurrency: usize, dry_run: bool) -> Self {
68 Self {
69 concurrency: concurrency.max(1),
70 dry_run,
71 show_progress: true,
72 fail_fast: false,
73 }
74 }
75
76 pub fn from_config(config: BulkConfig) -> Self {
77 Self {
78 concurrency: config.concurrency.max(1),
79 dry_run: config.dry_run,
80 show_progress: config.show_progress,
81 fail_fast: config.fail_fast,
82 }
83 }
84
85 pub fn with_progress(mut self, show_progress: bool) -> Self {
86 self.show_progress = show_progress;
87 self
88 }
89
90 pub fn with_fail_fast(mut self, fail_fast: bool) -> Self {
91 self.fail_fast = fail_fast;
92 self
93 }
94
95 pub async fn run<T, Fut, F>(&self, items: Vec<T>, job: F) -> Result<()>
96 where
97 T: Send + Sync + std::fmt::Debug + 'static,
98 F: Fn(T) -> Fut + Send + Sync + 'static,
99 Fut: std::future::Future<Output = Result<()>> + Send,
100 {
101 if items.is_empty() {
102 debug!("No items to process");
103 return Ok(());
104 }
105
106 let total = items.len();
107 info!(
108 total,
109 concurrency = self.concurrency,
110 "Starting bulk execution"
111 );
112
113 let semaphore = Arc::new(Semaphore::new(self.concurrency));
114 let job = Arc::new(job);
115 let progress = self.create_progress_bar(total);
116 let dry_run = self.dry_run;
117
118 let results = stream::iter(items.into_iter().enumerate().map(|(idx, item)| {
119 let job = Arc::clone(&job);
120 let semaphore = Arc::clone(&semaphore);
121 let progress = progress.clone();
122 async move {
123 let _permit = semaphore.acquire().await?;
124 if dry_run {
125 info!(?item, "Dry run: skipping execution");
126 progress.inc(1);
127 return Ok(());
128 }
129 debug!(index = idx, "Processing item");
130 match job(item).await {
131 Ok(()) => {
132 progress.inc(1);
133 Ok(())
134 }
135 Err(e) => {
136 warn!(index = idx, error = %e, "Task failed");
137 progress.inc(1);
138 Err(e)
139 }
140 }
141 }
142 }))
143 .buffer_unordered(self.concurrency);
144
145 if self.fail_fast {
146 results.try_collect::<Vec<_>>().await?;
147 } else {
148 let all_results: Vec<Result<()>> = results.collect().await;
149 let failures: Vec<_> = all_results.into_iter().filter_map(|r| r.err()).collect();
150
151 if !failures.is_empty() {
152 warn!(failure_count = failures.len(), "Some tasks failed");
153 progress.finish_with_message(format!("Completed with {} failures", failures.len()));
154 return Err(BulkError::MultipleFailed {
155 count: failures.len(),
156 }
157 .into());
158 }
159 }
160
161 progress.finish_with_message("All tasks completed successfully");
162 info!(total, "Bulk execution completed");
163 Ok(())
164 }
165
166 pub async fn execute_with_results<T, R, Fut, F>(
167 &self,
168 items: Vec<T>,
169 job: F,
170 ) -> Result<BulkResult<R>>
171 where
172 T: Send + Sync + std::fmt::Debug + 'static,
173 R: Send + 'static,
174 F: Fn(T) -> Fut + Send + Sync + 'static,
175 Fut: std::future::Future<Output = Result<R>> + Send,
176 {
177 if items.is_empty() {
178 debug!("No items to process");
179 return Ok(BulkResult {
180 successful: vec![],
181 failed: vec![],
182 });
183 }
184
185 let total = items.len();
186 info!(
187 total,
188 concurrency = self.concurrency,
189 "Starting bulk execution with results"
190 );
191
192 let semaphore = Arc::new(Semaphore::new(self.concurrency));
193 let job = Arc::new(job);
194 let progress = self.create_progress_bar(total);
195 let dry_run = self.dry_run;
196
197 let results: Vec<(usize, Result<R>)> =
198 stream::iter(items.into_iter().enumerate().map(|(idx, item)| {
199 let job = Arc::clone(&job);
200 let semaphore = Arc::clone(&semaphore);
201 let progress = progress.clone();
202 async move {
203 let _permit = semaphore.acquire().await?;
204 if dry_run {
205 info!(?item, "Dry run: skipping execution");
206 progress.inc(1);
207 return Ok::<(usize, Result<R>), anyhow::Error>((
208 idx,
209 Err(anyhow::anyhow!("Dry run")),
210 ));
211 }
212 debug!(index = idx, "Processing item");
213 let result = job(item).await;
214 progress.inc(1);
215 Ok((idx, result))
216 }
217 }))
218 .buffer_unordered(self.concurrency)
219 .try_collect()
220 .await?;
221
222 let mut successful = Vec::new();
223 let mut failed = Vec::new();
224
225 for (idx, result) in results {
226 match result {
227 Ok(value) => successful.push(value),
228 Err(error) => failed.push((idx, error)),
229 }
230 }
231
232 if !failed.is_empty() {
233 warn!(
234 success_count = successful.len(),
235 failure_count = failed.len(),
236 "Some tasks failed"
237 );
238 progress.finish_with_message(format!(
239 "Completed: {} succeeded, {} failed",
240 successful.len(),
241 failed.len()
242 ));
243 } else {
244 progress.finish_with_message("All tasks completed successfully");
245 }
246
247 info!(
248 success = successful.len(),
249 failures = failed.len(),
250 "Bulk execution completed"
251 );
252
253 Ok(BulkResult { successful, failed })
254 }
255
256 fn create_progress_bar(&self, total: usize) -> ProgressBar {
257 let progress = if self.show_progress {
258 ProgressBar::new(total as u64)
259 } else {
260 ProgressBar::hidden()
261 };
262
263 progress.set_style(
264 ProgressStyle::with_template(
265 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} {msg}",
266 )
267 .unwrap_or_else(|e| {
268 warn!("Invalid progress template: {}, using default", e);
269 ProgressStyle::default_bar()
270 })
271 .progress_chars("#>-")
272 .tick_chars("⠁⠂⠄⡀⢀⠠⠐⠈ "),
273 );
274
275 progress
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use std::sync::atomic::{AtomicUsize, Ordering};
283
284 #[test]
285 fn test_new_executor() {
286 let executor = BulkExecutor::new(5, false);
287 assert_eq!(executor.concurrency, 5);
288 assert!(!executor.dry_run);
289 }
290
291 #[test]
292 fn test_new_executor_zero_concurrency() {
293 let executor = BulkExecutor::new(0, false);
294 assert_eq!(executor.concurrency, 1);
295 }
296
297 #[test]
298 fn test_new_executor_dry_run() {
299 let executor = BulkExecutor::new(3, true);
300 assert_eq!(executor.concurrency, 3);
301 assert!(executor.dry_run);
302 }
303
304 #[tokio::test]
305 async fn test_run_empty_items() {
306 let executor = BulkExecutor::new(2, false);
307 let items: Vec<i32> = vec![];
308
309 let result = executor.run(items, |_item| async { Ok(()) }).await;
310 assert!(result.is_ok());
311 }
312
313 #[tokio::test]
314 async fn test_run_single_item() {
315 let executor = BulkExecutor::new(1, false);
316 let counter = Arc::new(AtomicUsize::new(0));
317 let counter_clone = Arc::clone(&counter);
318
319 let items = vec![1];
320 let result = executor
321 .run(items, move |_item| {
322 let counter = Arc::clone(&counter_clone);
323 async move {
324 counter.fetch_add(1, Ordering::SeqCst);
325 Ok(())
326 }
327 })
328 .await;
329
330 assert!(result.is_ok());
331 assert_eq!(counter.load(Ordering::SeqCst), 1);
332 }
333
334 #[tokio::test]
335 async fn test_run_multiple_items() {
336 let executor = BulkExecutor::new(3, false);
337 let counter = Arc::new(AtomicUsize::new(0));
338 let counter_clone = Arc::clone(&counter);
339
340 let items = vec![1, 2, 3, 4, 5];
341 let result = executor
342 .run(items, move |_item| {
343 let counter = Arc::clone(&counter_clone);
344 async move {
345 counter.fetch_add(1, Ordering::SeqCst);
346 Ok(())
347 }
348 })
349 .await;
350
351 assert!(result.is_ok());
352 assert_eq!(counter.load(Ordering::SeqCst), 5);
353 }
354
355 #[tokio::test]
356 async fn test_dry_run_skips_execution() {
357 let executor = BulkExecutor::new(2, true);
358 let counter = Arc::new(AtomicUsize::new(0));
359 let counter_clone = Arc::clone(&counter);
360
361 let items = vec![1, 2, 3];
362 let result = executor
363 .run(items, move |_item| {
364 let counter = Arc::clone(&counter_clone);
365 async move {
366 counter.fetch_add(1, Ordering::SeqCst);
367 Ok(())
368 }
369 })
370 .await;
371
372 assert!(result.is_ok());
373 assert_eq!(counter.load(Ordering::SeqCst), 0);
374 }
375
376 #[tokio::test]
377 async fn test_run_with_error() {
378 let executor = BulkExecutor::new(2, false).with_fail_fast(true);
379 let items = vec![1, 2, 3];
380
381 let result = executor
382 .run(items, |item| async move {
383 if item == 2 {
384 anyhow::bail!("Test error on item 2");
385 }
386 Ok(())
387 })
388 .await;
389
390 assert!(result.is_err());
391 assert!(result
392 .unwrap_err()
393 .to_string()
394 .contains("Test error on item 2"));
395 }
396
397 #[tokio::test]
398 async fn test_run_with_multiple_errors() {
399 let executor = BulkExecutor::new(2, false);
400 let items = vec![1, 2, 3, 4];
401
402 let result = executor
403 .run(items, |item| async move {
404 if item == 2 || item == 4 {
405 anyhow::bail!("Test error on item {}", item);
406 }
407 Ok(())
408 })
409 .await;
410
411 assert!(result.is_err());
412 let err_msg = result.unwrap_err().to_string();
413 assert!(err_msg.contains("Multiple tasks failed") || err_msg.contains("2 failures"));
414 }
415
416 #[tokio::test]
417 async fn test_concurrency_limit() {
418 use std::time::Duration;
419 use tokio::time::sleep;
420
421 let executor = BulkExecutor::new(2, false);
422 let active_count = Arc::new(AtomicUsize::new(0));
423 let max_concurrent = Arc::new(AtomicUsize::new(0));
424 let active_clone = Arc::clone(&active_count);
425 let max_clone = Arc::clone(&max_concurrent);
426
427 let items = vec![1, 2, 3, 4, 5];
428 let result = executor
429 .run(items, move |_item| {
430 let active = Arc::clone(&active_clone);
431 let max = Arc::clone(&max_clone);
432 async move {
433 let current = active.fetch_add(1, Ordering::SeqCst) + 1;
434 max.fetch_max(current, Ordering::SeqCst);
435 sleep(Duration::from_millis(10)).await;
436 active.fetch_sub(1, Ordering::SeqCst);
437 Ok(())
438 }
439 })
440 .await;
441
442 assert!(result.is_ok());
443 assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
444 }
445}