Skip to main content

prax_query/async_optimize/
concurrent.rs

1//! Concurrent task execution with controlled parallelism.
2//!
3//! This module provides utilities for executing multiple independent database
4//! operations in parallel while respecting concurrency limits to avoid
5//! overwhelming the database connection pool.
6
7use std::future::Future;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use futures::stream::{FuturesUnordered, StreamExt};
13use tokio::sync::Semaphore;
14
15/// Configuration for concurrent execution.
16#[derive(Debug, Clone)]
17pub struct ConcurrencyConfig {
18    /// Maximum number of concurrent operations.
19    pub max_concurrency: usize,
20    /// Timeout for individual operations.
21    pub operation_timeout: Option<Duration>,
22    /// Whether to continue on error (collect all errors vs fail fast).
23    pub continue_on_error: bool,
24    /// Collect timing statistics.
25    pub collect_stats: bool,
26}
27
28impl Default for ConcurrencyConfig {
29    fn default() -> Self {
30        Self {
31            max_concurrency: num_cpus::get().max(4),
32            operation_timeout: Some(Duration::from_secs(30)),
33            continue_on_error: true,
34            collect_stats: true,
35        }
36    }
37}
38
39impl ConcurrencyConfig {
40    /// Create config optimized for database introspection.
41    #[must_use]
42    pub fn for_introspection() -> Self {
43        Self {
44            max_concurrency: 8, // Balance between speed and connection usage
45            operation_timeout: Some(Duration::from_secs(60)),
46            continue_on_error: true,
47            collect_stats: true,
48        }
49    }
50
51    /// Create config optimized for migration operations.
52    #[must_use]
53    pub fn for_migrations() -> Self {
54        Self {
55            max_concurrency: 4, // More conservative for DDL
56            operation_timeout: Some(Duration::from_secs(120)),
57            continue_on_error: false, // Migrations should fail fast
58            collect_stats: true,
59        }
60    }
61
62    /// Create config optimized for bulk data operations.
63    #[must_use]
64    pub fn for_bulk_operations() -> Self {
65        Self {
66            max_concurrency: 16, // Higher parallelism for DML
67            operation_timeout: Some(Duration::from_secs(300)),
68            continue_on_error: true,
69            collect_stats: true,
70        }
71    }
72
73    /// Set maximum concurrency.
74    #[must_use]
75    pub fn with_max_concurrency(mut self, max: usize) -> Self {
76        self.max_concurrency = max.max(1);
77        self
78    }
79
80    /// Set operation timeout.
81    #[must_use]
82    pub fn with_timeout(mut self, timeout: Duration) -> Self {
83        self.operation_timeout = Some(timeout);
84        self
85    }
86
87    /// Disable timeout.
88    #[must_use]
89    pub fn without_timeout(mut self) -> Self {
90        self.operation_timeout = None;
91        self
92    }
93
94    /// Set continue on error behavior.
95    #[must_use]
96    pub fn with_continue_on_error(mut self, continue_on_error: bool) -> Self {
97        self.continue_on_error = continue_on_error;
98        self
99    }
100}
101
102/// Error from concurrent task execution.
103#[derive(Debug, Clone)]
104pub struct TaskError {
105    /// Task identifier.
106    pub task_id: usize,
107    /// Error message.
108    pub message: String,
109    /// Whether this was a timeout.
110    pub is_timeout: bool,
111}
112
113impl std::fmt::Display for TaskError {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        if self.is_timeout {
116            write!(f, "Task {} timed out: {}", self.task_id, self.message)
117        } else {
118            write!(f, "Task {} failed: {}", self.task_id, self.message)
119        }
120    }
121}
122
123impl std::error::Error for TaskError {}
124
125/// Result of a single task.
126#[derive(Debug)]
127pub enum TaskResult<T> {
128    /// Task completed successfully.
129    Success {
130        /// Task identifier.
131        task_id: usize,
132        /// The result value.
133        value: T,
134        /// Execution duration.
135        duration: Duration,
136    },
137    /// Task failed.
138    Error(TaskError),
139}
140
141impl<T> TaskResult<T> {
142    /// Check if the task succeeded.
143    pub fn is_success(&self) -> bool {
144        matches!(self, Self::Success { .. })
145    }
146
147    /// Get the value if successful.
148    pub fn into_value(self) -> Option<T> {
149        match self {
150            Self::Success { value, .. } => Some(value),
151            Self::Error(_) => None,
152        }
153    }
154
155    /// Get the error if failed.
156    pub fn into_error(self) -> Option<TaskError> {
157        match self {
158            Self::Success { .. } => None,
159            Self::Error(e) => Some(e),
160        }
161    }
162}
163
164/// Statistics from concurrent execution.
165#[derive(Debug, Clone, Default)]
166pub struct ExecutionStats {
167    /// Total tasks processed.
168    pub total_tasks: u64,
169    /// Successful tasks.
170    pub successful: u64,
171    /// Failed tasks.
172    pub failed: u64,
173    /// Timed out tasks.
174    pub timed_out: u64,
175    /// Total execution time.
176    pub total_duration: Duration,
177    /// Average task duration (for successful tasks).
178    pub avg_task_duration: Duration,
179    /// Maximum concurrent tasks observed.
180    pub max_concurrent: usize,
181}
182
183/// Executor for running concurrent tasks with controlled parallelism.
184pub struct ConcurrentExecutor {
185    config: ConcurrencyConfig,
186    semaphore: Arc<Semaphore>,
187    stats: ExecutionStatsCollector,
188}
189
190impl ConcurrentExecutor {
191    /// Create a new concurrent executor.
192    pub fn new(config: ConcurrencyConfig) -> Self {
193        let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
194        Self {
195            config,
196            semaphore,
197            stats: ExecutionStatsCollector::new(),
198        }
199    }
200
201    /// Execute all tasks concurrently with controlled parallelism.
202    ///
203    /// Tasks are started immediately but limited by the semaphore to ensure
204    /// at most `max_concurrency` tasks run at once.
205    pub async fn execute_all<T, F, Fut>(
206        &self,
207        tasks: impl IntoIterator<Item = F>,
208    ) -> (Vec<TaskResult<T>>, ExecutionStats)
209    where
210        F: FnOnce() -> Fut + Send + 'static,
211        Fut: Future<Output = Result<T, String>> + Send + 'static,
212        T: Send + 'static,
213    {
214        let start = Instant::now();
215        self.stats.reset();
216
217        let tasks: Vec<_> = tasks.into_iter().collect();
218        let total_tasks = tasks.len();
219        self.stats.total.store(total_tasks as u64, Ordering::SeqCst);
220
221        let mut futures = FuturesUnordered::new();
222
223        for (task_id, task) in tasks.into_iter().enumerate() {
224            let semaphore = Arc::clone(&self.semaphore);
225            let timeout = self.config.operation_timeout;
226            let stats = self.stats.clone();
227
228            let future = async move {
229                // Acquire semaphore permit
230                let _permit = semaphore.acquire().await.expect("Semaphore closed");
231                stats.increment_concurrent();
232
233                let task_start = Instant::now();
234                let result = if let Some(timeout_duration) = timeout {
235                    match tokio::time::timeout(timeout_duration, task()).await {
236                        Ok(Ok(value)) => TaskResult::Success {
237                            task_id,
238                            value,
239                            duration: task_start.elapsed(),
240                        },
241                        Ok(Err(msg)) => TaskResult::Error(TaskError {
242                            task_id,
243                            message: msg,
244                            is_timeout: false,
245                        }),
246                        Err(_) => TaskResult::Error(TaskError {
247                            task_id,
248                            message: format!("Timeout after {:?}", timeout_duration),
249                            is_timeout: true,
250                        }),
251                    }
252                } else {
253                    match task().await {
254                        Ok(value) => TaskResult::Success {
255                            task_id,
256                            value,
257                            duration: task_start.elapsed(),
258                        },
259                        Err(msg) => TaskResult::Error(TaskError {
260                            task_id,
261                            message: msg,
262                            is_timeout: false,
263                        }),
264                    }
265                };
266
267                stats.decrement_concurrent();
268
269                match &result {
270                    TaskResult::Success { duration, .. } => {
271                        stats.record_success(*duration);
272                    }
273                    TaskResult::Error(e) if e.is_timeout => {
274                        stats.record_timeout();
275                    }
276                    TaskResult::Error(_) => {
277                        stats.record_failure();
278                    }
279                }
280
281                result
282            };
283
284            futures.push(future);
285        }
286
287        // Collect results in order of completion
288        let mut results = Vec::with_capacity(total_tasks);
289
290        while let Some(result) = futures.next().await {
291            if !self.config.continue_on_error {
292                if let TaskResult::Error(ref _e) = result {
293                    // Cancel remaining futures by dropping them
294                    drop(futures);
295                    results.push(result);
296
297                    let stats = self.stats.finalize(start.elapsed());
298                    return (results, stats);
299                }
300            }
301            results.push(result);
302        }
303
304        // Sort by task_id to maintain original order
305        results.sort_by_key(|r| match r {
306            TaskResult::Success { task_id, .. } => *task_id,
307            TaskResult::Error(e) => e.task_id,
308        });
309
310        let stats = self.stats.finalize(start.elapsed());
311        (results, stats)
312    }
313
314    /// Execute tasks and collect only successful results.
315    ///
316    /// Returns the values in the same order as the input tasks.
317    pub async fn execute_collect<T, F, Fut>(
318        &self,
319        tasks: impl IntoIterator<Item = F>,
320    ) -> (Vec<T>, Vec<TaskError>)
321    where
322        F: FnOnce() -> Fut + Send + 'static,
323        Fut: Future<Output = Result<T, String>> + Send + 'static,
324        T: Send + 'static,
325    {
326        let (results, _) = self.execute_all(tasks).await;
327
328        let mut values = Vec::new();
329        let mut errors = Vec::new();
330
331        for result in results {
332            match result {
333                TaskResult::Success { value, .. } => values.push(value),
334                TaskResult::Error(e) => errors.push(e),
335            }
336        }
337
338        (values, errors)
339    }
340
341    /// Execute tasks with indexed results.
342    ///
343    /// Returns a map of task_id -> result, useful when you need to correlate
344    /// results with their original indices.
345    pub async fn execute_indexed<T, F, Fut>(
346        &self,
347        tasks: impl IntoIterator<Item = F>,
348    ) -> std::collections::HashMap<usize, Result<T, TaskError>>
349    where
350        F: FnOnce() -> Fut + Send + 'static,
351        Fut: Future<Output = Result<T, String>> + Send + 'static,
352        T: Send + 'static,
353    {
354        let (results, _) = self.execute_all(tasks).await;
355
356        results
357            .into_iter()
358            .map(|r| match r {
359                TaskResult::Success { task_id, value, .. } => (task_id, Ok(value)),
360                TaskResult::Error(e) => (e.task_id, Err(e)),
361            })
362            .collect()
363    }
364}
365
366/// Internal stats collector with atomic counters.
367#[derive(Clone)]
368struct ExecutionStatsCollector {
369    total: Arc<AtomicU64>,
370    successful: Arc<AtomicU64>,
371    failed: Arc<AtomicU64>,
372    timed_out: Arc<AtomicU64>,
373    total_task_duration_ns: Arc<AtomicU64>,
374    current_concurrent: Arc<AtomicU64>,
375    max_concurrent: Arc<AtomicU64>,
376}
377
378impl ExecutionStatsCollector {
379    fn new() -> Self {
380        Self {
381            total: Arc::new(AtomicU64::new(0)),
382            successful: Arc::new(AtomicU64::new(0)),
383            failed: Arc::new(AtomicU64::new(0)),
384            timed_out: Arc::new(AtomicU64::new(0)),
385            total_task_duration_ns: Arc::new(AtomicU64::new(0)),
386            current_concurrent: Arc::new(AtomicU64::new(0)),
387            max_concurrent: Arc::new(AtomicU64::new(0)),
388        }
389    }
390
391    fn reset(&self) {
392        self.total.store(0, Ordering::SeqCst);
393        self.successful.store(0, Ordering::SeqCst);
394        self.failed.store(0, Ordering::SeqCst);
395        self.timed_out.store(0, Ordering::SeqCst);
396        self.total_task_duration_ns.store(0, Ordering::SeqCst);
397        self.current_concurrent.store(0, Ordering::SeqCst);
398        self.max_concurrent.store(0, Ordering::SeqCst);
399    }
400
401    fn increment_concurrent(&self) {
402        let current = self.current_concurrent.fetch_add(1, Ordering::SeqCst) + 1;
403        self.max_concurrent.fetch_max(current, Ordering::SeqCst);
404    }
405
406    fn decrement_concurrent(&self) {
407        self.current_concurrent.fetch_sub(1, Ordering::SeqCst);
408    }
409
410    fn record_success(&self, duration: Duration) {
411        self.successful.fetch_add(1, Ordering::SeqCst);
412        self.total_task_duration_ns
413            .fetch_add(duration.as_nanos() as u64, Ordering::SeqCst);
414    }
415
416    fn record_failure(&self) {
417        self.failed.fetch_add(1, Ordering::SeqCst);
418    }
419
420    fn record_timeout(&self) {
421        self.timed_out.fetch_add(1, Ordering::SeqCst);
422        self.failed.fetch_add(1, Ordering::SeqCst);
423    }
424
425    fn finalize(&self, total_duration: Duration) -> ExecutionStats {
426        let successful = self.successful.load(Ordering::SeqCst);
427        let total_task_duration_ns = self.total_task_duration_ns.load(Ordering::SeqCst);
428
429        let avg_task_duration = if successful > 0 {
430            Duration::from_nanos(total_task_duration_ns / successful)
431        } else {
432            Duration::ZERO
433        };
434
435        ExecutionStats {
436            total_tasks: self.total.load(Ordering::SeqCst),
437            successful,
438            failed: self.failed.load(Ordering::SeqCst),
439            timed_out: self.timed_out.load(Ordering::SeqCst),
440            total_duration,
441            avg_task_duration,
442            max_concurrent: self.max_concurrent.load(Ordering::SeqCst) as usize,
443        }
444    }
445}
446
447/// Helper for executing a batch of similar operations concurrently.
448///
449/// This is a convenience function for common patterns like fetching
450/// metadata for multiple tables.
451pub async fn execute_batch<T, I, F, Fut>(
452    items: I,
453    max_concurrency: usize,
454    operation: F,
455) -> Vec<Result<T, String>>
456where
457    I: IntoIterator,
458    F: Fn(I::Item) -> Fut + Clone + Send + 'static,
459    Fut: Future<Output = Result<T, String>> + Send + 'static,
460    T: Send + 'static,
461    I::Item: Send + 'static,
462{
463    let config = ConcurrencyConfig::default().with_max_concurrency(max_concurrency);
464    let executor = ConcurrentExecutor::new(config);
465
466    let tasks: Vec<_> = items
467        .into_iter()
468        .map(|item| {
469            let op = operation.clone();
470            move || op(item)
471        })
472        .collect();
473
474    let (results, _) = executor.execute_all(tasks).await;
475
476    results
477        .into_iter()
478        .map(|r| match r {
479            TaskResult::Success { value, .. } => Ok(value),
480            TaskResult::Error(e) => Err(e.message),
481        })
482        .collect()
483}
484
485/// Execute operations in parallel chunks.
486///
487/// Useful for operations that benefit from batching (like multi-row inserts)
488/// combined with parallel execution of batches.
489pub async fn execute_chunked<T, I, F, Fut>(
490    items: I,
491    chunk_size: usize,
492    max_concurrency: usize,
493    operation: F,
494) -> Vec<Vec<Result<T, String>>>
495where
496    I: IntoIterator,
497    I::IntoIter: ExactSizeIterator,
498    F: Fn(Vec<I::Item>) -> Fut + Clone + Send + 'static,
499    Fut: Future<Output = Vec<Result<T, String>>> + Send + 'static,
500    T: Send + 'static,
501    I::Item: Send + Clone + 'static,
502{
503    let items: Vec<_> = items.into_iter().collect();
504    let chunks: Vec<Vec<_>> = items.chunks(chunk_size).map(|c| c.to_vec()).collect();
505
506    let config = ConcurrencyConfig::default().with_max_concurrency(max_concurrency);
507    let executor = ConcurrentExecutor::new(config);
508
509    let tasks: Vec<_> = chunks
510        .into_iter()
511        .map(|chunk| {
512            let op = operation.clone();
513            move || async move { Ok::<_, String>(op(chunk).await) }
514        })
515        .collect();
516
517    let (results, _) = executor.execute_all(tasks).await;
518
519    results.into_iter().filter_map(|r| r.into_value()).collect()
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use std::sync::atomic::AtomicUsize;
526
527    #[tokio::test]
528    async fn test_concurrent_executor_basic() {
529        let executor = ConcurrentExecutor::new(ConcurrencyConfig::default());
530
531        let tasks: Vec<_> = (0..10)
532            .map(|i| move || async move { Ok::<_, String>(i * 2) })
533            .collect();
534
535        let (results, stats) = executor.execute_all(tasks).await;
536
537        assert_eq!(results.len(), 10);
538        assert_eq!(stats.total_tasks, 10);
539        assert_eq!(stats.successful, 10);
540        assert_eq!(stats.failed, 0);
541
542        // Verify results are in order
543        for (i, result) in results.iter().enumerate() {
544            match result {
545                TaskResult::Success { value, .. } => {
546                    assert_eq!(*value, i * 2);
547                }
548                _ => panic!("Expected success"),
549            }
550        }
551    }
552
553    #[tokio::test]
554    async fn test_concurrent_executor_with_errors() {
555        let config = ConcurrencyConfig::default().with_continue_on_error(true);
556        let executor = ConcurrentExecutor::new(config);
557
558        let tasks: Vec<_> = (0..5)
559            .map(|i| {
560                move || async move {
561                    if i == 2 {
562                        Err("Task 2 failed".to_string())
563                    } else {
564                        Ok::<_, String>(i)
565                    }
566                }
567            })
568            .collect();
569
570        let (results, stats) = executor.execute_all(tasks).await;
571
572        assert_eq!(results.len(), 5);
573        assert_eq!(stats.successful, 4);
574        assert_eq!(stats.failed, 1);
575    }
576
577    #[tokio::test]
578    async fn test_concurrent_executor_fail_fast() {
579        let config = ConcurrencyConfig::default()
580            .with_continue_on_error(false)
581            .with_max_concurrency(1); // Sequential to ensure order
582
583        let executor = ConcurrentExecutor::new(config);
584        let counter = Arc::new(AtomicUsize::new(0));
585
586        let tasks: Vec<_> = (0..5)
587            .map(|i| {
588                let counter = Arc::clone(&counter);
589                move || async move {
590                    counter.fetch_add(1, Ordering::SeqCst);
591                    if i == 2 {
592                        Err("Task 2 failed".to_string())
593                    } else {
594                        Ok::<_, String>(i)
595                    }
596                }
597            })
598            .collect();
599
600        let (results, _) = executor.execute_all(tasks).await;
601
602        // Should have stopped at first error - check using pattern match
603        let has_error = results.iter().any(|r| matches!(r, TaskResult::Error(_)));
604        assert!(has_error);
605    }
606
607    #[tokio::test]
608    async fn test_concurrent_executor_respects_concurrency() {
609        let max_concurrent = Arc::new(AtomicUsize::new(0));
610        let current = Arc::new(AtomicUsize::new(0));
611
612        let config = ConcurrencyConfig::default().with_max_concurrency(3);
613        let executor = ConcurrentExecutor::new(config);
614
615        let tasks: Vec<_> = (0..20)
616            .map(|i| {
617                let max_concurrent = Arc::clone(&max_concurrent);
618                let current = Arc::clone(&current);
619                move || async move {
620                    let c = current.fetch_add(1, Ordering::SeqCst) + 1;
621                    max_concurrent.fetch_max(c, Ordering::SeqCst);
622
623                    // Simulate work
624                    tokio::time::sleep(Duration::from_millis(10)).await;
625
626                    current.fetch_sub(1, Ordering::SeqCst);
627                    Ok::<_, String>(i)
628                }
629            })
630            .collect();
631
632        let (results, stats) = executor.execute_all(tasks).await;
633
634        assert_eq!(results.len(), 20);
635        assert!(stats.max_concurrent <= 3);
636        assert!(max_concurrent.load(Ordering::SeqCst) <= 3);
637    }
638
639    #[tokio::test]
640    async fn test_execute_batch() {
641        let results = execute_batch(vec!["a", "b", "c"], 4, |s: &str| async move {
642            Ok::<_, String>(s.len())
643        })
644        .await;
645
646        assert_eq!(results.len(), 3);
647        assert!(results.iter().all(|r| r.is_ok()));
648    }
649
650    #[tokio::test]
651    async fn test_timeout() {
652        let config = ConcurrencyConfig::default().with_timeout(Duration::from_millis(50));
653        let executor = ConcurrentExecutor::new(config);
654
655        let tasks: Vec<
656            Box<
657                dyn FnOnce() -> std::pin::Pin<Box<dyn Future<Output = Result<i32, String>> + Send>>
658                    + Send,
659            >,
660        > = vec![
661            Box::new(|| {
662                Box::pin(async {
663                    tokio::time::sleep(Duration::from_millis(10)).await;
664                    Ok::<_, String>(1)
665                })
666            }),
667            Box::new(|| {
668                Box::pin(async {
669                    tokio::time::sleep(Duration::from_millis(200)).await;
670                    Ok::<_, String>(2)
671                })
672            }),
673        ];
674
675        let (results, stats) = executor.execute_all(tasks).await;
676
677        assert_eq!(results.len(), 2);
678        assert_eq!(stats.timed_out, 1);
679    }
680}