Skip to main content

forge_runtime/jobs/
executor.rs

1use std::sync::Arc;
2
3use forge_core::job::{JobContext, ProgressUpdate};
4use tokio::time::timeout;
5
6use super::queue::{JobQueue, JobRecord};
7use super::registry::{JobEntry, JobRegistry};
8
9/// Executes jobs with timeout and retry handling.
10pub struct JobExecutor {
11    queue: JobQueue,
12    registry: Arc<JobRegistry>,
13    db_pool: sqlx::PgPool,
14    http_client: reqwest::Client,
15}
16
17impl JobExecutor {
18    /// Create a new job executor.
19    pub fn new(queue: JobQueue, registry: JobRegistry, db_pool: sqlx::PgPool) -> Self {
20        Self {
21            queue,
22            registry: Arc::new(registry),
23            db_pool,
24            http_client: reqwest::Client::new(),
25        }
26    }
27
28    /// Execute a claimed job.
29    pub async fn execute(&self, job: &JobRecord) -> ExecutionResult {
30        let entry = match self.registry.get(&job.job_type) {
31            Some(e) => e,
32            None => {
33                return ExecutionResult::Failed {
34                    error: format!("Unknown job type: {}", job.job_type),
35                    retryable: false,
36                };
37            }
38        };
39
40        // Mark job as running
41        if let Err(e) = self.queue.start(job.id).await {
42            return ExecutionResult::Failed {
43                error: format!("Failed to start job: {}", e),
44                retryable: true,
45            };
46        }
47
48        // Set up progress channel
49        let (progress_tx, progress_rx) = std::sync::mpsc::channel::<ProgressUpdate>();
50
51        // Spawn task to consume progress updates and save to database
52        // Use try_recv() with async sleep to avoid blocking the tokio runtime
53        let progress_queue = self.queue.clone();
54        let progress_job_id = job.id;
55        tokio::spawn(async move {
56            loop {
57                match progress_rx.try_recv() {
58                    Ok(update) => {
59                        if let Err(e) = progress_queue
60                            .update_progress(
61                                progress_job_id,
62                                update.percentage as i32,
63                                &update.message,
64                            )
65                            .await
66                        {
67                            tracing::warn!(
68                                job_id = %progress_job_id,
69                                "Failed to update job progress: {}",
70                                e
71                            );
72                        }
73                    }
74                    Err(std::sync::mpsc::TryRecvError::Empty) => {
75                        // No message yet, sleep briefly and check again
76                        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
77                    }
78                    Err(std::sync::mpsc::TryRecvError::Disconnected) => {
79                        // Sender dropped (job finished), exit loop
80                        break;
81                    }
82                }
83            }
84        });
85
86        // Create job context with progress channel
87        let ctx = JobContext::new(
88            job.id,
89            job.job_type.clone(),
90            job.attempts as u32,
91            job.max_attempts as u32,
92            self.db_pool.clone(),
93            self.http_client.clone(),
94        )
95        .with_progress(progress_tx);
96
97        // Execute with timeout
98        let job_timeout = entry.info.timeout;
99        let result = timeout(job_timeout, self.run_handler(&entry, &ctx, &job.input)).await;
100
101        match result {
102            Ok(Ok(output)) => {
103                // Job completed successfully
104                if let Err(e) = self.queue.complete(job.id, output.clone()).await {
105                    tracing::error!("Failed to mark job {} as complete: {}", job.id, e);
106                }
107                ExecutionResult::Completed { output }
108            }
109            Ok(Err(e)) => {
110                // Job failed
111                let error_msg = e.to_string();
112                let should_retry = job.attempts < job.max_attempts;
113
114                let retry_delay = if should_retry {
115                    Some(entry.info.retry.calculate_backoff(job.attempts as u32))
116                } else {
117                    None
118                };
119
120                let chrono_delay = retry_delay.map(|d| {
121                    chrono::Duration::from_std(d).unwrap_or(chrono::Duration::seconds(60))
122                });
123
124                if let Err(e) = self.queue.fail(job.id, &error_msg, chrono_delay).await {
125                    tracing::error!("Failed to mark job {} as failed: {}", job.id, e);
126                }
127
128                ExecutionResult::Failed {
129                    error: error_msg,
130                    retryable: should_retry,
131                }
132            }
133            Err(_) => {
134                // Timeout
135                let error_msg = format!("Job timed out after {:?}", job_timeout);
136                let should_retry = job.attempts < job.max_attempts;
137
138                let retry_delay = if should_retry {
139                    Some(chrono::Duration::seconds(60))
140                } else {
141                    None
142                };
143
144                if let Err(e) = self.queue.fail(job.id, &error_msg, retry_delay).await {
145                    tracing::error!("Failed to mark job {} as timed out: {}", job.id, e);
146                }
147
148                ExecutionResult::TimedOut {
149                    retryable: should_retry,
150                }
151            }
152        }
153    }
154
155    /// Run the job handler.
156    async fn run_handler(
157        &self,
158        entry: &Arc<JobEntry>,
159        ctx: &JobContext,
160        input: &serde_json::Value,
161    ) -> forge_core::Result<serde_json::Value> {
162        (entry.handler)(ctx, input.clone()).await
163    }
164}
165
166/// Result of job execution.
167#[derive(Debug)]
168pub enum ExecutionResult {
169    /// Job completed successfully.
170    Completed { output: serde_json::Value },
171    /// Job failed.
172    Failed { error: String, retryable: bool },
173    /// Job timed out.
174    TimedOut { retryable: bool },
175}
176
177impl ExecutionResult {
178    /// Check if execution was successful.
179    pub fn is_success(&self) -> bool {
180        matches!(self, Self::Completed { .. })
181    }
182
183    /// Check if the job should be retried.
184    pub fn should_retry(&self) -> bool {
185        match self {
186            Self::Failed { retryable, .. } => *retryable,
187            Self::TimedOut { retryable } => *retryable,
188            _ => false,
189        }
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_execution_result_success() {
199        let result = ExecutionResult::Completed {
200            output: serde_json::json!({}),
201        };
202        assert!(result.is_success());
203        assert!(!result.should_retry());
204    }
205
206    #[test]
207    fn test_execution_result_failed_retryable() {
208        let result = ExecutionResult::Failed {
209            error: "test error".to_string(),
210            retryable: true,
211        };
212        assert!(!result.is_success());
213        assert!(result.should_retry());
214    }
215
216    #[test]
217    fn test_execution_result_failed_not_retryable() {
218        let result = ExecutionResult::Failed {
219            error: "test error".to_string(),
220            retryable: false,
221        };
222        assert!(!result.is_success());
223        assert!(!result.should_retry());
224    }
225
226    #[test]
227    fn test_execution_result_timeout() {
228        let result = ExecutionResult::TimedOut { retryable: true };
229        assert!(!result.is_success());
230        assert!(result.should_retry());
231    }
232}