1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::CircuitBreakerClient;
5use forge_core::job::{JobContext, ProgressUpdate};
6use tokio::time::timeout;
7
8use super::queue::{JobQueue, JobRecord};
9use super::registry::{JobEntry, JobRegistry};
10
11pub struct JobExecutor {
13 queue: JobQueue,
14 registry: Arc<JobRegistry>,
15 db_pool: sqlx::PgPool,
16 http_client: CircuitBreakerClient,
17}
18
19impl JobExecutor {
20 const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
21
22 pub fn new(queue: JobQueue, registry: JobRegistry, db_pool: sqlx::PgPool) -> Self {
24 Self {
25 queue,
26 registry: Arc::new(registry),
27 db_pool,
28 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
29 }
30 }
31
32 pub async fn execute(&self, job: &JobRecord) -> ExecutionResult {
34 let entry = match self.registry.get(&job.job_type) {
35 Some(e) => e,
36 None => {
37 return ExecutionResult::Failed {
38 error: format!("Unknown job type: {}", job.job_type),
39 retryable: false,
40 };
41 }
42 };
43
44 if matches!(job.status, forge_core::job::JobStatus::Cancelled) {
45 return ExecutionResult::Cancelled {
46 reason: Self::cancellation_reason(job, "Job cancelled"),
47 };
48 }
49
50 if let Err(e) = self.queue.start(job.id).await {
52 if matches!(e, sqlx::Error::RowNotFound) {
53 return ExecutionResult::Cancelled {
54 reason: Self::cancellation_reason(job, "Job cancelled"),
55 };
56 }
57 return ExecutionResult::Failed {
58 error: format!("Failed to start job: {}", e),
59 retryable: true,
60 };
61 }
62
63 let (progress_tx, progress_rx) = std::sync::mpsc::channel::<ProgressUpdate>();
65
66 let progress_queue = self.queue.clone();
69 let progress_job_id = job.id;
70 tokio::spawn(async move {
71 loop {
72 match progress_rx.try_recv() {
73 Ok(update) => {
74 if let Err(e) = progress_queue
75 .update_progress(
76 progress_job_id,
77 update.percentage as i32,
78 &update.message,
79 )
80 .await
81 {
82 tracing::trace!(job_id = %progress_job_id, error = %e, "Failed to update job progress");
83 }
84 }
85 Err(std::sync::mpsc::TryRecvError::Empty) => {
86 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
88 }
89 Err(std::sync::mpsc::TryRecvError::Disconnected) => {
90 break;
92 }
93 }
94 }
95 });
96
97 let ctx = JobContext::new(
99 job.id,
100 job.job_type.clone(),
101 job.attempts as u32,
102 job.max_attempts as u32,
103 self.db_pool.clone(),
104 self.http_client.inner().clone(),
105 )
106 .with_saved(job.job_context.clone())
107 .with_progress(progress_tx);
108
109 let heartbeat_queue = self.queue.clone();
111 let heartbeat_job_id = job.id;
112 let (heartbeat_stop_tx, mut heartbeat_stop_rx) = tokio::sync::watch::channel(false);
113 let heartbeat_task = tokio::spawn(async move {
114 loop {
115 tokio::select! {
116 _ = tokio::time::sleep(Self::HEARTBEAT_INTERVAL) => {
117 if let Err(e) = heartbeat_queue.heartbeat(heartbeat_job_id).await {
118 tracing::trace!(job_id = %heartbeat_job_id, error = %e, "Failed to update job heartbeat");
119 }
120 }
121 changed = heartbeat_stop_rx.changed() => {
122 if changed.is_err() || *heartbeat_stop_rx.borrow() {
123 break;
124 }
125 }
126 }
127 }
128 });
129
130 let job_timeout = entry.info.timeout;
132 let result = timeout(job_timeout, self.run_handler(&entry, &ctx, &job.input)).await;
133
134 let _ = heartbeat_stop_tx.send(true);
135 let _ = heartbeat_task.await;
136
137 let ttl = entry.info.ttl;
138
139 match result {
140 Ok(Ok(output)) => {
141 if let Err(e) = self.queue.complete(job.id, output.clone(), ttl).await {
143 tracing::debug!(job_id = %job.id, error = %e, "Failed to mark job as complete");
144 }
145 ExecutionResult::Completed { output }
146 }
147 Ok(Err(e)) => {
148 let error_msg = e.to_string();
150 let cancel_requested = match ctx.is_cancel_requested().await {
152 Ok(value) => value,
153 Err(err) => {
154 tracing::trace!(job_id = %job.id, error = %err, "Failed to check cancellation status");
155 false
156 }
157 };
158 if matches!(e, forge_core::ForgeError::JobCancelled(_)) || cancel_requested {
159 let reason = Self::cancellation_reason(job, "Job cancellation requested");
160 let _ = self.queue.cancel(job.id, Some(&reason), ttl).await;
161 if let Err(e) = self
162 .run_compensation(&entry, &ctx, &job.input, &reason)
163 .await
164 {
165 tracing::warn!(job_id = %job.id, error = %e, "Job compensation failed");
166 }
167 return ExecutionResult::Cancelled { reason };
168 }
169 let should_retry = job.attempts < job.max_attempts;
170
171 let retry_delay = if should_retry {
172 Some(entry.info.retry.calculate_backoff(job.attempts as u32))
173 } else {
174 None
175 };
176
177 let chrono_delay = retry_delay.map(|d| {
178 chrono::Duration::from_std(d).unwrap_or(chrono::Duration::seconds(60))
179 });
180
181 let _ = self.queue.fail(job.id, &error_msg, chrono_delay, ttl).await;
182
183 ExecutionResult::Failed {
184 error: error_msg,
185 retryable: should_retry,
186 }
187 }
188 Err(_) => {
189 let error_msg = format!("Job timed out after {:?}", job_timeout);
191 let should_retry = job.attempts < job.max_attempts;
192
193 let retry_delay = if should_retry {
194 Some(chrono::Duration::seconds(60))
195 } else {
196 None
197 };
198
199 let _ = self.queue.fail(job.id, &error_msg, retry_delay, ttl).await;
200
201 ExecutionResult::TimedOut {
202 retryable: should_retry,
203 }
204 }
205 }
206 }
207
208 async fn run_handler(
210 &self,
211 entry: &Arc<JobEntry>,
212 ctx: &JobContext,
213 input: &serde_json::Value,
214 ) -> forge_core::Result<serde_json::Value> {
215 (entry.handler)(ctx, input.clone()).await
216 }
217
218 async fn run_compensation(
219 &self,
220 entry: &Arc<JobEntry>,
221 ctx: &JobContext,
222 input: &serde_json::Value,
223 reason: &str,
224 ) -> forge_core::Result<()> {
225 (entry.compensation)(ctx, input.clone(), reason).await
226 }
227
228 fn cancellation_reason(job: &JobRecord, fallback: &str) -> String {
229 job.cancel_reason
230 .clone()
231 .unwrap_or_else(|| fallback.to_string())
232 }
233}
234
235#[derive(Debug)]
237pub enum ExecutionResult {
238 Completed { output: serde_json::Value },
240 Failed { error: String, retryable: bool },
242 TimedOut { retryable: bool },
244 Cancelled { reason: String },
246}
247
248impl ExecutionResult {
249 pub fn is_success(&self) -> bool {
251 matches!(self, Self::Completed { .. })
252 }
253
254 pub fn should_retry(&self) -> bool {
256 match self {
257 Self::Failed { retryable, .. } => *retryable,
258 Self::TimedOut { retryable } => *retryable,
259 _ => false,
260 }
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_execution_result_success() {
270 let result = ExecutionResult::Completed {
271 output: serde_json::json!({}),
272 };
273 assert!(result.is_success());
274 assert!(!result.should_retry());
275 }
276
277 #[test]
278 fn test_execution_result_failed_retryable() {
279 let result = ExecutionResult::Failed {
280 error: "test error".to_string(),
281 retryable: true,
282 };
283 assert!(!result.is_success());
284 assert!(result.should_retry());
285 }
286
287 #[test]
288 fn test_execution_result_failed_not_retryable() {
289 let result = ExecutionResult::Failed {
290 error: "test error".to_string(),
291 retryable: false,
292 };
293 assert!(!result.is_success());
294 assert!(!result.should_retry());
295 }
296
297 #[test]
298 fn test_execution_result_timeout() {
299 let result = ExecutionResult::TimedOut { retryable: true };
300 assert!(!result.is_success());
301 assert!(result.should_retry());
302 }
303
304 #[test]
305 fn test_execution_result_cancelled() {
306 let result = ExecutionResult::Cancelled {
307 reason: "user request".to_string(),
308 };
309 assert!(!result.is_success());
310 assert!(!result.should_retry());
311 }
312}