1use std::sync::Arc;
2use std::time::Duration;
3
4use tokio::sync::Semaphore;
5use tokio_util::sync::CancellationToken;
6
7use crate::broker::Broker;
8use crate::context::TaskContext;
9use crate::error::KojinError;
10use crate::message::TaskMessage;
11use crate::middleware::Middleware;
12use crate::result_backend::ResultBackend;
13use crate::signature::Signature;
14
15use crate::registry::TaskRegistry;
16use crate::state::TaskState;
17
18#[derive(Debug, Clone)]
20pub struct WorkerConfig {
21 pub concurrency: usize,
23 pub queues: Vec<String>,
25 pub shutdown_timeout: Duration,
27 pub dequeue_timeout: Duration,
29}
30
31impl Default for WorkerConfig {
32 fn default() -> Self {
33 Self {
34 concurrency: 10,
35 queues: vec!["default".to_string()],
36 shutdown_timeout: Duration::from_secs(30),
37 dequeue_timeout: Duration::from_secs(5),
38 }
39 }
40}
41
42pub struct Worker<B: Broker> {
44 broker: Arc<B>,
45 registry: Arc<TaskRegistry>,
46 middlewares: Arc<Vec<Box<dyn Middleware>>>,
47 context: Arc<TaskContext>,
48 config: WorkerConfig,
49 cancel: CancellationToken,
50 result_backend: Option<Arc<dyn ResultBackend>>,
51 #[cfg(feature = "cron")]
52 cron_registry: Option<crate::cron::CronRegistry>,
53}
54
55impl<B: Broker> Worker<B> {
56 pub fn new(
60 broker: B,
61 registry: TaskRegistry,
62 context: TaskContext,
63 config: WorkerConfig,
64 ) -> Self {
65 Self {
66 broker: Arc::new(broker),
67 registry: Arc::new(registry),
68 middlewares: Arc::new(Vec::new()),
69 context: Arc::new(context),
70 config,
71 cancel: CancellationToken::new(),
72 result_backend: None,
73 #[cfg(feature = "cron")]
74 cron_registry: None,
75 }
76 }
77
78 pub fn with_result_backend(mut self, backend: Arc<dyn ResultBackend>) -> Self {
80 self.result_backend = Some(backend);
81 self
82 }
83
84 #[cfg(feature = "cron")]
86 pub fn with_cron_registry(mut self, registry: crate::cron::CronRegistry) -> Self {
87 self.cron_registry = Some(registry);
88 self
89 }
90
91 pub fn with_middleware(mut self, middleware: impl Middleware) -> Self {
93 Arc::get_mut(&mut self.middlewares)
94 .expect("middleware can only be added before starting")
95 .push(Box::new(middleware));
96 self
97 }
98
99 pub fn with_middleware_boxed(mut self, middleware: Box<dyn Middleware>) -> Self {
101 Arc::get_mut(&mut self.middlewares)
102 .expect("middleware can only be added before starting")
103 .push(middleware);
104 self
105 }
106
107 pub fn cancel_token(&self) -> CancellationToken {
109 self.cancel.clone()
110 }
111
112 pub async fn run(&self) {
114 let semaphore = Arc::new(Semaphore::new(self.config.concurrency));
115
116 #[cfg(feature = "cron")]
118 let _cron_handle = {
119 if let Some(ref cron_registry) = self.cron_registry {
120 let broker = self.broker.clone();
121 let registry = cron_registry.clone();
122 let cancel = self.cancel.clone();
123 Some(tokio::spawn(async move {
124 crate::cron::scheduler_loop(
125 broker,
126 registry,
127 cancel,
128 std::time::Duration::from_secs(1),
129 )
130 .await;
131 }))
132 } else {
133 None
134 }
135 };
136
137 tracing::info!(
138 concurrency = self.config.concurrency,
139 queues = ?self.config.queues,
140 "Worker starting"
141 );
142
143 loop {
144 if self.cancel.is_cancelled() {
145 break;
146 }
147
148 let permit = tokio::select! {
150 permit = semaphore.clone().acquire_owned() => {
151 match permit {
152 Ok(p) => p,
153 Err(_) => break, }
155 }
156 _ = self.cancel.cancelled() => break,
157 };
158
159 let message = tokio::select! {
161 result = self.broker.dequeue(&self.config.queues, self.config.dequeue_timeout) => {
162 match result {
163 Ok(Some(msg)) => msg,
164 Ok(None) => {
165 drop(permit);
166 continue; }
168 Err(e) => {
169 tracing::error!(error = %e, "Failed to dequeue");
170 drop(permit);
171 tokio::time::sleep(Duration::from_secs(1)).await;
172 continue;
173 }
174 }
175 }
176 _ = self.cancel.cancelled() => {
177 drop(permit);
178 break;
179 }
180 };
181
182 let broker = self.broker.clone();
184 let registry = self.registry.clone();
185 let middlewares = self.middlewares.clone();
186 let context = self.context.clone();
187 let result_backend = self.result_backend.clone();
188
189 tokio::spawn(async move {
190 let _permit = permit; execute_task(
192 broker,
193 registry,
194 middlewares,
195 context,
196 message,
197 result_backend,
198 )
199 .await;
200 });
201 }
202
203 tracing::info!("Worker shutting down, waiting for in-flight tasks...");
205 let drain_deadline = tokio::time::Instant::now() + self.config.shutdown_timeout;
206 loop {
207 if semaphore.available_permits() == self.config.concurrency {
209 break;
210 }
211 if tokio::time::Instant::now() >= drain_deadline {
212 tracing::warn!("Shutdown timeout reached, some tasks may not have completed");
213 break;
214 }
215 tokio::time::sleep(Duration::from_millis(100)).await;
216 }
217
218 tracing::info!("Worker stopped");
219 }
220}
221
222async fn execute_task<B: Broker>(
223 broker: Arc<B>,
224 registry: Arc<TaskRegistry>,
225 middlewares: Arc<Vec<Box<dyn Middleware>>>,
226 context: Arc<TaskContext>,
227 mut message: TaskMessage,
228 result_backend: Option<Arc<dyn ResultBackend>>,
229) {
230 let task_id = message.id;
231 let task_name = message.task_name.clone();
232
233 tracing::info!(task_id = %task_id, task_name = %task_name, "Executing task");
234 message.state = TaskState::Started;
235
236 for mw in middlewares.iter() {
238 if let Err(e) = mw.before(&message).await {
239 tracing::error!(task_id = %task_id, error = %e, "Middleware before() failed");
240 handle_failure(broker, middlewares, message, e).await;
241 return;
242 }
243 }
244
245 match registry
247 .dispatch(&task_name, message.payload.clone(), context)
248 .await
249 {
250 Ok(result) => {
251 for mw in middlewares.iter() {
253 if let Err(e) = mw.after(&message, &result).await {
254 tracing::warn!(task_id = %task_id, error = %e, "Middleware after() failed");
255 }
256 }
257 message.state = TaskState::Success;
258 if let Err(e) = broker.ack(&task_id).await {
259 tracing::error!(task_id = %task_id, error = %e, "Failed to ack task");
260 }
261
262 if let Some(ref backend) = result_backend {
264 if let Err(e) = backend.store(&task_id, &result).await {
265 tracing::error!(task_id = %task_id, error = %e, "Failed to store result");
266 }
267
268 if let Some(ref group_id) = message.group_id {
270 match backend
271 .complete_group_member(group_id, &task_id, &result)
272 .await
273 {
274 Ok(completed) => {
275 let total = message.group_total.unwrap_or(0);
276 tracing::debug!(
277 task_id = %task_id,
278 group_id = %group_id,
279 completed = completed,
280 total = total,
281 "Group member completed"
282 );
283 if completed == total {
285 if let Some(chord_callback) = message.chord_callback.take() {
286 let mut callback_msg = *chord_callback;
287 if let Ok(group_results) =
289 backend.get_group_results(group_id).await
290 {
291 if let Ok(json) = serde_json::to_string(&group_results) {
292 callback_msg
293 .headers
294 .insert("kojin.group_results".to_string(), json);
295 }
296 }
297 if let Err(e) = broker.enqueue(callback_msg).await {
298 tracing::error!(
299 group_id = %group_id,
300 error = %e,
301 "Failed to enqueue chord callback"
302 );
303 } else {
304 tracing::info!(
305 group_id = %group_id,
306 "Chord callback enqueued"
307 );
308 }
309 }
310 }
311 }
312 Err(e) => {
313 tracing::error!(
314 task_id = %task_id,
315 group_id = %group_id,
316 error = %e,
317 "Failed to complete group member"
318 );
319 }
320 }
321 }
322
323 if let Some(chain_next_json) = message.headers.get("kojin.chain_next") {
325 match serde_json::from_str::<Vec<Signature>>(chain_next_json) {
326 Ok(remaining) if !remaining.is_empty() => {
327 let mut next_msg = remaining[0].clone().into_message();
328 if let Ok(json) = serde_json::to_string(&result) {
330 next_msg
331 .headers
332 .insert("kojin.chain_input".to_string(), json);
333 }
334 if let Some(ref corr) = message.correlation_id {
336 next_msg.correlation_id = Some(corr.clone());
337 }
338 if remaining.len() > 1 {
340 let rest: Vec<Signature> = remaining[1..].to_vec();
341 if let Ok(json) = serde_json::to_string(&rest) {
342 next_msg
343 .headers
344 .insert("kojin.chain_next".to_string(), json);
345 }
346 }
347 if let Err(e) = broker.enqueue(next_msg).await {
348 tracing::error!(
349 task_id = %task_id,
350 error = %e,
351 "Failed to enqueue chain continuation"
352 );
353 } else {
354 tracing::info!(
355 task_id = %task_id,
356 remaining = remaining.len() - 1,
357 "Chain continuation enqueued"
358 );
359 }
360 }
361 Ok(_) => {} Err(e) => {
363 tracing::error!(
364 task_id = %task_id,
365 error = %e,
366 "Failed to deserialize chain_next"
367 );
368 }
369 }
370 }
371 }
372
373 tracing::info!(task_id = %task_id, task_name = %task_name, "Task completed successfully");
374 }
375 Err(e) => {
376 tracing::error!(task_id = %task_id, task_name = %task_name, error = %e, "Task failed");
377 handle_failure(broker, middlewares, message, e).await;
378 }
379 }
380}
381
382async fn handle_failure<B: Broker>(
383 broker: Arc<B>,
384 middlewares: Arc<Vec<Box<dyn Middleware>>>,
385 mut message: TaskMessage,
386 error: KojinError,
387) {
388 let task_id = message.id;
389
390 for mw in middlewares.iter() {
392 if let Err(e) = mw.on_error(&message, &error).await {
393 tracing::warn!(task_id = %task_id, error = %e, "Middleware on_error() failed");
394 }
395 }
396
397 if message.retries < message.max_retries {
399 message.retries += 1;
400 message.state = TaskState::Retry;
401 message.updated_at = chrono::Utc::now();
402
403 let backoff_delay =
404 crate::backoff::BackoffStrategy::default().delay_for(message.retries - 1);
405 tracing::info!(
406 task_id = %task_id,
407 retry = message.retries,
408 max_retries = message.max_retries,
409 backoff = ?backoff_delay,
410 "Retrying task"
411 );
412
413 tokio::time::sleep(backoff_delay).await;
415
416 if let Err(e) = broker.nack(message).await {
417 tracing::error!(task_id = %task_id, error = %e, "Failed to nack/requeue task");
418 }
419 } else {
420 message.state = TaskState::DeadLettered;
421 message.updated_at = chrono::Utc::now();
422 tracing::warn!(task_id = %task_id, "Max retries exceeded, moving to DLQ");
423
424 if let Err(e) = broker.dead_letter(message).await {
425 tracing::error!(task_id = %task_id, error = %e, "Failed to dead-letter task");
426 }
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433 use crate::memory_broker::MemoryBroker;
434 use crate::memory_result_backend::MemoryResultBackend;
435 use crate::task::Task;
436 use async_trait::async_trait;
437 use serde::{Deserialize, Serialize};
438 use std::sync::atomic::{AtomicU32, Ordering};
439
440 #[derive(Debug, Serialize, Deserialize)]
441 struct CountTask;
442
443 static COUNTER: AtomicU32 = AtomicU32::new(0);
444
445 #[async_trait]
446 impl Task for CountTask {
447 const NAME: &'static str = "count";
448 const MAX_RETRIES: u32 = 0;
449 type Output = ();
450
451 async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
452 COUNTER.fetch_add(1, Ordering::SeqCst);
453 Ok(())
454 }
455 }
456
457 #[tokio::test]
458 async fn worker_processes_tasks() {
459 let before = COUNTER.load(Ordering::SeqCst);
460
461 let broker = MemoryBroker::new();
462 let mut registry = TaskRegistry::new();
463 registry.register::<CountTask>();
464
465 for _ in 0..3 {
467 broker
468 .enqueue(TaskMessage::new(
469 "count",
470 "default",
471 serde_json::json!(null),
472 ))
473 .await
474 .unwrap();
475 }
476
477 let config = WorkerConfig {
478 concurrency: 2,
479 queues: vec!["default".to_string()],
480 shutdown_timeout: Duration::from_secs(5),
481 dequeue_timeout: Duration::from_millis(100),
482 };
483
484 let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config);
485 let cancel = worker.cancel_token();
486
487 let handle = tokio::spawn(async move {
489 worker.run().await;
490 });
491
492 tokio::time::sleep(Duration::from_millis(500)).await;
494 cancel.cancel();
495 handle.await.unwrap();
496
497 let after = COUNTER.load(Ordering::SeqCst);
498 assert_eq!(after - before, 3);
499 }
500
501 #[derive(Debug, Serialize, Deserialize)]
502 struct FailTask;
503
504 #[async_trait]
505 impl Task for FailTask {
506 const NAME: &'static str = "fail_task";
507 const MAX_RETRIES: u32 = 0;
508 type Output = ();
509
510 async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
511 Err(KojinError::TaskFailed("intentional failure".into()))
512 }
513 }
514
515 #[tokio::test]
516 async fn worker_dead_letters_after_max_retries() {
517 let broker = MemoryBroker::new();
518 let mut registry = TaskRegistry::new();
519 registry.register::<FailTask>();
520
521 broker
522 .enqueue(
523 TaskMessage::new("fail_task", "default", serde_json::json!(null))
524 .with_max_retries(0),
525 )
526 .await
527 .unwrap();
528
529 let config = WorkerConfig {
530 concurrency: 1,
531 queues: vec!["default".to_string()],
532 shutdown_timeout: Duration::from_secs(5),
533 dequeue_timeout: Duration::from_millis(100),
534 };
535
536 let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config);
537 let cancel = worker.cancel_token();
538
539 let handle = tokio::spawn(async move {
540 worker.run().await;
541 });
542
543 tokio::time::sleep(Duration::from_millis(500)).await;
544 cancel.cancel();
545 handle.await.unwrap();
546
547 assert_eq!(broker.dlq_len("default").await, 1);
548 }
549
550 #[tokio::test]
551 async fn worker_graceful_shutdown() {
552 let broker = MemoryBroker::new();
553 let registry = TaskRegistry::new();
554
555 let config = WorkerConfig {
556 concurrency: 1,
557 queues: vec!["default".to_string()],
558 shutdown_timeout: Duration::from_secs(1),
559 dequeue_timeout: Duration::from_millis(100),
560 };
561
562 let worker = Worker::new(broker, registry, TaskContext::new(), config);
563 let cancel = worker.cancel_token();
564
565 let handle = tokio::spawn(async move {
566 worker.run().await;
567 });
568
569 cancel.cancel();
571 tokio::time::timeout(Duration::from_secs(3), handle)
573 .await
574 .expect("Worker should shutdown within timeout")
575 .unwrap();
576 }
577
578 #[derive(Debug, Serialize, Deserialize)]
579 struct AddTask {
580 a: i32,
581 b: i32,
582 }
583
584 #[async_trait]
585 impl Task for AddTask {
586 const NAME: &'static str = "add";
587 const MAX_RETRIES: u32 = 0;
588 type Output = i32;
589
590 async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
591 Ok(self.a + self.b)
592 }
593 }
594
595 #[tokio::test]
596 async fn worker_stores_results() {
597 let broker = MemoryBroker::new();
598 let backend = Arc::new(MemoryResultBackend::new());
599 let mut registry = TaskRegistry::new();
600 registry.register::<AddTask>();
601
602 let msg = TaskMessage::new("add", "default", serde_json::json!({"a": 3, "b": 4}));
603 let task_id = msg.id;
604 broker.enqueue(msg).await.unwrap();
605
606 let config = WorkerConfig {
607 concurrency: 1,
608 queues: vec!["default".to_string()],
609 shutdown_timeout: Duration::from_secs(5),
610 dequeue_timeout: Duration::from_millis(100),
611 };
612
613 let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config)
614 .with_result_backend(backend.clone());
615 let cancel = worker.cancel_token();
616
617 let handle = tokio::spawn(async move {
618 worker.run().await;
619 });
620
621 tokio::time::sleep(Duration::from_millis(500)).await;
622 cancel.cancel();
623 handle.await.unwrap();
624
625 let result = backend.get(&task_id).await.unwrap();
626 assert_eq!(result, Some(serde_json::json!(7)));
627 }
628
629 static CHAIN_COUNTER: AtomicU32 = AtomicU32::new(0);
630
631 #[derive(Debug, Serialize, Deserialize)]
632 struct ChainCountTask;
633
634 #[async_trait]
635 impl Task for ChainCountTask {
636 const NAME: &'static str = "chain_count";
637 const MAX_RETRIES: u32 = 0;
638 type Output = u32;
639
640 async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
641 let val = CHAIN_COUNTER.fetch_add(1, Ordering::SeqCst) + 1;
642 Ok(val)
643 }
644 }
645
646 #[tokio::test]
647 async fn worker_chain_continuation() {
648 let broker = MemoryBroker::new();
649 let backend = Arc::new(MemoryResultBackend::new());
650 let mut registry = TaskRegistry::new();
651 registry.register::<ChainCountTask>();
652
653 let before = CHAIN_COUNTER.load(Ordering::SeqCst);
654
655 let remaining = vec![
657 crate::signature::Signature::new("chain_count", "default", serde_json::json!(null)),
658 crate::signature::Signature::new("chain_count", "default", serde_json::json!(null)),
659 ];
660 let mut msg =
661 TaskMessage::new("chain_count", "default", serde_json::json!(null)).with_max_retries(0);
662 msg.headers.insert(
663 "kojin.chain_next".to_string(),
664 serde_json::to_string(&remaining).unwrap(),
665 );
666 broker.enqueue(msg).await.unwrap();
667
668 let config = WorkerConfig {
669 concurrency: 1,
670 queues: vec!["default".to_string()],
671 shutdown_timeout: Duration::from_secs(5),
672 dequeue_timeout: Duration::from_millis(100),
673 };
674
675 let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config)
676 .with_result_backend(backend);
677 let cancel = worker.cancel_token();
678
679 let handle = tokio::spawn(async move {
680 worker.run().await;
681 });
682
683 tokio::time::sleep(Duration::from_millis(1500)).await;
684 cancel.cancel();
685 handle.await.unwrap();
686
687 let after = CHAIN_COUNTER.load(Ordering::SeqCst);
689 assert_eq!(after - before, 3);
690 }
691}