Skip to main content

forge_runtime/jobs/
worker.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use tokio::sync::mpsc;
5use tracing::Instrument;
6use uuid::Uuid;
7
8use super::executor::JobExecutor;
9use super::queue::JobQueue;
10use super::registry::JobRegistry;
11
12/// Worker configuration.
13#[derive(Debug, Clone)]
14pub struct WorkerConfig {
15    /// Worker ID (auto-generated if not provided).
16    pub id: Option<Uuid>,
17    /// Worker capabilities (e.g., ["general", "media"]).
18    pub capabilities: Vec<String>,
19    /// Maximum concurrent jobs.
20    pub max_concurrent: usize,
21    /// Poll interval when queue is empty.
22    pub poll_interval: Duration,
23    /// Batch size for claiming jobs.
24    pub batch_size: i32,
25    /// Stale job cleanup interval.
26    pub stale_cleanup_interval: Duration,
27    /// Stale job threshold.
28    pub stale_threshold: chrono::Duration,
29}
30
31impl Default for WorkerConfig {
32    fn default() -> Self {
33        Self {
34            id: None,
35            capabilities: vec!["general".to_string()],
36            max_concurrent: 10,
37            poll_interval: Duration::from_millis(100),
38            batch_size: 10,
39            stale_cleanup_interval: Duration::from_secs(60),
40            stale_threshold: chrono::Duration::minutes(5),
41        }
42    }
43}
44
45/// Background job worker.
46pub struct Worker {
47    id: Uuid,
48    config: WorkerConfig,
49    queue: JobQueue,
50    executor: Arc<JobExecutor>,
51    shutdown_tx: Option<mpsc::Sender<()>>,
52}
53
54impl Worker {
55    /// Create a new worker.
56    pub fn new(
57        config: WorkerConfig,
58        queue: JobQueue,
59        registry: JobRegistry,
60        db_pool: sqlx::PgPool,
61    ) -> Self {
62        let id = config.id.unwrap_or_else(Uuid::new_v4);
63        let executor = Arc::new(JobExecutor::new(queue.clone(), registry, db_pool));
64
65        Self {
66            id,
67            config,
68            queue,
69            executor,
70            shutdown_tx: None,
71        }
72    }
73
74    /// Get worker ID.
75    pub fn id(&self) -> Uuid {
76        self.id
77    }
78
79    /// Get worker capabilities.
80    pub fn capabilities(&self) -> &[String] {
81        &self.config.capabilities
82    }
83
84    /// Run the worker (blocks until shutdown).
85    pub async fn run(&mut self) -> Result<(), WorkerError> {
86        let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
87        self.shutdown_tx = Some(shutdown_tx);
88
89        // Semaphore to limit concurrent jobs
90        let semaphore = Arc::new(tokio::sync::Semaphore::new(self.config.max_concurrent));
91
92        // Spawn stale and expired cleanup task
93        let cleanup_queue = self.queue.clone();
94        let cleanup_interval = self.config.stale_cleanup_interval;
95        let stale_threshold = self.config.stale_threshold;
96        tokio::spawn(async move {
97            loop {
98                tokio::time::sleep(cleanup_interval).await;
99
100                // Release stale jobs back to pending
101                if let Err(e) = cleanup_queue.release_stale(stale_threshold).await {
102                    tracing::warn!(error = %e, "Failed to cleanup stale jobs");
103                }
104
105                // Delete expired job records
106                match cleanup_queue.cleanup_expired().await {
107                    Ok(count) if count > 0 => {
108                        tracing::debug!(count, "Cleaned up expired job records");
109                    }
110                    Err(e) => {
111                        tracing::warn!(error = %e, "Failed to cleanup expired jobs");
112                    }
113                    _ => {}
114                }
115            }
116        });
117
118        tracing::debug!(
119            worker_id = %self.id,
120            capabilities = ?self.config.capabilities,
121            "Worker started"
122        );
123
124        loop {
125            tokio::select! {
126                _ = shutdown_rx.recv() => {
127                    tracing::debug!(worker_id = %self.id, "Worker shutting down");
128                    break;
129                }
130                _ = tokio::time::sleep(self.config.poll_interval) => {
131                    // Calculate how many jobs we can claim
132                    let available = semaphore.available_permits();
133                    if available == 0 {
134                        continue;
135                    }
136
137                    let batch_size = (available as i32).min(self.config.batch_size);
138
139                    // Claim jobs
140                    let jobs = match self.queue.claim(
141                        self.id,
142                        &self.config.capabilities,
143                        batch_size,
144                    ).await {
145                        Ok(jobs) => jobs,
146                        Err(e) => {
147                            tracing::warn!(error = %e, "Failed to claim jobs");
148                            continue;
149                        }
150                    };
151
152                    // Process each job
153                    for job in jobs {
154                        let permit = match semaphore.clone().acquire_owned().await {
155                            Ok(p) => p,
156                            Err(_) => {
157                                tracing::error!("Worker semaphore closed, stopping job processing");
158                                break;
159                            }
160                        };
161                        let executor = self.executor.clone();
162                        let job_id = job.id;
163                        let job_type = job.job_type.clone();
164
165                        tokio::spawn(async move {
166                            let start = std::time::Instant::now();
167                            let span = tracing::info_span!(
168                                "job.execute",
169                                job_id = %job_id,
170                                job_type = %job_type,
171                            );
172
173                            let result = executor.execute(&job).instrument(span).await;
174
175                            let duration_secs = start.elapsed().as_secs_f64();
176
177                            match &result {
178                                super::executor::ExecutionResult::Completed { .. } => {
179                                    tracing::info!(job_id = %job_id, job_type = %job_type, duration_ms = (duration_secs * 1000.0) as u64, "Job completed");
180                                    crate::observability::record_job_execution(&job_type, "completed", duration_secs);
181                                }
182                                super::executor::ExecutionResult::Failed { error, retryable } => {
183                                    if *retryable {
184                                        tracing::warn!(job_id = %job_id, job_type = %job_type, error = %error, "Job failed, will retry");
185                                        crate::observability::record_job_execution(&job_type, "retrying", duration_secs);
186                                    } else {
187                                        tracing::error!(job_id = %job_id, job_type = %job_type, error = %error, "Job failed permanently");
188                                        crate::observability::record_job_execution(&job_type, "failed", duration_secs);
189                                    }
190                                }
191                                super::executor::ExecutionResult::TimedOut { retryable } => {
192                                    tracing::error!(job_id = %job_id, job_type = %job_type, will_retry = %retryable, "Job timed out");
193                                    crate::observability::record_job_execution(&job_type, "timeout", duration_secs);
194                                }
195                                super::executor::ExecutionResult::Cancelled { reason } => {
196                                    tracing::info!(job_id = %job_id, job_type = %job_type, reason = %reason, "Job cancelled");
197                                    crate::observability::record_job_execution(&job_type, "cancelled", duration_secs);
198                                }
199                            }
200
201                            drop(permit);
202                        });
203                    }
204                }
205            }
206        }
207
208        Ok(())
209    }
210
211    /// Request graceful shutdown.
212    pub async fn shutdown(&self) {
213        if let Some(ref tx) = self.shutdown_tx {
214            let _ = tx.send(()).await;
215        }
216    }
217}
218
219/// Worker errors.
220#[derive(Debug, thiserror::Error)]
221pub enum WorkerError {
222    #[error("Database error: {0}")]
223    Database(String),
224
225    #[error("Job execution error: {0}")]
226    Execution(String),
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_worker_config_default() {
235        let config = WorkerConfig::default();
236        assert_eq!(config.capabilities, vec!["general".to_string()]);
237        assert_eq!(config.max_concurrent, 10);
238        assert_eq!(config.batch_size, 10);
239    }
240
241    #[test]
242    fn test_worker_config_custom() {
243        let config = WorkerConfig {
244            capabilities: vec!["media".to_string(), "general".to_string()],
245            max_concurrent: 4,
246            ..Default::default()
247        };
248        assert_eq!(config.capabilities.len(), 2);
249        assert_eq!(config.max_concurrent, 4);
250    }
251}