1use std::future::Future;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12use futures::stream::{FuturesUnordered, StreamExt};
13use tokio::sync::Semaphore;
14
15#[derive(Debug, Clone)]
17pub struct ConcurrencyConfig {
18 pub max_concurrency: usize,
20 pub operation_timeout: Option<Duration>,
22 pub continue_on_error: bool,
24 pub collect_stats: bool,
26}
27
28impl Default for ConcurrencyConfig {
29 fn default() -> Self {
30 Self {
31 max_concurrency: num_cpus::get().max(4),
32 operation_timeout: Some(Duration::from_secs(30)),
33 continue_on_error: true,
34 collect_stats: true,
35 }
36 }
37}
38
39impl ConcurrencyConfig {
40 #[must_use]
42 pub fn for_introspection() -> Self {
43 Self {
44 max_concurrency: 8, operation_timeout: Some(Duration::from_secs(60)),
46 continue_on_error: true,
47 collect_stats: true,
48 }
49 }
50
51 #[must_use]
53 pub fn for_migrations() -> Self {
54 Self {
55 max_concurrency: 4, operation_timeout: Some(Duration::from_secs(120)),
57 continue_on_error: false, collect_stats: true,
59 }
60 }
61
62 #[must_use]
64 pub fn for_bulk_operations() -> Self {
65 Self {
66 max_concurrency: 16, operation_timeout: Some(Duration::from_secs(300)),
68 continue_on_error: true,
69 collect_stats: true,
70 }
71 }
72
73 #[must_use]
75 pub fn with_max_concurrency(mut self, max: usize) -> Self {
76 self.max_concurrency = max.max(1);
77 self
78 }
79
80 #[must_use]
82 pub fn with_timeout(mut self, timeout: Duration) -> Self {
83 self.operation_timeout = Some(timeout);
84 self
85 }
86
87 #[must_use]
89 pub fn without_timeout(mut self) -> Self {
90 self.operation_timeout = None;
91 self
92 }
93
94 #[must_use]
96 pub fn with_continue_on_error(mut self, continue_on_error: bool) -> Self {
97 self.continue_on_error = continue_on_error;
98 self
99 }
100}
101
102#[derive(Debug, Clone)]
104pub struct TaskError {
105 pub task_id: usize,
107 pub message: String,
109 pub is_timeout: bool,
111}
112
113impl std::fmt::Display for TaskError {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 if self.is_timeout {
116 write!(f, "Task {} timed out: {}", self.task_id, self.message)
117 } else {
118 write!(f, "Task {} failed: {}", self.task_id, self.message)
119 }
120 }
121}
122
123impl std::error::Error for TaskError {}
124
125#[derive(Debug)]
127pub enum TaskResult<T> {
128 Success {
130 task_id: usize,
132 value: T,
134 duration: Duration,
136 },
137 Error(TaskError),
139}
140
141impl<T> TaskResult<T> {
142 pub fn is_success(&self) -> bool {
144 matches!(self, Self::Success { .. })
145 }
146
147 pub fn into_value(self) -> Option<T> {
149 match self {
150 Self::Success { value, .. } => Some(value),
151 Self::Error(_) => None,
152 }
153 }
154
155 pub fn into_error(self) -> Option<TaskError> {
157 match self {
158 Self::Success { .. } => None,
159 Self::Error(e) => Some(e),
160 }
161 }
162}
163
164#[derive(Debug, Clone, Default)]
166pub struct ExecutionStats {
167 pub total_tasks: u64,
169 pub successful: u64,
171 pub failed: u64,
173 pub timed_out: u64,
175 pub total_duration: Duration,
177 pub avg_task_duration: Duration,
179 pub max_concurrent: usize,
181}
182
183pub struct ConcurrentExecutor {
185 config: ConcurrencyConfig,
186 semaphore: Arc<Semaphore>,
187 stats: ExecutionStatsCollector,
188}
189
190impl ConcurrentExecutor {
191 pub fn new(config: ConcurrencyConfig) -> Self {
193 let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
194 Self {
195 config,
196 semaphore,
197 stats: ExecutionStatsCollector::new(),
198 }
199 }
200
201 pub async fn execute_all<T, F, Fut>(
206 &self,
207 tasks: impl IntoIterator<Item = F>,
208 ) -> (Vec<TaskResult<T>>, ExecutionStats)
209 where
210 F: FnOnce() -> Fut + Send + 'static,
211 Fut: Future<Output = Result<T, String>> + Send + 'static,
212 T: Send + 'static,
213 {
214 let start = Instant::now();
215 self.stats.reset();
216
217 let tasks: Vec<_> = tasks.into_iter().collect();
218 let total_tasks = tasks.len();
219 self.stats.total.store(total_tasks as u64, Ordering::SeqCst);
220
221 let mut futures = FuturesUnordered::new();
222
223 for (task_id, task) in tasks.into_iter().enumerate() {
224 let semaphore = Arc::clone(&self.semaphore);
225 let timeout = self.config.operation_timeout;
226 let stats = self.stats.clone();
227
228 let future = async move {
229 let _permit = semaphore.acquire().await.expect("Semaphore closed");
231 stats.increment_concurrent();
232
233 let task_start = Instant::now();
234 let result = if let Some(timeout_duration) = timeout {
235 match tokio::time::timeout(timeout_duration, task()).await {
236 Ok(Ok(value)) => TaskResult::Success {
237 task_id,
238 value,
239 duration: task_start.elapsed(),
240 },
241 Ok(Err(msg)) => TaskResult::Error(TaskError {
242 task_id,
243 message: msg,
244 is_timeout: false,
245 }),
246 Err(_) => TaskResult::Error(TaskError {
247 task_id,
248 message: format!("Timeout after {:?}", timeout_duration),
249 is_timeout: true,
250 }),
251 }
252 } else {
253 match task().await {
254 Ok(value) => TaskResult::Success {
255 task_id,
256 value,
257 duration: task_start.elapsed(),
258 },
259 Err(msg) => TaskResult::Error(TaskError {
260 task_id,
261 message: msg,
262 is_timeout: false,
263 }),
264 }
265 };
266
267 stats.decrement_concurrent();
268
269 match &result {
270 TaskResult::Success { duration, .. } => {
271 stats.record_success(*duration);
272 }
273 TaskResult::Error(e) if e.is_timeout => {
274 stats.record_timeout();
275 }
276 TaskResult::Error(_) => {
277 stats.record_failure();
278 }
279 }
280
281 result
282 };
283
284 futures.push(future);
285 }
286
287 let mut results = Vec::with_capacity(total_tasks);
289
290 while let Some(result) = futures.next().await {
291 if !self.config.continue_on_error {
292 if let TaskResult::Error(ref _e) = result {
293 drop(futures);
295 results.push(result);
296
297 let stats = self.stats.finalize(start.elapsed());
298 return (results, stats);
299 }
300 }
301 results.push(result);
302 }
303
304 results.sort_by_key(|r| match r {
306 TaskResult::Success { task_id, .. } => *task_id,
307 TaskResult::Error(e) => e.task_id,
308 });
309
310 let stats = self.stats.finalize(start.elapsed());
311 (results, stats)
312 }
313
314 pub async fn execute_collect<T, F, Fut>(
318 &self,
319 tasks: impl IntoIterator<Item = F>,
320 ) -> (Vec<T>, Vec<TaskError>)
321 where
322 F: FnOnce() -> Fut + Send + 'static,
323 Fut: Future<Output = Result<T, String>> + Send + 'static,
324 T: Send + 'static,
325 {
326 let (results, _) = self.execute_all(tasks).await;
327
328 let mut values = Vec::new();
329 let mut errors = Vec::new();
330
331 for result in results {
332 match result {
333 TaskResult::Success { value, .. } => values.push(value),
334 TaskResult::Error(e) => errors.push(e),
335 }
336 }
337
338 (values, errors)
339 }
340
341 pub async fn execute_indexed<T, F, Fut>(
346 &self,
347 tasks: impl IntoIterator<Item = F>,
348 ) -> std::collections::HashMap<usize, Result<T, TaskError>>
349 where
350 F: FnOnce() -> Fut + Send + 'static,
351 Fut: Future<Output = Result<T, String>> + Send + 'static,
352 T: Send + 'static,
353 {
354 let (results, _) = self.execute_all(tasks).await;
355
356 results
357 .into_iter()
358 .map(|r| match r {
359 TaskResult::Success {
360 task_id, value, ..
361 } => (task_id, Ok(value)),
362 TaskResult::Error(e) => (e.task_id, Err(e)),
363 })
364 .collect()
365 }
366}
367
368#[derive(Clone)]
370struct ExecutionStatsCollector {
371 total: Arc<AtomicU64>,
372 successful: Arc<AtomicU64>,
373 failed: Arc<AtomicU64>,
374 timed_out: Arc<AtomicU64>,
375 total_task_duration_ns: Arc<AtomicU64>,
376 current_concurrent: Arc<AtomicU64>,
377 max_concurrent: Arc<AtomicU64>,
378}
379
380impl ExecutionStatsCollector {
381 fn new() -> Self {
382 Self {
383 total: Arc::new(AtomicU64::new(0)),
384 successful: Arc::new(AtomicU64::new(0)),
385 failed: Arc::new(AtomicU64::new(0)),
386 timed_out: Arc::new(AtomicU64::new(0)),
387 total_task_duration_ns: Arc::new(AtomicU64::new(0)),
388 current_concurrent: Arc::new(AtomicU64::new(0)),
389 max_concurrent: Arc::new(AtomicU64::new(0)),
390 }
391 }
392
393 fn reset(&self) {
394 self.total.store(0, Ordering::SeqCst);
395 self.successful.store(0, Ordering::SeqCst);
396 self.failed.store(0, Ordering::SeqCst);
397 self.timed_out.store(0, Ordering::SeqCst);
398 self.total_task_duration_ns.store(0, Ordering::SeqCst);
399 self.current_concurrent.store(0, Ordering::SeqCst);
400 self.max_concurrent.store(0, Ordering::SeqCst);
401 }
402
403 fn increment_concurrent(&self) {
404 let current = self.current_concurrent.fetch_add(1, Ordering::SeqCst) + 1;
405 self.max_concurrent.fetch_max(current, Ordering::SeqCst);
406 }
407
408 fn decrement_concurrent(&self) {
409 self.current_concurrent.fetch_sub(1, Ordering::SeqCst);
410 }
411
412 fn record_success(&self, duration: Duration) {
413 self.successful.fetch_add(1, Ordering::SeqCst);
414 self.total_task_duration_ns
415 .fetch_add(duration.as_nanos() as u64, Ordering::SeqCst);
416 }
417
418 fn record_failure(&self) {
419 self.failed.fetch_add(1, Ordering::SeqCst);
420 }
421
422 fn record_timeout(&self) {
423 self.timed_out.fetch_add(1, Ordering::SeqCst);
424 self.failed.fetch_add(1, Ordering::SeqCst);
425 }
426
427 fn finalize(&self, total_duration: Duration) -> ExecutionStats {
428 let successful = self.successful.load(Ordering::SeqCst);
429 let total_task_duration_ns = self.total_task_duration_ns.load(Ordering::SeqCst);
430
431 let avg_task_duration = if successful > 0 {
432 Duration::from_nanos(total_task_duration_ns / successful)
433 } else {
434 Duration::ZERO
435 };
436
437 ExecutionStats {
438 total_tasks: self.total.load(Ordering::SeqCst),
439 successful,
440 failed: self.failed.load(Ordering::SeqCst),
441 timed_out: self.timed_out.load(Ordering::SeqCst),
442 total_duration,
443 avg_task_duration,
444 max_concurrent: self.max_concurrent.load(Ordering::SeqCst) as usize,
445 }
446 }
447}
448
449pub async fn execute_batch<T, I, F, Fut>(
454 items: I,
455 max_concurrency: usize,
456 operation: F,
457) -> Vec<Result<T, String>>
458where
459 I: IntoIterator,
460 F: Fn(I::Item) -> Fut + Clone + Send + 'static,
461 Fut: Future<Output = Result<T, String>> + Send + 'static,
462 T: Send + 'static,
463 I::Item: Send + 'static,
464{
465 let config = ConcurrencyConfig::default().with_max_concurrency(max_concurrency);
466 let executor = ConcurrentExecutor::new(config);
467
468 let tasks: Vec<_> = items
469 .into_iter()
470 .map(|item| {
471 let op = operation.clone();
472 move || op(item)
473 })
474 .collect();
475
476 let (results, _) = executor.execute_all(tasks).await;
477
478 results
479 .into_iter()
480 .map(|r| match r {
481 TaskResult::Success { value, .. } => Ok(value),
482 TaskResult::Error(e) => Err(e.message),
483 })
484 .collect()
485}
486
487pub async fn execute_chunked<T, I, F, Fut>(
492 items: I,
493 chunk_size: usize,
494 max_concurrency: usize,
495 operation: F,
496) -> Vec<Vec<Result<T, String>>>
497where
498 I: IntoIterator,
499 I::IntoIter: ExactSizeIterator,
500 F: Fn(Vec<I::Item>) -> Fut + Clone + Send + 'static,
501 Fut: Future<Output = Vec<Result<T, String>>> + Send + 'static,
502 T: Send + 'static,
503 I::Item: Send + Clone + 'static,
504{
505 let items: Vec<_> = items.into_iter().collect();
506 let chunks: Vec<Vec<_>> = items.chunks(chunk_size).map(|c| c.to_vec()).collect();
507
508 let config = ConcurrencyConfig::default().with_max_concurrency(max_concurrency);
509 let executor = ConcurrentExecutor::new(config);
510
511 let tasks: Vec<_> = chunks
512 .into_iter()
513 .map(|chunk| {
514 let op = operation.clone();
515 move || async move { Ok::<_, String>(op(chunk).await) }
516 })
517 .collect();
518
519 let (results, _) = executor.execute_all(tasks).await;
520
521 results
522 .into_iter()
523 .filter_map(|r| r.into_value())
524 .collect()
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use std::sync::atomic::AtomicUsize;
531
532 #[tokio::test]
533 async fn test_concurrent_executor_basic() {
534 let executor = ConcurrentExecutor::new(ConcurrencyConfig::default());
535
536 let tasks: Vec<_> = (0..10)
537 .map(|i| move || async move { Ok::<_, String>(i * 2) })
538 .collect();
539
540 let (results, stats) = executor.execute_all(tasks).await;
541
542 assert_eq!(results.len(), 10);
543 assert_eq!(stats.total_tasks, 10);
544 assert_eq!(stats.successful, 10);
545 assert_eq!(stats.failed, 0);
546
547 for (i, result) in results.iter().enumerate() {
549 match result {
550 TaskResult::Success { value, .. } => {
551 assert_eq!(*value, i * 2);
552 }
553 _ => panic!("Expected success"),
554 }
555 }
556 }
557
558 #[tokio::test]
559 async fn test_concurrent_executor_with_errors() {
560 let config = ConcurrencyConfig::default().with_continue_on_error(true);
561 let executor = ConcurrentExecutor::new(config);
562
563 let tasks: Vec<_> = (0..5)
564 .map(|i| {
565 move || async move {
566 if i == 2 {
567 Err("Task 2 failed".to_string())
568 } else {
569 Ok::<_, String>(i)
570 }
571 }
572 })
573 .collect();
574
575 let (results, stats) = executor.execute_all(tasks).await;
576
577 assert_eq!(results.len(), 5);
578 assert_eq!(stats.successful, 4);
579 assert_eq!(stats.failed, 1);
580 }
581
582 #[tokio::test]
583 async fn test_concurrent_executor_fail_fast() {
584 let config = ConcurrencyConfig::default()
585 .with_continue_on_error(false)
586 .with_max_concurrency(1); let executor = ConcurrentExecutor::new(config);
589 let counter = Arc::new(AtomicUsize::new(0));
590
591 let tasks: Vec<_> = (0..5)
592 .map(|i| {
593 let counter = Arc::clone(&counter);
594 move || async move {
595 counter.fetch_add(1, Ordering::SeqCst);
596 if i == 2 {
597 Err("Task 2 failed".to_string())
598 } else {
599 Ok::<_, String>(i)
600 }
601 }
602 })
603 .collect();
604
605 let (results, _) = executor.execute_all(tasks).await;
606
607 let has_error = results.iter().any(|r| matches!(r, TaskResult::Error(_)));
609 assert!(has_error);
610 }
611
612 #[tokio::test]
613 async fn test_concurrent_executor_respects_concurrency() {
614 let max_concurrent = Arc::new(AtomicUsize::new(0));
615 let current = Arc::new(AtomicUsize::new(0));
616
617 let config = ConcurrencyConfig::default().with_max_concurrency(3);
618 let executor = ConcurrentExecutor::new(config);
619
620 let tasks: Vec<_> = (0..20)
621 .map(|i| {
622 let max_concurrent = Arc::clone(&max_concurrent);
623 let current = Arc::clone(¤t);
624 move || async move {
625 let c = current.fetch_add(1, Ordering::SeqCst) + 1;
626 max_concurrent.fetch_max(c, Ordering::SeqCst);
627
628 tokio::time::sleep(Duration::from_millis(10)).await;
630
631 current.fetch_sub(1, Ordering::SeqCst);
632 Ok::<_, String>(i)
633 }
634 })
635 .collect();
636
637 let (results, stats) = executor.execute_all(tasks).await;
638
639 assert_eq!(results.len(), 20);
640 assert!(stats.max_concurrent <= 3);
641 assert!(max_concurrent.load(Ordering::SeqCst) <= 3);
642 }
643
644 #[tokio::test]
645 async fn test_execute_batch() {
646 let results = execute_batch(
647 vec!["a", "b", "c"],
648 4,
649 |s: &str| async move { Ok::<_, String>(s.len()) },
650 )
651 .await;
652
653 assert_eq!(results.len(), 3);
654 assert!(results.iter().all(|r| r.is_ok()));
655 }
656
657 #[tokio::test]
658 async fn test_timeout() {
659 let config = ConcurrencyConfig::default().with_timeout(Duration::from_millis(50));
660 let executor = ConcurrentExecutor::new(config);
661
662 let tasks: Vec<Box<dyn FnOnce() -> std::pin::Pin<Box<dyn Future<Output = Result<i32, String>> + Send>> + Send>> = vec![
663 Box::new(|| Box::pin(async {
664 tokio::time::sleep(Duration::from_millis(10)).await;
665 Ok::<_, String>(1)
666 })),
667 Box::new(|| Box::pin(async {
668 tokio::time::sleep(Duration::from_millis(200)).await;
669 Ok::<_, String>(2)
670 })),
671 ];
672
673 let (results, stats) = executor.execute_all(tasks).await;
674
675 assert_eq!(results.len(), 2);
676 assert_eq!(stats.timed_out, 1);
677 }
678}
679