Skip to main content

modo/job/
worker.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Duration;
5
6use chrono::Utc;
7use tokio::sync::Semaphore;
8use tokio::task::JoinHandle;
9use tokio_util::sync::CancellationToken;
10
11use crate::db::{ConnExt, ConnQueryExt, Database, FromValue};
12use crate::error::Result;
13use crate::service::{Registry, RegistrySnapshot};
14
15use super::cleanup::cleanup_loop;
16use super::config::{JobConfig, QueueConfig};
17use super::context::JobContext;
18use super::handler::JobHandler;
19use super::meta::Meta;
20use super::reaper::reaper_loop;
21
22/// Per-handler options controlling retry and timeout behavior.
23pub struct JobOptions {
24    /// Maximum number of execution attempts before the job is marked `Dead`.
25    /// Defaults to `3`.
26    pub max_attempts: u32,
27    /// Per-execution timeout in seconds. If a handler exceeds this, the
28    /// attempt is treated as a failure. Defaults to `300` (5 min).
29    pub timeout_secs: u64,
30}
31
32impl Default for JobOptions {
33    fn default() -> Self {
34        Self {
35            max_attempts: 3,
36            timeout_secs: 300,
37        }
38    }
39}
40
41type ErasedHandler =
42    Arc<dyn Fn(JobContext) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
43
44struct HandlerEntry {
45    handler: ErasedHandler,
46    options: JobOptions,
47}
48
49/// Builder for constructing a [`Worker`] with registered job handlers.
50///
51/// Obtained via [`Worker::builder`]. Call [`WorkerBuilder::register`] (or
52/// [`WorkerBuilder::register_with`]) for each job name, then call
53/// [`WorkerBuilder::start`] to spawn the background loops and obtain a
54/// [`Worker`] handle.
55#[must_use]
56pub struct WorkerBuilder {
57    config: JobConfig,
58    registry: Arc<RegistrySnapshot>,
59    db: Database,
60    handlers: HashMap<String, HandlerEntry>,
61}
62
63impl WorkerBuilder {
64    /// Register a handler for the given job name with default [`JobOptions`].
65    pub fn register<H, Args>(mut self, name: &str, handler: H) -> Self
66    where
67        H: JobHandler<Args> + Send + Sync,
68    {
69        self.register_inner(name, handler, JobOptions::default());
70        self
71    }
72
73    /// Register a handler for the given job name with custom [`JobOptions`].
74    pub fn register_with<H, Args>(mut self, name: &str, handler: H, options: JobOptions) -> Self
75    where
76        H: JobHandler<Args> + Send + Sync,
77    {
78        self.register_inner(name, handler, options);
79        self
80    }
81
82    fn register_inner<H, Args>(&mut self, name: &str, handler: H, options: JobOptions)
83    where
84        H: JobHandler<Args> + Send + Sync,
85    {
86        let handler = Arc::new(
87            move |ctx: JobContext| -> Pin<Box<dyn Future<Output = Result<()>> + Send>> {
88                let h = handler.clone();
89                Box::pin(async move { h.call(ctx).await })
90            },
91        ) as ErasedHandler;
92
93        self.handlers
94            .insert(name.to_string(), HandlerEntry { handler, options });
95    }
96
97    /// Spawn the worker loops and return a [`Worker`] handle for shutdown.
98    ///
99    /// Three background tasks are started:
100    /// - **poll loop** — claims and dispatches pending jobs
101    /// - **stale reaper** — resets jobs stuck in `running` past the configured
102    ///   threshold
103    /// - **cleanup loop** (optional) — deletes old terminal jobs
104    pub async fn start(self) -> Worker {
105        let cancel = CancellationToken::new();
106        let handlers = Arc::new(self.handlers);
107        let handler_names: Vec<String> = handlers.keys().cloned().collect();
108
109        // Build per-queue semaphores
110        let queue_semaphores: Vec<(QueueConfig, Arc<Semaphore>)> = self
111            .config
112            .queues
113            .iter()
114            .map(|q| (q.clone(), Arc::new(Semaphore::new(q.concurrency as usize))))
115            .collect();
116
117        // Spawn poll loop
118        let poll_handle = tokio::spawn(poll_loop(
119            self.db.clone(),
120            self.registry.clone(),
121            handlers.clone(),
122            handler_names,
123            queue_semaphores,
124            self.config.poll_interval_secs,
125            cancel.clone(),
126        ));
127
128        // Spawn stale reaper
129        let reaper_handle = tokio::spawn(reaper_loop(
130            self.db.clone(),
131            self.config.stale_threshold_secs,
132            self.config.stale_reaper_interval_secs,
133            cancel.clone(),
134        ));
135
136        // Spawn cleanup (if configured)
137        let cleanup_handle = if let Some(ref cleanup) = self.config.cleanup {
138            Some(tokio::spawn(cleanup_loop(
139                self.db.clone(),
140                cleanup.interval_secs,
141                cleanup.retention_secs,
142                cancel.clone(),
143            )))
144        } else {
145            None
146        };
147
148        Worker {
149            cancel,
150            poll_handle,
151            reaper_handle,
152            cleanup_handle,
153            drain_timeout: Duration::from_secs(self.config.drain_timeout_secs),
154        }
155    }
156}
157
158/// A running job worker that processes enqueued jobs.
159///
160/// Implements [`crate::runtime::Task`] for graceful shutdown. Pass the
161/// `Worker` to the [`run!`](crate::run) macro so it is shut down when the
162/// process receives a termination signal.
163///
164/// Construct via [`Worker::builder`].
165pub struct Worker {
166    cancel: CancellationToken,
167    poll_handle: JoinHandle<()>,
168    reaper_handle: JoinHandle<()>,
169    cleanup_handle: Option<JoinHandle<()>>,
170    drain_timeout: Duration,
171}
172
173impl Worker {
174    /// Create a [`WorkerBuilder`] from config and service registry.
175    ///
176    /// # Panics
177    ///
178    /// Panics if a [`Database`](crate::db::Database) is not registered in
179    /// `registry`.
180    pub fn builder(config: &JobConfig, registry: &Registry) -> WorkerBuilder {
181        let snapshot = registry.snapshot();
182        let db = snapshot
183            .get::<Database>()
184            .expect("Database must be registered before building Worker");
185
186        WorkerBuilder {
187            config: config.clone(),
188            registry: snapshot,
189            db: (*db).clone(),
190            handlers: HashMap::new(),
191        }
192    }
193}
194
195impl crate::runtime::Task for Worker {
196    async fn shutdown(self) -> Result<()> {
197        self.cancel.cancel();
198        let drain = async {
199            let _ = self.poll_handle.await;
200            let _ = self.reaper_handle.await;
201            if let Some(h) = self.cleanup_handle {
202                let _ = h.await;
203            }
204        };
205        let _ = tokio::time::timeout(self.drain_timeout, drain).await;
206        Ok(())
207    }
208}
209
210/// A job row claimed from the database during polling.
211struct ClaimedJob {
212    id: String,
213    name: String,
214    queue: String,
215    payload: String,
216    attempt: i32,
217}
218
219async fn poll_loop(
220    db: Database,
221    registry: Arc<RegistrySnapshot>,
222    handlers: Arc<HashMap<String, HandlerEntry>>,
223    handler_names: Vec<String>,
224    queue_semaphores: Vec<(QueueConfig, Arc<Semaphore>)>,
225    poll_interval_secs: u64,
226    cancel: CancellationToken,
227) {
228    let poll_interval = Duration::from_secs(poll_interval_secs);
229
230    // Precompute the SQL template once — handler_names never changes after start.
231    let placeholders: Vec<String> = handler_names
232        .iter()
233        .enumerate()
234        .map(|(i, _)| format!("?{}", i + 5))
235        .collect();
236    let placeholders_str = placeholders.join(", ");
237    let limit_param = handler_names.len() + 5;
238    let claim_sql = format!(
239        "UPDATE jobs SET status = 'running', attempt = attempt + 1, \
240         started_at = ?1, updated_at = ?2 \
241         WHERE id IN (\
242             SELECT id FROM jobs \
243             WHERE status = 'pending' AND run_at <= ?3 \
244             AND queue = ?4 AND name IN ({placeholders_str}) \
245             ORDER BY run_at ASC LIMIT ?{limit_param}\
246         ) RETURNING id, name, queue, payload, attempt",
247    );
248
249    loop {
250        tokio::select! {
251            _ = cancel.cancelled() => break,
252            _ = tokio::time::sleep(poll_interval) => {
253                if handler_names.is_empty() {
254                    continue;
255                }
256
257                let now_str = Utc::now().to_rfc3339();
258
259                for (queue_config, semaphore) in &queue_semaphores {
260                    let slots = semaphore.available_permits();
261                    if slots == 0 {
262                        continue;
263                    }
264
265                    let mut params: Vec<libsql::Value> = vec![
266                        libsql::Value::Text(now_str.clone()),       // ?1 started_at
267                        libsql::Value::Text(now_str.clone()),       // ?2 updated_at
268                        libsql::Value::Text(now_str.clone()),       // ?3 run_at <=
269                        libsql::Value::Text(queue_config.name.clone()), // ?4 queue =
270                    ];
271                    for name in &handler_names {
272                        params.push(libsql::Value::Text(name.clone()));
273                    }
274                    params.push(libsql::Value::Integer(slots as i64)); // LIMIT
275
276                    let claimed = match db.conn().query_all_map(
277                        &claim_sql,
278                        params,
279                        |row| {
280                            Ok(ClaimedJob {
281                                id: String::from_value(row.get_value(0).map_err(crate::Error::from)?)?,
282                                name: String::from_value(row.get_value(1).map_err(crate::Error::from)?)?,
283                                queue: String::from_value(row.get_value(2).map_err(crate::Error::from)?)?,
284                                payload: String::from_value(row.get_value(3).map_err(crate::Error::from)?)?,
285                                attempt: i32::from_value(row.get_value(4).map_err(crate::Error::from)?)?,
286                            })
287                        },
288                    ).await {
289                        Ok(rows) => rows,
290                        Err(e) => {
291                            tracing::error!(error = %e, queue = %queue_config.name, "failed to claim jobs");
292                            continue;
293                        }
294                    };
295
296                    for job in claimed {
297                        let Some(entry) = handlers.get(&job.name) else {
298                            tracing::warn!(job_name = %job.name, "no handler registered");
299                            continue;
300                        };
301
302                        let permit = match semaphore.clone().try_acquire_owned() {
303                            Ok(p) => p,
304                            Err(_) => {
305                                tracing::warn!(job_id = %job.id, "no permit available, job will be reaped");
306                                break;
307                            }
308                        };
309
310                        let handler = entry.handler.clone();
311                        let max_attempts = entry.options.max_attempts;
312                        let timeout_secs = entry.options.timeout_secs;
313                        let reg = registry.clone();
314                        let db_clone = db.clone();
315                        let job_id = job.id.clone();
316                        let job_name = job.name.clone();
317
318                        let deadline =
319                            tokio::time::Instant::now() + Duration::from_secs(timeout_secs);
320
321                        let meta = Meta {
322                            id: job.id,
323                            name: job.name,
324                            queue: job.queue,
325                            attempt: job.attempt as u32,
326                            max_attempts,
327                            deadline: Some(deadline),
328                        };
329
330                        let ctx = JobContext {
331                            registry: reg,
332                            payload: job.payload,
333                            meta,
334                        };
335
336                        tokio::spawn(async move {
337                            let result = tokio::time::timeout(
338                                Duration::from_secs(timeout_secs),
339                                (handler)(ctx),
340                            )
341                            .await;
342
343                            let now_str = Utc::now().to_rfc3339();
344
345                            match result {
346                                Ok(Ok(())) => {
347                                    let _ = db_clone.conn().execute_raw(
348                                        "UPDATE jobs SET status = 'completed', \
349                                         completed_at = ?1, updated_at = ?2 WHERE id = ?3",
350                                        libsql::params![now_str.as_str(), now_str.as_str(), job_id.as_str()],
351                                    )
352                                    .await;
353
354                                    tracing::info!(
355                                        job_id = %job_id,
356                                        job_name = %job_name,
357                                        "job completed"
358                                    );
359                                }
360                                Ok(Err(e)) => {
361                                    let error_msg = format!("{e}");
362                                    handle_job_failure(
363                                        &db_clone,
364                                        &job_id,
365                                        &job_name,
366                                        job.attempt as u32,
367                                        max_attempts,
368                                        &error_msg,
369                                        &now_str,
370                                    )
371                                    .await;
372                                }
373                                Err(_) => {
374                                    handle_job_failure(
375                                        &db_clone,
376                                        &job_id,
377                                        &job_name,
378                                        job.attempt as u32,
379                                        max_attempts,
380                                        "timeout",
381                                        &now_str,
382                                    )
383                                    .await;
384                                }
385                            }
386
387                            drop(permit);
388                        });
389                    }
390                }
391            }
392        }
393    }
394}
395
396async fn handle_job_failure(
397    db: &Database,
398    job_id: &str,
399    job_name: &str,
400    attempt: u32,
401    max_attempts: u32,
402    error_msg: &str,
403    now_str: &str,
404) {
405    if attempt >= max_attempts {
406        let _ = db
407            .conn()
408            .execute_raw(
409                "UPDATE jobs SET status = 'dead', \
410                 failed_at = ?1, error_message = ?2, updated_at = ?3 WHERE id = ?4",
411                libsql::params![now_str, error_msg, now_str, job_id],
412            )
413            .await;
414
415        tracing::error!(
416            job_id = %job_id,
417            job_name = %job_name,
418            attempt = attempt,
419            error = %error_msg,
420            "job dead after max attempts"
421        );
422    } else {
423        let delay_secs = std::cmp::min(5u64 * 2u64.pow(attempt - 1), 3600);
424        let retry_at = (Utc::now() + chrono::Duration::seconds(delay_secs as i64)).to_rfc3339();
425
426        let _ = db
427            .conn()
428            .execute_raw(
429                "UPDATE jobs SET status = 'pending', \
430                 run_at = ?1, started_at = NULL, \
431                 failed_at = ?2, error_message = ?3, updated_at = ?4 WHERE id = ?5",
432                libsql::params![retry_at.as_str(), now_str, error_msg, now_str, job_id],
433            )
434            .await;
435
436        tracing::warn!(
437            job_id = %job_id,
438            job_name = %job_name,
439            attempt = attempt,
440            retry_in_secs = delay_secs,
441            error = %error_msg,
442            "job failed, rescheduled"
443        );
444    }
445}