Skip to main content

forge_runtime/jobs/
executor.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::CircuitBreakerClient;
5use forge_core::job::{JobContext, ProgressUpdate};
6use tokio::time::timeout;
7
8use super::queue::{JobQueue, JobRecord};
9use super::registry::{JobEntry, JobRegistry};
10
11/// Executes jobs with timeout and retry handling.
12pub struct JobExecutor {
13    queue: JobQueue,
14    registry: Arc<JobRegistry>,
15    db_pool: sqlx::PgPool,
16    http_client: CircuitBreakerClient,
17}
18
19impl JobExecutor {
20    const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
21
22    /// Create a new job executor.
23    pub fn new(queue: JobQueue, registry: JobRegistry, db_pool: sqlx::PgPool) -> Self {
24        Self {
25            queue,
26            registry: Arc::new(registry),
27            db_pool,
28            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
29        }
30    }
31
32    /// Execute a claimed job.
33    pub async fn execute(&self, job: &JobRecord) -> ExecutionResult {
34        let entry = match self.registry.get(&job.job_type) {
35            Some(e) => e,
36            None => {
37                return ExecutionResult::Failed {
38                    error: format!("Unknown job type: {}", job.job_type),
39                    retryable: false,
40                };
41            }
42        };
43
44        if matches!(job.status, forge_core::job::JobStatus::Cancelled) {
45            return ExecutionResult::Cancelled {
46                reason: Self::cancellation_reason(job, "Job cancelled"),
47            };
48        }
49
50        // Mark job as running
51        if let Err(e) = self.queue.start(job.id).await {
52            if matches!(e, sqlx::Error::RowNotFound) {
53                return ExecutionResult::Cancelled {
54                    reason: Self::cancellation_reason(job, "Job cancelled"),
55                };
56            }
57            return ExecutionResult::Failed {
58                error: format!("Failed to start job: {}", e),
59                retryable: true,
60            };
61        }
62
63        // Set up progress channel
64        let (progress_tx, progress_rx) = std::sync::mpsc::channel::<ProgressUpdate>();
65
66        // Spawn task to consume progress updates and save to database
67        // Use try_recv() with async sleep to avoid blocking the tokio runtime
68        let progress_queue = self.queue.clone();
69        let progress_job_id = job.id;
70        tokio::spawn(async move {
71            loop {
72                match progress_rx.try_recv() {
73                    Ok(update) => {
74                        if let Err(e) = progress_queue
75                            .update_progress(
76                                progress_job_id,
77                                update.percentage as i32,
78                                &update.message,
79                            )
80                            .await
81                        {
82                            tracing::debug!(job_id = %progress_job_id, error = %e, "Failed to update job progress");
83                        }
84                    }
85                    Err(std::sync::mpsc::TryRecvError::Empty) => {
86                        // No message yet, sleep briefly and check again
87                        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
88                    }
89                    Err(std::sync::mpsc::TryRecvError::Disconnected) => {
90                        // Sender dropped (job finished), exit loop
91                        break;
92                    }
93                }
94            }
95        });
96
97        // Create job context with progress channel
98        let mut ctx = JobContext::new(
99            job.id,
100            job.job_type.clone(),
101            job.attempts as u32,
102            job.max_attempts as u32,
103            self.db_pool.clone(),
104            self.http_client.clone(),
105        )
106        .with_saved(job.job_context.clone())
107        .with_progress(progress_tx);
108        ctx.set_http_timeout(entry.info.http_timeout);
109
110        // Keepalive heartbeat prevents stale cleanup from reclaiming healthy long jobs.
111        let heartbeat_queue = self.queue.clone();
112        let heartbeat_job_id = job.id;
113        let (heartbeat_stop_tx, mut heartbeat_stop_rx) = tokio::sync::watch::channel(false);
114        let heartbeat_task = tokio::spawn(async move {
115            loop {
116                tokio::select! {
117                    _ = tokio::time::sleep(Self::HEARTBEAT_INTERVAL) => {
118                        if let Err(e) = heartbeat_queue.heartbeat(heartbeat_job_id).await {
119                            tracing::debug!(job_id = %heartbeat_job_id, error = %e, "Failed to update job heartbeat");
120                        }
121                    }
122                    changed = heartbeat_stop_rx.changed() => {
123                        if changed.is_err() || *heartbeat_stop_rx.borrow() {
124                            break;
125                        }
126                    }
127                }
128            }
129        });
130
131        // Execute with timeout
132        let job_timeout = entry.info.timeout;
133        let result = timeout(job_timeout, self.run_handler(&entry, &ctx, &job.input)).await;
134
135        let _ = heartbeat_stop_tx.send(true);
136        let _ = heartbeat_task.await;
137
138        let ttl = entry.info.ttl;
139
140        match result {
141            Ok(Ok(output)) => {
142                // Job completed successfully
143                if let Err(e) = self.queue.complete(job.id, output.clone(), ttl).await {
144                    tracing::debug!(job_id = %job.id, error = %e, "Failed to mark job as complete");
145                }
146                ExecutionResult::Completed { output }
147            }
148            Ok(Err(e)) => {
149                // Job failed
150                let error_msg = e.to_string();
151                // Accepts either an explicit cancellation error or a late cancellation request.
152                let cancel_requested = match ctx.is_cancel_requested().await {
153                    Ok(value) => value,
154                    Err(err) => {
155                        tracing::debug!(job_id = %job.id, error = %err, "Failed to check cancellation status");
156                        false
157                    }
158                };
159                if matches!(e, forge_core::ForgeError::JobCancelled(_)) || cancel_requested {
160                    let reason = Self::cancellation_reason(job, "Job cancellation requested");
161                    if let Err(e) = self.queue.cancel(job.id, Some(&reason), ttl).await {
162                        tracing::debug!(job_id = %job.id, error = %e, "Failed to cancel job");
163                    }
164                    if let Err(e) = self
165                        .run_compensation(&entry, &ctx, &job.input, &reason)
166                        .await
167                    {
168                        tracing::error!(job_id = %job.id, error = %e, "Job compensation failed");
169                    }
170                    return ExecutionResult::Cancelled { reason };
171                }
172                let should_retry = job.attempts < job.max_attempts;
173
174                let retry_delay = if should_retry {
175                    Some(entry.info.retry.calculate_backoff(job.attempts as u32))
176                } else {
177                    None
178                };
179
180                let chrono_delay = retry_delay.map(|d| {
181                    chrono::Duration::from_std(d).unwrap_or(chrono::Duration::seconds(60))
182                });
183
184                if let Err(e) = self.queue.fail(job.id, &error_msg, chrono_delay, ttl).await {
185                    tracing::error!(job_id = %job.id, error = %e, "Failed to record job failure");
186                }
187
188                ExecutionResult::Failed {
189                    error: error_msg,
190                    retryable: should_retry,
191                }
192            }
193            Err(_) => {
194                // Timeout
195                let error_msg = format!("Job timed out after {:?}", job_timeout);
196                let should_retry = job.attempts < job.max_attempts;
197
198                let retry_delay = if should_retry {
199                    Some(chrono::Duration::seconds(60))
200                } else {
201                    None
202                };
203
204                if let Err(e) = self.queue.fail(job.id, &error_msg, retry_delay, ttl).await {
205                    tracing::error!(job_id = %job.id, error = %e, "Failed to record job timeout");
206                }
207
208                ExecutionResult::TimedOut {
209                    retryable: should_retry,
210                }
211            }
212        }
213    }
214
215    /// Run the job handler.
216    async fn run_handler(
217        &self,
218        entry: &Arc<JobEntry>,
219        ctx: &JobContext,
220        input: &serde_json::Value,
221    ) -> forge_core::Result<serde_json::Value> {
222        (entry.handler)(ctx, input.clone()).await
223    }
224
225    async fn run_compensation(
226        &self,
227        entry: &Arc<JobEntry>,
228        ctx: &JobContext,
229        input: &serde_json::Value,
230        reason: &str,
231    ) -> forge_core::Result<()> {
232        (entry.compensation)(ctx, input.clone(), reason).await
233    }
234
235    fn cancellation_reason(job: &JobRecord, fallback: &str) -> String {
236        job.cancel_reason
237            .clone()
238            .unwrap_or_else(|| fallback.to_string())
239    }
240}
241
242/// Result of job execution.
243#[derive(Debug)]
244pub enum ExecutionResult {
245    /// Job completed successfully.
246    Completed { output: serde_json::Value },
247    /// Job failed.
248    Failed { error: String, retryable: bool },
249    /// Job timed out.
250    TimedOut { retryable: bool },
251    /// Job cancelled.
252    Cancelled { reason: String },
253}
254
255impl ExecutionResult {
256    /// Check if execution was successful.
257    pub fn is_success(&self) -> bool {
258        matches!(self, Self::Completed { .. })
259    }
260
261    /// Check if the job should be retried.
262    pub fn should_retry(&self) -> bool {
263        match self {
264            Self::Failed { retryable, .. } => *retryable,
265            Self::TimedOut { retryable } => *retryable,
266            _ => false,
267        }
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_execution_result_success() {
277        let result = ExecutionResult::Completed {
278            output: serde_json::json!({}),
279        };
280        assert!(result.is_success());
281        assert!(!result.should_retry());
282    }
283
284    #[test]
285    fn test_execution_result_failed_retryable() {
286        let result = ExecutionResult::Failed {
287            error: "test error".to_string(),
288            retryable: true,
289        };
290        assert!(!result.is_success());
291        assert!(result.should_retry());
292    }
293
294    #[test]
295    fn test_execution_result_failed_not_retryable() {
296        let result = ExecutionResult::Failed {
297            error: "test error".to_string(),
298            retryable: false,
299        };
300        assert!(!result.is_success());
301        assert!(!result.should_retry());
302    }
303
304    #[test]
305    fn test_execution_result_timeout() {
306        let result = ExecutionResult::TimedOut { retryable: true };
307        assert!(!result.is_success());
308        assert!(result.should_retry());
309    }
310
311    #[test]
312    fn test_execution_result_cancelled() {
313        let result = ExecutionResult::Cancelled {
314            reason: "user request".to_string(),
315        };
316        assert!(!result.is_success());
317        assert!(!result.should_retry());
318    }
319}