1use std::sync::Arc;
2use std::time::Duration;
3
4use tokio::sync::Semaphore;
5use tokio_util::sync::CancellationToken;
6
7use crate::broker::Broker;
8use crate::context::TaskContext;
9use crate::error::KojinError;
10use crate::message::TaskMessage;
11use crate::middleware::Middleware;
12use crate::result_backend::ResultBackend;
13use crate::signature::Signature;
14
15use crate::registry::TaskRegistry;
16use crate::state::TaskState;
17
18#[derive(Debug, Clone)]
20pub struct WorkerConfig {
21 pub concurrency: usize,
23 pub queues: Vec<String>,
25 pub shutdown_timeout: Duration,
27 pub dequeue_timeout: Duration,
29}
30
31impl Default for WorkerConfig {
32 fn default() -> Self {
33 Self {
34 concurrency: 10,
35 queues: vec!["default".to_string()],
36 shutdown_timeout: Duration::from_secs(30),
37 dequeue_timeout: Duration::from_secs(5),
38 }
39 }
40}
41
42pub struct Worker<B: Broker> {
44 broker: Arc<B>,
45 registry: Arc<TaskRegistry>,
46 middlewares: Arc<Vec<Box<dyn Middleware>>>,
47 context: Arc<TaskContext>,
48 config: WorkerConfig,
49 cancel: CancellationToken,
50 result_backend: Option<Arc<dyn ResultBackend>>,
51 #[cfg(feature = "cron")]
52 cron_registry: Option<crate::cron::CronRegistry>,
53}
54
55impl<B: Broker> Worker<B> {
56 pub fn new(
60 broker: B,
61 registry: TaskRegistry,
62 context: TaskContext,
63 config: WorkerConfig,
64 ) -> Self {
65 Self {
66 broker: Arc::new(broker),
67 registry: Arc::new(registry),
68 middlewares: Arc::new(Vec::new()),
69 context: Arc::new(context),
70 config,
71 cancel: CancellationToken::new(),
72 result_backend: None,
73 #[cfg(feature = "cron")]
74 cron_registry: None,
75 }
76 }
77
78 pub fn with_result_backend(mut self, backend: Arc<dyn ResultBackend>) -> Self {
80 self.result_backend = Some(backend);
81 self
82 }
83
84 #[cfg(feature = "cron")]
86 pub fn with_cron_registry(mut self, registry: crate::cron::CronRegistry) -> Self {
87 self.cron_registry = Some(registry);
88 self
89 }
90
91 pub fn with_middleware(mut self, middleware: impl Middleware) -> Self {
93 Arc::get_mut(&mut self.middlewares)
94 .expect("middleware can only be added before starting")
95 .push(Box::new(middleware));
96 self
97 }
98
99 pub fn with_middleware_boxed(mut self, middleware: Box<dyn Middleware>) -> Self {
101 Arc::get_mut(&mut self.middlewares)
102 .expect("middleware can only be added before starting")
103 .push(middleware);
104 self
105 }
106
107 pub fn cancel_token(&self) -> CancellationToken {
109 self.cancel.clone()
110 }
111
112 pub async fn run(&self) {
114 let semaphore = Arc::new(Semaphore::new(self.config.concurrency));
115
116 #[cfg(feature = "cron")]
118 let _cron_handle = {
119 if let Some(ref cron_registry) = self.cron_registry {
120 let broker = self.broker.clone();
121 let registry = cron_registry.clone();
122 let cancel = self.cancel.clone();
123 Some(tokio::spawn(async move {
124 crate::cron::scheduler_loop(
125 broker,
126 registry,
127 cancel,
128 std::time::Duration::from_secs(1),
129 )
130 .await;
131 }))
132 } else {
133 None
134 }
135 };
136
137 tracing::info!(
138 concurrency = self.config.concurrency,
139 queues = ?self.config.queues,
140 "Worker starting"
141 );
142
143 loop {
144 if self.cancel.is_cancelled() {
145 break;
146 }
147
148 let permit = tokio::select! {
150 permit = semaphore.clone().acquire_owned() => {
151 match permit {
152 Ok(p) => p,
153 Err(_) => break, }
155 }
156 _ = self.cancel.cancelled() => break,
157 };
158
159 let message = tokio::select! {
161 result = self.broker.dequeue(&self.config.queues, self.config.dequeue_timeout) => {
162 match result {
163 Ok(Some(msg)) => msg,
164 Ok(None) => {
165 drop(permit);
166 continue; }
168 Err(e) => {
169 tracing::error!(error = %e, "Failed to dequeue");
170 drop(permit);
171 tokio::time::sleep(Duration::from_secs(1)).await;
172 continue;
173 }
174 }
175 }
176 _ = self.cancel.cancelled() => {
177 drop(permit);
178 break;
179 }
180 };
181
182 let broker = self.broker.clone();
184 let registry = self.registry.clone();
185 let middlewares = self.middlewares.clone();
186 let context = self.context.clone();
187 let result_backend = self.result_backend.clone();
188
189 tokio::spawn(async move {
190 let _permit = permit; execute_task(
192 broker,
193 registry,
194 middlewares,
195 context,
196 message,
197 result_backend,
198 )
199 .await;
200 });
201 }
202
203 tracing::info!("Worker shutting down, waiting for in-flight tasks...");
205 let drain_deadline = tokio::time::Instant::now() + self.config.shutdown_timeout;
206 loop {
207 if semaphore.available_permits() == self.config.concurrency {
209 break;
210 }
211 if tokio::time::Instant::now() >= drain_deadline {
212 tracing::warn!("Shutdown timeout reached, some tasks may not have completed");
213 break;
214 }
215 tokio::time::sleep(Duration::from_millis(100)).await;
216 }
217
218 tracing::info!("Worker stopped");
219 }
220}
221
222async fn execute_task<B: Broker>(
223 broker: Arc<B>,
224 registry: Arc<TaskRegistry>,
225 middlewares: Arc<Vec<Box<dyn Middleware>>>,
226 context: Arc<TaskContext>,
227 mut message: TaskMessage,
228 result_backend: Option<Arc<dyn ResultBackend>>,
229) {
230 let task_id = message.id;
231 let task_name = message.task_name.clone();
232
233 if let Some(eta) = message.eta {
235 if eta > chrono::Utc::now() {
236 tracing::debug!(task_id = %task_id, %eta, "task eta is in the future — re-scheduling");
237 if let Err(e) = broker.ack(&task_id).await {
238 tracing::error!(task_id = %task_id, error = %e, "failed to ack before re-schedule");
239 }
240 if let Err(e) = broker.schedule(message, eta).await {
241 tracing::error!(task_id = %task_id, error = %e, "failed to re-schedule task with future eta");
242 }
243 return;
244 }
245 }
246
247 tracing::info!(task_id = %task_id, task_name = %task_name, "Executing task");
248 message.state = TaskState::Started;
249
250 for mw in middlewares.iter() {
252 if let Err(e) = mw.before(&message).await {
253 tracing::error!(task_id = %task_id, error = %e, "Middleware before() failed");
254 handle_failure(broker, middlewares, message, e).await;
255 return;
256 }
257 }
258
259 match registry
261 .dispatch(&task_name, message.payload.clone(), context)
262 .await
263 {
264 Ok(result) => {
265 for mw in middlewares.iter() {
267 if let Err(e) = mw.after(&message, &result).await {
268 tracing::warn!(task_id = %task_id, error = %e, "Middleware after() failed");
269 }
270 }
271 message.state = TaskState::Success;
272 if let Err(e) = broker.ack(&task_id).await {
273 tracing::error!(task_id = %task_id, error = %e, "Failed to ack task");
274 }
275
276 if let Some(ref backend) = result_backend {
278 if let Err(e) = backend.store(&task_id, &result).await {
279 tracing::error!(task_id = %task_id, error = %e, "Failed to store result");
280 }
281
282 if let Some(ref group_id) = message.group_id {
284 match backend
285 .complete_group_member(group_id, &task_id, &result)
286 .await
287 {
288 Ok(completed) => {
289 let total = message.group_total.unwrap_or(0);
290 tracing::debug!(
291 task_id = %task_id,
292 group_id = %group_id,
293 completed = completed,
294 total = total,
295 "Group member completed"
296 );
297 if completed == total {
299 if let Some(chord_callback) = message.chord_callback.take() {
300 let mut callback_msg = *chord_callback;
301 if let Ok(group_results) =
303 backend.get_group_results(group_id).await
304 {
305 if let Ok(json) = serde_json::to_string(&group_results) {
306 callback_msg
307 .headers
308 .insert("kojin.group_results".to_string(), json);
309 }
310 }
311 if let Err(e) = broker.enqueue(callback_msg).await {
312 tracing::error!(
313 group_id = %group_id,
314 error = %e,
315 "Failed to enqueue chord callback"
316 );
317 } else {
318 tracing::info!(
319 group_id = %group_id,
320 "Chord callback enqueued"
321 );
322 }
323 }
324 }
325 }
326 Err(e) => {
327 tracing::error!(
328 task_id = %task_id,
329 group_id = %group_id,
330 error = %e,
331 "Failed to complete group member"
332 );
333 }
334 }
335 }
336
337 if let Some(chain_next_json) = message.headers.get("kojin.chain_next") {
339 match serde_json::from_str::<Vec<Signature>>(chain_next_json) {
340 Ok(remaining) if !remaining.is_empty() => {
341 let mut next_msg = remaining[0].clone().into_message();
342 if let Ok(json) = serde_json::to_string(&result) {
344 next_msg
345 .headers
346 .insert("kojin.chain_input".to_string(), json);
347 }
348 if let Some(ref corr) = message.correlation_id {
350 next_msg.correlation_id = Some(corr.clone());
351 }
352 if remaining.len() > 1 {
354 let rest: Vec<Signature> = remaining[1..].to_vec();
355 if let Ok(json) = serde_json::to_string(&rest) {
356 next_msg
357 .headers
358 .insert("kojin.chain_next".to_string(), json);
359 }
360 }
361 if let Err(e) = broker.enqueue(next_msg).await {
362 tracing::error!(
363 task_id = %task_id,
364 error = %e,
365 "Failed to enqueue chain continuation"
366 );
367 } else {
368 tracing::info!(
369 task_id = %task_id,
370 remaining = remaining.len() - 1,
371 "Chain continuation enqueued"
372 );
373 }
374 }
375 Ok(_) => {} Err(e) => {
377 tracing::error!(
378 task_id = %task_id,
379 error = %e,
380 "Failed to deserialize chain_next"
381 );
382 }
383 }
384 }
385 }
386
387 tracing::info!(task_id = %task_id, task_name = %task_name, "Task completed successfully");
388 }
389 Err(e) => {
390 tracing::error!(task_id = %task_id, task_name = %task_name, error = %e, "Task failed");
391 handle_failure(broker, middlewares, message, e).await;
392 }
393 }
394}
395
396async fn handle_failure<B: Broker>(
397 broker: Arc<B>,
398 middlewares: Arc<Vec<Box<dyn Middleware>>>,
399 mut message: TaskMessage,
400 error: KojinError,
401) {
402 let task_id = message.id;
403
404 for mw in middlewares.iter() {
406 if let Err(e) = mw.on_error(&message, &error).await {
407 tracing::warn!(task_id = %task_id, error = %e, "Middleware on_error() failed");
408 }
409 }
410
411 if message.retries < message.max_retries {
413 message.retries += 1;
414 message.state = TaskState::Retry;
415 message.updated_at = chrono::Utc::now();
416
417 let backoff_delay =
418 crate::backoff::BackoffStrategy::default().delay_for(message.retries - 1);
419 tracing::info!(
420 task_id = %task_id,
421 retry = message.retries,
422 max_retries = message.max_retries,
423 backoff = ?backoff_delay,
424 "Retrying task"
425 );
426
427 tokio::time::sleep(backoff_delay).await;
429
430 if let Err(e) = broker.nack(message).await {
431 tracing::error!(task_id = %task_id, error = %e, "Failed to nack/requeue task");
432 }
433 } else {
434 message.state = TaskState::DeadLettered;
435 message.updated_at = chrono::Utc::now();
436 tracing::warn!(task_id = %task_id, "Max retries exceeded, moving to DLQ");
437
438 if let Err(e) = broker.dead_letter(message).await {
439 tracing::error!(task_id = %task_id, error = %e, "Failed to dead-letter task");
440 }
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use crate::memory_broker::MemoryBroker;
448 use crate::memory_result_backend::MemoryResultBackend;
449 use crate::task::Task;
450 use async_trait::async_trait;
451 use serde::{Deserialize, Serialize};
452 use std::sync::atomic::{AtomicU32, Ordering};
453
454 #[derive(Debug, Serialize, Deserialize)]
455 struct CountTask;
456
457 static COUNTER: AtomicU32 = AtomicU32::new(0);
458
459 #[async_trait]
460 impl Task for CountTask {
461 const NAME: &'static str = "count";
462 const MAX_RETRIES: u32 = 0;
463 type Output = ();
464
465 async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
466 COUNTER.fetch_add(1, Ordering::SeqCst);
467 Ok(())
468 }
469 }
470
471 #[tokio::test]
472 async fn worker_processes_tasks() {
473 let before = COUNTER.load(Ordering::SeqCst);
474
475 let broker = MemoryBroker::new();
476 let mut registry = TaskRegistry::new();
477 registry.register::<CountTask>();
478
479 for _ in 0..3 {
481 broker
482 .enqueue(TaskMessage::new(
483 "count",
484 "default",
485 serde_json::json!(null),
486 ))
487 .await
488 .unwrap();
489 }
490
491 let config = WorkerConfig {
492 concurrency: 2,
493 queues: vec!["default".to_string()],
494 shutdown_timeout: Duration::from_secs(5),
495 dequeue_timeout: Duration::from_millis(100),
496 };
497
498 let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config);
499 let cancel = worker.cancel_token();
500
501 let handle = tokio::spawn(async move {
503 worker.run().await;
504 });
505
506 tokio::time::sleep(Duration::from_millis(500)).await;
508 cancel.cancel();
509 handle.await.unwrap();
510
511 let after = COUNTER.load(Ordering::SeqCst);
512 assert_eq!(after - before, 3);
513 }
514
515 #[derive(Debug, Serialize, Deserialize)]
516 struct FailTask;
517
518 #[async_trait]
519 impl Task for FailTask {
520 const NAME: &'static str = "fail_task";
521 const MAX_RETRIES: u32 = 0;
522 type Output = ();
523
524 async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
525 Err(KojinError::TaskFailed("intentional failure".into()))
526 }
527 }
528
529 #[tokio::test]
530 async fn worker_dead_letters_after_max_retries() {
531 let broker = MemoryBroker::new();
532 let mut registry = TaskRegistry::new();
533 registry.register::<FailTask>();
534
535 broker
536 .enqueue(
537 TaskMessage::new("fail_task", "default", serde_json::json!(null))
538 .with_max_retries(0),
539 )
540 .await
541 .unwrap();
542
543 let config = WorkerConfig {
544 concurrency: 1,
545 queues: vec!["default".to_string()],
546 shutdown_timeout: Duration::from_secs(5),
547 dequeue_timeout: Duration::from_millis(100),
548 };
549
550 let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config);
551 let cancel = worker.cancel_token();
552
553 let handle = tokio::spawn(async move {
554 worker.run().await;
555 });
556
557 tokio::time::sleep(Duration::from_millis(500)).await;
558 cancel.cancel();
559 handle.await.unwrap();
560
561 assert_eq!(broker.dlq_len("default").await.unwrap(), 1);
562 }
563
564 #[tokio::test]
565 async fn worker_graceful_shutdown() {
566 let broker = MemoryBroker::new();
567 let registry = TaskRegistry::new();
568
569 let config = WorkerConfig {
570 concurrency: 1,
571 queues: vec!["default".to_string()],
572 shutdown_timeout: Duration::from_secs(1),
573 dequeue_timeout: Duration::from_millis(100),
574 };
575
576 let worker = Worker::new(broker, registry, TaskContext::new(), config);
577 let cancel = worker.cancel_token();
578
579 let handle = tokio::spawn(async move {
580 worker.run().await;
581 });
582
583 cancel.cancel();
585 tokio::time::timeout(Duration::from_secs(3), handle)
587 .await
588 .expect("Worker should shutdown within timeout")
589 .unwrap();
590 }
591
592 #[derive(Debug, Serialize, Deserialize)]
593 struct AddTask {
594 a: i32,
595 b: i32,
596 }
597
598 #[async_trait]
599 impl Task for AddTask {
600 const NAME: &'static str = "add";
601 const MAX_RETRIES: u32 = 0;
602 type Output = i32;
603
604 async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
605 Ok(self.a + self.b)
606 }
607 }
608
609 #[tokio::test]
610 async fn worker_stores_results() {
611 let broker = MemoryBroker::new();
612 let backend = Arc::new(MemoryResultBackend::new());
613 let mut registry = TaskRegistry::new();
614 registry.register::<AddTask>();
615
616 let msg = TaskMessage::new("add", "default", serde_json::json!({"a": 3, "b": 4}));
617 let task_id = msg.id;
618 broker.enqueue(msg).await.unwrap();
619
620 let config = WorkerConfig {
621 concurrency: 1,
622 queues: vec!["default".to_string()],
623 shutdown_timeout: Duration::from_secs(5),
624 dequeue_timeout: Duration::from_millis(100),
625 };
626
627 let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config)
628 .with_result_backend(backend.clone());
629 let cancel = worker.cancel_token();
630
631 let handle = tokio::spawn(async move {
632 worker.run().await;
633 });
634
635 tokio::time::sleep(Duration::from_millis(500)).await;
636 cancel.cancel();
637 handle.await.unwrap();
638
639 let result = backend.get(&task_id).await.unwrap();
640 assert_eq!(result, Some(serde_json::json!(7)));
641 }
642
643 static CHAIN_COUNTER: AtomicU32 = AtomicU32::new(0);
644
645 #[derive(Debug, Serialize, Deserialize)]
646 struct ChainCountTask;
647
648 #[async_trait]
649 impl Task for ChainCountTask {
650 const NAME: &'static str = "chain_count";
651 const MAX_RETRIES: u32 = 0;
652 type Output = u32;
653
654 async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
655 let val = CHAIN_COUNTER.fetch_add(1, Ordering::SeqCst) + 1;
656 Ok(val)
657 }
658 }
659
660 #[tokio::test]
661 async fn worker_chain_continuation() {
662 let broker = MemoryBroker::new();
663 let backend = Arc::new(MemoryResultBackend::new());
664 let mut registry = TaskRegistry::new();
665 registry.register::<ChainCountTask>();
666
667 let before = CHAIN_COUNTER.load(Ordering::SeqCst);
668
669 let remaining = vec![
671 crate::signature::Signature::new("chain_count", "default", serde_json::json!(null)),
672 crate::signature::Signature::new("chain_count", "default", serde_json::json!(null)),
673 ];
674 let mut msg =
675 TaskMessage::new("chain_count", "default", serde_json::json!(null)).with_max_retries(0);
676 msg.headers.insert(
677 "kojin.chain_next".to_string(),
678 serde_json::to_string(&remaining).unwrap(),
679 );
680 broker.enqueue(msg).await.unwrap();
681
682 let config = WorkerConfig {
683 concurrency: 1,
684 queues: vec!["default".to_string()],
685 shutdown_timeout: Duration::from_secs(5),
686 dequeue_timeout: Duration::from_millis(100),
687 };
688
689 let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config)
690 .with_result_backend(backend);
691 let cancel = worker.cancel_token();
692
693 let handle = tokio::spawn(async move {
694 worker.run().await;
695 });
696
697 tokio::time::sleep(Duration::from_millis(1500)).await;
698 cancel.cancel();
699 handle.await.unwrap();
700
701 let after = CHAIN_COUNTER.load(Ordering::SeqCst);
703 assert_eq!(after - before, 3);
704 }
705}