Skip to main content

modo/job/
worker.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Duration;
5
6use chrono::Utc;
7use tokio::sync::Semaphore;
8use tokio::task::JoinHandle;
9use tokio_util::sync::CancellationToken;
10
11use crate::db::{ConnExt, ConnQueryExt, Database, FromValue};
12use crate::error::Result;
13use crate::service::{Registry, RegistrySnapshot};
14
15use super::cleanup::cleanup_loop;
16use super::config::{JobConfig, QueueConfig};
17use super::context::JobContext;
18use super::handler::JobHandler;
19use super::meta::Meta;
20use super::reaper::reaper_loop;
21
22/// Per-handler options controlling retry and timeout behavior.
23///
24/// Pass to [`WorkerBuilder::register_with`] to override defaults for a single
25/// handler. Use [`WorkerBuilder::register`] for the default
26/// `max_attempts = 3` / `timeout_secs = 300`.
27pub struct JobOptions {
28    /// Maximum number of execution attempts before the job is marked `Dead`.
29    /// Defaults to `3`.
30    pub max_attempts: u32,
31    /// Per-execution timeout in seconds. If a handler exceeds this, the
32    /// attempt is treated as a failure. Defaults to `300` (5 min).
33    pub timeout_secs: u64,
34}
35
36impl Default for JobOptions {
37    fn default() -> Self {
38        Self {
39            max_attempts: 3,
40            timeout_secs: 300,
41        }
42    }
43}
44
45type ErasedHandler =
46    Arc<dyn Fn(JobContext) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
47
48struct HandlerEntry {
49    handler: ErasedHandler,
50    options: JobOptions,
51}
52
53/// Builder for constructing a [`Worker`] with registered job handlers.
54///
55/// Obtained via [`Worker::builder`]. Call [`WorkerBuilder::register`] (or
56/// [`WorkerBuilder::register_with`]) for each job name, then call
57/// [`WorkerBuilder::start`] to spawn the background loops and obtain a
58/// [`Worker`] handle.
59#[must_use]
60pub struct WorkerBuilder {
61    config: JobConfig,
62    registry: Arc<RegistrySnapshot>,
63    db: Database,
64    handlers: HashMap<String, HandlerEntry>,
65}
66
67impl WorkerBuilder {
68    /// Register a handler for the given job name with default [`JobOptions`].
69    pub fn register<H, Args>(mut self, name: &str, handler: H) -> Self
70    where
71        H: JobHandler<Args> + Send + Sync,
72    {
73        self.register_inner(name, handler, JobOptions::default());
74        self
75    }
76
77    /// Register a handler for the given job name with custom [`JobOptions`].
78    pub fn register_with<H, Args>(mut self, name: &str, handler: H, options: JobOptions) -> Self
79    where
80        H: JobHandler<Args> + Send + Sync,
81    {
82        self.register_inner(name, handler, options);
83        self
84    }
85
86    fn register_inner<H, Args>(&mut self, name: &str, handler: H, options: JobOptions)
87    where
88        H: JobHandler<Args> + Send + Sync,
89    {
90        let handler = Arc::new(
91            move |ctx: JobContext| -> Pin<Box<dyn Future<Output = Result<()>> + Send>> {
92                let h = handler.clone();
93                Box::pin(async move { h.call(ctx).await })
94            },
95        ) as ErasedHandler;
96
97        self.handlers
98            .insert(name.to_string(), HandlerEntry { handler, options });
99    }
100
101    /// Spawn the worker loops and return a [`Worker`] handle for shutdown.
102    ///
103    /// Three background tasks are started:
104    /// - **poll loop** — claims and dispatches pending jobs
105    /// - **stale reaper** — resets jobs stuck in `running` past the configured
106    ///   threshold
107    /// - **cleanup loop** (optional) — deletes old terminal jobs
108    pub async fn start(self) -> Worker {
109        let cancel = CancellationToken::new();
110        let handlers = Arc::new(self.handlers);
111        let handler_names: Vec<String> = handlers.keys().cloned().collect();
112
113        // Build per-queue semaphores
114        let queue_semaphores: Vec<(QueueConfig, Arc<Semaphore>)> = self
115            .config
116            .queues
117            .iter()
118            .map(|q| (q.clone(), Arc::new(Semaphore::new(q.concurrency as usize))))
119            .collect();
120
121        // Spawn poll loop
122        let poll_handle = tokio::spawn(poll_loop(
123            self.db.clone(),
124            self.registry.clone(),
125            handlers.clone(),
126            handler_names,
127            queue_semaphores,
128            self.config.poll_interval_secs,
129            cancel.clone(),
130        ));
131
132        // Spawn stale reaper
133        let reaper_handle = tokio::spawn(reaper_loop(
134            self.db.clone(),
135            self.config.stale_threshold_secs,
136            self.config.stale_reaper_interval_secs,
137            cancel.clone(),
138        ));
139
140        // Spawn cleanup (if configured)
141        let cleanup_handle = if let Some(ref cleanup) = self.config.cleanup {
142            Some(tokio::spawn(cleanup_loop(
143                self.db.clone(),
144                cleanup.interval_secs,
145                cleanup.retention_secs,
146                cancel.clone(),
147            )))
148        } else {
149            None
150        };
151
152        Worker {
153            cancel,
154            poll_handle,
155            reaper_handle,
156            cleanup_handle,
157            drain_timeout: Duration::from_secs(self.config.drain_timeout_secs),
158        }
159    }
160}
161
162/// A running job worker that processes enqueued jobs.
163///
164/// Implements [`crate::runtime::Task`] for graceful shutdown. Pass the
165/// `Worker` to the [`run!`](crate::run) macro so it is shut down when the
166/// process receives a termination signal.
167///
168/// Construct via [`Worker::builder`].
169pub struct Worker {
170    cancel: CancellationToken,
171    poll_handle: JoinHandle<()>,
172    reaper_handle: JoinHandle<()>,
173    cleanup_handle: Option<JoinHandle<()>>,
174    drain_timeout: Duration,
175}
176
177impl Worker {
178    /// Create a [`WorkerBuilder`] from config and service registry.
179    ///
180    /// # Panics
181    ///
182    /// Panics if a [`Database`](crate::db::Database) is not registered in
183    /// `registry`.
184    pub fn builder(config: &JobConfig, registry: &Registry) -> WorkerBuilder {
185        let snapshot = registry.snapshot();
186        let db = snapshot
187            .get::<Database>()
188            .expect("Database must be registered before building Worker");
189
190        WorkerBuilder {
191            config: config.clone(),
192            registry: snapshot,
193            db: (*db).clone(),
194            handlers: HashMap::new(),
195        }
196    }
197}
198
199impl crate::runtime::Task for Worker {
200    async fn shutdown(self) -> Result<()> {
201        self.cancel.cancel();
202        let drain = async {
203            let _ = self.poll_handle.await;
204            let _ = self.reaper_handle.await;
205            if let Some(h) = self.cleanup_handle {
206                let _ = h.await;
207            }
208        };
209        let _ = tokio::time::timeout(self.drain_timeout, drain).await;
210        Ok(())
211    }
212}
213
214/// A job row claimed from the database during polling.
215struct ClaimedJob {
216    id: String,
217    name: String,
218    queue: String,
219    payload: String,
220    attempt: i32,
221}
222
223async fn poll_loop(
224    db: Database,
225    registry: Arc<RegistrySnapshot>,
226    handlers: Arc<HashMap<String, HandlerEntry>>,
227    handler_names: Vec<String>,
228    queue_semaphores: Vec<(QueueConfig, Arc<Semaphore>)>,
229    poll_interval_secs: u64,
230    cancel: CancellationToken,
231) {
232    let poll_interval = Duration::from_secs(poll_interval_secs);
233
234    // Precompute the SQL template once — handler_names never changes after start.
235    let placeholders: Vec<String> = handler_names
236        .iter()
237        .enumerate()
238        .map(|(i, _)| format!("?{}", i + 5))
239        .collect();
240    let placeholders_str = placeholders.join(", ");
241    let limit_param = handler_names.len() + 5;
242    let claim_sql = format!(
243        "UPDATE jobs SET status = 'running', attempt = attempt + 1, \
244         started_at = ?1, updated_at = ?2 \
245         WHERE id IN (\
246             SELECT id FROM jobs \
247             WHERE status = 'pending' AND run_at <= ?3 \
248             AND queue = ?4 AND name IN ({placeholders_str}) \
249             ORDER BY run_at ASC LIMIT ?{limit_param}\
250         ) RETURNING id, name, queue, payload, attempt",
251    );
252
253    loop {
254        tokio::select! {
255            _ = cancel.cancelled() => break,
256            _ = tokio::time::sleep(poll_interval) => {
257                if handler_names.is_empty() {
258                    continue;
259                }
260
261                let now_str = Utc::now().to_rfc3339();
262
263                for (queue_config, semaphore) in &queue_semaphores {
264                    let slots = semaphore.available_permits();
265                    if slots == 0 {
266                        continue;
267                    }
268
269                    let mut params: Vec<libsql::Value> = vec![
270                        libsql::Value::Text(now_str.clone()),       // ?1 started_at
271                        libsql::Value::Text(now_str.clone()),       // ?2 updated_at
272                        libsql::Value::Text(now_str.clone()),       // ?3 run_at <=
273                        libsql::Value::Text(queue_config.name.clone()), // ?4 queue =
274                    ];
275                    for name in &handler_names {
276                        params.push(libsql::Value::Text(name.clone()));
277                    }
278                    params.push(libsql::Value::Integer(slots as i64)); // LIMIT
279
280                    let claimed = match db.conn().query_all_map(
281                        &claim_sql,
282                        params,
283                        |row| {
284                            Ok(ClaimedJob {
285                                id: String::from_value(row.get_value(0).map_err(crate::Error::from)?)?,
286                                name: String::from_value(row.get_value(1).map_err(crate::Error::from)?)?,
287                                queue: String::from_value(row.get_value(2).map_err(crate::Error::from)?)?,
288                                payload: String::from_value(row.get_value(3).map_err(crate::Error::from)?)?,
289                                attempt: i32::from_value(row.get_value(4).map_err(crate::Error::from)?)?,
290                            })
291                        },
292                    ).await {
293                        Ok(rows) => rows,
294                        Err(e) => {
295                            tracing::error!(error = %e, queue = %queue_config.name, "failed to claim jobs");
296                            continue;
297                        }
298                    };
299
300                    for job in claimed {
301                        let Some(entry) = handlers.get(&job.name) else {
302                            tracing::warn!(job_name = %job.name, "no handler registered");
303                            continue;
304                        };
305
306                        let permit = match semaphore.clone().try_acquire_owned() {
307                            Ok(p) => p,
308                            Err(_) => {
309                                tracing::warn!(job_id = %job.id, "no permit available, job will be reaped");
310                                break;
311                            }
312                        };
313
314                        let handler = entry.handler.clone();
315                        let max_attempts = entry.options.max_attempts;
316                        let timeout_secs = entry.options.timeout_secs;
317                        let reg = registry.clone();
318                        let db_clone = db.clone();
319                        let job_id = job.id.clone();
320                        let job_name = job.name.clone();
321
322                        let deadline =
323                            tokio::time::Instant::now() + Duration::from_secs(timeout_secs);
324
325                        let meta = Meta {
326                            id: job.id,
327                            name: job.name,
328                            queue: job.queue,
329                            attempt: job.attempt as u32,
330                            max_attempts,
331                            deadline: Some(deadline),
332                        };
333
334                        let ctx = JobContext {
335                            registry: reg,
336                            payload: job.payload,
337                            meta,
338                        };
339
340                        tokio::spawn(async move {
341                            let result = tokio::time::timeout(
342                                Duration::from_secs(timeout_secs),
343                                (handler)(ctx),
344                            )
345                            .await;
346
347                            let now_str = Utc::now().to_rfc3339();
348
349                            match result {
350                                Ok(Ok(())) => {
351                                    let _ = db_clone.conn().execute_raw(
352                                        "UPDATE jobs SET status = 'completed', \
353                                         completed_at = ?1, updated_at = ?2 WHERE id = ?3",
354                                        libsql::params![now_str.as_str(), now_str.as_str(), job_id.as_str()],
355                                    )
356                                    .await;
357
358                                    tracing::info!(
359                                        job_id = %job_id,
360                                        job_name = %job_name,
361                                        "job completed"
362                                    );
363                                }
364                                Ok(Err(e)) => {
365                                    let error_msg = format!("{e}");
366                                    handle_job_failure(
367                                        &db_clone,
368                                        &job_id,
369                                        &job_name,
370                                        job.attempt as u32,
371                                        max_attempts,
372                                        &error_msg,
373                                        &now_str,
374                                    )
375                                    .await;
376                                }
377                                Err(_) => {
378                                    handle_job_failure(
379                                        &db_clone,
380                                        &job_id,
381                                        &job_name,
382                                        job.attempt as u32,
383                                        max_attempts,
384                                        "timeout",
385                                        &now_str,
386                                    )
387                                    .await;
388                                }
389                            }
390
391                            drop(permit);
392                        });
393                    }
394                }
395            }
396        }
397    }
398}
399
400async fn handle_job_failure(
401    db: &Database,
402    job_id: &str,
403    job_name: &str,
404    attempt: u32,
405    max_attempts: u32,
406    error_msg: &str,
407    now_str: &str,
408) {
409    if attempt >= max_attempts {
410        let _ = db
411            .conn()
412            .execute_raw(
413                "UPDATE jobs SET status = 'dead', \
414                 failed_at = ?1, error_message = ?2, updated_at = ?3 WHERE id = ?4",
415                libsql::params![now_str, error_msg, now_str, job_id],
416            )
417            .await;
418
419        tracing::error!(
420            job_id = %job_id,
421            job_name = %job_name,
422            attempt = attempt,
423            error = %error_msg,
424            "job dead after max attempts"
425        );
426    } else {
427        let delay_secs = std::cmp::min(5u64 * 2u64.pow(attempt - 1), 3600);
428        let retry_at = (Utc::now() + chrono::Duration::seconds(delay_secs as i64)).to_rfc3339();
429
430        let _ = db
431            .conn()
432            .execute_raw(
433                "UPDATE jobs SET status = 'pending', \
434                 run_at = ?1, started_at = NULL, \
435                 failed_at = ?2, error_message = ?3, updated_at = ?4 WHERE id = ?5",
436                libsql::params![retry_at.as_str(), now_str, error_msg, now_str, job_id],
437            )
438            .await;
439
440        tracing::warn!(
441            job_id = %job_id,
442            job_name = %job_name,
443            attempt = attempt,
444            retry_in_secs = delay_secs,
445            error = %error_msg,
446            "job failed, rescheduled"
447        );
448    }
449}