dag_executor/dag/
worker_pool.rs1use crate::advanced::{CircuitBreaker, RetryPolicy};
4use crate::context::Context;
5use crate::error::TaskError;
6use crate::metrics::MetricsCollector;
7use crate::tasks::{Task, TaskOutput};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::Semaphore;
11use tokio::task::JoinHandle;
12
13#[derive(Debug)]
15pub struct TaskResult {
16 pub id: String,
18 pub attempts: u32,
20 pub outcome: Result<TaskOutput, TaskError>,
22}
23
24pub struct WorkerPool {
31 semaphore: Arc<Semaphore>,
32 retry: RetryPolicy,
33 timeout: Option<Duration>,
34 metrics: Arc<MetricsCollector>,
35}
36
37impl WorkerPool {
38 pub fn new(
40 concurrency: usize,
41 retry: RetryPolicy,
42 timeout: Option<Duration>,
43 metrics: Arc<MetricsCollector>,
44 ) -> Self {
45 WorkerPool {
46 semaphore: Arc::new(Semaphore::new(concurrency.max(1))),
47 retry,
48 timeout,
49 metrics,
50 }
51 }
52
53 pub fn available_permits(&self) -> usize {
55 self.semaphore.available_permits()
56 }
57
58 pub fn spawn(
63 &self,
64 task: Arc<dyn Task>,
65 ctx: Arc<Context>,
66 breaker: Option<Arc<CircuitBreaker>>,
67 ) -> JoinHandle<TaskResult> {
68 let semaphore = self.semaphore.clone();
69 let retry = self.retry;
70 let timeout = self.timeout;
71 let metrics = self.metrics.clone();
72 let id = task.id().to_string();
73
74 tokio::spawn(async move {
75 let _permit = semaphore
78 .acquire_owned()
79 .await
80 .expect("semaphore is never closed");
81
82 metrics.task_started();
83 let mut attempts = 0u32;
84
85 loop {
86 if ctx.is_cancelled() {
87 return TaskResult {
88 id,
89 attempts,
90 outcome: Err(TaskError::Cancelled),
91 };
92 }
93
94 if let Some(ref b) = breaker {
95 if !b.allow_request() {
96 return TaskResult {
97 id: id.clone(),
98 attempts,
99 outcome: Err(TaskError::CircuitOpen(id.clone())),
100 };
101 }
102 }
103
104 attempts += 1;
105 let result = run_once(task.clone(), ctx.clone(), timeout).await;
106
107 match result {
108 Ok(output) => {
109 if let Some(ref b) = breaker {
110 b.record_success();
111 }
112 return TaskResult {
113 id,
114 attempts,
115 outcome: Ok(output),
116 };
117 }
118 Err(err) => {
119 if let Some(ref b) = breaker {
120 b.record_failure();
121 }
122 let retryable = err.is_retryable()
123 && retry.should_retry(attempts)
124 && !ctx.is_cancelled();
125 if !retryable {
126 return TaskResult {
127 id,
128 attempts,
129 outcome: Err(err),
130 };
131 }
132 metrics.retry();
133 let delay = retry.delay_for(attempts);
134 if !delay.is_zero() {
135 tokio::time::sleep(delay).await;
136 }
137 }
138 }
139 }
140 })
141 }
142}
143
144async fn run_once(
146 task: Arc<dyn Task>,
147 ctx: Arc<Context>,
148 timeout: Option<Duration>,
149) -> Result<TaskOutput, TaskError> {
150 match timeout {
151 Some(t) => match tokio::time::timeout(t, task.execute(ctx)).await {
152 Ok(res) => res,
153 Err(_) => Err(TaskError::Timeout(t)),
154 },
155 None => task.execute(ctx).await,
156 }
157}