1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::CircuitBreakerClient;
5use forge_core::job::{JobContext, ProgressUpdate};
6use tokio::time::timeout;
7
8use super::queue::{JobQueue, JobRecord};
9use super::registry::{JobEntry, JobRegistry};
10
11pub struct JobExecutor {
13 queue: JobQueue,
14 registry: Arc<JobRegistry>,
15 db_pool: sqlx::PgPool,
16 http_client: CircuitBreakerClient,
17}
18
19impl JobExecutor {
20 const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
21
22 pub fn new(queue: JobQueue, registry: JobRegistry, db_pool: sqlx::PgPool) -> Self {
24 Self {
25 queue,
26 registry: Arc::new(registry),
27 db_pool,
28 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
29 }
30 }
31
32 pub async fn execute(&self, job: &JobRecord) -> ExecutionResult {
34 let entry = match self.registry.get(&job.job_type) {
35 Some(e) => e,
36 None => {
37 return ExecutionResult::Failed {
38 error: format!("Unknown job type: {}", job.job_type),
39 retryable: false,
40 };
41 }
42 };
43
44 if matches!(job.status, forge_core::job::JobStatus::Cancelled) {
45 return ExecutionResult::Cancelled {
46 reason: Self::cancellation_reason(job, "Job cancelled"),
47 };
48 }
49
50 if let Err(e) = self.queue.start(job.id).await {
52 if matches!(e, sqlx::Error::RowNotFound) {
53 return ExecutionResult::Cancelled {
54 reason: Self::cancellation_reason(job, "Job cancelled"),
55 };
56 }
57 return ExecutionResult::Failed {
58 error: format!("Failed to start job: {}", e),
59 retryable: true,
60 };
61 }
62
63 let (progress_tx, progress_rx) = std::sync::mpsc::channel::<ProgressUpdate>();
65
66 let progress_queue = self.queue.clone();
69 let progress_job_id = job.id;
70 tokio::spawn(async move {
71 loop {
72 match progress_rx.try_recv() {
73 Ok(update) => {
74 if let Err(e) = progress_queue
75 .update_progress(
76 progress_job_id,
77 update.percentage as i32,
78 &update.message,
79 )
80 .await
81 {
82 tracing::debug!(job_id = %progress_job_id, error = %e, "Failed to update job progress");
83 }
84 }
85 Err(std::sync::mpsc::TryRecvError::Empty) => {
86 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
88 }
89 Err(std::sync::mpsc::TryRecvError::Disconnected) => {
90 break;
92 }
93 }
94 }
95 });
96
97 let mut ctx = JobContext::new(
99 job.id,
100 job.job_type.clone(),
101 job.attempts as u32,
102 job.max_attempts as u32,
103 self.db_pool.clone(),
104 self.http_client.clone(),
105 )
106 .with_saved(job.job_context.clone())
107 .with_progress(progress_tx);
108 ctx.set_http_timeout(entry.info.http_timeout);
109
110 let heartbeat_queue = self.queue.clone();
112 let heartbeat_job_id = job.id;
113 let (heartbeat_stop_tx, mut heartbeat_stop_rx) = tokio::sync::watch::channel(false);
114 let heartbeat_task = tokio::spawn(async move {
115 loop {
116 tokio::select! {
117 _ = tokio::time::sleep(Self::HEARTBEAT_INTERVAL) => {
118 if let Err(e) = heartbeat_queue.heartbeat(heartbeat_job_id).await {
119 tracing::debug!(job_id = %heartbeat_job_id, error = %e, "Failed to update job heartbeat");
120 }
121 }
122 changed = heartbeat_stop_rx.changed() => {
123 if changed.is_err() || *heartbeat_stop_rx.borrow() {
124 break;
125 }
126 }
127 }
128 }
129 });
130
131 let job_timeout = entry.info.timeout;
133 let result = timeout(job_timeout, self.run_handler(&entry, &ctx, &job.input)).await;
134
135 let _ = heartbeat_stop_tx.send(true);
136 let _ = heartbeat_task.await;
137
138 let ttl = entry.info.ttl;
139
140 match result {
141 Ok(Ok(output)) => {
142 if let Err(e) = self.queue.complete(job.id, output.clone(), ttl).await {
144 tracing::debug!(job_id = %job.id, error = %e, "Failed to mark job as complete");
145 }
146 ExecutionResult::Completed { output }
147 }
148 Ok(Err(e)) => {
149 let error_msg = e.to_string();
151 let cancel_requested = match ctx.is_cancel_requested().await {
153 Ok(value) => value,
154 Err(err) => {
155 tracing::debug!(job_id = %job.id, error = %err, "Failed to check cancellation status");
156 false
157 }
158 };
159 if matches!(e, forge_core::ForgeError::JobCancelled(_)) || cancel_requested {
160 let reason = Self::cancellation_reason(job, "Job cancellation requested");
161 if let Err(e) = self.queue.cancel(job.id, Some(&reason), ttl).await {
162 tracing::debug!(job_id = %job.id, error = %e, "Failed to cancel job");
163 }
164 if let Err(e) = self
165 .run_compensation(&entry, &ctx, &job.input, &reason)
166 .await
167 {
168 tracing::error!(job_id = %job.id, error = %e, "Job compensation failed");
169 }
170 return ExecutionResult::Cancelled { reason };
171 }
172 let should_retry = job.attempts < job.max_attempts;
173
174 let retry_delay = if should_retry {
175 Some(entry.info.retry.calculate_backoff(job.attempts as u32))
176 } else {
177 None
178 };
179
180 let chrono_delay = retry_delay.map(|d| {
181 chrono::Duration::from_std(d).unwrap_or(chrono::Duration::seconds(60))
182 });
183
184 if let Err(e) = self.queue.fail(job.id, &error_msg, chrono_delay, ttl).await {
185 tracing::error!(job_id = %job.id, error = %e, "Failed to record job failure");
186 }
187
188 ExecutionResult::Failed {
189 error: error_msg,
190 retryable: should_retry,
191 }
192 }
193 Err(_) => {
194 let error_msg = format!("Job timed out after {:?}", job_timeout);
196 let should_retry = job.attempts < job.max_attempts;
197
198 let retry_delay = if should_retry {
199 Some(chrono::Duration::seconds(60))
200 } else {
201 None
202 };
203
204 if let Err(e) = self.queue.fail(job.id, &error_msg, retry_delay, ttl).await {
205 tracing::error!(job_id = %job.id, error = %e, "Failed to record job timeout");
206 }
207
208 ExecutionResult::TimedOut {
209 retryable: should_retry,
210 }
211 }
212 }
213 }
214
215 async fn run_handler(
217 &self,
218 entry: &Arc<JobEntry>,
219 ctx: &JobContext,
220 input: &serde_json::Value,
221 ) -> forge_core::Result<serde_json::Value> {
222 (entry.handler)(ctx, input.clone()).await
223 }
224
225 async fn run_compensation(
226 &self,
227 entry: &Arc<JobEntry>,
228 ctx: &JobContext,
229 input: &serde_json::Value,
230 reason: &str,
231 ) -> forge_core::Result<()> {
232 (entry.compensation)(ctx, input.clone(), reason).await
233 }
234
235 fn cancellation_reason(job: &JobRecord, fallback: &str) -> String {
236 job.cancel_reason
237 .clone()
238 .unwrap_or_else(|| fallback.to_string())
239 }
240}
241
242#[derive(Debug)]
244pub enum ExecutionResult {
245 Completed { output: serde_json::Value },
247 Failed { error: String, retryable: bool },
249 TimedOut { retryable: bool },
251 Cancelled { reason: String },
253}
254
255impl ExecutionResult {
256 pub fn is_success(&self) -> bool {
258 matches!(self, Self::Completed { .. })
259 }
260
261 pub fn should_retry(&self) -> bool {
263 match self {
264 Self::Failed { retryable, .. } => *retryable,
265 Self::TimedOut { retryable } => *retryable,
266 _ => false,
267 }
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_execution_result_success() {
277 let result = ExecutionResult::Completed {
278 output: serde_json::json!({}),
279 };
280 assert!(result.is_success());
281 assert!(!result.should_retry());
282 }
283
284 #[test]
285 fn test_execution_result_failed_retryable() {
286 let result = ExecutionResult::Failed {
287 error: "test error".to_string(),
288 retryable: true,
289 };
290 assert!(!result.is_success());
291 assert!(result.should_retry());
292 }
293
294 #[test]
295 fn test_execution_result_failed_not_retryable() {
296 let result = ExecutionResult::Failed {
297 error: "test error".to_string(),
298 retryable: false,
299 };
300 assert!(!result.is_success());
301 assert!(!result.should_retry());
302 }
303
304 #[test]
305 fn test_execution_result_timeout() {
306 let result = ExecutionResult::TimedOut { retryable: true };
307 assert!(!result.is_success());
308 assert!(result.should_retry());
309 }
310
311 #[test]
312 fn test_execution_result_cancelled() {
313 let result = ExecutionResult::Cancelled {
314 reason: "user request".to_string(),
315 };
316 assert!(!result.is_success());
317 assert!(!result.should_retry());
318 }
319}