1use crate::context::JobContext;
2use awa_model::{AwaError, JobRow};
3use sqlx::PgPool;
4use std::any::Any;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tracing::{error, info, info_span, warn, Instrument};
10
11#[derive(Debug)]
13pub enum JobResult {
14 Completed,
16 RetryAfter(std::time::Duration),
18 Snooze(std::time::Duration),
20 Cancel(String),
22}
23
24#[derive(Debug, thiserror::Error)]
26pub enum JobError {
27 #[error("{0}")]
29 Retryable(#[source] Box<dyn std::error::Error + Send + Sync>),
30
31 #[error("terminal: {0}")]
33 Terminal(String),
34}
35
36impl JobError {
37 pub fn retryable(err: impl std::error::Error + Send + Sync + 'static) -> Self {
38 JobError::Retryable(Box::new(err))
39 }
40
41 pub fn terminal(msg: impl Into<String>) -> Self {
42 JobError::Terminal(msg.into())
43 }
44}
45
46#[async_trait::async_trait]
48pub trait Worker: Send + Sync + 'static {
49 fn kind(&self) -> &'static str;
51
52 async fn perform(&self, job_row: &JobRow, ctx: &JobContext) -> Result<JobResult, JobError>;
54}
55
56pub(crate) type BoxedWorker = Box<dyn Worker>;
58
59pub struct JobExecutor {
61 pool: PgPool,
62 workers: Arc<HashMap<String, BoxedWorker>>,
63 in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
64 queue_in_flight: Arc<HashMap<String, Arc<AtomicU32>>>,
65 state: Arc<HashMap<std::any::TypeId, Box<dyn Any + Send + Sync>>>,
66 metrics: crate::metrics::AwaMetrics,
67}
68
69impl JobExecutor {
70 pub fn new(
71 pool: PgPool,
72 workers: Arc<HashMap<String, BoxedWorker>>,
73 in_flight: Arc<RwLock<HashMap<i64, Arc<AtomicBool>>>>,
74 queue_in_flight: Arc<HashMap<String, Arc<AtomicU32>>>,
75 state: Arc<HashMap<std::any::TypeId, Box<dyn Any + Send + Sync>>>,
76 metrics: crate::metrics::AwaMetrics,
77 ) -> Self {
78 Self {
79 pool,
80 workers,
81 in_flight,
82 queue_in_flight,
83 state,
84 metrics,
85 }
86 }
87
88 pub fn execute(&self, job: JobRow, cancel: Arc<AtomicBool>) -> tokio::task::JoinHandle<()> {
90 let pool = self.pool.clone();
91 let workers = self.workers.clone();
92 let in_flight = self.in_flight.clone();
93 let queue_in_flight = self.queue_in_flight.clone();
94 let state = self.state.clone();
95 let metrics = self.metrics.clone();
96 let job_id = job.id;
97 let job_kind = job.kind.clone();
98 let job_queue = job.queue.clone();
99
100 let span = info_span!(
101 "job.execute",
102 job.id = job_id,
103 job.kind = %job_kind,
104 job.queue = %job_queue,
105 job.attempt = job.attempt,
106 otel.name = %format!("job.execute {}", job_kind),
107 otel.status_code = tracing::field::Empty,
108 );
109
110 tokio::spawn(
111 async move {
112 {
114 let mut guard = in_flight.write().await;
115 guard.insert(job_id, cancel.clone());
116 }
117 if let Some(counter) = queue_in_flight.get(&job_queue) {
118 counter.fetch_add(1, Ordering::SeqCst);
119 }
120 metrics.record_in_flight_change(&job_queue, 1);
121
122 let start = std::time::Instant::now();
123 let ctx = JobContext::new(job.clone(), cancel, state);
124
125 let result = match workers.get(&job.kind) {
126 Some(worker) => worker.perform(&job, &ctx).await,
127 None => {
128 error!(kind = %job.kind, job_id, "No worker registered for job kind");
129 Err(JobError::Terminal(format!(
130 "unknown job kind: {}",
131 job.kind
132 )))
133 }
134 };
135
136 let duration = start.elapsed();
137
138 match &result {
140 Ok(JobResult::Completed) => {
141 metrics.record_job_completed(&job_kind, &job_queue, duration);
142 }
143 Ok(JobResult::RetryAfter(_)) => {
144 metrics.record_job_retried(&job_kind, &job_queue);
145 }
146 Ok(JobResult::Cancel(_)) => {
147 metrics.jobs_cancelled.add(
148 1,
149 &[
150 opentelemetry::KeyValue::new("awa.job.kind", job_kind.clone()),
151 opentelemetry::KeyValue::new("awa.job.queue", job_queue.clone()),
152 ],
153 );
154 }
155 Ok(JobResult::Snooze(_)) => {} Err(JobError::Terminal(_)) => {
157 metrics.record_job_failed(&job_kind, &job_queue, true);
158 }
159 Err(JobError::Retryable(_)) => {
160 metrics.record_job_retried(&job_kind, &job_queue);
161 }
162 }
163
164 if let Err(err) = complete_job(&pool, &job, result).await {
166 error!(job_id, error = %err, "Failed to complete job");
167 }
168
169 {
171 let mut guard = in_flight.write().await;
172 guard.remove(&job_id);
173 }
174 if let Some(counter) = queue_in_flight.get(&job_queue) {
175 counter.fetch_sub(1, Ordering::SeqCst);
176 }
177 metrics.record_in_flight_change(&job_queue, -1);
178 }
179 .instrument(span),
180 )
181 }
182}
183
184async fn complete_job(
186 pool: &PgPool,
187 job: &JobRow,
188 result: Result<JobResult, JobError>,
189) -> Result<(), AwaError> {
190 match result {
191 Ok(JobResult::Completed) => {
192 tracing::Span::current().record("otel.status_code", "OK");
193 info!(job_id = job.id, kind = %job.kind, attempt = job.attempt, "Job completed");
194 sqlx::query(
195 "UPDATE awa.jobs SET state = 'completed', finalized_at = now() WHERE id = $1",
196 )
197 .bind(job.id)
198 .execute(pool)
199 .await?;
200 }
201
202 Ok(JobResult::RetryAfter(duration)) => {
203 let seconds = duration.as_secs() as f64;
204 info!(
205 job_id = job.id,
206 kind = %job.kind,
207 retry_after_secs = seconds,
208 "Job requested retry after duration"
209 );
210 sqlx::query(
211 r#"
212 UPDATE awa.jobs
213 SET state = 'retryable',
214 run_at = now() + make_interval(secs => $2),
215 finalized_at = now()
216 WHERE id = $1
217 "#,
218 )
219 .bind(job.id)
220 .bind(seconds)
221 .execute(pool)
222 .await?;
223 }
224
225 Ok(JobResult::Snooze(duration)) => {
226 let seconds = duration.as_secs() as f64;
227 info!(
228 job_id = job.id,
229 kind = %job.kind,
230 snooze_secs = seconds,
231 "Job snoozed (attempt not incremented)"
232 );
233 sqlx::query(
236 r#"
237 UPDATE awa.jobs
238 SET state = 'scheduled',
239 run_at = now() + make_interval(secs => $2),
240 attempt = attempt - 1,
241 heartbeat_at = NULL,
242 deadline_at = NULL
243 WHERE id = $1
244 "#,
245 )
246 .bind(job.id)
247 .bind(seconds)
248 .execute(pool)
249 .await?;
250 }
251
252 Ok(JobResult::Cancel(reason)) => {
253 info!(
254 job_id = job.id,
255 kind = %job.kind,
256 reason = %reason,
257 "Job cancelled by handler"
258 );
259 sqlx::query(
260 r#"
261 UPDATE awa.jobs
262 SET state = 'cancelled',
263 finalized_at = now(),
264 errors = errors || $2::jsonb
265 WHERE id = $1
266 "#,
267 )
268 .bind(job.id)
269 .bind(serde_json::json!({
270 "error": format!("cancelled: {}", reason),
271 "attempt": job.attempt,
272 "at": chrono::Utc::now().to_rfc3339()
273 }))
274 .execute(pool)
275 .await?;
276 }
277
278 Err(JobError::Terminal(msg)) => {
279 tracing::Span::current().record("otel.status_code", "ERROR");
280 error!(
281 job_id = job.id,
282 kind = %job.kind,
283 error = %msg,
284 "Job failed terminally"
285 );
286 sqlx::query(
287 r#"
288 UPDATE awa.jobs
289 SET state = 'failed',
290 finalized_at = now(),
291 errors = errors || $2::jsonb
292 WHERE id = $1
293 "#,
294 )
295 .bind(job.id)
296 .bind(serde_json::json!({
297 "error": msg,
298 "attempt": job.attempt,
299 "at": chrono::Utc::now().to_rfc3339(),
300 "terminal": true
301 }))
302 .execute(pool)
303 .await?;
304 }
305
306 Err(JobError::Retryable(err)) => {
307 let error_msg = err.to_string();
308 if job.attempt >= job.max_attempts {
309 tracing::Span::current().record("otel.status_code", "ERROR");
310 error!(
311 job_id = job.id,
312 kind = %job.kind,
313 attempt = job.attempt,
314 max_attempts = job.max_attempts,
315 error = %error_msg,
316 "Job failed (max attempts exhausted)"
317 );
318 sqlx::query(
319 r#"
320 UPDATE awa.jobs
321 SET state = 'failed',
322 finalized_at = now(),
323 errors = errors || $2::jsonb
324 WHERE id = $1
325 "#,
326 )
327 .bind(job.id)
328 .bind(serde_json::json!({
329 "error": error_msg,
330 "attempt": job.attempt,
331 "at": chrono::Utc::now().to_rfc3339()
332 }))
333 .execute(pool)
334 .await?;
335 } else {
336 warn!(
337 job_id = job.id,
338 kind = %job.kind,
339 attempt = job.attempt,
340 error = %error_msg,
341 "Job failed (will retry)"
342 );
343 sqlx::query(
345 r#"
346 UPDATE awa.jobs
347 SET state = 'retryable',
348 run_at = now() + awa.backoff_duration($2, $3),
349 finalized_at = now(),
350 heartbeat_at = NULL,
351 deadline_at = NULL,
352 errors = errors || $4::jsonb
353 WHERE id = $1
354 "#,
355 )
356 .bind(job.id)
357 .bind(job.attempt)
358 .bind(job.max_attempts)
359 .bind(serde_json::json!({
360 "error": error_msg,
361 "attempt": job.attempt,
362 "at": chrono::Utc::now().to_rfc3339()
363 }))
364 .execute(pool)
365 .await?;
366 }
367 }
368 }
369
370 Ok(())
371}