Skip to main content

kojin_core/
worker.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use tokio::sync::Semaphore;
5use tokio_util::sync::CancellationToken;
6
7use crate::broker::Broker;
8use crate::context::TaskContext;
9use crate::error::KojinError;
10use crate::message::TaskMessage;
11use crate::middleware::Middleware;
12use crate::result_backend::ResultBackend;
13use crate::signature::Signature;
14
15use crate::registry::TaskRegistry;
16use crate::state::TaskState;
17
18/// Worker configuration.
19#[derive(Debug, Clone)]
20pub struct WorkerConfig {
21    /// Max concurrent tasks.
22    pub concurrency: usize,
23    /// Queue names to consume from.
24    pub queues: Vec<String>,
25    /// How long to wait for in-flight tasks during shutdown.
26    pub shutdown_timeout: Duration,
27    /// Dequeue poll timeout.
28    pub dequeue_timeout: Duration,
29}
30
31impl Default for WorkerConfig {
32    fn default() -> Self {
33        Self {
34            concurrency: 10,
35            queues: vec!["default".to_string()],
36            shutdown_timeout: Duration::from_secs(30),
37            dequeue_timeout: Duration::from_secs(5),
38        }
39    }
40}
41
42/// The worker loop that dequeues and executes tasks.
43pub struct Worker<B: Broker> {
44    broker: Arc<B>,
45    registry: Arc<TaskRegistry>,
46    middlewares: Arc<Vec<Box<dyn Middleware>>>,
47    context: Arc<TaskContext>,
48    config: WorkerConfig,
49    cancel: CancellationToken,
50    result_backend: Option<Arc<dyn ResultBackend>>,
51    #[cfg(feature = "cron")]
52    cron_registry: Option<crate::cron::CronRegistry>,
53}
54
55impl<B: Broker> Worker<B> {
56    /// Create a new worker with the given broker, task registry, shared context, and config.
57    ///
58    /// Prefer `KojinBuilder` (from the `kojin` facade crate) for ergonomic construction.
59    pub fn new(
60        broker: B,
61        registry: TaskRegistry,
62        context: TaskContext,
63        config: WorkerConfig,
64    ) -> Self {
65        Self {
66            broker: Arc::new(broker),
67            registry: Arc::new(registry),
68            middlewares: Arc::new(Vec::new()),
69            context: Arc::new(context),
70            config,
71            cancel: CancellationToken::new(),
72            result_backend: None,
73            #[cfg(feature = "cron")]
74            cron_registry: None,
75        }
76    }
77
78    /// Set the result backend.
79    pub fn with_result_backend(mut self, backend: Arc<dyn ResultBackend>) -> Self {
80        self.result_backend = Some(backend);
81        self
82    }
83
84    /// Set the cron registry for periodic task scheduling.
85    #[cfg(feature = "cron")]
86    pub fn with_cron_registry(mut self, registry: crate::cron::CronRegistry) -> Self {
87        self.cron_registry = Some(registry);
88        self
89    }
90
91    /// Add middleware to the worker pipeline.
92    pub fn with_middleware(mut self, middleware: impl Middleware) -> Self {
93        Arc::get_mut(&mut self.middlewares)
94            .expect("middleware can only be added before starting")
95            .push(Box::new(middleware));
96        self
97    }
98
99    /// Add a boxed middleware to the worker pipeline.
100    pub fn with_middleware_boxed(mut self, middleware: Box<dyn Middleware>) -> Self {
101        Arc::get_mut(&mut self.middlewares)
102            .expect("middleware can only be added before starting")
103            .push(middleware);
104        self
105    }
106
107    /// Get the cancellation token for external shutdown triggering.
108    pub fn cancel_token(&self) -> CancellationToken {
109        self.cancel.clone()
110    }
111
112    /// Run the worker loop until shutdown.
113    pub async fn run(&self) {
114        let semaphore = Arc::new(Semaphore::new(self.config.concurrency));
115
116        // Spawn cron scheduler if configured
117        #[cfg(feature = "cron")]
118        let _cron_handle = {
119            if let Some(ref cron_registry) = self.cron_registry {
120                let broker = self.broker.clone();
121                let registry = cron_registry.clone();
122                let cancel = self.cancel.clone();
123                Some(tokio::spawn(async move {
124                    crate::cron::scheduler_loop(
125                        broker,
126                        registry,
127                        cancel,
128                        std::time::Duration::from_secs(1),
129                    )
130                    .await;
131                }))
132            } else {
133                None
134            }
135        };
136
137        tracing::info!(
138            concurrency = self.config.concurrency,
139            queues = ?self.config.queues,
140            "Worker starting"
141        );
142
143        loop {
144            if self.cancel.is_cancelled() {
145                break;
146            }
147
148            // Acquire a concurrency permit
149            let permit = tokio::select! {
150                permit = semaphore.clone().acquire_owned() => {
151                    match permit {
152                        Ok(p) => p,
153                        Err(_) => break, // Semaphore closed
154                    }
155                }
156                _ = self.cancel.cancelled() => break,
157            };
158
159            // Dequeue a message
160            let message = tokio::select! {
161                result = self.broker.dequeue(&self.config.queues, self.config.dequeue_timeout) => {
162                    match result {
163                        Ok(Some(msg)) => msg,
164                        Ok(None) => {
165                            drop(permit);
166                            continue; // Timeout, try again
167                        }
168                        Err(e) => {
169                            tracing::error!(error = %e, "Failed to dequeue");
170                            drop(permit);
171                            tokio::time::sleep(Duration::from_secs(1)).await;
172                            continue;
173                        }
174                    }
175                }
176                _ = self.cancel.cancelled() => {
177                    drop(permit);
178                    break;
179                }
180            };
181
182            // Spawn task execution
183            let broker = self.broker.clone();
184            let registry = self.registry.clone();
185            let middlewares = self.middlewares.clone();
186            let context = self.context.clone();
187            let result_backend = self.result_backend.clone();
188
189            tokio::spawn(async move {
190                let _permit = permit; // Hold permit until done
191                execute_task(
192                    broker,
193                    registry,
194                    middlewares,
195                    context,
196                    message,
197                    result_backend,
198                )
199                .await;
200            });
201        }
202
203        // Graceful shutdown: wait for in-flight tasks to complete
204        tracing::info!("Worker shutting down, waiting for in-flight tasks...");
205        let drain_deadline = tokio::time::Instant::now() + self.config.shutdown_timeout;
206        loop {
207            // When all permits are available, no tasks are in-flight
208            if semaphore.available_permits() == self.config.concurrency {
209                break;
210            }
211            if tokio::time::Instant::now() >= drain_deadline {
212                tracing::warn!("Shutdown timeout reached, some tasks may not have completed");
213                break;
214            }
215            tokio::time::sleep(Duration::from_millis(100)).await;
216        }
217
218        tracing::info!("Worker stopped");
219    }
220}
221
222async fn execute_task<B: Broker>(
223    broker: Arc<B>,
224    registry: Arc<TaskRegistry>,
225    middlewares: Arc<Vec<Box<dyn Middleware>>>,
226    context: Arc<TaskContext>,
227    mut message: TaskMessage,
228    result_backend: Option<Arc<dyn ResultBackend>>,
229) {
230    let task_id = message.id;
231    let task_name = message.task_name.clone();
232
233    // ETA guard: if message has a future eta, re-schedule it
234    if let Some(eta) = message.eta {
235        if eta > chrono::Utc::now() {
236            tracing::debug!(task_id = %task_id, %eta, "task eta is in the future — re-scheduling");
237            if let Err(e) = broker.ack(&task_id).await {
238                tracing::error!(task_id = %task_id, error = %e, "failed to ack before re-schedule");
239            }
240            if let Err(e) = broker.schedule(message, eta).await {
241                tracing::error!(task_id = %task_id, error = %e, "failed to re-schedule task with future eta");
242            }
243            return;
244        }
245    }
246
247    tracing::info!(task_id = %task_id, task_name = %task_name, "Executing task");
248    message.state = TaskState::Started;
249
250    // Run before middleware
251    for mw in middlewares.iter() {
252        if let Err(e) = mw.before(&message).await {
253            tracing::error!(task_id = %task_id, error = %e, "Middleware before() failed");
254            handle_failure(broker, middlewares, message, e).await;
255            return;
256        }
257    }
258
259    // Dispatch to handler
260    match registry
261        .dispatch(&task_name, message.payload.clone(), context)
262        .await
263    {
264        Ok(result) => {
265            // Run after middleware
266            for mw in middlewares.iter() {
267                if let Err(e) = mw.after(&message, &result).await {
268                    tracing::warn!(task_id = %task_id, error = %e, "Middleware after() failed");
269                }
270            }
271            message.state = TaskState::Success;
272            if let Err(e) = broker.ack(&task_id).await {
273                tracing::error!(task_id = %task_id, error = %e, "Failed to ack task");
274            }
275
276            // Store result in backend
277            if let Some(ref backend) = result_backend {
278                if let Err(e) = backend.store(&task_id, &result).await {
279                    tracing::error!(task_id = %task_id, error = %e, "Failed to store result");
280                }
281
282                // Handle group completion
283                if let Some(ref group_id) = message.group_id {
284                    match backend
285                        .complete_group_member(group_id, &task_id, &result)
286                        .await
287                    {
288                        Ok(completed) => {
289                            let total = message.group_total.unwrap_or(0);
290                            tracing::debug!(
291                                task_id = %task_id,
292                                group_id = %group_id,
293                                completed = completed,
294                                total = total,
295                                "Group member completed"
296                            );
297                            // If all group members are done and there's a chord callback, enqueue it
298                            if completed == total {
299                                if let Some(chord_callback) = message.chord_callback.take() {
300                                    let mut callback_msg = *chord_callback;
301                                    // Inject group results into the callback payload via header
302                                    if let Ok(group_results) =
303                                        backend.get_group_results(group_id).await
304                                    {
305                                        if let Ok(json) = serde_json::to_string(&group_results) {
306                                            callback_msg
307                                                .headers
308                                                .insert("kojin.group_results".to_string(), json);
309                                        }
310                                    }
311                                    if let Err(e) = broker.enqueue(callback_msg).await {
312                                        tracing::error!(
313                                            group_id = %group_id,
314                                            error = %e,
315                                            "Failed to enqueue chord callback"
316                                        );
317                                    } else {
318                                        tracing::info!(
319                                            group_id = %group_id,
320                                            "Chord callback enqueued"
321                                        );
322                                    }
323                                }
324                            }
325                        }
326                        Err(e) => {
327                            tracing::error!(
328                                task_id = %task_id,
329                                group_id = %group_id,
330                                error = %e,
331                                "Failed to complete group member"
332                            );
333                        }
334                    }
335                }
336
337                // Handle chain continuation
338                if let Some(chain_next_json) = message.headers.get("kojin.chain_next") {
339                    match serde_json::from_str::<Vec<Signature>>(chain_next_json) {
340                        Ok(remaining) if !remaining.is_empty() => {
341                            let mut next_msg = remaining[0].clone().into_message();
342                            // Pass current result as input to next task
343                            if let Ok(json) = serde_json::to_string(&result) {
344                                next_msg
345                                    .headers
346                                    .insert("kojin.chain_input".to_string(), json);
347                            }
348                            // Propagate correlation_id
349                            if let Some(ref corr) = message.correlation_id {
350                                next_msg.correlation_id = Some(corr.clone());
351                            }
352                            // Store remaining chain steps (skip first)
353                            if remaining.len() > 1 {
354                                let rest: Vec<Signature> = remaining[1..].to_vec();
355                                if let Ok(json) = serde_json::to_string(&rest) {
356                                    next_msg
357                                        .headers
358                                        .insert("kojin.chain_next".to_string(), json);
359                                }
360                            }
361                            if let Err(e) = broker.enqueue(next_msg).await {
362                                tracing::error!(
363                                    task_id = %task_id,
364                                    error = %e,
365                                    "Failed to enqueue chain continuation"
366                                );
367                            } else {
368                                tracing::info!(
369                                    task_id = %task_id,
370                                    remaining = remaining.len() - 1,
371                                    "Chain continuation enqueued"
372                                );
373                            }
374                        }
375                        Ok(_) => {} // Empty remaining, chain done
376                        Err(e) => {
377                            tracing::error!(
378                                task_id = %task_id,
379                                error = %e,
380                                "Failed to deserialize chain_next"
381                            );
382                        }
383                    }
384                }
385            }
386
387            tracing::info!(task_id = %task_id, task_name = %task_name, "Task completed successfully");
388        }
389        Err(e) => {
390            tracing::error!(task_id = %task_id, task_name = %task_name, error = %e, "Task failed");
391            handle_failure(broker, middlewares, message, e).await;
392        }
393    }
394}
395
396async fn handle_failure<B: Broker>(
397    broker: Arc<B>,
398    middlewares: Arc<Vec<Box<dyn Middleware>>>,
399    mut message: TaskMessage,
400    error: KojinError,
401) {
402    let task_id = message.id;
403
404    // Run on_error middleware
405    for mw in middlewares.iter() {
406        if let Err(e) = mw.on_error(&message, &error).await {
407            tracing::warn!(task_id = %task_id, error = %e, "Middleware on_error() failed");
408        }
409    }
410
411    // Retry or dead-letter
412    if message.retries < message.max_retries {
413        message.retries += 1;
414        message.state = TaskState::Retry;
415        message.updated_at = chrono::Utc::now();
416
417        let backoff_delay =
418            crate::backoff::BackoffStrategy::default().delay_for(message.retries - 1);
419        tracing::info!(
420            task_id = %task_id,
421            retry = message.retries,
422            max_retries = message.max_retries,
423            backoff = ?backoff_delay,
424            "Retrying task"
425        );
426
427        // Simple sleep-based backoff (for MemoryBroker; Redis uses scheduled queue)
428        tokio::time::sleep(backoff_delay).await;
429
430        if let Err(e) = broker.nack(message).await {
431            tracing::error!(task_id = %task_id, error = %e, "Failed to nack/requeue task");
432        }
433    } else {
434        message.state = TaskState::DeadLettered;
435        message.updated_at = chrono::Utc::now();
436        tracing::warn!(task_id = %task_id, "Max retries exceeded, moving to DLQ");
437
438        if let Err(e) = broker.dead_letter(message).await {
439            tracing::error!(task_id = %task_id, error = %e, "Failed to dead-letter task");
440        }
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use crate::memory_broker::MemoryBroker;
448    use crate::memory_result_backend::MemoryResultBackend;
449    use crate::task::Task;
450    use async_trait::async_trait;
451    use serde::{Deserialize, Serialize};
452    use std::sync::atomic::{AtomicU32, Ordering};
453
454    #[derive(Debug, Serialize, Deserialize)]
455    struct CountTask;
456
457    static COUNTER: AtomicU32 = AtomicU32::new(0);
458
459    #[async_trait]
460    impl Task for CountTask {
461        const NAME: &'static str = "count";
462        const MAX_RETRIES: u32 = 0;
463        type Output = ();
464
465        async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
466            COUNTER.fetch_add(1, Ordering::SeqCst);
467            Ok(())
468        }
469    }
470
471    #[tokio::test]
472    async fn worker_processes_tasks() {
473        let before = COUNTER.load(Ordering::SeqCst);
474
475        let broker = MemoryBroker::new();
476        let mut registry = TaskRegistry::new();
477        registry.register::<CountTask>();
478
479        // Enqueue 3 tasks
480        for _ in 0..3 {
481            broker
482                .enqueue(TaskMessage::new(
483                    "count",
484                    "default",
485                    serde_json::json!(null),
486                ))
487                .await
488                .unwrap();
489        }
490
491        let config = WorkerConfig {
492            concurrency: 2,
493            queues: vec!["default".to_string()],
494            shutdown_timeout: Duration::from_secs(5),
495            dequeue_timeout: Duration::from_millis(100),
496        };
497
498        let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config);
499        let cancel = worker.cancel_token();
500
501        // Run worker in background
502        let handle = tokio::spawn(async move {
503            worker.run().await;
504        });
505
506        // Wait for tasks to be processed
507        tokio::time::sleep(Duration::from_millis(500)).await;
508        cancel.cancel();
509        handle.await.unwrap();
510
511        let after = COUNTER.load(Ordering::SeqCst);
512        assert_eq!(after - before, 3);
513    }
514
515    #[derive(Debug, Serialize, Deserialize)]
516    struct FailTask;
517
518    #[async_trait]
519    impl Task for FailTask {
520        const NAME: &'static str = "fail_task";
521        const MAX_RETRIES: u32 = 0;
522        type Output = ();
523
524        async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
525            Err(KojinError::TaskFailed("intentional failure".into()))
526        }
527    }
528
529    #[tokio::test]
530    async fn worker_dead_letters_after_max_retries() {
531        let broker = MemoryBroker::new();
532        let mut registry = TaskRegistry::new();
533        registry.register::<FailTask>();
534
535        broker
536            .enqueue(
537                TaskMessage::new("fail_task", "default", serde_json::json!(null))
538                    .with_max_retries(0),
539            )
540            .await
541            .unwrap();
542
543        let config = WorkerConfig {
544            concurrency: 1,
545            queues: vec!["default".to_string()],
546            shutdown_timeout: Duration::from_secs(5),
547            dequeue_timeout: Duration::from_millis(100),
548        };
549
550        let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config);
551        let cancel = worker.cancel_token();
552
553        let handle = tokio::spawn(async move {
554            worker.run().await;
555        });
556
557        tokio::time::sleep(Duration::from_millis(500)).await;
558        cancel.cancel();
559        handle.await.unwrap();
560
561        assert_eq!(broker.dlq_len("default").await.unwrap(), 1);
562    }
563
564    #[tokio::test]
565    async fn worker_graceful_shutdown() {
566        let broker = MemoryBroker::new();
567        let registry = TaskRegistry::new();
568
569        let config = WorkerConfig {
570            concurrency: 1,
571            queues: vec!["default".to_string()],
572            shutdown_timeout: Duration::from_secs(1),
573            dequeue_timeout: Duration::from_millis(100),
574        };
575
576        let worker = Worker::new(broker, registry, TaskContext::new(), config);
577        let cancel = worker.cancel_token();
578
579        let handle = tokio::spawn(async move {
580            worker.run().await;
581        });
582
583        // Cancel immediately
584        cancel.cancel();
585        // Should complete within shutdown timeout
586        tokio::time::timeout(Duration::from_secs(3), handle)
587            .await
588            .expect("Worker should shutdown within timeout")
589            .unwrap();
590    }
591
592    #[derive(Debug, Serialize, Deserialize)]
593    struct AddTask {
594        a: i32,
595        b: i32,
596    }
597
598    #[async_trait]
599    impl Task for AddTask {
600        const NAME: &'static str = "add";
601        const MAX_RETRIES: u32 = 0;
602        type Output = i32;
603
604        async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
605            Ok(self.a + self.b)
606        }
607    }
608
609    #[tokio::test]
610    async fn worker_stores_results() {
611        let broker = MemoryBroker::new();
612        let backend = Arc::new(MemoryResultBackend::new());
613        let mut registry = TaskRegistry::new();
614        registry.register::<AddTask>();
615
616        let msg = TaskMessage::new("add", "default", serde_json::json!({"a": 3, "b": 4}));
617        let task_id = msg.id;
618        broker.enqueue(msg).await.unwrap();
619
620        let config = WorkerConfig {
621            concurrency: 1,
622            queues: vec!["default".to_string()],
623            shutdown_timeout: Duration::from_secs(5),
624            dequeue_timeout: Duration::from_millis(100),
625        };
626
627        let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config)
628            .with_result_backend(backend.clone());
629        let cancel = worker.cancel_token();
630
631        let handle = tokio::spawn(async move {
632            worker.run().await;
633        });
634
635        tokio::time::sleep(Duration::from_millis(500)).await;
636        cancel.cancel();
637        handle.await.unwrap();
638
639        let result = backend.get(&task_id).await.unwrap();
640        assert_eq!(result, Some(serde_json::json!(7)));
641    }
642
643    static CHAIN_COUNTER: AtomicU32 = AtomicU32::new(0);
644
645    #[derive(Debug, Serialize, Deserialize)]
646    struct ChainCountTask;
647
648    #[async_trait]
649    impl Task for ChainCountTask {
650        const NAME: &'static str = "chain_count";
651        const MAX_RETRIES: u32 = 0;
652        type Output = u32;
653
654        async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
655            let val = CHAIN_COUNTER.fetch_add(1, Ordering::SeqCst) + 1;
656            Ok(val)
657        }
658    }
659
660    #[tokio::test]
661    async fn worker_chain_continuation() {
662        let broker = MemoryBroker::new();
663        let backend = Arc::new(MemoryResultBackend::new());
664        let mut registry = TaskRegistry::new();
665        registry.register::<ChainCountTask>();
666
667        let before = CHAIN_COUNTER.load(Ordering::SeqCst);
668
669        // Build a chain: chain_count -> chain_count -> chain_count
670        let remaining = vec![
671            crate::signature::Signature::new("chain_count", "default", serde_json::json!(null)),
672            crate::signature::Signature::new("chain_count", "default", serde_json::json!(null)),
673        ];
674        let mut msg =
675            TaskMessage::new("chain_count", "default", serde_json::json!(null)).with_max_retries(0);
676        msg.headers.insert(
677            "kojin.chain_next".to_string(),
678            serde_json::to_string(&remaining).unwrap(),
679        );
680        broker.enqueue(msg).await.unwrap();
681
682        let config = WorkerConfig {
683            concurrency: 1,
684            queues: vec!["default".to_string()],
685            shutdown_timeout: Duration::from_secs(5),
686            dequeue_timeout: Duration::from_millis(100),
687        };
688
689        let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config)
690            .with_result_backend(backend);
691        let cancel = worker.cancel_token();
692
693        let handle = tokio::spawn(async move {
694            worker.run().await;
695        });
696
697        tokio::time::sleep(Duration::from_millis(1500)).await;
698        cancel.cancel();
699        handle.await.unwrap();
700
701        // All 3 tasks in the chain should have executed
702        let after = CHAIN_COUNTER.load(Ordering::SeqCst);
703        assert_eq!(after - before, 3);
704    }
705}