Skip to main content

forge_runtime/jobs/
worker.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use tokio::sync::mpsc;
5use uuid::Uuid;
6
7use super::executor::JobExecutor;
8use super::queue::JobQueue;
9use super::registry::JobRegistry;
10
11/// Worker configuration.
12#[derive(Debug, Clone)]
13pub struct WorkerConfig {
14    /// Worker ID (auto-generated if not provided).
15    pub id: Option<Uuid>,
16    /// Worker capabilities (e.g., ["general", "media"]).
17    pub capabilities: Vec<String>,
18    /// Maximum concurrent jobs.
19    pub max_concurrent: usize,
20    /// Poll interval when queue is empty.
21    pub poll_interval: Duration,
22    /// Batch size for claiming jobs.
23    pub batch_size: i32,
24    /// Stale job cleanup interval.
25    pub stale_cleanup_interval: Duration,
26    /// Stale job threshold.
27    pub stale_threshold: chrono::Duration,
28}
29
30impl Default for WorkerConfig {
31    fn default() -> Self {
32        Self {
33            id: None,
34            capabilities: vec!["general".to_string()],
35            max_concurrent: 10,
36            poll_interval: Duration::from_millis(100),
37            batch_size: 10,
38            stale_cleanup_interval: Duration::from_secs(60),
39            stale_threshold: chrono::Duration::minutes(5),
40        }
41    }
42}
43
44/// Background job worker.
45pub struct Worker {
46    id: Uuid,
47    config: WorkerConfig,
48    queue: JobQueue,
49    executor: Arc<JobExecutor>,
50    shutdown_tx: Option<mpsc::Sender<()>>,
51}
52
53impl Worker {
54    /// Create a new worker.
55    pub fn new(
56        config: WorkerConfig,
57        queue: JobQueue,
58        registry: JobRegistry,
59        db_pool: sqlx::PgPool,
60    ) -> Self {
61        let id = config.id.unwrap_or_else(Uuid::new_v4);
62        let executor = Arc::new(JobExecutor::new(queue.clone(), registry, db_pool));
63
64        Self {
65            id,
66            config,
67            queue,
68            executor,
69            shutdown_tx: None,
70        }
71    }
72
73    /// Get worker ID.
74    pub fn id(&self) -> Uuid {
75        self.id
76    }
77
78    /// Get worker capabilities.
79    pub fn capabilities(&self) -> &[String] {
80        &self.config.capabilities
81    }
82
83    /// Run the worker (blocks until shutdown).
84    pub async fn run(&mut self) -> Result<(), WorkerError> {
85        let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
86        self.shutdown_tx = Some(shutdown_tx);
87
88        // Semaphore to limit concurrent jobs
89        let semaphore = Arc::new(tokio::sync::Semaphore::new(self.config.max_concurrent));
90
91        // Spawn stale and expired cleanup task
92        let cleanup_queue = self.queue.clone();
93        let cleanup_interval = self.config.stale_cleanup_interval;
94        let stale_threshold = self.config.stale_threshold;
95        tokio::spawn(async move {
96            loop {
97                tokio::time::sleep(cleanup_interval).await;
98
99                // Release stale jobs back to pending
100                if let Err(e) = cleanup_queue.release_stale(stale_threshold).await {
101                    tracing::warn!(error = %e, "Failed to cleanup stale jobs");
102                }
103
104                // Delete expired job records
105                match cleanup_queue.cleanup_expired().await {
106                    Ok(count) if count > 0 => {
107                        tracing::debug!(count, "Cleaned up expired job records");
108                    }
109                    Err(e) => {
110                        tracing::warn!(error = %e, "Failed to cleanup expired jobs");
111                    }
112                    _ => {}
113                }
114            }
115        });
116
117        tracing::debug!(
118            worker_id = %self.id,
119            capabilities = ?self.config.capabilities,
120            "Worker started"
121        );
122
123        loop {
124            tokio::select! {
125                _ = shutdown_rx.recv() => {
126                    tracing::debug!(worker_id = %self.id, "Worker shutting down");
127                    break;
128                }
129                _ = tokio::time::sleep(self.config.poll_interval) => {
130                    // Calculate how many jobs we can claim
131                    let available = semaphore.available_permits();
132                    if available == 0 {
133                        continue;
134                    }
135
136                    let batch_size = (available as i32).min(self.config.batch_size);
137
138                    // Claim jobs
139                    let jobs = match self.queue.claim(
140                        self.id,
141                        &self.config.capabilities,
142                        batch_size,
143                    ).await {
144                        Ok(jobs) => jobs,
145                        Err(e) => {
146                            tracing::warn!(error = %e, "Failed to claim jobs");
147                            continue;
148                        }
149                    };
150
151                    // Process each job
152                    for job in jobs {
153                        let permit = semaphore.clone().acquire_owned().await.expect("semaphore closed");
154                        let executor = self.executor.clone();
155                        let job_id = job.id;
156                        let job_type = job.job_type.clone();
157
158                        tokio::spawn(async move {
159                            tracing::trace!(job_id = %job_id, job_type = %job_type, "Processing job");
160
161                            let result = executor.execute(&job).await;
162
163                            match &result {
164                                super::executor::ExecutionResult::Completed { .. } => {
165                                    tracing::debug!(job_id = %job_id, job_type = %job_type, "Job completed");
166                                }
167                                super::executor::ExecutionResult::Failed { error, retryable } => {
168                                    if *retryable {
169                                        tracing::debug!(job_id = %job_id, error = %error, "Job failed, will retry");
170                                    } else {
171                                        tracing::warn!(job_id = %job_id, job_type = %job_type, error = %error, "Job failed permanently");
172                                    }
173                                }
174                                super::executor::ExecutionResult::TimedOut { retryable } => {
175                                    tracing::debug!(job_id = %job_id, will_retry = %retryable, "Job timed out");
176                                }
177                                super::executor::ExecutionResult::Cancelled { reason } => {
178                                    tracing::debug!(job_id = %job_id, reason = %reason, "Job cancelled");
179                                }
180                            }
181
182                            drop(permit); // Release semaphore
183                        });
184                    }
185                }
186            }
187        }
188
189        Ok(())
190    }
191
192    /// Request graceful shutdown.
193    pub async fn shutdown(&self) {
194        if let Some(ref tx) = self.shutdown_tx {
195            let _ = tx.send(()).await;
196        }
197    }
198}
199
200/// Worker errors.
201#[derive(Debug, thiserror::Error)]
202pub enum WorkerError {
203    #[error("Database error: {0}")]
204    Database(String),
205
206    #[error("Job execution error: {0}")]
207    Execution(String),
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn test_worker_config_default() {
216        let config = WorkerConfig::default();
217        assert_eq!(config.capabilities, vec!["general".to_string()]);
218        assert_eq!(config.max_concurrent, 10);
219        assert_eq!(config.batch_size, 10);
220    }
221
222    #[test]
223    fn test_worker_config_custom() {
224        let config = WorkerConfig {
225            capabilities: vec!["media".to_string(), "general".to_string()],
226            max_concurrent: 4,
227            ..Default::default()
228        };
229        assert_eq!(config.capabilities.len(), 2);
230        assert_eq!(config.max_concurrent, 4);
231    }
232}