Skip to main content

forge_runtime/jobs/
executor.rs

1use std::sync::Arc;
2
3use forge_core::CircuitBreakerClient;
4use forge_core::job::{JobContext, ProgressUpdate};
5use tokio::time::timeout;
6
7use super::queue::{JobQueue, JobRecord};
8use super::registry::{JobEntry, JobRegistry};
9
10/// Executes jobs with timeout and retry handling.
11pub struct JobExecutor {
12    queue: JobQueue,
13    registry: Arc<JobRegistry>,
14    db_pool: sqlx::PgPool,
15    http_client: CircuitBreakerClient,
16}
17
18impl JobExecutor {
19    /// Create a new job executor.
20    pub fn new(queue: JobQueue, registry: JobRegistry, db_pool: sqlx::PgPool) -> Self {
21        Self {
22            queue,
23            registry: Arc::new(registry),
24            db_pool,
25            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
26        }
27    }
28
29    /// Execute a claimed job.
30    pub async fn execute(&self, job: &JobRecord) -> ExecutionResult {
31        let entry = match self.registry.get(&job.job_type) {
32            Some(e) => e,
33            None => {
34                return ExecutionResult::Failed {
35                    error: format!("Unknown job type: {}", job.job_type),
36                    retryable: false,
37                };
38            }
39        };
40
41        if matches!(job.status, forge_core::job::JobStatus::Cancelled) {
42            return ExecutionResult::Cancelled {
43                reason: Self::cancellation_reason(job, "Job cancelled"),
44            };
45        }
46
47        // Mark job as running
48        if let Err(e) = self.queue.start(job.id).await {
49            if matches!(e, sqlx::Error::RowNotFound) {
50                return ExecutionResult::Cancelled {
51                    reason: Self::cancellation_reason(job, "Job cancelled"),
52                };
53            }
54            return ExecutionResult::Failed {
55                error: format!("Failed to start job: {}", e),
56                retryable: true,
57            };
58        }
59
60        // Set up progress channel
61        let (progress_tx, progress_rx) = std::sync::mpsc::channel::<ProgressUpdate>();
62
63        // Spawn task to consume progress updates and save to database
64        // Use try_recv() with async sleep to avoid blocking the tokio runtime
65        let progress_queue = self.queue.clone();
66        let progress_job_id = job.id;
67        tokio::spawn(async move {
68            loop {
69                match progress_rx.try_recv() {
70                    Ok(update) => {
71                        if let Err(e) = progress_queue
72                            .update_progress(
73                                progress_job_id,
74                                update.percentage as i32,
75                                &update.message,
76                            )
77                            .await
78                        {
79                            tracing::warn!(
80                                job_id = %progress_job_id,
81                                "Failed to update job progress: {}",
82                                e
83                            );
84                        }
85                    }
86                    Err(std::sync::mpsc::TryRecvError::Empty) => {
87                        // No message yet, sleep briefly and check again
88                        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
89                    }
90                    Err(std::sync::mpsc::TryRecvError::Disconnected) => {
91                        // Sender dropped (job finished), exit loop
92                        break;
93                    }
94                }
95            }
96        });
97
98        // Create job context with progress channel
99        let ctx = JobContext::new(
100            job.id,
101            job.job_type.clone(),
102            job.attempts as u32,
103            job.max_attempts as u32,
104            self.db_pool.clone(),
105            self.http_client.inner().clone(),
106        )
107        .with_saved(job.job_context.clone())
108        .with_progress(progress_tx);
109
110        // Execute with timeout
111        let job_timeout = entry.info.timeout;
112        let result = timeout(job_timeout, self.run_handler(&entry, &ctx, &job.input)).await;
113
114        let ttl = entry.info.ttl;
115
116        match result {
117            Ok(Ok(output)) => {
118                // Job completed successfully
119                if let Err(e) = self.queue.complete(job.id, output.clone(), ttl).await {
120                    tracing::error!("Failed to mark job {} as complete: {}", job.id, e);
121                }
122                ExecutionResult::Completed { output }
123            }
124            Ok(Err(e)) => {
125                // Job failed
126                let error_msg = e.to_string();
127                // Accepts either an explicit cancellation error or a late cancellation request.
128                let cancel_requested = match ctx.is_cancel_requested().await {
129                    Ok(value) => value,
130                    Err(err) => {
131                        tracing::warn!(
132                            job_id = %job.id,
133                            "Failed to check cancellation status: {}",
134                            err
135                        );
136                        false
137                    }
138                };
139                if matches!(e, forge_core::ForgeError::JobCancelled(_)) || cancel_requested {
140                    let reason = Self::cancellation_reason(job, "Job cancellation requested");
141                    if let Err(e) = self.queue.cancel(job.id, Some(&reason), ttl).await {
142                        tracing::error!("Failed to mark job {} as cancelled: {}", job.id, e);
143                    }
144                    if let Err(e) = self
145                        .run_compensation(&entry, &ctx, &job.input, &reason)
146                        .await
147                    {
148                        tracing::error!("Compensation failed for job {}: {}", job.id, e);
149                    }
150                    return ExecutionResult::Cancelled { reason };
151                }
152                let should_retry = job.attempts < job.max_attempts;
153
154                let retry_delay = if should_retry {
155                    Some(entry.info.retry.calculate_backoff(job.attempts as u32))
156                } else {
157                    None
158                };
159
160                let chrono_delay = retry_delay.map(|d| {
161                    chrono::Duration::from_std(d).unwrap_or(chrono::Duration::seconds(60))
162                });
163
164                if let Err(e) = self.queue.fail(job.id, &error_msg, chrono_delay, ttl).await {
165                    tracing::error!("Failed to mark job {} as failed: {}", job.id, e);
166                }
167
168                ExecutionResult::Failed {
169                    error: error_msg,
170                    retryable: should_retry,
171                }
172            }
173            Err(_) => {
174                // Timeout
175                let error_msg = format!("Job timed out after {:?}", job_timeout);
176                let should_retry = job.attempts < job.max_attempts;
177
178                let retry_delay = if should_retry {
179                    Some(chrono::Duration::seconds(60))
180                } else {
181                    None
182                };
183
184                if let Err(e) = self.queue.fail(job.id, &error_msg, retry_delay, ttl).await {
185                    tracing::error!("Failed to mark job {} as timed out: {}", job.id, e);
186                }
187
188                ExecutionResult::TimedOut {
189                    retryable: should_retry,
190                }
191            }
192        }
193    }
194
195    /// Run the job handler.
196    async fn run_handler(
197        &self,
198        entry: &Arc<JobEntry>,
199        ctx: &JobContext,
200        input: &serde_json::Value,
201    ) -> forge_core::Result<serde_json::Value> {
202        (entry.handler)(ctx, input.clone()).await
203    }
204
205    async fn run_compensation(
206        &self,
207        entry: &Arc<JobEntry>,
208        ctx: &JobContext,
209        input: &serde_json::Value,
210        reason: &str,
211    ) -> forge_core::Result<()> {
212        (entry.compensation)(ctx, input.clone(), reason).await
213    }
214
215    fn cancellation_reason(job: &JobRecord, fallback: &str) -> String {
216        job.cancel_reason
217            .clone()
218            .unwrap_or_else(|| fallback.to_string())
219    }
220}
221
222/// Result of job execution.
223#[derive(Debug)]
224pub enum ExecutionResult {
225    /// Job completed successfully.
226    Completed { output: serde_json::Value },
227    /// Job failed.
228    Failed { error: String, retryable: bool },
229    /// Job timed out.
230    TimedOut { retryable: bool },
231    /// Job cancelled.
232    Cancelled { reason: String },
233}
234
235impl ExecutionResult {
236    /// Check if execution was successful.
237    pub fn is_success(&self) -> bool {
238        matches!(self, Self::Completed { .. })
239    }
240
241    /// Check if the job should be retried.
242    pub fn should_retry(&self) -> bool {
243        match self {
244            Self::Failed { retryable, .. } => *retryable,
245            Self::TimedOut { retryable } => *retryable,
246            _ => false,
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_execution_result_success() {
257        let result = ExecutionResult::Completed {
258            output: serde_json::json!({}),
259        };
260        assert!(result.is_success());
261        assert!(!result.should_retry());
262    }
263
264    #[test]
265    fn test_execution_result_failed_retryable() {
266        let result = ExecutionResult::Failed {
267            error: "test error".to_string(),
268            retryable: true,
269        };
270        assert!(!result.is_success());
271        assert!(result.should_retry());
272    }
273
274    #[test]
275    fn test_execution_result_failed_not_retryable() {
276        let result = ExecutionResult::Failed {
277            error: "test error".to_string(),
278            retryable: false,
279        };
280        assert!(!result.is_success());
281        assert!(!result.should_retry());
282    }
283
284    #[test]
285    fn test_execution_result_timeout() {
286        let result = ExecutionResult::TimedOut { retryable: true };
287        assert!(!result.is_success());
288        assert!(result.should_retry());
289    }
290
291    #[test]
292    fn test_execution_result_cancelled() {
293        let result = ExecutionResult::Cancelled {
294            reason: "user request".to_string(),
295        };
296        assert!(!result.is_success());
297        assert!(!result.should_retry());
298    }
299}