forge_runtime/jobs/
worker.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use forge_core::observability::{Metric, Span, SpanKind};
5use tokio::sync::mpsc;
6use uuid::Uuid;
7
8use super::executor::JobExecutor;
9use super::queue::JobQueue;
10use super::registry::JobRegistry;
11use crate::observability::ObservabilityState;
12
13/// Worker configuration.
14#[derive(Debug, Clone)]
15pub struct WorkerConfig {
16    /// Worker ID (auto-generated if not provided).
17    pub id: Option<Uuid>,
18    /// Worker capabilities (e.g., ["general", "media"]).
19    pub capabilities: Vec<String>,
20    /// Maximum concurrent jobs.
21    pub max_concurrent: usize,
22    /// Poll interval when queue is empty.
23    pub poll_interval: Duration,
24    /// Batch size for claiming jobs.
25    pub batch_size: i32,
26    /// Stale job cleanup interval.
27    pub stale_cleanup_interval: Duration,
28    /// Stale job threshold.
29    pub stale_threshold: chrono::Duration,
30}
31
32impl Default for WorkerConfig {
33    fn default() -> Self {
34        Self {
35            id: None,
36            capabilities: vec!["general".to_string()],
37            max_concurrent: 10,
38            poll_interval: Duration::from_millis(100),
39            batch_size: 10,
40            stale_cleanup_interval: Duration::from_secs(60),
41            stale_threshold: chrono::Duration::minutes(5),
42        }
43    }
44}
45
46/// Background job worker.
47pub struct Worker {
48    id: Uuid,
49    config: WorkerConfig,
50    queue: JobQueue,
51    executor: Arc<JobExecutor>,
52    shutdown_tx: Option<mpsc::Sender<()>>,
53    observability: Option<ObservabilityState>,
54}
55
56impl Worker {
57    /// Create a new worker.
58    pub fn new(
59        config: WorkerConfig,
60        queue: JobQueue,
61        registry: JobRegistry,
62        db_pool: sqlx::PgPool,
63    ) -> Self {
64        let id = config.id.unwrap_or_else(Uuid::new_v4);
65        let executor = Arc::new(JobExecutor::new(queue.clone(), registry, db_pool));
66
67        Self {
68            id,
69            config,
70            queue,
71            executor,
72            shutdown_tx: None,
73            observability: None,
74        }
75    }
76
77    /// Create a new worker with observability.
78    pub fn with_observability(
79        config: WorkerConfig,
80        queue: JobQueue,
81        registry: JobRegistry,
82        db_pool: sqlx::PgPool,
83        observability: ObservabilityState,
84    ) -> Self {
85        let id = config.id.unwrap_or_else(Uuid::new_v4);
86        let executor = Arc::new(JobExecutor::new(queue.clone(), registry, db_pool));
87
88        Self {
89            id,
90            config,
91            queue,
92            executor,
93            shutdown_tx: None,
94            observability: Some(observability),
95        }
96    }
97
98    /// Get worker ID.
99    pub fn id(&self) -> Uuid {
100        self.id
101    }
102
103    /// Get worker capabilities.
104    pub fn capabilities(&self) -> &[String] {
105        &self.config.capabilities
106    }
107
108    /// Run the worker (blocks until shutdown).
109    pub async fn run(&mut self) -> Result<(), WorkerError> {
110        let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
111        self.shutdown_tx = Some(shutdown_tx);
112
113        // Semaphore to limit concurrent jobs
114        let semaphore = Arc::new(tokio::sync::Semaphore::new(self.config.max_concurrent));
115
116        // Spawn stale cleanup task
117        let cleanup_queue = self.queue.clone();
118        let cleanup_interval = self.config.stale_cleanup_interval;
119        let stale_threshold = self.config.stale_threshold;
120        tokio::spawn(async move {
121            loop {
122                tokio::time::sleep(cleanup_interval).await;
123                if let Err(e) = cleanup_queue.release_stale(stale_threshold).await {
124                    tracing::error!("Failed to cleanup stale jobs: {}", e);
125                }
126            }
127        });
128
129        tracing::info!(
130            worker_id = %self.id,
131            capabilities = ?self.config.capabilities,
132            "Worker started"
133        );
134
135        loop {
136            tokio::select! {
137                _ = shutdown_rx.recv() => {
138                    tracing::info!(worker_id = %self.id, "Worker shutting down");
139                    break;
140                }
141                _ = tokio::time::sleep(self.config.poll_interval) => {
142                    // Calculate how many jobs we can claim
143                    let available = semaphore.available_permits();
144                    if available == 0 {
145                        continue;
146                    }
147
148                    let batch_size = (available as i32).min(self.config.batch_size);
149
150                    // Claim jobs
151                    let jobs = match self.queue.claim(
152                        self.id,
153                        &self.config.capabilities,
154                        batch_size,
155                    ).await {
156                        Ok(jobs) => jobs,
157                        Err(e) => {
158                            tracing::error!("Failed to claim jobs: {}", e);
159                            continue;
160                        }
161                    };
162
163                    // Record jobs claimed metric
164                    if let Some(ref obs) = self.observability {
165                        let mut metric = Metric::counter("jobs_dispatched_total", jobs.len() as f64);
166                        metric.labels.insert("worker_id".to_string(), self.id.to_string());
167                        obs.record_metric(metric).await;
168                    }
169
170                    // Process each job
171                    for job in jobs {
172                        let permit = semaphore.clone().acquire_owned().await.unwrap();
173                        let executor = self.executor.clone();
174                        let job_id = job.id;
175                        let job_type = job.job_type.clone();
176                        let observability = self.observability.clone();
177                        let worker_id = self.id;
178
179                        tokio::spawn(async move {
180                            let start = Instant::now();
181
182                            tracing::debug!(
183                                job_id = %job_id,
184                                job_type = %job_type,
185                                "Processing job"
186                            );
187
188                            let result = executor.execute(&job).await;
189                            let duration = start.elapsed();
190
191                            // Record job duration metric
192                            if let Some(ref obs) = observability {
193                                let mut duration_metric = Metric::gauge(
194                                    "job_duration_seconds",
195                                    duration.as_secs_f64(),
196                                );
197                                duration_metric.labels.insert("job_type".to_string(), job_type.clone());
198                                duration_metric.labels.insert("worker_id".to_string(), worker_id.to_string());
199                                obs.record_metric(duration_metric).await;
200                            }
201
202                            // Record job execution span
203                            if let Some(ref obs) = observability {
204                                let mut span = Span::new(format!("job.{}", job_type));
205                                span.kind = SpanKind::Consumer;
206                                span.attributes.insert(
207                                    "job.id".to_string(),
208                                    serde_json::Value::String(job_id.to_string()),
209                                );
210                                span.attributes.insert(
211                                    "job.type".to_string(),
212                                    serde_json::Value::String(job_type.clone()),
213                                );
214                                span.attributes.insert(
215                                    "job.worker_id".to_string(),
216                                    serde_json::Value::String(worker_id.to_string()),
217                                );
218                                span.attributes.insert(
219                                    "job.duration_ms".to_string(),
220                                    serde_json::Value::Number(serde_json::Number::from(duration.as_millis() as u64)),
221                                );
222
223                                match &result {
224                                    super::executor::ExecutionResult::Completed { .. } => {
225                                        span.end_ok();
226                                    }
227                                    super::executor::ExecutionResult::Failed { error, .. } => {
228                                        span.end_error(error);
229                                    }
230                                    super::executor::ExecutionResult::TimedOut { .. } => {
231                                        span.end_error("Job timed out");
232                                    }
233                                }
234
235                                obs.record_span(span).await;
236                            }
237
238                            match &result {
239                                super::executor::ExecutionResult::Completed { .. } => {
240                                    tracing::info!(
241                                        job_id = %job_id,
242                                        job_type = %job_type,
243                                        "Job completed"
244                                    );
245
246                                    // Record completed metric
247                                    if let Some(ref obs) = observability {
248                                        let mut metric = Metric::counter("jobs_completed_total", 1.0);
249                                        metric.labels.insert("job_type".to_string(), job_type.clone());
250                                        metric.labels.insert("worker_id".to_string(), worker_id.to_string());
251                                        obs.record_metric(metric).await;
252                                    }
253                                }
254                                super::executor::ExecutionResult::Failed { error, retryable } => {
255                                    if *retryable {
256                                        tracing::warn!(
257                                            job_id = %job_id,
258                                            job_type = %job_type,
259                                            error = %error,
260                                            "Job failed, will retry"
261                                        );
262                                    } else {
263                                        tracing::error!(
264                                            job_id = %job_id,
265                                            job_type = %job_type,
266                                            error = %error,
267                                            "Job failed permanently"
268                                        );
269                                    }
270
271                                    // Record failed metric
272                                    if let Some(ref obs) = observability {
273                                        let mut metric = Metric::counter("jobs_failed_total", 1.0);
274                                        metric.labels.insert("job_type".to_string(), job_type.clone());
275                                        metric.labels.insert("worker_id".to_string(), worker_id.to_string());
276                                        metric.labels.insert("retryable".to_string(), retryable.to_string());
277                                        obs.record_metric(metric).await;
278                                    }
279                                }
280                                super::executor::ExecutionResult::TimedOut { retryable } => {
281                                    tracing::warn!(
282                                        job_id = %job_id,
283                                        job_type = %job_type,
284                                        will_retry = %retryable,
285                                        "Job timed out"
286                                    );
287
288                                    // Record timeout metric
289                                    if let Some(ref obs) = observability {
290                                        let mut metric = Metric::counter("jobs_timeout_total", 1.0);
291                                        metric.labels.insert("job_type".to_string(), job_type.clone());
292                                        metric.labels.insert("worker_id".to_string(), worker_id.to_string());
293                                        obs.record_metric(metric).await;
294                                    }
295                                }
296                            }
297
298                            drop(permit); // Release semaphore
299                        });
300                    }
301                }
302            }
303        }
304
305        Ok(())
306    }
307
308    /// Request graceful shutdown.
309    pub async fn shutdown(&self) {
310        if let Some(ref tx) = self.shutdown_tx {
311            let _ = tx.send(()).await;
312        }
313    }
314}
315
316/// Worker errors.
317#[derive(Debug, thiserror::Error)]
318pub enum WorkerError {
319    #[error("Database error: {0}")]
320    Database(String),
321
322    #[error("Job execution error: {0}")]
323    Execution(String),
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_worker_config_default() {
332        let config = WorkerConfig::default();
333        assert_eq!(config.capabilities, vec!["general".to_string()]);
334        assert_eq!(config.max_concurrent, 10);
335        assert_eq!(config.batch_size, 10);
336    }
337
338    #[test]
339    fn test_worker_config_custom() {
340        let config = WorkerConfig {
341            capabilities: vec!["media".to_string(), "general".to_string()],
342            max_concurrent: 4,
343            ..Default::default()
344        };
345        assert_eq!(config.capabilities.len(), 2);
346        assert_eq!(config.max_concurrent, 4);
347    }
348}