1use std::future::Future;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use futures::stream::{FuturesUnordered, StreamExt};
13use tokio::sync::Semaphore;
14
15#[derive(Debug, Clone)]
17pub struct ConcurrencyConfig {
18 pub max_concurrency: usize,
20 pub operation_timeout: Option<Duration>,
22 pub continue_on_error: bool,
24 pub collect_stats: bool,
26}
27
28impl Default for ConcurrencyConfig {
29 fn default() -> Self {
30 Self {
31 max_concurrency: num_cpus::get().max(4),
32 operation_timeout: Some(Duration::from_secs(30)),
33 continue_on_error: true,
34 collect_stats: true,
35 }
36 }
37}
38
39impl ConcurrencyConfig {
40 #[must_use]
42 pub fn for_introspection() -> Self {
43 Self {
44 max_concurrency: 8, operation_timeout: Some(Duration::from_secs(60)),
46 continue_on_error: true,
47 collect_stats: true,
48 }
49 }
50
51 #[must_use]
53 pub fn for_migrations() -> Self {
54 Self {
55 max_concurrency: 4, operation_timeout: Some(Duration::from_secs(120)),
57 continue_on_error: false, collect_stats: true,
59 }
60 }
61
62 #[must_use]
64 pub fn for_bulk_operations() -> Self {
65 Self {
66 max_concurrency: 16, operation_timeout: Some(Duration::from_secs(300)),
68 continue_on_error: true,
69 collect_stats: true,
70 }
71 }
72
73 #[must_use]
75 pub fn with_max_concurrency(mut self, max: usize) -> Self {
76 self.max_concurrency = max.max(1);
77 self
78 }
79
80 #[must_use]
82 pub fn with_timeout(mut self, timeout: Duration) -> Self {
83 self.operation_timeout = Some(timeout);
84 self
85 }
86
87 #[must_use]
89 pub fn without_timeout(mut self) -> Self {
90 self.operation_timeout = None;
91 self
92 }
93
94 #[must_use]
96 pub fn with_continue_on_error(mut self, continue_on_error: bool) -> Self {
97 self.continue_on_error = continue_on_error;
98 self
99 }
100}
101
102#[derive(Debug, Clone)]
104pub struct TaskError {
105 pub task_id: usize,
107 pub message: String,
109 pub is_timeout: bool,
111}
112
113impl std::fmt::Display for TaskError {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 if self.is_timeout {
116 write!(f, "Task {} timed out: {}", self.task_id, self.message)
117 } else {
118 write!(f, "Task {} failed: {}", self.task_id, self.message)
119 }
120 }
121}
122
123impl std::error::Error for TaskError {}
124
125#[derive(Debug)]
127pub enum TaskResult<T> {
128 Success {
130 task_id: usize,
132 value: T,
134 duration: Duration,
136 },
137 Error(TaskError),
139}
140
141impl<T> TaskResult<T> {
142 pub fn is_success(&self) -> bool {
144 matches!(self, Self::Success { .. })
145 }
146
147 pub fn into_value(self) -> Option<T> {
149 match self {
150 Self::Success { value, .. } => Some(value),
151 Self::Error(_) => None,
152 }
153 }
154
155 pub fn into_error(self) -> Option<TaskError> {
157 match self {
158 Self::Success { .. } => None,
159 Self::Error(e) => Some(e),
160 }
161 }
162}
163
164#[derive(Debug, Clone, Default)]
166pub struct ExecutionStats {
167 pub total_tasks: u64,
169 pub successful: u64,
171 pub failed: u64,
173 pub timed_out: u64,
175 pub total_duration: Duration,
177 pub avg_task_duration: Duration,
179 pub max_concurrent: usize,
181}
182
183pub struct ConcurrentExecutor {
185 config: ConcurrencyConfig,
186 semaphore: Arc<Semaphore>,
187 stats: ExecutionStatsCollector,
188}
189
190impl ConcurrentExecutor {
191 pub fn new(config: ConcurrencyConfig) -> Self {
193 let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
194 Self {
195 config,
196 semaphore,
197 stats: ExecutionStatsCollector::new(),
198 }
199 }
200
201 pub async fn execute_all<T, F, Fut>(
206 &self,
207 tasks: impl IntoIterator<Item = F>,
208 ) -> (Vec<TaskResult<T>>, ExecutionStats)
209 where
210 F: FnOnce() -> Fut + Send + 'static,
211 Fut: Future<Output = Result<T, String>> + Send + 'static,
212 T: Send + 'static,
213 {
214 let start = Instant::now();
215 self.stats.reset();
216
217 let tasks: Vec<_> = tasks.into_iter().collect();
218 let total_tasks = tasks.len();
219 self.stats.total.store(total_tasks as u64, Ordering::SeqCst);
220
221 let mut futures = FuturesUnordered::new();
222
223 for (task_id, task) in tasks.into_iter().enumerate() {
224 let semaphore = Arc::clone(&self.semaphore);
225 let timeout = self.config.operation_timeout;
226 let stats = self.stats.clone();
227
228 let future = async move {
229 let _permit = semaphore.acquire().await.expect("Semaphore closed");
231 stats.increment_concurrent();
232
233 let task_start = Instant::now();
234 let result = if let Some(timeout_duration) = timeout {
235 match tokio::time::timeout(timeout_duration, task()).await {
236 Ok(Ok(value)) => TaskResult::Success {
237 task_id,
238 value,
239 duration: task_start.elapsed(),
240 },
241 Ok(Err(msg)) => TaskResult::Error(TaskError {
242 task_id,
243 message: msg,
244 is_timeout: false,
245 }),
246 Err(_) => TaskResult::Error(TaskError {
247 task_id,
248 message: format!("Timeout after {:?}", timeout_duration),
249 is_timeout: true,
250 }),
251 }
252 } else {
253 match task().await {
254 Ok(value) => TaskResult::Success {
255 task_id,
256 value,
257 duration: task_start.elapsed(),
258 },
259 Err(msg) => TaskResult::Error(TaskError {
260 task_id,
261 message: msg,
262 is_timeout: false,
263 }),
264 }
265 };
266
267 stats.decrement_concurrent();
268
269 match &result {
270 TaskResult::Success { duration, .. } => {
271 stats.record_success(*duration);
272 }
273 TaskResult::Error(e) if e.is_timeout => {
274 stats.record_timeout();
275 }
276 TaskResult::Error(_) => {
277 stats.record_failure();
278 }
279 }
280
281 result
282 };
283
284 futures.push(future);
285 }
286
287 let mut results = Vec::with_capacity(total_tasks);
289
290 while let Some(result) = futures.next().await {
291 if !self.config.continue_on_error {
292 if let TaskResult::Error(ref _e) = result {
293 drop(futures);
295 results.push(result);
296
297 let stats = self.stats.finalize(start.elapsed());
298 return (results, stats);
299 }
300 }
301 results.push(result);
302 }
303
304 results.sort_by_key(|r| match r {
306 TaskResult::Success { task_id, .. } => *task_id,
307 TaskResult::Error(e) => e.task_id,
308 });
309
310 let stats = self.stats.finalize(start.elapsed());
311 (results, stats)
312 }
313
314 pub async fn execute_collect<T, F, Fut>(
318 &self,
319 tasks: impl IntoIterator<Item = F>,
320 ) -> (Vec<T>, Vec<TaskError>)
321 where
322 F: FnOnce() -> Fut + Send + 'static,
323 Fut: Future<Output = Result<T, String>> + Send + 'static,
324 T: Send + 'static,
325 {
326 let (results, _) = self.execute_all(tasks).await;
327
328 let mut values = Vec::new();
329 let mut errors = Vec::new();
330
331 for result in results {
332 match result {
333 TaskResult::Success { value, .. } => values.push(value),
334 TaskResult::Error(e) => errors.push(e),
335 }
336 }
337
338 (values, errors)
339 }
340
341 pub async fn execute_indexed<T, F, Fut>(
346 &self,
347 tasks: impl IntoIterator<Item = F>,
348 ) -> std::collections::HashMap<usize, Result<T, TaskError>>
349 where
350 F: FnOnce() -> Fut + Send + 'static,
351 Fut: Future<Output = Result<T, String>> + Send + 'static,
352 T: Send + 'static,
353 {
354 let (results, _) = self.execute_all(tasks).await;
355
356 results
357 .into_iter()
358 .map(|r| match r {
359 TaskResult::Success { task_id, value, .. } => (task_id, Ok(value)),
360 TaskResult::Error(e) => (e.task_id, Err(e)),
361 })
362 .collect()
363 }
364}
365
366#[derive(Clone)]
368struct ExecutionStatsCollector {
369 total: Arc<AtomicU64>,
370 successful: Arc<AtomicU64>,
371 failed: Arc<AtomicU64>,
372 timed_out: Arc<AtomicU64>,
373 total_task_duration_ns: Arc<AtomicU64>,
374 current_concurrent: Arc<AtomicU64>,
375 max_concurrent: Arc<AtomicU64>,
376}
377
378impl ExecutionStatsCollector {
379 fn new() -> Self {
380 Self {
381 total: Arc::new(AtomicU64::new(0)),
382 successful: Arc::new(AtomicU64::new(0)),
383 failed: Arc::new(AtomicU64::new(0)),
384 timed_out: Arc::new(AtomicU64::new(0)),
385 total_task_duration_ns: Arc::new(AtomicU64::new(0)),
386 current_concurrent: Arc::new(AtomicU64::new(0)),
387 max_concurrent: Arc::new(AtomicU64::new(0)),
388 }
389 }
390
391 fn reset(&self) {
392 self.total.store(0, Ordering::SeqCst);
393 self.successful.store(0, Ordering::SeqCst);
394 self.failed.store(0, Ordering::SeqCst);
395 self.timed_out.store(0, Ordering::SeqCst);
396 self.total_task_duration_ns.store(0, Ordering::SeqCst);
397 self.current_concurrent.store(0, Ordering::SeqCst);
398 self.max_concurrent.store(0, Ordering::SeqCst);
399 }
400
401 fn increment_concurrent(&self) {
402 let current = self.current_concurrent.fetch_add(1, Ordering::SeqCst) + 1;
403 self.max_concurrent.fetch_max(current, Ordering::SeqCst);
404 }
405
406 fn decrement_concurrent(&self) {
407 self.current_concurrent.fetch_sub(1, Ordering::SeqCst);
408 }
409
410 fn record_success(&self, duration: Duration) {
411 self.successful.fetch_add(1, Ordering::SeqCst);
412 self.total_task_duration_ns
413 .fetch_add(duration.as_nanos() as u64, Ordering::SeqCst);
414 }
415
416 fn record_failure(&self) {
417 self.failed.fetch_add(1, Ordering::SeqCst);
418 }
419
420 fn record_timeout(&self) {
421 self.timed_out.fetch_add(1, Ordering::SeqCst);
422 self.failed.fetch_add(1, Ordering::SeqCst);
423 }
424
425 fn finalize(&self, total_duration: Duration) -> ExecutionStats {
426 let successful = self.successful.load(Ordering::SeqCst);
427 let total_task_duration_ns = self.total_task_duration_ns.load(Ordering::SeqCst);
428
429 let avg_task_duration = if successful > 0 {
430 Duration::from_nanos(total_task_duration_ns / successful)
431 } else {
432 Duration::ZERO
433 };
434
435 ExecutionStats {
436 total_tasks: self.total.load(Ordering::SeqCst),
437 successful,
438 failed: self.failed.load(Ordering::SeqCst),
439 timed_out: self.timed_out.load(Ordering::SeqCst),
440 total_duration,
441 avg_task_duration,
442 max_concurrent: self.max_concurrent.load(Ordering::SeqCst) as usize,
443 }
444 }
445}
446
447pub async fn execute_batch<T, I, F, Fut>(
452 items: I,
453 max_concurrency: usize,
454 operation: F,
455) -> Vec<Result<T, String>>
456where
457 I: IntoIterator,
458 F: Fn(I::Item) -> Fut + Clone + Send + 'static,
459 Fut: Future<Output = Result<T, String>> + Send + 'static,
460 T: Send + 'static,
461 I::Item: Send + 'static,
462{
463 let config = ConcurrencyConfig::default().with_max_concurrency(max_concurrency);
464 let executor = ConcurrentExecutor::new(config);
465
466 let tasks: Vec<_> = items
467 .into_iter()
468 .map(|item| {
469 let op = operation.clone();
470 move || op(item)
471 })
472 .collect();
473
474 let (results, _) = executor.execute_all(tasks).await;
475
476 results
477 .into_iter()
478 .map(|r| match r {
479 TaskResult::Success { value, .. } => Ok(value),
480 TaskResult::Error(e) => Err(e.message),
481 })
482 .collect()
483}
484
485pub async fn execute_chunked<T, I, F, Fut>(
490 items: I,
491 chunk_size: usize,
492 max_concurrency: usize,
493 operation: F,
494) -> Vec<Vec<Result<T, String>>>
495where
496 I: IntoIterator,
497 I::IntoIter: ExactSizeIterator,
498 F: Fn(Vec<I::Item>) -> Fut + Clone + Send + 'static,
499 Fut: Future<Output = Vec<Result<T, String>>> + Send + 'static,
500 T: Send + 'static,
501 I::Item: Send + Clone + 'static,
502{
503 let items: Vec<_> = items.into_iter().collect();
504 let chunks: Vec<Vec<_>> = items.chunks(chunk_size).map(|c| c.to_vec()).collect();
505
506 let config = ConcurrencyConfig::default().with_max_concurrency(max_concurrency);
507 let executor = ConcurrentExecutor::new(config);
508
509 let tasks: Vec<_> = chunks
510 .into_iter()
511 .map(|chunk| {
512 let op = operation.clone();
513 move || async move { Ok::<_, String>(op(chunk).await) }
514 })
515 .collect();
516
517 let (results, _) = executor.execute_all(tasks).await;
518
519 results.into_iter().filter_map(|r| r.into_value()).collect()
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use std::sync::atomic::AtomicUsize;
526
527 #[tokio::test]
528 async fn test_concurrent_executor_basic() {
529 let executor = ConcurrentExecutor::new(ConcurrencyConfig::default());
530
531 let tasks: Vec<_> = (0..10)
532 .map(|i| move || async move { Ok::<_, String>(i * 2) })
533 .collect();
534
535 let (results, stats) = executor.execute_all(tasks).await;
536
537 assert_eq!(results.len(), 10);
538 assert_eq!(stats.total_tasks, 10);
539 assert_eq!(stats.successful, 10);
540 assert_eq!(stats.failed, 0);
541
542 for (i, result) in results.iter().enumerate() {
544 match result {
545 TaskResult::Success { value, .. } => {
546 assert_eq!(*value, i * 2);
547 }
548 _ => panic!("Expected success"),
549 }
550 }
551 }
552
553 #[tokio::test]
554 async fn test_concurrent_executor_with_errors() {
555 let config = ConcurrencyConfig::default().with_continue_on_error(true);
556 let executor = ConcurrentExecutor::new(config);
557
558 let tasks: Vec<_> = (0..5)
559 .map(|i| {
560 move || async move {
561 if i == 2 {
562 Err("Task 2 failed".to_string())
563 } else {
564 Ok::<_, String>(i)
565 }
566 }
567 })
568 .collect();
569
570 let (results, stats) = executor.execute_all(tasks).await;
571
572 assert_eq!(results.len(), 5);
573 assert_eq!(stats.successful, 4);
574 assert_eq!(stats.failed, 1);
575 }
576
577 #[tokio::test]
578 async fn test_concurrent_executor_fail_fast() {
579 let config = ConcurrencyConfig::default()
580 .with_continue_on_error(false)
581 .with_max_concurrency(1); let executor = ConcurrentExecutor::new(config);
584 let counter = Arc::new(AtomicUsize::new(0));
585
586 let tasks: Vec<_> = (0..5)
587 .map(|i| {
588 let counter = Arc::clone(&counter);
589 move || async move {
590 counter.fetch_add(1, Ordering::SeqCst);
591 if i == 2 {
592 Err("Task 2 failed".to_string())
593 } else {
594 Ok::<_, String>(i)
595 }
596 }
597 })
598 .collect();
599
600 let (results, _) = executor.execute_all(tasks).await;
601
602 let has_error = results.iter().any(|r| matches!(r, TaskResult::Error(_)));
604 assert!(has_error);
605 }
606
607 #[tokio::test]
608 async fn test_concurrent_executor_respects_concurrency() {
609 let max_concurrent = Arc::new(AtomicUsize::new(0));
610 let current = Arc::new(AtomicUsize::new(0));
611
612 let config = ConcurrencyConfig::default().with_max_concurrency(3);
613 let executor = ConcurrentExecutor::new(config);
614
615 let tasks: Vec<_> = (0..20)
616 .map(|i| {
617 let max_concurrent = Arc::clone(&max_concurrent);
618 let current = Arc::clone(¤t);
619 move || async move {
620 let c = current.fetch_add(1, Ordering::SeqCst) + 1;
621 max_concurrent.fetch_max(c, Ordering::SeqCst);
622
623 tokio::time::sleep(Duration::from_millis(10)).await;
625
626 current.fetch_sub(1, Ordering::SeqCst);
627 Ok::<_, String>(i)
628 }
629 })
630 .collect();
631
632 let (results, stats) = executor.execute_all(tasks).await;
633
634 assert_eq!(results.len(), 20);
635 assert!(stats.max_concurrent <= 3);
636 assert!(max_concurrent.load(Ordering::SeqCst) <= 3);
637 }
638
639 #[tokio::test]
640 async fn test_execute_batch() {
641 let results = execute_batch(vec!["a", "b", "c"], 4, |s: &str| async move {
642 Ok::<_, String>(s.len())
643 })
644 .await;
645
646 assert_eq!(results.len(), 3);
647 assert!(results.iter().all(|r| r.is_ok()));
648 }
649
650 #[tokio::test]
651 async fn test_timeout() {
652 let config = ConcurrencyConfig::default().with_timeout(Duration::from_millis(50));
653 let executor = ConcurrentExecutor::new(config);
654
655 let tasks: Vec<
656 Box<
657 dyn FnOnce() -> std::pin::Pin<Box<dyn Future<Output = Result<i32, String>> + Send>>
658 + Send,
659 >,
660 > = vec![
661 Box::new(|| {
662 Box::pin(async {
663 tokio::time::sleep(Duration::from_millis(10)).await;
664 Ok::<_, String>(1)
665 })
666 }),
667 Box::new(|| {
668 Box::pin(async {
669 tokio::time::sleep(Duration::from_millis(200)).await;
670 Ok::<_, String>(2)
671 })
672 }),
673 ];
674
675 let (results, stats) = executor.execute_all(tasks).await;
676
677 assert_eq!(results.len(), 2);
678 assert_eq!(stats.timed_out, 1);
679 }
680}