1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Duration;
5
6use chrono::Utc;
7use tokio::sync::Semaphore;
8use tokio::task::JoinHandle;
9use tokio_util::sync::CancellationToken;
10
11use crate::db::{ConnExt, ConnQueryExt, Database, FromValue};
12use crate::error::Result;
13use crate::service::{Registry, RegistrySnapshot};
14
15use super::cleanup::cleanup_loop;
16use super::config::{JobConfig, QueueConfig};
17use super::context::JobContext;
18use super::handler::JobHandler;
19use super::meta::Meta;
20use super::reaper::reaper_loop;
21
22pub struct JobOptions {
28 pub max_attempts: u32,
31 pub timeout_secs: u64,
34}
35
36impl Default for JobOptions {
37 fn default() -> Self {
38 Self {
39 max_attempts: 3,
40 timeout_secs: 300,
41 }
42 }
43}
44
45type ErasedHandler =
46 Arc<dyn Fn(JobContext) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
47
48struct HandlerEntry {
49 handler: ErasedHandler,
50 options: JobOptions,
51}
52
53#[must_use]
60pub struct WorkerBuilder {
61 config: JobConfig,
62 registry: Arc<RegistrySnapshot>,
63 db: Database,
64 handlers: HashMap<String, HandlerEntry>,
65}
66
67impl WorkerBuilder {
68 pub fn register<H, Args>(mut self, name: &str, handler: H) -> Self
70 where
71 H: JobHandler<Args> + Send + Sync,
72 {
73 self.register_inner(name, handler, JobOptions::default());
74 self
75 }
76
77 pub fn register_with<H, Args>(mut self, name: &str, handler: H, options: JobOptions) -> Self
79 where
80 H: JobHandler<Args> + Send + Sync,
81 {
82 self.register_inner(name, handler, options);
83 self
84 }
85
86 fn register_inner<H, Args>(&mut self, name: &str, handler: H, options: JobOptions)
87 where
88 H: JobHandler<Args> + Send + Sync,
89 {
90 let handler = Arc::new(
91 move |ctx: JobContext| -> Pin<Box<dyn Future<Output = Result<()>> + Send>> {
92 let h = handler.clone();
93 Box::pin(async move { h.call(ctx).await })
94 },
95 ) as ErasedHandler;
96
97 self.handlers
98 .insert(name.to_string(), HandlerEntry { handler, options });
99 }
100
101 pub async fn start(self) -> Worker {
109 let cancel = CancellationToken::new();
110 let handlers = Arc::new(self.handlers);
111 let handler_names: Vec<String> = handlers.keys().cloned().collect();
112
113 let queue_semaphores: Vec<(QueueConfig, Arc<Semaphore>)> = self
115 .config
116 .queues
117 .iter()
118 .map(|q| (q.clone(), Arc::new(Semaphore::new(q.concurrency as usize))))
119 .collect();
120
121 let poll_handle = tokio::spawn(poll_loop(
123 self.db.clone(),
124 self.registry.clone(),
125 handlers.clone(),
126 handler_names,
127 queue_semaphores,
128 self.config.poll_interval_secs,
129 cancel.clone(),
130 ));
131
132 let reaper_handle = tokio::spawn(reaper_loop(
134 self.db.clone(),
135 self.config.stale_threshold_secs,
136 self.config.stale_reaper_interval_secs,
137 cancel.clone(),
138 ));
139
140 let cleanup_handle = if let Some(ref cleanup) = self.config.cleanup {
142 Some(tokio::spawn(cleanup_loop(
143 self.db.clone(),
144 cleanup.interval_secs,
145 cleanup.retention_secs,
146 cancel.clone(),
147 )))
148 } else {
149 None
150 };
151
152 Worker {
153 cancel,
154 poll_handle,
155 reaper_handle,
156 cleanup_handle,
157 drain_timeout: Duration::from_secs(self.config.drain_timeout_secs),
158 }
159 }
160}
161
162pub struct Worker {
170 cancel: CancellationToken,
171 poll_handle: JoinHandle<()>,
172 reaper_handle: JoinHandle<()>,
173 cleanup_handle: Option<JoinHandle<()>>,
174 drain_timeout: Duration,
175}
176
177impl Worker {
178 pub fn builder(config: &JobConfig, registry: &Registry) -> WorkerBuilder {
185 let snapshot = registry.snapshot();
186 let db = snapshot
187 .get::<Database>()
188 .expect("Database must be registered before building Worker");
189
190 WorkerBuilder {
191 config: config.clone(),
192 registry: snapshot,
193 db: (*db).clone(),
194 handlers: HashMap::new(),
195 }
196 }
197}
198
199impl crate::runtime::Task for Worker {
200 async fn shutdown(self) -> Result<()> {
201 self.cancel.cancel();
202 let drain = async {
203 let _ = self.poll_handle.await;
204 let _ = self.reaper_handle.await;
205 if let Some(h) = self.cleanup_handle {
206 let _ = h.await;
207 }
208 };
209 let _ = tokio::time::timeout(self.drain_timeout, drain).await;
210 Ok(())
211 }
212}
213
214struct ClaimedJob {
216 id: String,
217 name: String,
218 queue: String,
219 payload: String,
220 attempt: i32,
221}
222
223async fn poll_loop(
224 db: Database,
225 registry: Arc<RegistrySnapshot>,
226 handlers: Arc<HashMap<String, HandlerEntry>>,
227 handler_names: Vec<String>,
228 queue_semaphores: Vec<(QueueConfig, Arc<Semaphore>)>,
229 poll_interval_secs: u64,
230 cancel: CancellationToken,
231) {
232 let poll_interval = Duration::from_secs(poll_interval_secs);
233
234 let placeholders: Vec<String> = handler_names
236 .iter()
237 .enumerate()
238 .map(|(i, _)| format!("?{}", i + 5))
239 .collect();
240 let placeholders_str = placeholders.join(", ");
241 let limit_param = handler_names.len() + 5;
242 let claim_sql = format!(
243 "UPDATE jobs SET status = 'running', attempt = attempt + 1, \
244 started_at = ?1, updated_at = ?2 \
245 WHERE id IN (\
246 SELECT id FROM jobs \
247 WHERE status = 'pending' AND run_at <= ?3 \
248 AND queue = ?4 AND name IN ({placeholders_str}) \
249 ORDER BY run_at ASC LIMIT ?{limit_param}\
250 ) RETURNING id, name, queue, payload, attempt",
251 );
252
253 loop {
254 tokio::select! {
255 _ = cancel.cancelled() => break,
256 _ = tokio::time::sleep(poll_interval) => {
257 if handler_names.is_empty() {
258 continue;
259 }
260
261 let now_str = Utc::now().to_rfc3339();
262
263 for (queue_config, semaphore) in &queue_semaphores {
264 let slots = semaphore.available_permits();
265 if slots == 0 {
266 continue;
267 }
268
269 let mut params: Vec<libsql::Value> = vec![
270 libsql::Value::Text(now_str.clone()), libsql::Value::Text(now_str.clone()), libsql::Value::Text(now_str.clone()), libsql::Value::Text(queue_config.name.clone()), ];
275 for name in &handler_names {
276 params.push(libsql::Value::Text(name.clone()));
277 }
278 params.push(libsql::Value::Integer(slots as i64)); let claimed = match db.conn().query_all_map(
281 &claim_sql,
282 params,
283 |row| {
284 Ok(ClaimedJob {
285 id: String::from_value(row.get_value(0).map_err(crate::Error::from)?)?,
286 name: String::from_value(row.get_value(1).map_err(crate::Error::from)?)?,
287 queue: String::from_value(row.get_value(2).map_err(crate::Error::from)?)?,
288 payload: String::from_value(row.get_value(3).map_err(crate::Error::from)?)?,
289 attempt: i32::from_value(row.get_value(4).map_err(crate::Error::from)?)?,
290 })
291 },
292 ).await {
293 Ok(rows) => rows,
294 Err(e) => {
295 tracing::error!(error = %e, queue = %queue_config.name, "failed to claim jobs");
296 continue;
297 }
298 };
299
300 for job in claimed {
301 let Some(entry) = handlers.get(&job.name) else {
302 tracing::warn!(job_name = %job.name, "no handler registered");
303 continue;
304 };
305
306 let permit = match semaphore.clone().try_acquire_owned() {
307 Ok(p) => p,
308 Err(_) => {
309 tracing::warn!(job_id = %job.id, "no permit available, job will be reaped");
310 break;
311 }
312 };
313
314 let handler = entry.handler.clone();
315 let max_attempts = entry.options.max_attempts;
316 let timeout_secs = entry.options.timeout_secs;
317 let reg = registry.clone();
318 let db_clone = db.clone();
319 let job_id = job.id.clone();
320 let job_name = job.name.clone();
321
322 let deadline =
323 tokio::time::Instant::now() + Duration::from_secs(timeout_secs);
324
325 let meta = Meta {
326 id: job.id,
327 name: job.name,
328 queue: job.queue,
329 attempt: job.attempt as u32,
330 max_attempts,
331 deadline: Some(deadline),
332 };
333
334 let ctx = JobContext {
335 registry: reg,
336 payload: job.payload,
337 meta,
338 };
339
340 tokio::spawn(async move {
341 let result = tokio::time::timeout(
342 Duration::from_secs(timeout_secs),
343 (handler)(ctx),
344 )
345 .await;
346
347 let now_str = Utc::now().to_rfc3339();
348
349 match result {
350 Ok(Ok(())) => {
351 let _ = db_clone.conn().execute_raw(
352 "UPDATE jobs SET status = 'completed', \
353 completed_at = ?1, updated_at = ?2 WHERE id = ?3",
354 libsql::params![now_str.as_str(), now_str.as_str(), job_id.as_str()],
355 )
356 .await;
357
358 tracing::info!(
359 job_id = %job_id,
360 job_name = %job_name,
361 "job completed"
362 );
363 }
364 Ok(Err(e)) => {
365 let error_msg = format!("{e}");
366 handle_job_failure(
367 &db_clone,
368 &job_id,
369 &job_name,
370 job.attempt as u32,
371 max_attempts,
372 &error_msg,
373 &now_str,
374 )
375 .await;
376 }
377 Err(_) => {
378 handle_job_failure(
379 &db_clone,
380 &job_id,
381 &job_name,
382 job.attempt as u32,
383 max_attempts,
384 "timeout",
385 &now_str,
386 )
387 .await;
388 }
389 }
390
391 drop(permit);
392 });
393 }
394 }
395 }
396 }
397 }
398}
399
400async fn handle_job_failure(
401 db: &Database,
402 job_id: &str,
403 job_name: &str,
404 attempt: u32,
405 max_attempts: u32,
406 error_msg: &str,
407 now_str: &str,
408) {
409 if attempt >= max_attempts {
410 let _ = db
411 .conn()
412 .execute_raw(
413 "UPDATE jobs SET status = 'dead', \
414 failed_at = ?1, error_message = ?2, updated_at = ?3 WHERE id = ?4",
415 libsql::params![now_str, error_msg, now_str, job_id],
416 )
417 .await;
418
419 tracing::error!(
420 job_id = %job_id,
421 job_name = %job_name,
422 attempt = attempt,
423 error = %error_msg,
424 "job dead after max attempts"
425 );
426 } else {
427 let delay_secs = std::cmp::min(5u64 * 2u64.pow(attempt - 1), 3600);
428 let retry_at = (Utc::now() + chrono::Duration::seconds(delay_secs as i64)).to_rfc3339();
429
430 let _ = db
431 .conn()
432 .execute_raw(
433 "UPDATE jobs SET status = 'pending', \
434 run_at = ?1, started_at = NULL, \
435 failed_at = ?2, error_message = ?3, updated_at = ?4 WHERE id = ?5",
436 libsql::params![retry_at.as_str(), now_str, error_msg, now_str, job_id],
437 )
438 .await;
439
440 tracing::warn!(
441 job_id = %job_id,
442 job_name = %job_name,
443 attempt = attempt,
444 retry_in_secs = delay_secs,
445 error = %error_msg,
446 "job failed, rescheduled"
447 );
448 }
449}