Skip to main content

dag_executor/dag/
worker_pool.rs

1//! Concurrent task execution with bounded parallelism, retries and timeouts.
2
3use crate::advanced::{CircuitBreaker, RetryPolicy};
4use crate::context::Context;
5use crate::error::TaskError;
6use crate::metrics::MetricsCollector;
7use crate::tasks::{Task, TaskOutput};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::Semaphore;
11use tokio::task::JoinHandle;
12
13/// The final result of running one task (after any retries).
14#[derive(Debug)]
15pub struct TaskResult {
16    /// The task id.
17    pub id: String,
18    /// Total attempts made.
19    pub attempts: u32,
20    /// Success output or the terminal error.
21    pub outcome: Result<TaskOutput, TaskError>,
22}
23
24/// Executes tasks concurrently under a global concurrency limit.
25///
26/// Each task is run on its own Tokio task but must first acquire a permit from
27/// a shared [`Semaphore`], which caps how many run at once. The pool also owns
28/// the retry loop, per-attempt timeout, and circuit-breaker gating so the
29/// executor's scheduling loop stays simple.
30pub struct WorkerPool {
31    semaphore: Arc<Semaphore>,
32    retry: RetryPolicy,
33    timeout: Option<Duration>,
34    metrics: Arc<MetricsCollector>,
35}
36
37impl WorkerPool {
38    /// Create a pool allowing `concurrency` simultaneous tasks.
39    pub fn new(
40        concurrency: usize,
41        retry: RetryPolicy,
42        timeout: Option<Duration>,
43        metrics: Arc<MetricsCollector>,
44    ) -> Self {
45        WorkerPool {
46            semaphore: Arc::new(Semaphore::new(concurrency.max(1))),
47            retry,
48            timeout,
49            metrics,
50        }
51    }
52
53    /// Permits currently available (i.e. free concurrency slots).
54    pub fn available_permits(&self) -> usize {
55        self.semaphore.available_permits()
56    }
57
58    /// Spawn `task` for execution, returning a handle to its eventual result.
59    ///
60    /// `breaker`, if supplied, gates execution: when open, the task fails fast
61    /// with [`TaskError::CircuitOpen`] without consuming an attempt.
62    pub fn spawn(
63        &self,
64        task: Arc<dyn Task>,
65        ctx: Arc<Context>,
66        breaker: Option<Arc<CircuitBreaker>>,
67    ) -> JoinHandle<TaskResult> {
68        let semaphore = self.semaphore.clone();
69        let retry = self.retry;
70        let timeout = self.timeout;
71        let metrics = self.metrics.clone();
72        let id = task.id().to_string();
73
74        tokio::spawn(async move {
75            // Hold a permit for the entire (possibly multi-attempt) lifetime so
76            // concurrency stays bounded even while a task is backing off.
77            let _permit = semaphore
78                .acquire_owned()
79                .await
80                .expect("semaphore is never closed");
81
82            metrics.task_started();
83            let mut attempts = 0u32;
84
85            loop {
86                if ctx.is_cancelled() {
87                    return TaskResult {
88                        id,
89                        attempts,
90                        outcome: Err(TaskError::Cancelled),
91                    };
92                }
93
94                if let Some(ref b) = breaker {
95                    if !b.allow_request() {
96                        return TaskResult {
97                            id: id.clone(),
98                            attempts,
99                            outcome: Err(TaskError::CircuitOpen(id.clone())),
100                        };
101                    }
102                }
103
104                attempts += 1;
105                let result = run_once(task.clone(), ctx.clone(), timeout).await;
106
107                match result {
108                    Ok(output) => {
109                        if let Some(ref b) = breaker {
110                            b.record_success();
111                        }
112                        return TaskResult {
113                            id,
114                            attempts,
115                            outcome: Ok(output),
116                        };
117                    }
118                    Err(err) => {
119                        if let Some(ref b) = breaker {
120                            b.record_failure();
121                        }
122                        let retryable = err.is_retryable()
123                            && retry.should_retry(attempts)
124                            && !ctx.is_cancelled();
125                        if !retryable {
126                            return TaskResult {
127                                id,
128                                attempts,
129                                outcome: Err(err),
130                            };
131                        }
132                        metrics.retry();
133                        let delay = retry.delay_for(attempts);
134                        if !delay.is_zero() {
135                            tokio::time::sleep(delay).await;
136                        }
137                    }
138                }
139            }
140        })
141    }
142}
143
144/// Run a single attempt, applying the optional timeout.
145async fn run_once(
146    task: Arc<dyn Task>,
147    ctx: Arc<Context>,
148    timeout: Option<Duration>,
149) -> Result<TaskOutput, TaskError> {
150    match timeout {
151        Some(t) => match tokio::time::timeout(t, task.execute(ctx)).await {
152            Ok(res) => res,
153            Err(_) => Err(TaskError::Timeout(t)),
154        },
155        None => task.execute(ctx).await,
156    }
157}