prax_query/async_optimize/
concurrent.rs

1//! Concurrent task execution with controlled parallelism.
2//!
3//! This module provides utilities for executing multiple independent database
4//! operations in parallel while respecting concurrency limits to avoid
5//! overwhelming the database connection pool.
6
7use std::future::Future;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12use futures::stream::{FuturesUnordered, StreamExt};
13use tokio::sync::Semaphore;
14
15/// Configuration for concurrent execution.
16#[derive(Debug, Clone)]
17pub struct ConcurrencyConfig {
18    /// Maximum number of concurrent operations.
19    pub max_concurrency: usize,
20    /// Timeout for individual operations.
21    pub operation_timeout: Option<Duration>,
22    /// Whether to continue on error (collect all errors vs fail fast).
23    pub continue_on_error: bool,
24    /// Collect timing statistics.
25    pub collect_stats: bool,
26}
27
28impl Default for ConcurrencyConfig {
29    fn default() -> Self {
30        Self {
31            max_concurrency: num_cpus::get().max(4),
32            operation_timeout: Some(Duration::from_secs(30)),
33            continue_on_error: true,
34            collect_stats: true,
35        }
36    }
37}
38
39impl ConcurrencyConfig {
40    /// Create config optimized for database introspection.
41    #[must_use]
42    pub fn for_introspection() -> Self {
43        Self {
44            max_concurrency: 8, // Balance between speed and connection usage
45            operation_timeout: Some(Duration::from_secs(60)),
46            continue_on_error: true,
47            collect_stats: true,
48        }
49    }
50
51    /// Create config optimized for migration operations.
52    #[must_use]
53    pub fn for_migrations() -> Self {
54        Self {
55            max_concurrency: 4, // More conservative for DDL
56            operation_timeout: Some(Duration::from_secs(120)),
57            continue_on_error: false, // Migrations should fail fast
58            collect_stats: true,
59        }
60    }
61
62    /// Create config optimized for bulk data operations.
63    #[must_use]
64    pub fn for_bulk_operations() -> Self {
65        Self {
66            max_concurrency: 16, // Higher parallelism for DML
67            operation_timeout: Some(Duration::from_secs(300)),
68            continue_on_error: true,
69            collect_stats: true,
70        }
71    }
72
73    /// Set maximum concurrency.
74    #[must_use]
75    pub fn with_max_concurrency(mut self, max: usize) -> Self {
76        self.max_concurrency = max.max(1);
77        self
78    }
79
80    /// Set operation timeout.
81    #[must_use]
82    pub fn with_timeout(mut self, timeout: Duration) -> Self {
83        self.operation_timeout = Some(timeout);
84        self
85    }
86
87    /// Disable timeout.
88    #[must_use]
89    pub fn without_timeout(mut self) -> Self {
90        self.operation_timeout = None;
91        self
92    }
93
94    /// Set continue on error behavior.
95    #[must_use]
96    pub fn with_continue_on_error(mut self, continue_on_error: bool) -> Self {
97        self.continue_on_error = continue_on_error;
98        self
99    }
100}
101
102/// Error from concurrent task execution.
103#[derive(Debug, Clone)]
104pub struct TaskError {
105    /// Task identifier.
106    pub task_id: usize,
107    /// Error message.
108    pub message: String,
109    /// Whether this was a timeout.
110    pub is_timeout: bool,
111}
112
113impl std::fmt::Display for TaskError {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        if self.is_timeout {
116            write!(f, "Task {} timed out: {}", self.task_id, self.message)
117        } else {
118            write!(f, "Task {} failed: {}", self.task_id, self.message)
119        }
120    }
121}
122
123impl std::error::Error for TaskError {}
124
125/// Result of a single task.
126#[derive(Debug)]
127pub enum TaskResult<T> {
128    /// Task completed successfully.
129    Success {
130        /// Task identifier.
131        task_id: usize,
132        /// The result value.
133        value: T,
134        /// Execution duration.
135        duration: Duration,
136    },
137    /// Task failed.
138    Error(TaskError),
139}
140
141impl<T> TaskResult<T> {
142    /// Check if the task succeeded.
143    pub fn is_success(&self) -> bool {
144        matches!(self, Self::Success { .. })
145    }
146
147    /// Get the value if successful.
148    pub fn into_value(self) -> Option<T> {
149        match self {
150            Self::Success { value, .. } => Some(value),
151            Self::Error(_) => None,
152        }
153    }
154
155    /// Get the error if failed.
156    pub fn into_error(self) -> Option<TaskError> {
157        match self {
158            Self::Success { .. } => None,
159            Self::Error(e) => Some(e),
160        }
161    }
162}
163
164/// Statistics from concurrent execution.
165#[derive(Debug, Clone, Default)]
166pub struct ExecutionStats {
167    /// Total tasks processed.
168    pub total_tasks: u64,
169    /// Successful tasks.
170    pub successful: u64,
171    /// Failed tasks.
172    pub failed: u64,
173    /// Timed out tasks.
174    pub timed_out: u64,
175    /// Total execution time.
176    pub total_duration: Duration,
177    /// Average task duration (for successful tasks).
178    pub avg_task_duration: Duration,
179    /// Maximum concurrent tasks observed.
180    pub max_concurrent: usize,
181}
182
183/// Executor for running concurrent tasks with controlled parallelism.
184pub struct ConcurrentExecutor {
185    config: ConcurrencyConfig,
186    semaphore: Arc<Semaphore>,
187    stats: ExecutionStatsCollector,
188}
189
190impl ConcurrentExecutor {
191    /// Create a new concurrent executor.
192    pub fn new(config: ConcurrencyConfig) -> Self {
193        let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
194        Self {
195            config,
196            semaphore,
197            stats: ExecutionStatsCollector::new(),
198        }
199    }
200
201    /// Execute all tasks concurrently with controlled parallelism.
202    ///
203    /// Tasks are started immediately but limited by the semaphore to ensure
204    /// at most `max_concurrency` tasks run at once.
205    pub async fn execute_all<T, F, Fut>(
206        &self,
207        tasks: impl IntoIterator<Item = F>,
208    ) -> (Vec<TaskResult<T>>, ExecutionStats)
209    where
210        F: FnOnce() -> Fut + Send + 'static,
211        Fut: Future<Output = Result<T, String>> + Send + 'static,
212        T: Send + 'static,
213    {
214        let start = Instant::now();
215        self.stats.reset();
216
217        let tasks: Vec<_> = tasks.into_iter().collect();
218        let total_tasks = tasks.len();
219        self.stats.total.store(total_tasks as u64, Ordering::SeqCst);
220
221        let mut futures = FuturesUnordered::new();
222
223        for (task_id, task) in tasks.into_iter().enumerate() {
224            let semaphore = Arc::clone(&self.semaphore);
225            let timeout = self.config.operation_timeout;
226            let stats = self.stats.clone();
227
228            let future = async move {
229                // Acquire semaphore permit
230                let _permit = semaphore.acquire().await.expect("Semaphore closed");
231                stats.increment_concurrent();
232
233                let task_start = Instant::now();
234                let result = if let Some(timeout_duration) = timeout {
235                    match tokio::time::timeout(timeout_duration, task()).await {
236                        Ok(Ok(value)) => TaskResult::Success {
237                            task_id,
238                            value,
239                            duration: task_start.elapsed(),
240                        },
241                        Ok(Err(msg)) => TaskResult::Error(TaskError {
242                            task_id,
243                            message: msg,
244                            is_timeout: false,
245                        }),
246                        Err(_) => TaskResult::Error(TaskError {
247                            task_id,
248                            message: format!("Timeout after {:?}", timeout_duration),
249                            is_timeout: true,
250                        }),
251                    }
252                } else {
253                    match task().await {
254                        Ok(value) => TaskResult::Success {
255                            task_id,
256                            value,
257                            duration: task_start.elapsed(),
258                        },
259                        Err(msg) => TaskResult::Error(TaskError {
260                            task_id,
261                            message: msg,
262                            is_timeout: false,
263                        }),
264                    }
265                };
266
267                stats.decrement_concurrent();
268
269                match &result {
270                    TaskResult::Success { duration, .. } => {
271                        stats.record_success(*duration);
272                    }
273                    TaskResult::Error(e) if e.is_timeout => {
274                        stats.record_timeout();
275                    }
276                    TaskResult::Error(_) => {
277                        stats.record_failure();
278                    }
279                }
280
281                result
282            };
283
284            futures.push(future);
285        }
286
287        // Collect results in order of completion
288        let mut results = Vec::with_capacity(total_tasks);
289
290        while let Some(result) = futures.next().await {
291            if !self.config.continue_on_error {
292                if let TaskResult::Error(ref _e) = result {
293                    // Cancel remaining futures by dropping them
294                    drop(futures);
295                    results.push(result);
296
297                    let stats = self.stats.finalize(start.elapsed());
298                    return (results, stats);
299                }
300            }
301            results.push(result);
302        }
303
304        // Sort by task_id to maintain original order
305        results.sort_by_key(|r| match r {
306            TaskResult::Success { task_id, .. } => *task_id,
307            TaskResult::Error(e) => e.task_id,
308        });
309
310        let stats = self.stats.finalize(start.elapsed());
311        (results, stats)
312    }
313
314    /// Execute tasks and collect only successful results.
315    ///
316    /// Returns the values in the same order as the input tasks.
317    pub async fn execute_collect<T, F, Fut>(
318        &self,
319        tasks: impl IntoIterator<Item = F>,
320    ) -> (Vec<T>, Vec<TaskError>)
321    where
322        F: FnOnce() -> Fut + Send + 'static,
323        Fut: Future<Output = Result<T, String>> + Send + 'static,
324        T: Send + 'static,
325    {
326        let (results, _) = self.execute_all(tasks).await;
327
328        let mut values = Vec::new();
329        let mut errors = Vec::new();
330
331        for result in results {
332            match result {
333                TaskResult::Success { value, .. } => values.push(value),
334                TaskResult::Error(e) => errors.push(e),
335            }
336        }
337
338        (values, errors)
339    }
340
341    /// Execute tasks with indexed results.
342    ///
343    /// Returns a map of task_id -> result, useful when you need to correlate
344    /// results with their original indices.
345    pub async fn execute_indexed<T, F, Fut>(
346        &self,
347        tasks: impl IntoIterator<Item = F>,
348    ) -> std::collections::HashMap<usize, Result<T, TaskError>>
349    where
350        F: FnOnce() -> Fut + Send + 'static,
351        Fut: Future<Output = Result<T, String>> + Send + 'static,
352        T: Send + 'static,
353    {
354        let (results, _) = self.execute_all(tasks).await;
355
356        results
357            .into_iter()
358            .map(|r| match r {
359                TaskResult::Success {
360                    task_id, value, ..
361                } => (task_id, Ok(value)),
362                TaskResult::Error(e) => (e.task_id, Err(e)),
363            })
364            .collect()
365    }
366}
367
368/// Internal stats collector with atomic counters.
369#[derive(Clone)]
370struct ExecutionStatsCollector {
371    total: Arc<AtomicU64>,
372    successful: Arc<AtomicU64>,
373    failed: Arc<AtomicU64>,
374    timed_out: Arc<AtomicU64>,
375    total_task_duration_ns: Arc<AtomicU64>,
376    current_concurrent: Arc<AtomicU64>,
377    max_concurrent: Arc<AtomicU64>,
378}
379
380impl ExecutionStatsCollector {
381    fn new() -> Self {
382        Self {
383            total: Arc::new(AtomicU64::new(0)),
384            successful: Arc::new(AtomicU64::new(0)),
385            failed: Arc::new(AtomicU64::new(0)),
386            timed_out: Arc::new(AtomicU64::new(0)),
387            total_task_duration_ns: Arc::new(AtomicU64::new(0)),
388            current_concurrent: Arc::new(AtomicU64::new(0)),
389            max_concurrent: Arc::new(AtomicU64::new(0)),
390        }
391    }
392
393    fn reset(&self) {
394        self.total.store(0, Ordering::SeqCst);
395        self.successful.store(0, Ordering::SeqCst);
396        self.failed.store(0, Ordering::SeqCst);
397        self.timed_out.store(0, Ordering::SeqCst);
398        self.total_task_duration_ns.store(0, Ordering::SeqCst);
399        self.current_concurrent.store(0, Ordering::SeqCst);
400        self.max_concurrent.store(0, Ordering::SeqCst);
401    }
402
403    fn increment_concurrent(&self) {
404        let current = self.current_concurrent.fetch_add(1, Ordering::SeqCst) + 1;
405        self.max_concurrent.fetch_max(current, Ordering::SeqCst);
406    }
407
408    fn decrement_concurrent(&self) {
409        self.current_concurrent.fetch_sub(1, Ordering::SeqCst);
410    }
411
412    fn record_success(&self, duration: Duration) {
413        self.successful.fetch_add(1, Ordering::SeqCst);
414        self.total_task_duration_ns
415            .fetch_add(duration.as_nanos() as u64, Ordering::SeqCst);
416    }
417
418    fn record_failure(&self) {
419        self.failed.fetch_add(1, Ordering::SeqCst);
420    }
421
422    fn record_timeout(&self) {
423        self.timed_out.fetch_add(1, Ordering::SeqCst);
424        self.failed.fetch_add(1, Ordering::SeqCst);
425    }
426
427    fn finalize(&self, total_duration: Duration) -> ExecutionStats {
428        let successful = self.successful.load(Ordering::SeqCst);
429        let total_task_duration_ns = self.total_task_duration_ns.load(Ordering::SeqCst);
430
431        let avg_task_duration = if successful > 0 {
432            Duration::from_nanos(total_task_duration_ns / successful)
433        } else {
434            Duration::ZERO
435        };
436
437        ExecutionStats {
438            total_tasks: self.total.load(Ordering::SeqCst),
439            successful,
440            failed: self.failed.load(Ordering::SeqCst),
441            timed_out: self.timed_out.load(Ordering::SeqCst),
442            total_duration,
443            avg_task_duration,
444            max_concurrent: self.max_concurrent.load(Ordering::SeqCst) as usize,
445        }
446    }
447}
448
449/// Helper for executing a batch of similar operations concurrently.
450///
451/// This is a convenience function for common patterns like fetching
452/// metadata for multiple tables.
453pub async fn execute_batch<T, I, F, Fut>(
454    items: I,
455    max_concurrency: usize,
456    operation: F,
457) -> Vec<Result<T, String>>
458where
459    I: IntoIterator,
460    F: Fn(I::Item) -> Fut + Clone + Send + 'static,
461    Fut: Future<Output = Result<T, String>> + Send + 'static,
462    T: Send + 'static,
463    I::Item: Send + 'static,
464{
465    let config = ConcurrencyConfig::default().with_max_concurrency(max_concurrency);
466    let executor = ConcurrentExecutor::new(config);
467
468    let tasks: Vec<_> = items
469        .into_iter()
470        .map(|item| {
471            let op = operation.clone();
472            move || op(item)
473        })
474        .collect();
475
476    let (results, _) = executor.execute_all(tasks).await;
477
478    results
479        .into_iter()
480        .map(|r| match r {
481            TaskResult::Success { value, .. } => Ok(value),
482            TaskResult::Error(e) => Err(e.message),
483        })
484        .collect()
485}
486
487/// Execute operations in parallel chunks.
488///
489/// Useful for operations that benefit from batching (like multi-row inserts)
490/// combined with parallel execution of batches.
491pub async fn execute_chunked<T, I, F, Fut>(
492    items: I,
493    chunk_size: usize,
494    max_concurrency: usize,
495    operation: F,
496) -> Vec<Vec<Result<T, String>>>
497where
498    I: IntoIterator,
499    I::IntoIter: ExactSizeIterator,
500    F: Fn(Vec<I::Item>) -> Fut + Clone + Send + 'static,
501    Fut: Future<Output = Vec<Result<T, String>>> + Send + 'static,
502    T: Send + 'static,
503    I::Item: Send + Clone + 'static,
504{
505    let items: Vec<_> = items.into_iter().collect();
506    let chunks: Vec<Vec<_>> = items.chunks(chunk_size).map(|c| c.to_vec()).collect();
507
508    let config = ConcurrencyConfig::default().with_max_concurrency(max_concurrency);
509    let executor = ConcurrentExecutor::new(config);
510
511    let tasks: Vec<_> = chunks
512        .into_iter()
513        .map(|chunk| {
514            let op = operation.clone();
515            move || async move { Ok::<_, String>(op(chunk).await) }
516        })
517        .collect();
518
519    let (results, _) = executor.execute_all(tasks).await;
520
521    results
522        .into_iter()
523        .filter_map(|r| r.into_value())
524        .collect()
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use std::sync::atomic::AtomicUsize;
531
532    #[tokio::test]
533    async fn test_concurrent_executor_basic() {
534        let executor = ConcurrentExecutor::new(ConcurrencyConfig::default());
535
536        let tasks: Vec<_> = (0..10)
537            .map(|i| move || async move { Ok::<_, String>(i * 2) })
538            .collect();
539
540        let (results, stats) = executor.execute_all(tasks).await;
541
542        assert_eq!(results.len(), 10);
543        assert_eq!(stats.total_tasks, 10);
544        assert_eq!(stats.successful, 10);
545        assert_eq!(stats.failed, 0);
546
547        // Verify results are in order
548        for (i, result) in results.iter().enumerate() {
549            match result {
550                TaskResult::Success { value, .. } => {
551                    assert_eq!(*value, i * 2);
552                }
553                _ => panic!("Expected success"),
554            }
555        }
556    }
557
558    #[tokio::test]
559    async fn test_concurrent_executor_with_errors() {
560        let config = ConcurrencyConfig::default().with_continue_on_error(true);
561        let executor = ConcurrentExecutor::new(config);
562
563        let tasks: Vec<_> = (0..5)
564            .map(|i| {
565                move || async move {
566                    if i == 2 {
567                        Err("Task 2 failed".to_string())
568                    } else {
569                        Ok::<_, String>(i)
570                    }
571                }
572            })
573            .collect();
574
575        let (results, stats) = executor.execute_all(tasks).await;
576
577        assert_eq!(results.len(), 5);
578        assert_eq!(stats.successful, 4);
579        assert_eq!(stats.failed, 1);
580    }
581
582    #[tokio::test]
583    async fn test_concurrent_executor_fail_fast() {
584        let config = ConcurrencyConfig::default()
585            .with_continue_on_error(false)
586            .with_max_concurrency(1); // Sequential to ensure order
587
588        let executor = ConcurrentExecutor::new(config);
589        let counter = Arc::new(AtomicUsize::new(0));
590
591        let tasks: Vec<_> = (0..5)
592            .map(|i| {
593                let counter = Arc::clone(&counter);
594                move || async move {
595                    counter.fetch_add(1, Ordering::SeqCst);
596                    if i == 2 {
597                        Err("Task 2 failed".to_string())
598                    } else {
599                        Ok::<_, String>(i)
600                    }
601                }
602            })
603            .collect();
604
605        let (results, _) = executor.execute_all(tasks).await;
606
607        // Should have stopped at first error - check using pattern match
608        let has_error = results.iter().any(|r| matches!(r, TaskResult::Error(_)));
609        assert!(has_error);
610    }
611
612    #[tokio::test]
613    async fn test_concurrent_executor_respects_concurrency() {
614        let max_concurrent = Arc::new(AtomicUsize::new(0));
615        let current = Arc::new(AtomicUsize::new(0));
616
617        let config = ConcurrencyConfig::default().with_max_concurrency(3);
618        let executor = ConcurrentExecutor::new(config);
619
620        let tasks: Vec<_> = (0..20)
621            .map(|i| {
622                let max_concurrent = Arc::clone(&max_concurrent);
623                let current = Arc::clone(&current);
624                move || async move {
625                    let c = current.fetch_add(1, Ordering::SeqCst) + 1;
626                    max_concurrent.fetch_max(c, Ordering::SeqCst);
627
628                    // Simulate work
629                    tokio::time::sleep(Duration::from_millis(10)).await;
630
631                    current.fetch_sub(1, Ordering::SeqCst);
632                    Ok::<_, String>(i)
633                }
634            })
635            .collect();
636
637        let (results, stats) = executor.execute_all(tasks).await;
638
639        assert_eq!(results.len(), 20);
640        assert!(stats.max_concurrent <= 3);
641        assert!(max_concurrent.load(Ordering::SeqCst) <= 3);
642    }
643
644    #[tokio::test]
645    async fn test_execute_batch() {
646        let results = execute_batch(
647            vec!["a", "b", "c"],
648            4,
649            |s: &str| async move { Ok::<_, String>(s.len()) },
650        )
651        .await;
652
653        assert_eq!(results.len(), 3);
654        assert!(results.iter().all(|r| r.is_ok()));
655    }
656
657    #[tokio::test]
658    async fn test_timeout() {
659        let config = ConcurrencyConfig::default().with_timeout(Duration::from_millis(50));
660        let executor = ConcurrentExecutor::new(config);
661
662        let tasks: Vec<Box<dyn FnOnce() -> std::pin::Pin<Box<dyn Future<Output = Result<i32, String>> + Send>> + Send>> = vec![
663            Box::new(|| Box::pin(async {
664                tokio::time::sleep(Duration::from_millis(10)).await;
665                Ok::<_, String>(1)
666            })),
667            Box::new(|| Box::pin(async {
668                tokio::time::sleep(Duration::from_millis(200)).await;
669                Ok::<_, String>(2)
670            })),
671        ];
672
673        let (results, stats) = executor.execute_all(tasks).await;
674
675        assert_eq!(results.len(), 2);
676        assert_eq!(stats.timed_out, 1);
677    }
678}
679