1use std::future::Future;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use futures::stream::{FuturesUnordered, StreamExt};
13use tokio::sync::Semaphore;
14
15#[derive(Debug, Clone)]
17pub struct ConcurrencyConfig {
18 pub max_concurrency: usize,
20 pub operation_timeout: Option<Duration>,
22 pub continue_on_error: bool,
24 pub collect_stats: bool,
26}
27
28impl Default for ConcurrencyConfig {
29 fn default() -> Self {
30 Self {
31 max_concurrency: num_cpus::get().max(4),
32 operation_timeout: Some(Duration::from_secs(30)),
33 continue_on_error: true,
34 collect_stats: true,
35 }
36 }
37}
38
39impl ConcurrencyConfig {
40 #[must_use]
42 pub fn for_introspection() -> Self {
43 Self {
44 max_concurrency: 8, operation_timeout: Some(Duration::from_secs(60)),
46 continue_on_error: true,
47 collect_stats: true,
48 }
49 }
50
51 #[must_use]
53 pub fn for_migrations() -> Self {
54 Self {
55 max_concurrency: 4, operation_timeout: Some(Duration::from_secs(120)),
57 continue_on_error: false, collect_stats: true,
59 }
60 }
61
62 #[must_use]
64 pub fn for_bulk_operations() -> Self {
65 Self {
66 max_concurrency: 16, operation_timeout: Some(Duration::from_secs(300)),
68 continue_on_error: true,
69 collect_stats: true,
70 }
71 }
72
73 #[must_use]
75 pub fn with_max_concurrency(mut self, max: usize) -> Self {
76 self.max_concurrency = max.max(1);
77 self
78 }
79
80 #[must_use]
82 pub fn with_timeout(mut self, timeout: Duration) -> Self {
83 self.operation_timeout = Some(timeout);
84 self
85 }
86
87 #[must_use]
89 pub fn without_timeout(mut self) -> Self {
90 self.operation_timeout = None;
91 self
92 }
93
94 #[must_use]
96 pub fn with_continue_on_error(mut self, continue_on_error: bool) -> Self {
97 self.continue_on_error = continue_on_error;
98 self
99 }
100}
101
102#[derive(Debug, Clone)]
104pub struct TaskError {
105 pub task_id: usize,
107 pub message: String,
109 pub is_timeout: bool,
111}
112
113impl std::fmt::Display for TaskError {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 if self.is_timeout {
116 write!(f, "Task {} timed out: {}", self.task_id, self.message)
117 } else {
118 write!(f, "Task {} failed: {}", self.task_id, self.message)
119 }
120 }
121}
122
123impl std::error::Error for TaskError {}
124
125#[derive(Debug)]
127pub enum TaskResult<T> {
128 Success {
130 task_id: usize,
132 value: T,
134 duration: Duration,
136 },
137 Error(TaskError),
139}
140
141impl<T> TaskResult<T> {
142 pub fn is_success(&self) -> bool {
144 matches!(self, Self::Success { .. })
145 }
146
147 pub fn into_value(self) -> Option<T> {
149 match self {
150 Self::Success { value, .. } => Some(value),
151 Self::Error(_) => None,
152 }
153 }
154
155 pub fn into_error(self) -> Option<TaskError> {
157 match self {
158 Self::Success { .. } => None,
159 Self::Error(e) => Some(e),
160 }
161 }
162}
163
164#[derive(Debug, Clone, Default)]
166pub struct ExecutionStats {
167 pub total_tasks: u64,
169 pub successful: u64,
171 pub failed: u64,
173 pub timed_out: u64,
175 pub total_duration: Duration,
177 pub avg_task_duration: Duration,
179 pub max_concurrent: usize,
181}
182
183pub struct ConcurrentExecutor {
185 config: ConcurrencyConfig,
186 semaphore: Arc<Semaphore>,
187 stats: ExecutionStatsCollector,
188}
189
190impl ConcurrentExecutor {
191 pub fn new(config: ConcurrencyConfig) -> Self {
193 let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
194 Self {
195 config,
196 semaphore,
197 stats: ExecutionStatsCollector::new(),
198 }
199 }
200
201 pub async fn execute_all<T, F, Fut>(
206 &self,
207 tasks: impl IntoIterator<Item = F>,
208 ) -> (Vec<TaskResult<T>>, ExecutionStats)
209 where
210 F: FnOnce() -> Fut + Send + 'static,
211 Fut: Future<Output = Result<T, String>> + Send + 'static,
212 T: Send + 'static,
213 {
214 let start = Instant::now();
215 self.stats.reset();
216
217 let tasks: Vec<_> = tasks.into_iter().collect();
218 let total_tasks = tasks.len();
219 self.stats.total.store(total_tasks as u64, Ordering::SeqCst);
220
221 let mut futures = FuturesUnordered::new();
222
223 for (task_id, task) in tasks.into_iter().enumerate() {
224 let semaphore = Arc::clone(&self.semaphore);
225 let timeout = self.config.operation_timeout;
226 let stats = self.stats.clone();
227
228 let future = async move {
229 let _permit = semaphore.acquire().await.expect("Semaphore closed");
231 stats.increment_concurrent();
232
233 let task_start = Instant::now();
234 let result = if let Some(timeout_duration) = timeout {
235 match tokio::time::timeout(timeout_duration, task()).await {
236 Ok(Ok(value)) => TaskResult::Success {
237 task_id,
238 value,
239 duration: task_start.elapsed(),
240 },
241 Ok(Err(msg)) => TaskResult::Error(TaskError {
242 task_id,
243 message: msg,
244 is_timeout: false,
245 }),
246 Err(_) => TaskResult::Error(TaskError {
247 task_id,
248 message: format!("Timeout after {:?}", timeout_duration),
249 is_timeout: true,
250 }),
251 }
252 } else {
253 match task().await {
254 Ok(value) => TaskResult::Success {
255 task_id,
256 value,
257 duration: task_start.elapsed(),
258 },
259 Err(msg) => TaskResult::Error(TaskError {
260 task_id,
261 message: msg,
262 is_timeout: false,
263 }),
264 }
265 };
266
267 stats.decrement_concurrent();
268
269 match &result {
270 TaskResult::Success { duration, .. } => {
271 stats.record_success(*duration);
272 }
273 TaskResult::Error(e) if e.is_timeout => {
274 stats.record_timeout();
275 }
276 TaskResult::Error(_) => {
277 stats.record_failure();
278 }
279 }
280
281 result
282 };
283
284 futures.push(future);
285 }
286
287 let mut results = Vec::with_capacity(total_tasks);
289
290 while let Some(result) = futures.next().await {
291 if !self.config.continue_on_error
292 && let TaskResult::Error(ref _e) = result
293 {
294 drop(futures);
296 results.push(result);
297
298 let stats = self.stats.finalize(start.elapsed());
299 return (results, stats);
300 }
301 results.push(result);
302 }
303
304 results.sort_by_key(|r| match r {
306 TaskResult::Success { task_id, .. } => *task_id,
307 TaskResult::Error(e) => e.task_id,
308 });
309
310 let stats = self.stats.finalize(start.elapsed());
311 (results, stats)
312 }
313
314 pub async fn execute_collect<T, F, Fut>(
318 &self,
319 tasks: impl IntoIterator<Item = F>,
320 ) -> (Vec<T>, Vec<TaskError>)
321 where
322 F: FnOnce() -> Fut + Send + 'static,
323 Fut: Future<Output = Result<T, String>> + Send + 'static,
324 T: Send + 'static,
325 {
326 let (results, _) = self.execute_all(tasks).await;
327
328 let mut values = Vec::new();
329 let mut errors = Vec::new();
330
331 for result in results {
332 match result {
333 TaskResult::Success { value, .. } => values.push(value),
334 TaskResult::Error(e) => errors.push(e),
335 }
336 }
337
338 (values, errors)
339 }
340
341 pub async fn execute_indexed<T, F, Fut>(
346 &self,
347 tasks: impl IntoIterator<Item = F>,
348 ) -> std::collections::HashMap<usize, Result<T, TaskError>>
349 where
350 F: FnOnce() -> Fut + Send + 'static,
351 Fut: Future<Output = Result<T, String>> + Send + 'static,
352 T: Send + 'static,
353 {
354 let (results, _) = self.execute_all(tasks).await;
355
356 results
357 .into_iter()
358 .map(|r| match r {
359 TaskResult::Success { task_id, value, .. } => (task_id, Ok(value)),
360 TaskResult::Error(e) => (e.task_id, Err(e)),
361 })
362 .collect()
363 }
364}
365
366#[derive(Clone)]
368struct ExecutionStatsCollector {
369 total: Arc<AtomicU64>,
370 successful: Arc<AtomicU64>,
371 failed: Arc<AtomicU64>,
372 timed_out: Arc<AtomicU64>,
373 total_task_duration_ns: Arc<AtomicU64>,
374 current_concurrent: Arc<AtomicU64>,
375 max_concurrent: Arc<AtomicU64>,
376}
377
378impl ExecutionStatsCollector {
379 fn new() -> Self {
380 Self {
381 total: Arc::new(AtomicU64::new(0)),
382 successful: Arc::new(AtomicU64::new(0)),
383 failed: Arc::new(AtomicU64::new(0)),
384 timed_out: Arc::new(AtomicU64::new(0)),
385 total_task_duration_ns: Arc::new(AtomicU64::new(0)),
386 current_concurrent: Arc::new(AtomicU64::new(0)),
387 max_concurrent: Arc::new(AtomicU64::new(0)),
388 }
389 }
390
391 fn reset(&self) {
392 self.total.store(0, Ordering::SeqCst);
393 self.successful.store(0, Ordering::SeqCst);
394 self.failed.store(0, Ordering::SeqCst);
395 self.timed_out.store(0, Ordering::SeqCst);
396 self.total_task_duration_ns.store(0, Ordering::SeqCst);
397 self.current_concurrent.store(0, Ordering::SeqCst);
398 self.max_concurrent.store(0, Ordering::SeqCst);
399 }
400
401 fn increment_concurrent(&self) {
402 let current = self.current_concurrent.fetch_add(1, Ordering::SeqCst) + 1;
403 self.max_concurrent.fetch_max(current, Ordering::SeqCst);
404 }
405
406 fn decrement_concurrent(&self) {
407 self.current_concurrent.fetch_sub(1, Ordering::SeqCst);
408 }
409
410 fn record_success(&self, duration: Duration) {
411 self.successful.fetch_add(1, Ordering::SeqCst);
412 self.total_task_duration_ns
413 .fetch_add(duration.as_nanos() as u64, Ordering::SeqCst);
414 }
415
416 fn record_failure(&self) {
417 self.failed.fetch_add(1, Ordering::SeqCst);
418 }
419
420 fn record_timeout(&self) {
421 self.timed_out.fetch_add(1, Ordering::SeqCst);
422 self.failed.fetch_add(1, Ordering::SeqCst);
423 }
424
425 fn finalize(&self, total_duration: Duration) -> ExecutionStats {
426 let successful = self.successful.load(Ordering::SeqCst);
427 let total_task_duration_ns = self.total_task_duration_ns.load(Ordering::SeqCst);
428
429 let avg_task_duration = total_task_duration_ns
430 .checked_div(successful)
431 .map(Duration::from_nanos)
432 .unwrap_or(Duration::ZERO);
433
434 ExecutionStats {
435 total_tasks: self.total.load(Ordering::SeqCst),
436 successful,
437 failed: self.failed.load(Ordering::SeqCst),
438 timed_out: self.timed_out.load(Ordering::SeqCst),
439 total_duration,
440 avg_task_duration,
441 max_concurrent: self.max_concurrent.load(Ordering::SeqCst) as usize,
442 }
443 }
444}
445
446pub async fn execute_batch<T, I, F, Fut>(
451 items: I,
452 max_concurrency: usize,
453 operation: F,
454) -> Vec<Result<T, String>>
455where
456 I: IntoIterator,
457 F: Fn(I::Item) -> Fut + Clone + Send + 'static,
458 Fut: Future<Output = Result<T, String>> + Send + 'static,
459 T: Send + 'static,
460 I::Item: Send + 'static,
461{
462 let config = ConcurrencyConfig::default().with_max_concurrency(max_concurrency);
463 let executor = ConcurrentExecutor::new(config);
464
465 let tasks: Vec<_> = items
466 .into_iter()
467 .map(|item| {
468 let op = operation.clone();
469 move || op(item)
470 })
471 .collect();
472
473 let (results, _) = executor.execute_all(tasks).await;
474
475 results
476 .into_iter()
477 .map(|r| match r {
478 TaskResult::Success { value, .. } => Ok(value),
479 TaskResult::Error(e) => Err(e.message),
480 })
481 .collect()
482}
483
484pub async fn execute_chunked<T, I, F, Fut>(
489 items: I,
490 chunk_size: usize,
491 max_concurrency: usize,
492 operation: F,
493) -> Vec<Vec<Result<T, String>>>
494where
495 I: IntoIterator,
496 I::IntoIter: ExactSizeIterator,
497 F: Fn(Vec<I::Item>) -> Fut + Clone + Send + 'static,
498 Fut: Future<Output = Vec<Result<T, String>>> + Send + 'static,
499 T: Send + 'static,
500 I::Item: Send + Clone + 'static,
501{
502 let items: Vec<_> = items.into_iter().collect();
503 let chunks: Vec<Vec<_>> = items.chunks(chunk_size).map(|c| c.to_vec()).collect();
504
505 let config = ConcurrencyConfig::default().with_max_concurrency(max_concurrency);
506 let executor = ConcurrentExecutor::new(config);
507
508 let tasks: Vec<_> = chunks
509 .into_iter()
510 .map(|chunk| {
511 let op = operation.clone();
512 move || async move { Ok::<_, String>(op(chunk).await) }
513 })
514 .collect();
515
516 let (results, _) = executor.execute_all(tasks).await;
517
518 results.into_iter().filter_map(|r| r.into_value()).collect()
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use std::sync::atomic::AtomicUsize;
525
526 #[tokio::test]
527 async fn test_concurrent_executor_basic() {
528 let executor = ConcurrentExecutor::new(ConcurrencyConfig::default());
529
530 let tasks: Vec<_> = (0..10)
531 .map(|i| move || async move { Ok::<_, String>(i * 2) })
532 .collect();
533
534 let (results, stats) = executor.execute_all(tasks).await;
535
536 assert_eq!(results.len(), 10);
537 assert_eq!(stats.total_tasks, 10);
538 assert_eq!(stats.successful, 10);
539 assert_eq!(stats.failed, 0);
540
541 for (i, result) in results.iter().enumerate() {
543 match result {
544 TaskResult::Success { value, .. } => {
545 assert_eq!(*value, i * 2);
546 }
547 _ => panic!("Expected success"),
548 }
549 }
550 }
551
552 #[tokio::test]
553 async fn test_concurrent_executor_with_errors() {
554 let config = ConcurrencyConfig::default().with_continue_on_error(true);
555 let executor = ConcurrentExecutor::new(config);
556
557 let tasks: Vec<_> = (0..5)
558 .map(|i| {
559 move || async move {
560 if i == 2 {
561 Err("Task 2 failed".to_string())
562 } else {
563 Ok::<_, String>(i)
564 }
565 }
566 })
567 .collect();
568
569 let (results, stats) = executor.execute_all(tasks).await;
570
571 assert_eq!(results.len(), 5);
572 assert_eq!(stats.successful, 4);
573 assert_eq!(stats.failed, 1);
574 }
575
576 #[tokio::test]
577 async fn test_concurrent_executor_fail_fast() {
578 let config = ConcurrencyConfig::default()
579 .with_continue_on_error(false)
580 .with_max_concurrency(1); let executor = ConcurrentExecutor::new(config);
583 let counter = Arc::new(AtomicUsize::new(0));
584
585 let tasks: Vec<_> = (0..5)
586 .map(|i| {
587 let counter = Arc::clone(&counter);
588 move || async move {
589 counter.fetch_add(1, Ordering::SeqCst);
590 if i == 2 {
591 Err("Task 2 failed".to_string())
592 } else {
593 Ok::<_, String>(i)
594 }
595 }
596 })
597 .collect();
598
599 let (results, _) = executor.execute_all(tasks).await;
600
601 let has_error = results.iter().any(|r| matches!(r, TaskResult::Error(_)));
603 assert!(has_error);
604 }
605
606 #[tokio::test]
607 async fn test_concurrent_executor_respects_concurrency() {
608 let max_concurrent = Arc::new(AtomicUsize::new(0));
609 let current = Arc::new(AtomicUsize::new(0));
610
611 let config = ConcurrencyConfig::default().with_max_concurrency(3);
612 let executor = ConcurrentExecutor::new(config);
613
614 let tasks: Vec<_> = (0..20)
615 .map(|i| {
616 let max_concurrent = Arc::clone(&max_concurrent);
617 let current = Arc::clone(¤t);
618 move || async move {
619 let c = current.fetch_add(1, Ordering::SeqCst) + 1;
620 max_concurrent.fetch_max(c, Ordering::SeqCst);
621
622 tokio::time::sleep(Duration::from_millis(10)).await;
624
625 current.fetch_sub(1, Ordering::SeqCst);
626 Ok::<_, String>(i)
627 }
628 })
629 .collect();
630
631 let (results, stats) = executor.execute_all(tasks).await;
632
633 assert_eq!(results.len(), 20);
634 assert!(stats.max_concurrent <= 3);
635 assert!(max_concurrent.load(Ordering::SeqCst) <= 3);
636 }
637
638 #[tokio::test]
639 async fn test_execute_batch() {
640 let results = execute_batch(vec!["a", "b", "c"], 4, |s: &str| async move {
641 Ok::<_, String>(s.len())
642 })
643 .await;
644
645 assert_eq!(results.len(), 3);
646 assert!(results.iter().all(|r| r.is_ok()));
647 }
648
649 #[tokio::test]
650 async fn test_timeout() {
651 let config = ConcurrencyConfig::default().with_timeout(Duration::from_millis(50));
652 let executor = ConcurrentExecutor::new(config);
653
654 #[allow(clippy::type_complexity)]
655 let tasks: Vec<
656 Box<
657 dyn FnOnce() -> std::pin::Pin<Box<dyn Future<Output = Result<i32, String>> + Send>>
658 + Send,
659 >,
660 > = vec![
661 Box::new(|| {
662 Box::pin(async {
663 tokio::time::sleep(Duration::from_millis(10)).await;
664 Ok::<_, String>(1)
665 })
666 }),
667 Box::new(|| {
668 Box::pin(async {
669 tokio::time::sleep(Duration::from_millis(200)).await;
670 Ok::<_, String>(2)
671 })
672 }),
673 ];
674
675 let (results, stats) = executor.execute_all(tasks).await;
676
677 assert_eq!(results.len(), 2);
678 assert_eq!(stats.timed_out, 1);
679 }
680}