1use crate::{Error, Job};
13use async_trait::async_trait;
14use chrono::Utc;
15use futures::FutureExt;
16use sea_orm::DatabaseConnection;
17use std::collections::HashMap;
18use std::future::Future;
19use std::panic::AssertUnwindSafe;
20use std::pin::Pin;
21use std::sync::atomic::{AtomicBool, Ordering};
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::sync::Semaphore;
25use tracing::{debug, error, info, warn};
26
27#[async_trait]
38pub trait TenantScopeProvider: Send + Sync {
39 async fn with_scope(
41 &self,
42 tenant_id: i64,
43 f: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
44 ) -> Result<(), Error>;
45}
46
47#[derive(Debug, Clone)]
53pub struct WorkerConfig {
54 pub queues: Vec<String>,
56 pub max_jobs: usize,
58 pub sleep_duration: Duration,
60 pub stop_on_error: bool,
62 pub visibility_timeout: Duration,
64}
65
66impl Default for WorkerConfig {
67 fn default() -> Self {
68 Self {
69 queues: vec!["default".to_string()],
70 max_jobs: 10,
71 sleep_duration: Duration::from_secs(1),
72 stop_on_error: false,
73 visibility_timeout: Duration::from_secs(300),
74 }
75 }
76}
77
78impl WorkerConfig {
79 pub fn new(queues: Vec<String>) -> Self {
81 Self {
82 queues,
83 ..Default::default()
84 }
85 }
86
87 pub fn max_jobs(mut self, max: usize) -> Self {
89 self.max_jobs = max;
90 self
91 }
92
93 pub fn with_visibility_timeout(mut self, d: Duration) -> Self {
95 self.visibility_timeout = d;
96 self
97 }
98}
99
100type JobHandler = Arc<
109 dyn Fn(String, u32) -> Pin<Box<dyn Future<Output = (Result<(), Error>, Duration)> + Send>>
110 + Send
111 + Sync,
112>;
113
114pub struct WorkerLoop {
124 config: WorkerConfig,
126 handlers: HashMap<String, JobHandler>,
128 semaphore: Arc<Semaphore>,
130 shutdown: Arc<AtomicBool>,
132 tenant_scope: Option<Arc<dyn TenantScopeProvider>>,
134 worker_id: String,
136}
137
138impl WorkerLoop {
139 pub fn new(config: WorkerConfig) -> Self {
144 let semaphore = Arc::new(Semaphore::new(config.max_jobs));
145 Self {
146 config,
147 handlers: HashMap::new(),
148 semaphore,
149 shutdown: Arc::new(AtomicBool::new(false)),
150 tenant_scope: None,
151 worker_id: uuid::Uuid::new_v4().to_string(),
152 }
153 }
154
155 pub fn with_tenant_scope(mut self, provider: Arc<dyn TenantScopeProvider>) -> Self {
161 self.tenant_scope = Some(provider);
162 self
163 }
164
165 pub fn register<J>(&mut self)
176 where
177 J: Job + serde::de::DeserializeOwned + 'static,
178 {
179 let type_name = std::any::type_name::<J>().to_string();
180
181 let handler: JobHandler = Arc::new(move |data: String, attempt: u32| {
182 Box::pin(async move {
183 let job: J = match serde_json::from_str::<J>(&data) {
184 Ok(j) => j,
185 Err(e) => {
186 return (
187 Err(Error::DeserializationFailed(e.to_string())),
188 Duration::from_secs(5),
189 )
190 }
191 };
192 let delay = job.retry_delay(attempt);
194 let result = job.handle().await;
195 (result, delay)
196 })
197 });
198
199 self.handlers.insert(type_name, handler);
200 }
201
202 pub fn from_registry(config: WorkerConfig) -> Self {
207 let mut w = Self::new(config);
208 crate::db::Queue::apply_registrars(&mut w);
209 w
210 }
211
212 pub fn shutdown(&self) {
217 self.shutdown.store(true, Ordering::SeqCst);
218 }
219
220 pub async fn run(&self) -> Result<(), Error> {
228 let conn: &'static DatabaseConnection = crate::db::Queue::connection();
229
230 info!(
231 worker_id = %self.worker_id,
232 queues = ?self.config.queues,
233 max_jobs = self.config.max_jobs,
234 "WorkerLoop starting"
235 );
236
237 let signal_task = {
243 let shutdown = self.shutdown.clone();
244 tokio::spawn(async move {
245 let mut sigterm = match tokio::signal::unix::signal(
248 tokio::signal::unix::SignalKind::terminate(),
249 ) {
250 Ok(s) => s,
251 Err(e) => {
252 error!(error = %e, "failed to install SIGTERM handler — requesting shutdown");
253 shutdown.store(true, Ordering::SeqCst);
254 return;
255 }
256 };
257 tokio::select! {
258 _ = sigterm.recv() => {
259 info!("SIGTERM received — shutting down WorkerLoop");
260 }
261 _ = tokio::signal::ctrl_c() => {
262 info!("Ctrl-C received — shutting down WorkerLoop");
263 }
264 }
265 shutdown.store(true, Ordering::SeqCst);
266 })
267 };
268 let _signal_guard = AbortOnDrop(signal_task);
270
271 'outer: loop {
272 if self.shutdown.load(Ordering::SeqCst) {
274 info!(worker_id = %self.worker_id, "Shutdown flag set — draining in-flight jobs");
275
276 let _drain_guard = self
283 .semaphore
284 .acquire_many(self.config.max_jobs as u32)
285 .await;
286
287 crate::db::requeue_claimed_by(conn, &self.worker_id)
289 .await
290 .map_err(|e| {
291 error!(error = %e, "requeue_claimed_by failed during shutdown");
292 e
293 })?;
294
295 info!(worker_id = %self.worker_id, "WorkerLoop shut down cleanly");
296 return Ok(());
298 }
299
300 for queue in &self.config.queues {
302 match crate::db::reaper(conn, queue, self.config.visibility_timeout).await {
304 Ok(()) => {}
305 Err(e) => {
306 error!(queue = %queue, error = %e, "reaper error");
307 if self.config.stop_on_error {
308 return Err(e);
309 }
310 }
311 }
312
313 match crate::db::claim(conn, queue, &self.worker_id).await {
315 Ok(Some(job_row)) => {
316 self.spawn_job(conn, job_row);
317 continue 'outer; }
319 Ok(None) => {} Err(e) => {
321 error!(queue = %queue, error = %e, "claim error");
322 if self.config.stop_on_error {
323 return Err(e);
324 }
325 }
326 }
327 }
328
329 tokio::time::sleep(self.config.sleep_duration).await;
331 }
332 }
333
334 fn spawn_job(&self, conn: &'static DatabaseConnection, job_row: crate::db::JobRow) {
339 let permit = self.semaphore.clone();
342 let handlers = self.handlers.clone();
343 let tenant_scope = self.tenant_scope.clone();
344 let worker_id = self.worker_id.clone();
345 let shutdown = self.shutdown.clone();
346
347 tokio::spawn(async move {
348 if shutdown.load(Ordering::SeqCst) {
353 return;
354 }
355
356 let _permit = permit.acquire_owned().await.expect("semaphore closed");
358
359 if shutdown.load(Ordering::SeqCst) {
364 return;
365 }
366
367 let job_id = job_row.id;
368 let job_type = job_row.job_type.clone();
369 let tenant_id = job_row.tenant_id;
370 let attempts = job_row.attempts;
371 let max_retries = job_row.max_retries;
372
373 debug!(
374 job_id = %job_id,
375 job_type = %job_type,
376 attempts = attempts,
377 tenant_id = ?tenant_id,
378 worker_id = %worker_id,
379 "Executing job"
380 );
381
382 let handler = match handlers.get(&job_type) {
383 Some(h) => h.clone(),
384 None => {
385 warn!(job_type = %job_type, "No handler registered — releasing job for retry");
386 let available_at = Utc::now()
388 + chrono::Duration::from_std(Duration::from_secs(5)).unwrap_or_default();
389 crate::db::release_job(conn, job_id, attempts + 1, available_at)
390 .await
391 .ok();
392 return;
393 }
394 };
395
396 let result = AssertUnwindSafe(async move {
401 match (&tenant_scope, tenant_id) {
403 (Some(scope), Some(id)) => {
404 let fut = Box::pin(async move {
405 let (res, _delay) = handler(job_row.payload.clone(), attempts).await;
406 res
407 });
408 (scope.with_scope(id, fut).await, Duration::from_secs(5))
409 }
410 _ => handler(job_row.payload.clone(), attempts).await,
411 }
412 })
413 .catch_unwind()
414 .await;
415
416 match result {
417 Ok((Ok(()), _)) => {
419 debug!(job_id = %job_id, job_type = %job_type, "Job succeeded — deleting row");
420 crate::db::delete_job(conn, job_id).await.ok();
421 }
422
423 Ok((Err(e), retry_delay)) => {
425 error!(job_id = %job_id, job_type = %job_type, error = %e, "Job handler returned error");
426 handle_failure(
427 conn,
428 job_id,
429 attempts,
430 max_retries,
431 &e.to_string(),
432 retry_delay,
433 )
434 .await;
435 }
436
437 Err(_panic) => {
439 error!(job_id = %job_id, job_type = %job_type, "Job handler panicked — counting as failure");
440 let msg = "job handler panicked";
441 let delay = default_jitter_delay(attempts);
444 handle_failure(conn, job_id, attempts, max_retries, msg, delay).await;
445 }
446 }
447 });
448 }
449}
450
451async fn handle_failure(
453 conn: &'static DatabaseConnection,
454 job_id: i64,
455 attempts: u32,
456 max_retries: u32,
457 err_msg: &str,
458 retry_delay: Duration,
459) {
460 if attempts + 1 >= max_retries {
461 warn!(job_id = %job_id, attempts = attempts, "Job exhausted retries — parking as failed");
463 crate::db::fail_job(conn, job_id, err_msg).await.ok();
464 } else {
465 let available_at = Utc::now() + chrono::Duration::from_std(retry_delay).unwrap_or_default();
467 debug!(
468 job_id = %job_id,
469 retry_at = %available_at,
470 "Scheduling job retry"
471 );
472 crate::db::release_job(conn, job_id, attempts + 1, available_at)
473 .await
474 .ok();
475 }
476}
477
478fn default_jitter_delay(attempt: u32) -> Duration {
483 use rand::Rng;
484 let base_secs: u64 = 5;
485 let cap_secs: u64 = 15 * 60;
486 let max_delay = cap_secs.min(base_secs.saturating_mul(2u64.saturating_pow(attempt)));
487 let jitter = rand::thread_rng().gen_range(0..=max_delay);
488 Duration::from_secs(jitter)
489}
490
491struct AbortOnDrop(tokio::task::JoinHandle<()>);
496
497impl Drop for AbortOnDrop {
498 fn drop(&mut self) {
499 self.0.abort();
500 }
501}
502
503pub type Worker = WorkerLoop;
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use std::sync::Mutex;
510
511 #[test]
513 fn test_tenant_scope_provider_is_object_safe() {
514 struct NoopProvider;
515
516 #[async_trait]
517 impl TenantScopeProvider for NoopProvider {
518 async fn with_scope(
519 &self,
520 _tenant_id: i64,
521 f: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
522 ) -> Result<(), Error> {
523 f.await
524 }
525 }
526
527 let _provider: Arc<dyn TenantScopeProvider> = Arc::new(NoopProvider);
529 }
530
531 struct MockScopeProvider {
533 called_with: Arc<Mutex<Vec<i64>>>,
534 should_fail: bool,
535 }
536
537 impl MockScopeProvider {
538 fn new() -> Self {
539 Self {
540 called_with: Arc::new(Mutex::new(Vec::new())),
541 should_fail: false,
542 }
543 }
544
545 fn failing() -> Self {
546 Self {
547 called_with: Arc::new(Mutex::new(Vec::new())),
548 should_fail: true,
549 }
550 }
551 }
552
553 #[async_trait]
554 impl TenantScopeProvider for MockScopeProvider {
555 async fn with_scope(
556 &self,
557 tenant_id: i64,
558 f: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
559 ) -> Result<(), Error> {
560 self.called_with.lock().unwrap().push(tenant_id);
561 if self.should_fail {
562 return Err(Error::tenant_not_found(tenant_id));
563 }
564 f.await
565 }
566 }
567
568 #[test]
570 fn test_worker_loop_new() {
571 let w = WorkerLoop::new(WorkerConfig::default());
572 assert!(w.tenant_scope.is_none());
573 assert!(!w.worker_id.is_empty());
574 }
575
576 #[test]
578 fn test_with_tenant_scope_stores_provider() {
579 let w = WorkerLoop::new(WorkerConfig::default());
580 let provider = Arc::new(MockScopeProvider::new());
581 let w = w.with_tenant_scope(provider);
582 assert!(w.tenant_scope.is_some());
583 }
584
585 #[test]
587 fn test_worker_without_scope_has_none_by_default() {
588 let w = WorkerLoop::new(WorkerConfig::default());
589 assert!(w.tenant_scope.is_none());
590 }
591
592 #[tokio::test]
594 async fn test_mock_scope_provider_calls_future() {
595 let provider = MockScopeProvider::new();
596 let calls = provider.called_with.clone();
597
598 let result = provider.with_scope(42, Box::pin(async { Ok(()) })).await;
599
600 assert!(result.is_ok());
601 assert_eq!(calls.lock().unwrap().as_slice(), &[42]);
602 }
603
604 #[tokio::test]
606 async fn test_mock_scope_provider_failure_returns_tenant_not_found() {
607 let provider = MockScopeProvider::failing();
608
609 let result = provider.with_scope(99, Box::pin(async { Ok(()) })).await;
610
611 assert!(matches!(
612 result,
613 Err(Error::TenantNotFound { tenant_id: 99 })
614 ));
615 }
616
617 #[tokio::test]
619 async fn test_scope_dispatch_tenant_id_some_calls_with_scope() {
620 let mock = MockScopeProvider::new();
621 let calls = mock.called_with.clone();
622 let provider: Arc<dyn TenantScopeProvider> = Arc::new(mock);
623
624 let tenant_id: Option<i64> = Some(1);
625 let tenant_scope: Option<Arc<dyn TenantScopeProvider>> = Some(provider);
626
627 let job_ran = Arc::new(Mutex::new(false));
628 let job_ran_clone = job_ran.clone();
629 let job_fut = Box::pin(async move {
630 *job_ran_clone.lock().unwrap() = true;
631 Ok(())
632 });
633
634 let result = match (&tenant_scope, tenant_id) {
635 (Some(scope), Some(id)) => scope.with_scope(id, job_fut).await,
636 _ => job_fut.await,
637 };
638
639 assert!(result.is_ok());
640 assert_eq!(calls.lock().unwrap().as_slice(), &[1i64]);
641 assert!(*job_ran.lock().unwrap(), "job future must have been called");
642 }
643
644 #[tokio::test]
646 async fn test_scope_dispatch_tenant_id_none_skips_with_scope() {
647 let mock = MockScopeProvider::new();
648 let calls = mock.called_with.clone();
649 let provider: Arc<dyn TenantScopeProvider> = Arc::new(mock);
650
651 let tenant_id: Option<i64> = None;
652 let tenant_scope: Option<Arc<dyn TenantScopeProvider>> = Some(provider);
653
654 let job_ran = Arc::new(Mutex::new(false));
655 let job_ran_clone = job_ran.clone();
656 let job_fut = Box::pin(async move {
657 *job_ran_clone.lock().unwrap() = true;
658 Ok(())
659 });
660
661 let result = match (&tenant_scope, tenant_id) {
662 (Some(scope), Some(id)) => scope.with_scope(id, job_fut).await,
663 _ => job_fut.await,
664 };
665
666 assert!(result.is_ok());
667 assert!(
668 calls.lock().unwrap().is_empty(),
669 "with_scope must not be called when tenant_id is None"
670 );
671 assert!(
672 *job_ran.lock().unwrap(),
673 "job future must still run directly"
674 );
675 }
676
677 #[tokio::test]
679 async fn test_scope_dispatch_no_provider_runs_job_directly() {
680 let tenant_id: Option<i64> = Some(1);
681 let tenant_scope: Option<Arc<dyn TenantScopeProvider>> = None;
682
683 let job_ran = Arc::new(Mutex::new(false));
684 let job_ran_clone = job_ran.clone();
685 let job_fut = Box::pin(async move {
686 *job_ran_clone.lock().unwrap() = true;
687 Ok(())
688 });
689
690 let result = match (&tenant_scope, tenant_id) {
691 (Some(scope), Some(id)) => scope.with_scope(id, job_fut).await,
692 _ => job_fut.await,
693 };
694
695 assert!(result.is_ok());
696 assert!(
697 *job_ran.lock().unwrap(),
698 "job must run directly without a provider"
699 );
700 }
701
702 #[test]
704 fn test_shutdown_sets_flag() {
705 let w = WorkerLoop::new(WorkerConfig::default());
706 assert!(!w.shutdown.load(Ordering::SeqCst));
707 w.shutdown();
708 assert!(w.shutdown.load(Ordering::SeqCst));
709 }
710
711 #[test]
713 fn test_worker_config_visibility_timeout_default() {
714 let c = WorkerConfig::default();
715 assert_eq!(c.visibility_timeout, Duration::from_secs(300));
716 }
717
718 #[test]
720 fn test_default_jitter_delay_bounds() {
721 for _ in 0..50 {
722 assert!(default_jitter_delay(0).as_secs() <= 5);
723 assert!(default_jitter_delay(3).as_secs() <= 40);
724 assert!(default_jitter_delay(30).as_secs() <= 900);
725 }
726 }
727}