1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Duration;
5
6use chrono::Utc;
7use tokio::sync::Semaphore;
8use tokio::task::JoinHandle;
9use tokio_util::sync::CancellationToken;
10
11use crate::db::{ConnExt, ConnQueryExt, Database, FromValue};
12use crate::error::Result;
13use crate::service::{Registry, RegistrySnapshot};
14
15use super::cleanup::cleanup_loop;
16use super::config::{JobConfig, QueueConfig};
17use super::context::JobContext;
18use super::handler::JobHandler;
19use super::meta::Meta;
20use super::reaper::reaper_loop;
21
22pub struct JobOptions {
24 pub max_attempts: u32,
27 pub timeout_secs: u64,
30}
31
32impl Default for JobOptions {
33 fn default() -> Self {
34 Self {
35 max_attempts: 3,
36 timeout_secs: 300,
37 }
38 }
39}
40
41type ErasedHandler =
42 Arc<dyn Fn(JobContext) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
43
44struct HandlerEntry {
45 handler: ErasedHandler,
46 options: JobOptions,
47}
48
49#[must_use]
56pub struct WorkerBuilder {
57 config: JobConfig,
58 registry: Arc<RegistrySnapshot>,
59 db: Database,
60 handlers: HashMap<String, HandlerEntry>,
61}
62
63impl WorkerBuilder {
64 pub fn register<H, Args>(mut self, name: &str, handler: H) -> Self
66 where
67 H: JobHandler<Args> + Send + Sync,
68 {
69 self.register_inner(name, handler, JobOptions::default());
70 self
71 }
72
73 pub fn register_with<H, Args>(mut self, name: &str, handler: H, options: JobOptions) -> Self
75 where
76 H: JobHandler<Args> + Send + Sync,
77 {
78 self.register_inner(name, handler, options);
79 self
80 }
81
82 fn register_inner<H, Args>(&mut self, name: &str, handler: H, options: JobOptions)
83 where
84 H: JobHandler<Args> + Send + Sync,
85 {
86 let handler = Arc::new(
87 move |ctx: JobContext| -> Pin<Box<dyn Future<Output = Result<()>> + Send>> {
88 let h = handler.clone();
89 Box::pin(async move { h.call(ctx).await })
90 },
91 ) as ErasedHandler;
92
93 self.handlers
94 .insert(name.to_string(), HandlerEntry { handler, options });
95 }
96
97 pub async fn start(self) -> Worker {
105 let cancel = CancellationToken::new();
106 let handlers = Arc::new(self.handlers);
107 let handler_names: Vec<String> = handlers.keys().cloned().collect();
108
109 let queue_semaphores: Vec<(QueueConfig, Arc<Semaphore>)> = self
111 .config
112 .queues
113 .iter()
114 .map(|q| (q.clone(), Arc::new(Semaphore::new(q.concurrency as usize))))
115 .collect();
116
117 let poll_handle = tokio::spawn(poll_loop(
119 self.db.clone(),
120 self.registry.clone(),
121 handlers.clone(),
122 handler_names,
123 queue_semaphores,
124 self.config.poll_interval_secs,
125 cancel.clone(),
126 ));
127
128 let reaper_handle = tokio::spawn(reaper_loop(
130 self.db.clone(),
131 self.config.stale_threshold_secs,
132 self.config.stale_reaper_interval_secs,
133 cancel.clone(),
134 ));
135
136 let cleanup_handle = if let Some(ref cleanup) = self.config.cleanup {
138 Some(tokio::spawn(cleanup_loop(
139 self.db.clone(),
140 cleanup.interval_secs,
141 cleanup.retention_secs,
142 cancel.clone(),
143 )))
144 } else {
145 None
146 };
147
148 Worker {
149 cancel,
150 poll_handle,
151 reaper_handle,
152 cleanup_handle,
153 drain_timeout: Duration::from_secs(self.config.drain_timeout_secs),
154 }
155 }
156}
157
158pub struct Worker {
166 cancel: CancellationToken,
167 poll_handle: JoinHandle<()>,
168 reaper_handle: JoinHandle<()>,
169 cleanup_handle: Option<JoinHandle<()>>,
170 drain_timeout: Duration,
171}
172
173impl Worker {
174 pub fn builder(config: &JobConfig, registry: &Registry) -> WorkerBuilder {
181 let snapshot = registry.snapshot();
182 let db = snapshot
183 .get::<Database>()
184 .expect("Database must be registered before building Worker");
185
186 WorkerBuilder {
187 config: config.clone(),
188 registry: snapshot,
189 db: (*db).clone(),
190 handlers: HashMap::new(),
191 }
192 }
193}
194
195impl crate::runtime::Task for Worker {
196 async fn shutdown(self) -> Result<()> {
197 self.cancel.cancel();
198 let drain = async {
199 let _ = self.poll_handle.await;
200 let _ = self.reaper_handle.await;
201 if let Some(h) = self.cleanup_handle {
202 let _ = h.await;
203 }
204 };
205 let _ = tokio::time::timeout(self.drain_timeout, drain).await;
206 Ok(())
207 }
208}
209
210struct ClaimedJob {
212 id: String,
213 name: String,
214 queue: String,
215 payload: String,
216 attempt: i32,
217}
218
219async fn poll_loop(
220 db: Database,
221 registry: Arc<RegistrySnapshot>,
222 handlers: Arc<HashMap<String, HandlerEntry>>,
223 handler_names: Vec<String>,
224 queue_semaphores: Vec<(QueueConfig, Arc<Semaphore>)>,
225 poll_interval_secs: u64,
226 cancel: CancellationToken,
227) {
228 let poll_interval = Duration::from_secs(poll_interval_secs);
229
230 let placeholders: Vec<String> = handler_names
232 .iter()
233 .enumerate()
234 .map(|(i, _)| format!("?{}", i + 5))
235 .collect();
236 let placeholders_str = placeholders.join(", ");
237 let limit_param = handler_names.len() + 5;
238 let claim_sql = format!(
239 "UPDATE jobs SET status = 'running', attempt = attempt + 1, \
240 started_at = ?1, updated_at = ?2 \
241 WHERE id IN (\
242 SELECT id FROM jobs \
243 WHERE status = 'pending' AND run_at <= ?3 \
244 AND queue = ?4 AND name IN ({placeholders_str}) \
245 ORDER BY run_at ASC LIMIT ?{limit_param}\
246 ) RETURNING id, name, queue, payload, attempt",
247 );
248
249 loop {
250 tokio::select! {
251 _ = cancel.cancelled() => break,
252 _ = tokio::time::sleep(poll_interval) => {
253 if handler_names.is_empty() {
254 continue;
255 }
256
257 let now_str = Utc::now().to_rfc3339();
258
259 for (queue_config, semaphore) in &queue_semaphores {
260 let slots = semaphore.available_permits();
261 if slots == 0 {
262 continue;
263 }
264
265 let mut params: Vec<libsql::Value> = vec![
266 libsql::Value::Text(now_str.clone()), libsql::Value::Text(now_str.clone()), libsql::Value::Text(now_str.clone()), libsql::Value::Text(queue_config.name.clone()), ];
271 for name in &handler_names {
272 params.push(libsql::Value::Text(name.clone()));
273 }
274 params.push(libsql::Value::Integer(slots as i64)); let claimed = match db.conn().query_all_map(
277 &claim_sql,
278 params,
279 |row| {
280 Ok(ClaimedJob {
281 id: String::from_value(row.get_value(0).map_err(crate::Error::from)?)?,
282 name: String::from_value(row.get_value(1).map_err(crate::Error::from)?)?,
283 queue: String::from_value(row.get_value(2).map_err(crate::Error::from)?)?,
284 payload: String::from_value(row.get_value(3).map_err(crate::Error::from)?)?,
285 attempt: i32::from_value(row.get_value(4).map_err(crate::Error::from)?)?,
286 })
287 },
288 ).await {
289 Ok(rows) => rows,
290 Err(e) => {
291 tracing::error!(error = %e, queue = %queue_config.name, "failed to claim jobs");
292 continue;
293 }
294 };
295
296 for job in claimed {
297 let Some(entry) = handlers.get(&job.name) else {
298 tracing::warn!(job_name = %job.name, "no handler registered");
299 continue;
300 };
301
302 let permit = match semaphore.clone().try_acquire_owned() {
303 Ok(p) => p,
304 Err(_) => {
305 tracing::warn!(job_id = %job.id, "no permit available, job will be reaped");
306 break;
307 }
308 };
309
310 let handler = entry.handler.clone();
311 let max_attempts = entry.options.max_attempts;
312 let timeout_secs = entry.options.timeout_secs;
313 let reg = registry.clone();
314 let db_clone = db.clone();
315 let job_id = job.id.clone();
316 let job_name = job.name.clone();
317
318 let deadline =
319 tokio::time::Instant::now() + Duration::from_secs(timeout_secs);
320
321 let meta = Meta {
322 id: job.id,
323 name: job.name,
324 queue: job.queue,
325 attempt: job.attempt as u32,
326 max_attempts,
327 deadline: Some(deadline),
328 };
329
330 let ctx = JobContext {
331 registry: reg,
332 payload: job.payload,
333 meta,
334 };
335
336 tokio::spawn(async move {
337 let result = tokio::time::timeout(
338 Duration::from_secs(timeout_secs),
339 (handler)(ctx),
340 )
341 .await;
342
343 let now_str = Utc::now().to_rfc3339();
344
345 match result {
346 Ok(Ok(())) => {
347 let _ = db_clone.conn().execute_raw(
348 "UPDATE jobs SET status = 'completed', \
349 completed_at = ?1, updated_at = ?2 WHERE id = ?3",
350 libsql::params![now_str.as_str(), now_str.as_str(), job_id.as_str()],
351 )
352 .await;
353
354 tracing::info!(
355 job_id = %job_id,
356 job_name = %job_name,
357 "job completed"
358 );
359 }
360 Ok(Err(e)) => {
361 let error_msg = format!("{e}");
362 handle_job_failure(
363 &db_clone,
364 &job_id,
365 &job_name,
366 job.attempt as u32,
367 max_attempts,
368 &error_msg,
369 &now_str,
370 )
371 .await;
372 }
373 Err(_) => {
374 handle_job_failure(
375 &db_clone,
376 &job_id,
377 &job_name,
378 job.attempt as u32,
379 max_attempts,
380 "timeout",
381 &now_str,
382 )
383 .await;
384 }
385 }
386
387 drop(permit);
388 });
389 }
390 }
391 }
392 }
393 }
394}
395
396async fn handle_job_failure(
397 db: &Database,
398 job_id: &str,
399 job_name: &str,
400 attempt: u32,
401 max_attempts: u32,
402 error_msg: &str,
403 now_str: &str,
404) {
405 if attempt >= max_attempts {
406 let _ = db
407 .conn()
408 .execute_raw(
409 "UPDATE jobs SET status = 'dead', \
410 failed_at = ?1, error_message = ?2, updated_at = ?3 WHERE id = ?4",
411 libsql::params![now_str, error_msg, now_str, job_id],
412 )
413 .await;
414
415 tracing::error!(
416 job_id = %job_id,
417 job_name = %job_name,
418 attempt = attempt,
419 error = %error_msg,
420 "job dead after max attempts"
421 );
422 } else {
423 let delay_secs = std::cmp::min(5u64 * 2u64.pow(attempt - 1), 3600);
424 let retry_at = (Utc::now() + chrono::Duration::seconds(delay_secs as i64)).to_rfc3339();
425
426 let _ = db
427 .conn()
428 .execute_raw(
429 "UPDATE jobs SET status = 'pending', \
430 run_at = ?1, started_at = NULL, \
431 failed_at = ?2, error_message = ?3, updated_at = ?4 WHERE id = ?5",
432 libsql::params![retry_at.as_str(), now_str, error_msg, now_str, job_id],
433 )
434 .await;
435
436 tracing::warn!(
437 job_id = %job_id,
438 job_name = %job_name,
439 attempt = attempt,
440 retry_in_secs = delay_secs,
441 error = %error_msg,
442 "job failed, rescheduled"
443 );
444 }
445}