1use crate::{Error, Job, JobPayload, QueueConnection};
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::Semaphore;
11use tracing::{debug, error, info, warn};
12
13#[async_trait]
20pub trait TenantScopeProvider: Send + Sync {
21 async fn with_scope(
23 &self,
24 tenant_id: i64,
25 f: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
26 ) -> Result<(), Error>;
27}
28
29#[derive(Debug, Clone)]
31pub struct WorkerConfig {
32 pub queues: Vec<String>,
34 pub max_jobs: usize,
36 pub sleep_duration: Duration,
38 pub stop_on_error: bool,
40}
41
42impl Default for WorkerConfig {
43 fn default() -> Self {
44 Self {
45 queues: vec!["default".to_string()],
46 max_jobs: 10,
47 sleep_duration: Duration::from_secs(1),
48 stop_on_error: false,
49 }
50 }
51}
52
53impl WorkerConfig {
54 pub fn new(queues: Vec<String>) -> Self {
56 Self {
57 queues,
58 ..Default::default()
59 }
60 }
61
62 pub fn max_jobs(mut self, max: usize) -> Self {
64 self.max_jobs = max;
65 self
66 }
67}
68
69type JobHandler =
71 Arc<dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send>> + Send + Sync>;
72
73pub struct Worker {
75 connection: QueueConnection,
77 config: WorkerConfig,
79 handlers: HashMap<String, JobHandler>,
81 semaphore: Arc<Semaphore>,
83 shutdown: Arc<tokio::sync::Notify>,
85 tenant_scope: Option<Arc<dyn TenantScopeProvider>>,
87}
88
89impl Worker {
90 pub fn new(connection: QueueConnection, config: WorkerConfig) -> Self {
92 let semaphore = Arc::new(Semaphore::new(config.max_jobs));
93 Self {
94 connection,
95 config,
96 handlers: HashMap::new(),
97 semaphore,
98 shutdown: Arc::new(tokio::sync::Notify::new()),
99 tenant_scope: None,
100 }
101 }
102
103 pub fn with_tenant_scope(mut self, provider: Arc<dyn TenantScopeProvider>) -> Self {
109 self.tenant_scope = Some(provider);
110 self
111 }
112
113 pub fn register<J>(&mut self)
121 where
122 J: Job + serde::de::DeserializeOwned + 'static,
123 {
124 let type_name = std::any::type_name::<J>().to_string();
125
126 let handler: JobHandler = Arc::new(move |data: String| {
127 Box::pin(async move {
128 let job: J = serde_json::from_str(&data)
129 .map_err(|e| Error::DeserializationFailed(e.to_string()))?;
130 job.handle().await
131 })
132 });
133
134 self.handlers.insert(type_name, handler);
135 }
136
137 pub async fn run(&self) -> Result<(), Error> {
139 info!(
140 queues = ?self.config.queues,
141 max_jobs = self.config.max_jobs,
142 "Starting queue worker"
143 );
144
145 let conn = self.connection.clone();
147 let queues = self.config.queues.clone();
148 let shutdown = self.shutdown.clone();
149
150 tokio::spawn(async move {
151 loop {
152 tokio::select! {
153 _ = shutdown.notified() => break,
154 _ = tokio::time::sleep(Duration::from_secs(1)) => {
155 for queue in &queues {
156 if let Err(e) = conn.migrate_delayed(queue).await {
157 error!(queue = queue, error = %e, "Failed to migrate delayed jobs");
158 }
159 }
160 }
161 }
162 }
163 });
164
165 loop {
167 tokio::select! {
168 _ = self.shutdown.notified() => {
169 info!("Worker shutting down");
170 info!("Waiting for in-flight jobs to complete");
172 let _ = self.semaphore.acquire_many(self.config.max_jobs as u32).await;
173 return Ok(());
174 }
175 result = self.process_next() => {
176 if let Err(e) = result {
177 error!(error = %e, "Error processing job");
178 if self.config.stop_on_error {
179 return Err(e);
180 }
181 }
182 }
183 }
184 }
185 }
186
187 async fn process_next(&self) -> Result<(), Error> {
189 for queue in &self.config.queues {
191 if let Some(payload) = self.connection.pop_nowait(queue).await? {
192 self.process_job(payload).await?;
193 return Ok(());
194 }
195 }
196
197 tokio::time::sleep(self.config.sleep_duration).await;
199 Ok(())
200 }
201
202 async fn process_job(&self, payload: JobPayload) -> Result<(), Error> {
204 let permit = self.semaphore.clone().acquire_owned().await.unwrap();
205 let connection = self.connection.clone();
206 let handlers = self.handlers.clone();
207 let job_type = payload.job_type.clone();
208 let job_id = payload.id;
209 let tenant_scope = self.tenant_scope.clone();
210 let tenant_id = payload.tenant_id;
211
212 tokio::spawn(async move {
213 let _permit = permit; debug!(
216 job_id = %job_id,
217 job_type = &job_type,
218 tenant_id = ?tenant_id,
219 "Processing job"
220 );
221
222 let handler = match handlers.get(&job_type) {
223 Some(h) => h,
224 None => {
225 warn!(job_type = &job_type, "No handler registered for job type");
226 return;
227 }
228 };
229
230 let job_result = match (&tenant_scope, tenant_id) {
233 (Some(scope), Some(id)) => {
234 let job_fut = Box::pin(handler(payload.data.clone()));
235 scope.with_scope(id, job_fut).await
236 }
237 _ => handler(payload.data.clone()).await,
238 };
239
240 match job_result {
241 Ok(()) => {
242 info!(job_id = %job_id, job_type = &job_type, "Job completed successfully");
243 }
244 Err(e) => {
245 error!(job_id = %job_id, job_type = &job_type, error = %e, "Job failed");
246
247 if payload.has_exceeded_retries() {
248 warn!(job_id = %job_id, "Job exceeded max retries, moving to failed queue");
249 if let Err(e) = connection.fail(payload, &e).await {
250 error!(error = %e, "Failed to move job to failed queue");
251 }
252 } else {
253 let delay = Duration::from_secs(2u64.pow(payload.attempts));
254 if let Err(e) = connection.release(payload, delay).await {
255 error!(error = %e, "Failed to release job for retry");
256 }
257 }
258 }
259 }
260 });
261
262 Ok(())
263 }
264
265 pub fn shutdown(&self) {
267 self.shutdown.notify_waiters();
268 }
269}
270
271impl Clone for Worker {
273 fn clone(&self) -> Self {
274 Self {
275 connection: self.connection.clone(),
276 config: self.config.clone(),
277 handlers: HashMap::new(), semaphore: self.semaphore.clone(),
279 shutdown: self.shutdown.clone(),
280 tenant_scope: self.tenant_scope.clone(),
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use std::sync::Mutex;
289
290 #[test]
292 fn test_tenant_scope_provider_is_object_safe() {
293 struct NoopProvider;
294
295 #[async_trait]
296 impl TenantScopeProvider for NoopProvider {
297 async fn with_scope(
298 &self,
299 _tenant_id: i64,
300 f: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
301 ) -> Result<(), Error> {
302 f.await
303 }
304 }
305
306 let _provider: Arc<dyn TenantScopeProvider> = Arc::new(NoopProvider);
308 }
309
310 struct MockScopeProvider {
312 called_with: Arc<Mutex<Vec<i64>>>,
313 should_fail: bool,
314 }
315
316 impl MockScopeProvider {
317 fn new() -> Self {
318 Self {
319 called_with: Arc::new(Mutex::new(Vec::new())),
320 should_fail: false,
321 }
322 }
323
324 fn failing() -> Self {
325 Self {
326 called_with: Arc::new(Mutex::new(Vec::new())),
327 should_fail: true,
328 }
329 }
330 }
331
332 #[async_trait]
333 impl TenantScopeProvider for MockScopeProvider {
334 async fn with_scope(
335 &self,
336 tenant_id: i64,
337 f: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
338 ) -> Result<(), Error> {
339 self.called_with.lock().unwrap().push(tenant_id);
340 if self.should_fail {
341 return Err(Error::tenant_not_found(tenant_id));
342 }
343 f.await
344 }
345 }
346
347 async fn make_worker() -> Worker {
353 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
354 use tokio::net::TcpListener;
355
356 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
357 let port = listener.local_addr().unwrap().port();
358
359 tokio::spawn(async move {
361 loop {
362 let Ok((mut stream, _)) = listener.accept().await else {
363 break;
364 };
365 tokio::spawn(async move {
366 let (reader, mut writer) = stream.split();
367 let mut lines = BufReader::new(reader).lines();
368 while let Ok(Some(_line)) = lines.next_line().await {
369 let _ = writer.write_all(b"+OK\r\n").await;
373 }
374 });
375 }
376 });
377
378 let config = crate::QueueConfig::new(format!("redis://127.0.0.1:{port}"));
379 let conn = tokio::time::timeout(
380 std::time::Duration::from_secs(2),
381 crate::QueueConnection::new(config),
382 )
383 .await
384 .expect("fake Redis connection timed out")
385 .expect("fake Redis connection failed");
386
387 Worker::new(conn, WorkerConfig::default())
388 }
389
390 #[tokio::test]
392 async fn test_with_tenant_scope_stores_provider() {
393 let worker = make_worker().await;
394 let provider = Arc::new(MockScopeProvider::new());
395 let worker = worker.with_tenant_scope(provider);
396 assert!(
397 worker.tenant_scope.is_some(),
398 "tenant_scope must be Some after with_tenant_scope()"
399 );
400 }
401
402 #[tokio::test]
404 async fn test_worker_without_scope_has_none_by_default() {
405 let worker = make_worker().await;
406 assert!(
407 worker.tenant_scope.is_none(),
408 "tenant_scope must be None by default"
409 );
410 }
411
412 #[tokio::test]
414 async fn test_clone_preserves_tenant_scope() {
415 let worker = make_worker().await;
416 let provider: Arc<dyn TenantScopeProvider> = Arc::new(MockScopeProvider::new());
417 let worker = worker.with_tenant_scope(provider);
418 let cloned = worker.clone();
419 assert!(
420 cloned.tenant_scope.is_some(),
421 "Clone must preserve tenant_scope"
422 );
423 }
424
425 #[tokio::test]
427 async fn test_clone_without_scope_preserves_none() {
428 let worker = make_worker().await;
429 let cloned = worker.clone();
430 assert!(
431 cloned.tenant_scope.is_none(),
432 "Clone must preserve None tenant_scope"
433 );
434 }
435
436 #[tokio::test]
438 async fn test_mock_scope_provider_calls_future() {
439 let provider = MockScopeProvider::new();
440 let calls = provider.called_with.clone();
441
442 let result = provider.with_scope(42, Box::pin(async { Ok(()) })).await;
443
444 assert!(result.is_ok());
445 assert_eq!(calls.lock().unwrap().as_slice(), &[42]);
446 }
447
448 #[tokio::test]
450 async fn test_mock_scope_provider_failure_returns_tenant_not_found() {
451 let provider = MockScopeProvider::failing();
452
453 let result = provider.with_scope(99, Box::pin(async { Ok(()) })).await;
454
455 assert!(matches!(
456 result,
457 Err(Error::TenantNotFound { tenant_id: 99 })
458 ));
459 }
460
461 #[tokio::test]
463 async fn test_scope_dispatch_tenant_id_some_calls_with_scope() {
464 let mock = MockScopeProvider::new();
465 let calls = mock.called_with.clone();
466 let provider: Arc<dyn TenantScopeProvider> = Arc::new(mock);
467
468 let tenant_id: Option<i64> = Some(1);
470 let tenant_scope: Option<Arc<dyn TenantScopeProvider>> = Some(provider);
471
472 let job_ran = Arc::new(Mutex::new(false));
473 let job_ran_clone = job_ran.clone();
474 let job_fut = Box::pin(async move {
475 *job_ran_clone.lock().unwrap() = true;
476 Ok(())
477 });
478
479 let result = match (&tenant_scope, tenant_id) {
480 (Some(scope), Some(id)) => scope.with_scope(id, job_fut).await,
481 _ => job_fut.await,
482 };
483
484 assert!(result.is_ok());
485 assert_eq!(calls.lock().unwrap().as_slice(), &[1i64]);
486 assert!(*job_ran.lock().unwrap(), "job future must have been called");
487 }
488
489 #[tokio::test]
491 async fn test_scope_dispatch_tenant_id_none_skips_with_scope() {
492 let mock = MockScopeProvider::new();
493 let calls = mock.called_with.clone();
494 let provider: Arc<dyn TenantScopeProvider> = Arc::new(mock);
495
496 let tenant_id: Option<i64> = None;
497 let tenant_scope: Option<Arc<dyn TenantScopeProvider>> = Some(provider);
498
499 let job_ran = Arc::new(Mutex::new(false));
500 let job_ran_clone = job_ran.clone();
501 let job_fut = Box::pin(async move {
502 *job_ran_clone.lock().unwrap() = true;
503 Ok(())
504 });
505
506 let result = match (&tenant_scope, tenant_id) {
507 (Some(scope), Some(id)) => scope.with_scope(id, job_fut).await,
508 _ => job_fut.await,
509 };
510
511 assert!(result.is_ok());
512 assert!(
513 calls.lock().unwrap().is_empty(),
514 "with_scope must not be called when tenant_id is None"
515 );
516 assert!(
517 *job_ran.lock().unwrap(),
518 "job future must still run directly"
519 );
520 }
521
522 #[tokio::test]
524 async fn test_scope_dispatch_no_provider_runs_job_directly() {
525 let tenant_id: Option<i64> = Some(1);
526 let tenant_scope: Option<Arc<dyn TenantScopeProvider>> = None;
527
528 let job_ran = Arc::new(Mutex::new(false));
529 let job_ran_clone = job_ran.clone();
530 let job_fut = Box::pin(async move {
531 *job_ran_clone.lock().unwrap() = true;
532 Ok(())
533 });
534
535 let result = match (&tenant_scope, tenant_id) {
536 (Some(scope), Some(id)) => scope.with_scope(id, job_fut).await,
537 _ => job_fut.await,
538 };
539
540 assert!(result.is_ok());
541 assert!(
542 *job_ran.lock().unwrap(),
543 "job must run directly without a provider"
544 );
545 }
546}