Skip to main content

awa_worker/
executor.rs

1use crate::context::JobContext;
2use awa_model::{AwaError, JobRow};
3use sqlx::PgPool;
4use std::any::Any;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tracing::{error, info, info_span, warn, Instrument};
10
11/// Result of executing a job handler.
12#[derive(Debug)]
13pub enum JobResult {
14    /// Job completed successfully.
15    Completed,
16    /// Job should be retried after the given duration. Increments attempt.
17    RetryAfter(std::time::Duration),
18    /// Job should be snoozed (re-available after duration). Does NOT increment attempt.
19    Snooze(std::time::Duration),
20    /// Job should be cancelled.
21    Cancel(String),
22}
23
24/// Error type for job handlers — any error is retryable unless it's terminal.
25#[derive(Debug, thiserror::Error)]
26pub enum JobError {
27    /// Retryable error — will be retried if attempts remain.
28    #[error("{0}")]
29    Retryable(#[source] Box<dyn std::error::Error + Send + Sync>),
30
31    /// Terminal error — immediately fails the job regardless of remaining attempts.
32    #[error("terminal: {0}")]
33    Terminal(String),
34}
35
36impl JobError {
37    pub fn retryable(err: impl std::error::Error + Send + Sync + 'static) -> Self {
38        JobError::Retryable(Box::new(err))
39    }
40
41    pub fn terminal(msg: impl Into<String>) -> Self {
42        JobError::Terminal(msg.into())
43    }
44}
45
46/// Worker trait — implement this for each job type.
47#[async_trait::async_trait]
48pub trait Worker: Send + Sync + 'static {
49    /// The kind string for this worker (must match the job's kind).
50    fn kind(&self) -> &'static str;
51
52    /// Execute the job. The raw args JSON and context are provided.
53    async fn perform(&self, job_row: &JobRow, ctx: &JobContext) -> Result<JobResult, JobError>;
54}
55
56/// Type-erased worker wrapper for the registry.
57pub(crate) type BoxedWorker = Box<dyn Worker>;
58
59/// Manages job execution — spawns worker futures and tracks in-flight jobs.
60pub struct JobExecutor {
61    pool: PgPool,
62    workers: Arc<HashMap<String, BoxedWorker>>,
63    in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
64    queue_in_flight: Arc<HashMap<String, Arc<AtomicU32>>>,
65    state: Arc<HashMap<std::any::TypeId, Box<dyn Any + Send + Sync>>>,
66    metrics: crate::metrics::AwaMetrics,
67}
68
69impl JobExecutor {
70    pub fn new(
71        pool: PgPool,
72        workers: Arc<HashMap<String, BoxedWorker>>,
73        in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
74        queue_in_flight: Arc<HashMap<String, Arc<AtomicU32>>>,
75        state: Arc<HashMap<std::any::TypeId, Box<dyn Any + Send + Sync>>>,
76        metrics: crate::metrics::AwaMetrics,
77    ) -> Self {
78        Self {
79            pool,
80            workers,
81            in_flight,
82            queue_in_flight,
83            state,
84            metrics,
85        }
86    }
87
88    /// Execute a claimed job. Returns a JoinHandle for the spawned task.
89    pub fn execute(&self, job: JobRow, cancel: Arc<AtomicBool>) -> tokio::task::JoinHandle<()> {
90        let pool = self.pool.clone();
91        let workers = self.workers.clone();
92        let in_flight = self.in_flight.clone();
93        let queue_in_flight = self.queue_in_flight.clone();
94        let state = self.state.clone();
95        let metrics = self.metrics.clone();
96        let job_id = job.id;
97        let job_kind = job.kind.clone();
98        let job_queue = job.queue.clone();
99
100        let span = info_span!(
101            "job.execute",
102            job.id = job_id,
103            job.kind = %job_kind,
104            job.queue = %job_queue,
105            job.attempt = job.attempt,
106            otel.name = %format!("job.execute {}", job_kind),
107            otel.status_code = tracing::field::Empty,
108        );
109
110        tokio::spawn(
111            async move {
112                // Register as in-flight
113                {
114                    let mut guard = in_flight.write().await;
115                    guard.insert(job_id, cancel.clone());
116                }
117                if let Some(counter) = queue_in_flight.get(&job_queue) {
118                    counter.fetch_add(1, Ordering::SeqCst);
119                }
120                metrics.record_in_flight_change(&job_queue, 1);
121
122                let start = std::time::Instant::now();
123                let ctx = JobContext::new(job.clone(), cancel, state);
124
125                let result = match workers.get(&job.kind) {
126                    Some(worker) => worker.perform(&job, &ctx).await,
127                    None => {
128                        error!(kind = %job.kind, job_id, "No worker registered for job kind");
129                        Err(JobError::Terminal(format!(
130                            "unknown job kind: {}",
131                            job.kind
132                        )))
133                    }
134                };
135
136                let duration = start.elapsed();
137
138                // Record metrics based on outcome
139                match &result {
140                    Ok(JobResult::Completed) => {
141                        metrics.record_job_completed(&job_kind, &job_queue, duration);
142                    }
143                    Ok(JobResult::RetryAfter(_)) => {
144                        metrics.record_job_retried(&job_kind, &job_queue);
145                    }
146                    Ok(JobResult::Cancel(_)) => {
147                        metrics.jobs_cancelled.add(
148                            1,
149                            &[
150                                opentelemetry::KeyValue::new("awa.job.kind", job_kind.clone()),
151                                opentelemetry::KeyValue::new("awa.job.queue", job_queue.clone()),
152                            ],
153                        );
154                    }
155                    Ok(JobResult::Snooze(_)) => {} // Not a terminal outcome
156                    Err(JobError::Terminal(_)) => {
157                        metrics.record_job_failed(&job_kind, &job_queue, true);
158                    }
159                    Err(JobError::Retryable(_)) => {
160                        metrics.record_job_retried(&job_kind, &job_queue);
161                    }
162                }
163
164                // Complete the job based on the result
165                if let Err(err) = complete_job(&pool, &job, result).await {
166                    error!(job_id, error = %err, "Failed to complete job");
167                }
168
169                // Remove from in-flight
170                {
171                    let mut guard = in_flight.write().await;
172                    guard.remove(&job_id);
173                }
174                if let Some(counter) = queue_in_flight.get(&job_queue) {
175                    counter.fetch_sub(1, Ordering::SeqCst);
176                }
177                metrics.record_in_flight_change(&job_queue, -1);
178            }
179            .instrument(span),
180        )
181    }
182}
183
184/// Update job state in the database based on handler result.
185async fn complete_job(
186    pool: &PgPool,
187    job: &JobRow,
188    result: Result<JobResult, JobError>,
189) -> Result<(), AwaError> {
190    match result {
191        Ok(JobResult::Completed) => {
192            tracing::Span::current().record("otel.status_code", "OK");
193            info!(job_id = job.id, kind = %job.kind, attempt = job.attempt, "Job completed");
194            sqlx::query(
195                "UPDATE awa.jobs SET state = 'completed', finalized_at = now() WHERE id = $1",
196            )
197            .bind(job.id)
198            .execute(pool)
199            .await?;
200        }
201
202        Ok(JobResult::RetryAfter(duration)) => {
203            let seconds = duration.as_secs() as f64;
204            info!(
205                job_id = job.id,
206                kind = %job.kind,
207                retry_after_secs = seconds,
208                "Job requested retry after duration"
209            );
210            sqlx::query(
211                r#"
212                UPDATE awa.jobs
213                SET state = 'retryable',
214                    run_at = now() + make_interval(secs => $2),
215                    finalized_at = now()
216                WHERE id = $1
217                "#,
218            )
219            .bind(job.id)
220            .bind(seconds)
221            .execute(pool)
222            .await?;
223        }
224
225        Ok(JobResult::Snooze(duration)) => {
226            let seconds = duration.as_secs() as f64;
227            info!(
228                job_id = job.id,
229                kind = %job.kind,
230                snooze_secs = seconds,
231                "Job snoozed (attempt not incremented)"
232            );
233            // Snooze: back to available with new run_at, decrement attempt
234            // (since it was already incremented at claim time)
235            sqlx::query(
236                r#"
237                UPDATE awa.jobs
238                SET state = 'scheduled',
239                    run_at = now() + make_interval(secs => $2),
240                    attempt = attempt - 1,
241                    heartbeat_at = NULL,
242                    deadline_at = NULL
243                WHERE id = $1
244                "#,
245            )
246            .bind(job.id)
247            .bind(seconds)
248            .execute(pool)
249            .await?;
250        }
251
252        Ok(JobResult::Cancel(reason)) => {
253            info!(
254                job_id = job.id,
255                kind = %job.kind,
256                reason = %reason,
257                "Job cancelled by handler"
258            );
259            sqlx::query(
260                r#"
261                UPDATE awa.jobs
262                SET state = 'cancelled',
263                    finalized_at = now(),
264                    errors = errors || $2::jsonb
265                WHERE id = $1
266                "#,
267            )
268            .bind(job.id)
269            .bind(serde_json::json!({
270                "error": format!("cancelled: {}", reason),
271                "attempt": job.attempt,
272                "at": chrono::Utc::now().to_rfc3339()
273            }))
274            .execute(pool)
275            .await?;
276        }
277
278        Err(JobError::Terminal(msg)) => {
279            tracing::Span::current().record("otel.status_code", "ERROR");
280            error!(
281                job_id = job.id,
282                kind = %job.kind,
283                error = %msg,
284                "Job failed terminally"
285            );
286            sqlx::query(
287                r#"
288                UPDATE awa.jobs
289                SET state = 'failed',
290                    finalized_at = now(),
291                    errors = errors || $2::jsonb
292                WHERE id = $1
293                "#,
294            )
295            .bind(job.id)
296            .bind(serde_json::json!({
297                "error": msg,
298                "attempt": job.attempt,
299                "at": chrono::Utc::now().to_rfc3339(),
300                "terminal": true
301            }))
302            .execute(pool)
303            .await?;
304        }
305
306        Err(JobError::Retryable(err)) => {
307            let error_msg = err.to_string();
308            if job.attempt >= job.max_attempts {
309                tracing::Span::current().record("otel.status_code", "ERROR");
310                error!(
311                    job_id = job.id,
312                    kind = %job.kind,
313                    attempt = job.attempt,
314                    max_attempts = job.max_attempts,
315                    error = %error_msg,
316                    "Job failed (max attempts exhausted)"
317                );
318                sqlx::query(
319                    r#"
320                    UPDATE awa.jobs
321                    SET state = 'failed',
322                        finalized_at = now(),
323                        errors = errors || $2::jsonb
324                    WHERE id = $1
325                    "#,
326                )
327                .bind(job.id)
328                .bind(serde_json::json!({
329                    "error": error_msg,
330                    "attempt": job.attempt,
331                    "at": chrono::Utc::now().to_rfc3339()
332                }))
333                .execute(pool)
334                .await?;
335            } else {
336                warn!(
337                    job_id = job.id,
338                    kind = %job.kind,
339                    attempt = job.attempt,
340                    error = %error_msg,
341                    "Job failed (will retry)"
342                );
343                // Use database-side backoff calculation
344                sqlx::query(
345                    r#"
346                    UPDATE awa.jobs
347                    SET state = 'retryable',
348                        run_at = now() + awa.backoff_duration($2, $3),
349                        finalized_at = now(),
350                        heartbeat_at = NULL,
351                        deadline_at = NULL,
352                        errors = errors || $4::jsonb
353                    WHERE id = $1
354                    "#,
355                )
356                .bind(job.id)
357                .bind(job.attempt)
358                .bind(job.max_attempts)
359                .bind(serde_json::json!({
360                    "error": error_msg,
361                    "attempt": job.attempt,
362                    "at": chrono::Utc::now().to_rfc3339()
363                }))
364                .execute(pool)
365                .await?;
366            }
367        }
368    }
369
370    Ok(())
371}