1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use forge_core::observability::{Metric, Span, SpanKind};
5use tokio::sync::mpsc;
6use uuid::Uuid;
7
8use super::executor::JobExecutor;
9use super::queue::JobQueue;
10use super::registry::JobRegistry;
11use crate::observability::ObservabilityState;
12
13#[derive(Debug, Clone)]
15pub struct WorkerConfig {
16 pub id: Option<Uuid>,
18 pub capabilities: Vec<String>,
20 pub max_concurrent: usize,
22 pub poll_interval: Duration,
24 pub batch_size: i32,
26 pub stale_cleanup_interval: Duration,
28 pub stale_threshold: chrono::Duration,
30}
31
32impl Default for WorkerConfig {
33 fn default() -> Self {
34 Self {
35 id: None,
36 capabilities: vec!["general".to_string()],
37 max_concurrent: 10,
38 poll_interval: Duration::from_millis(100),
39 batch_size: 10,
40 stale_cleanup_interval: Duration::from_secs(60),
41 stale_threshold: chrono::Duration::minutes(5),
42 }
43 }
44}
45
46pub struct Worker {
48 id: Uuid,
49 config: WorkerConfig,
50 queue: JobQueue,
51 executor: Arc<JobExecutor>,
52 shutdown_tx: Option<mpsc::Sender<()>>,
53 observability: Option<ObservabilityState>,
54}
55
56impl Worker {
57 pub fn new(
59 config: WorkerConfig,
60 queue: JobQueue,
61 registry: JobRegistry,
62 db_pool: sqlx::PgPool,
63 ) -> Self {
64 let id = config.id.unwrap_or_else(Uuid::new_v4);
65 let executor = Arc::new(JobExecutor::new(queue.clone(), registry, db_pool));
66
67 Self {
68 id,
69 config,
70 queue,
71 executor,
72 shutdown_tx: None,
73 observability: None,
74 }
75 }
76
77 pub fn with_observability(
79 config: WorkerConfig,
80 queue: JobQueue,
81 registry: JobRegistry,
82 db_pool: sqlx::PgPool,
83 observability: ObservabilityState,
84 ) -> Self {
85 let id = config.id.unwrap_or_else(Uuid::new_v4);
86 let executor = Arc::new(JobExecutor::new(queue.clone(), registry, db_pool));
87
88 Self {
89 id,
90 config,
91 queue,
92 executor,
93 shutdown_tx: None,
94 observability: Some(observability),
95 }
96 }
97
98 pub fn id(&self) -> Uuid {
100 self.id
101 }
102
103 pub fn capabilities(&self) -> &[String] {
105 &self.config.capabilities
106 }
107
108 pub async fn run(&mut self) -> Result<(), WorkerError> {
110 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
111 self.shutdown_tx = Some(shutdown_tx);
112
113 let semaphore = Arc::new(tokio::sync::Semaphore::new(self.config.max_concurrent));
115
116 let cleanup_queue = self.queue.clone();
118 let cleanup_interval = self.config.stale_cleanup_interval;
119 let stale_threshold = self.config.stale_threshold;
120 tokio::spawn(async move {
121 loop {
122 tokio::time::sleep(cleanup_interval).await;
123 if let Err(e) = cleanup_queue.release_stale(stale_threshold).await {
124 tracing::error!("Failed to cleanup stale jobs: {}", e);
125 }
126 }
127 });
128
129 tracing::info!(
130 worker_id = %self.id,
131 capabilities = ?self.config.capabilities,
132 "Worker started"
133 );
134
135 loop {
136 tokio::select! {
137 _ = shutdown_rx.recv() => {
138 tracing::info!(worker_id = %self.id, "Worker shutting down");
139 break;
140 }
141 _ = tokio::time::sleep(self.config.poll_interval) => {
142 let available = semaphore.available_permits();
144 if available == 0 {
145 continue;
146 }
147
148 let batch_size = (available as i32).min(self.config.batch_size);
149
150 let jobs = match self.queue.claim(
152 self.id,
153 &self.config.capabilities,
154 batch_size,
155 ).await {
156 Ok(jobs) => jobs,
157 Err(e) => {
158 tracing::error!("Failed to claim jobs: {}", e);
159 continue;
160 }
161 };
162
163 if let Some(ref obs) = self.observability {
165 let mut metric = Metric::counter("jobs_dispatched_total", jobs.len() as f64);
166 metric.labels.insert("worker_id".to_string(), self.id.to_string());
167 obs.record_metric(metric).await;
168 }
169
170 for job in jobs {
172 let permit = semaphore.clone().acquire_owned().await.unwrap();
173 let executor = self.executor.clone();
174 let job_id = job.id;
175 let job_type = job.job_type.clone();
176 let observability = self.observability.clone();
177 let worker_id = self.id;
178
179 tokio::spawn(async move {
180 let start = Instant::now();
181
182 tracing::debug!(
183 job_id = %job_id,
184 job_type = %job_type,
185 "Processing job"
186 );
187
188 let result = executor.execute(&job).await;
189 let duration = start.elapsed();
190
191 if let Some(ref obs) = observability {
193 let mut duration_metric = Metric::gauge(
194 "job_duration_seconds",
195 duration.as_secs_f64(),
196 );
197 duration_metric.labels.insert("job_type".to_string(), job_type.clone());
198 duration_metric.labels.insert("worker_id".to_string(), worker_id.to_string());
199 obs.record_metric(duration_metric).await;
200 }
201
202 if let Some(ref obs) = observability {
204 let mut span = Span::new(format!("job.{}", job_type));
205 span.kind = SpanKind::Consumer;
206 span.attributes.insert(
207 "job.id".to_string(),
208 serde_json::Value::String(job_id.to_string()),
209 );
210 span.attributes.insert(
211 "job.type".to_string(),
212 serde_json::Value::String(job_type.clone()),
213 );
214 span.attributes.insert(
215 "job.worker_id".to_string(),
216 serde_json::Value::String(worker_id.to_string()),
217 );
218 span.attributes.insert(
219 "job.duration_ms".to_string(),
220 serde_json::Value::Number(serde_json::Number::from(duration.as_millis() as u64)),
221 );
222
223 match &result {
224 super::executor::ExecutionResult::Completed { .. } => {
225 span.end_ok();
226 }
227 super::executor::ExecutionResult::Failed { error, .. } => {
228 span.end_error(error);
229 }
230 super::executor::ExecutionResult::TimedOut { .. } => {
231 span.end_error("Job timed out");
232 }
233 }
234
235 obs.record_span(span).await;
236 }
237
238 match &result {
239 super::executor::ExecutionResult::Completed { .. } => {
240 tracing::info!(
241 job_id = %job_id,
242 job_type = %job_type,
243 "Job completed"
244 );
245
246 if let Some(ref obs) = observability {
248 let mut metric = Metric::counter("jobs_completed_total", 1.0);
249 metric.labels.insert("job_type".to_string(), job_type.clone());
250 metric.labels.insert("worker_id".to_string(), worker_id.to_string());
251 obs.record_metric(metric).await;
252 }
253 }
254 super::executor::ExecutionResult::Failed { error, retryable } => {
255 if *retryable {
256 tracing::warn!(
257 job_id = %job_id,
258 job_type = %job_type,
259 error = %error,
260 "Job failed, will retry"
261 );
262 } else {
263 tracing::error!(
264 job_id = %job_id,
265 job_type = %job_type,
266 error = %error,
267 "Job failed permanently"
268 );
269 }
270
271 if let Some(ref obs) = observability {
273 let mut metric = Metric::counter("jobs_failed_total", 1.0);
274 metric.labels.insert("job_type".to_string(), job_type.clone());
275 metric.labels.insert("worker_id".to_string(), worker_id.to_string());
276 metric.labels.insert("retryable".to_string(), retryable.to_string());
277 obs.record_metric(metric).await;
278 }
279 }
280 super::executor::ExecutionResult::TimedOut { retryable } => {
281 tracing::warn!(
282 job_id = %job_id,
283 job_type = %job_type,
284 will_retry = %retryable,
285 "Job timed out"
286 );
287
288 if let Some(ref obs) = observability {
290 let mut metric = Metric::counter("jobs_timeout_total", 1.0);
291 metric.labels.insert("job_type".to_string(), job_type.clone());
292 metric.labels.insert("worker_id".to_string(), worker_id.to_string());
293 obs.record_metric(metric).await;
294 }
295 }
296 }
297
298 drop(permit); });
300 }
301 }
302 }
303 }
304
305 Ok(())
306 }
307
308 pub async fn shutdown(&self) {
310 if let Some(ref tx) = self.shutdown_tx {
311 let _ = tx.send(()).await;
312 }
313 }
314}
315
316#[derive(Debug, thiserror::Error)]
318pub enum WorkerError {
319 #[error("Database error: {0}")]
320 Database(String),
321
322 #[error("Job execution error: {0}")]
323 Execution(String),
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_worker_config_default() {
332 let config = WorkerConfig::default();
333 assert_eq!(config.capabilities, vec!["general".to_string()]);
334 assert_eq!(config.max_concurrent, 10);
335 assert_eq!(config.batch_size, 10);
336 }
337
338 #[test]
339 fn test_worker_config_custom() {
340 let config = WorkerConfig {
341 capabilities: vec!["media".to_string(), "general".to_string()],
342 max_concurrent: 4,
343 ..Default::default()
344 };
345 assert_eq!(config.capabilities.len(), 2);
346 assert_eq!(config.max_concurrent, 4);
347 }
348}