Skip to main content

kojin_core/
worker.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use tokio::sync::Semaphore;
5use tokio_util::sync::CancellationToken;
6
7use crate::broker::Broker;
8use crate::context::TaskContext;
9use crate::error::KojinError;
10use crate::message::TaskMessage;
11use crate::middleware::Middleware;
12use crate::result_backend::ResultBackend;
13use crate::signature::Signature;
14
15use crate::registry::TaskRegistry;
16use crate::state::TaskState;
17
18/// Worker configuration.
19#[derive(Debug, Clone)]
20pub struct WorkerConfig {
21    /// Max concurrent tasks.
22    pub concurrency: usize,
23    /// Queue names to consume from.
24    pub queues: Vec<String>,
25    /// How long to wait for in-flight tasks during shutdown.
26    pub shutdown_timeout: Duration,
27    /// Dequeue poll timeout.
28    pub dequeue_timeout: Duration,
29}
30
31impl Default for WorkerConfig {
32    fn default() -> Self {
33        Self {
34            concurrency: 10,
35            queues: vec!["default".to_string()],
36            shutdown_timeout: Duration::from_secs(30),
37            dequeue_timeout: Duration::from_secs(5),
38        }
39    }
40}
41
42/// The worker loop that dequeues and executes tasks.
43pub struct Worker<B: Broker> {
44    broker: Arc<B>,
45    registry: Arc<TaskRegistry>,
46    middlewares: Arc<Vec<Box<dyn Middleware>>>,
47    context: Arc<TaskContext>,
48    config: WorkerConfig,
49    cancel: CancellationToken,
50    result_backend: Option<Arc<dyn ResultBackend>>,
51    #[cfg(feature = "cron")]
52    cron_registry: Option<crate::cron::CronRegistry>,
53}
54
55impl<B: Broker> Worker<B> {
56    /// Create a new worker with the given broker, task registry, shared context, and config.
57    ///
58    /// Prefer `KojinBuilder` (from the `kojin` facade crate) for ergonomic construction.
59    pub fn new(
60        broker: B,
61        registry: TaskRegistry,
62        context: TaskContext,
63        config: WorkerConfig,
64    ) -> Self {
65        Self {
66            broker: Arc::new(broker),
67            registry: Arc::new(registry),
68            middlewares: Arc::new(Vec::new()),
69            context: Arc::new(context),
70            config,
71            cancel: CancellationToken::new(),
72            result_backend: None,
73            #[cfg(feature = "cron")]
74            cron_registry: None,
75        }
76    }
77
78    /// Set the result backend.
79    pub fn with_result_backend(mut self, backend: Arc<dyn ResultBackend>) -> Self {
80        self.result_backend = Some(backend);
81        self
82    }
83
84    /// Set the cron registry for periodic task scheduling.
85    #[cfg(feature = "cron")]
86    pub fn with_cron_registry(mut self, registry: crate::cron::CronRegistry) -> Self {
87        self.cron_registry = Some(registry);
88        self
89    }
90
91    /// Add middleware to the worker pipeline.
92    pub fn with_middleware(mut self, middleware: impl Middleware) -> Self {
93        Arc::get_mut(&mut self.middlewares)
94            .expect("middleware can only be added before starting")
95            .push(Box::new(middleware));
96        self
97    }
98
99    /// Add a boxed middleware to the worker pipeline.
100    pub fn with_middleware_boxed(mut self, middleware: Box<dyn Middleware>) -> Self {
101        Arc::get_mut(&mut self.middlewares)
102            .expect("middleware can only be added before starting")
103            .push(middleware);
104        self
105    }
106
107    /// Get the cancellation token for external shutdown triggering.
108    pub fn cancel_token(&self) -> CancellationToken {
109        self.cancel.clone()
110    }
111
112    /// Run the worker loop until shutdown.
113    pub async fn run(&self) {
114        let semaphore = Arc::new(Semaphore::new(self.config.concurrency));
115
116        // Spawn cron scheduler if configured
117        #[cfg(feature = "cron")]
118        let _cron_handle = {
119            if let Some(ref cron_registry) = self.cron_registry {
120                let broker = self.broker.clone();
121                let registry = cron_registry.clone();
122                let cancel = self.cancel.clone();
123                Some(tokio::spawn(async move {
124                    crate::cron::scheduler_loop(
125                        broker,
126                        registry,
127                        cancel,
128                        std::time::Duration::from_secs(1),
129                    )
130                    .await;
131                }))
132            } else {
133                None
134            }
135        };
136
137        tracing::info!(
138            concurrency = self.config.concurrency,
139            queues = ?self.config.queues,
140            "Worker starting"
141        );
142
143        loop {
144            if self.cancel.is_cancelled() {
145                break;
146            }
147
148            // Acquire a concurrency permit
149            let permit = tokio::select! {
150                permit = semaphore.clone().acquire_owned() => {
151                    match permit {
152                        Ok(p) => p,
153                        Err(_) => break, // Semaphore closed
154                    }
155                }
156                _ = self.cancel.cancelled() => break,
157            };
158
159            // Dequeue a message
160            let message = tokio::select! {
161                result = self.broker.dequeue(&self.config.queues, self.config.dequeue_timeout) => {
162                    match result {
163                        Ok(Some(msg)) => msg,
164                        Ok(None) => {
165                            drop(permit);
166                            continue; // Timeout, try again
167                        }
168                        Err(e) => {
169                            tracing::error!(error = %e, "Failed to dequeue");
170                            drop(permit);
171                            tokio::time::sleep(Duration::from_secs(1)).await;
172                            continue;
173                        }
174                    }
175                }
176                _ = self.cancel.cancelled() => {
177                    drop(permit);
178                    break;
179                }
180            };
181
182            // Spawn task execution
183            let broker = self.broker.clone();
184            let registry = self.registry.clone();
185            let middlewares = self.middlewares.clone();
186            let context = self.context.clone();
187            let result_backend = self.result_backend.clone();
188
189            tokio::spawn(async move {
190                let _permit = permit; // Hold permit until done
191                execute_task(
192                    broker,
193                    registry,
194                    middlewares,
195                    context,
196                    message,
197                    result_backend,
198                )
199                .await;
200            });
201        }
202
203        // Graceful shutdown: wait for in-flight tasks to complete
204        tracing::info!("Worker shutting down, waiting for in-flight tasks...");
205        let drain_deadline = tokio::time::Instant::now() + self.config.shutdown_timeout;
206        loop {
207            // When all permits are available, no tasks are in-flight
208            if semaphore.available_permits() == self.config.concurrency {
209                break;
210            }
211            if tokio::time::Instant::now() >= drain_deadline {
212                tracing::warn!("Shutdown timeout reached, some tasks may not have completed");
213                break;
214            }
215            tokio::time::sleep(Duration::from_millis(100)).await;
216        }
217
218        tracing::info!("Worker stopped");
219    }
220}
221
222async fn execute_task<B: Broker>(
223    broker: Arc<B>,
224    registry: Arc<TaskRegistry>,
225    middlewares: Arc<Vec<Box<dyn Middleware>>>,
226    context: Arc<TaskContext>,
227    mut message: TaskMessage,
228    result_backend: Option<Arc<dyn ResultBackend>>,
229) {
230    let task_id = message.id;
231    let task_name = message.task_name.clone();
232
233    tracing::info!(task_id = %task_id, task_name = %task_name, "Executing task");
234    message.state = TaskState::Started;
235
236    // Run before middleware
237    for mw in middlewares.iter() {
238        if let Err(e) = mw.before(&message).await {
239            tracing::error!(task_id = %task_id, error = %e, "Middleware before() failed");
240            handle_failure(broker, middlewares, message, e).await;
241            return;
242        }
243    }
244
245    // Dispatch to handler
246    match registry
247        .dispatch(&task_name, message.payload.clone(), context)
248        .await
249    {
250        Ok(result) => {
251            // Run after middleware
252            for mw in middlewares.iter() {
253                if let Err(e) = mw.after(&message, &result).await {
254                    tracing::warn!(task_id = %task_id, error = %e, "Middleware after() failed");
255                }
256            }
257            message.state = TaskState::Success;
258            if let Err(e) = broker.ack(&task_id).await {
259                tracing::error!(task_id = %task_id, error = %e, "Failed to ack task");
260            }
261
262            // Store result in backend
263            if let Some(ref backend) = result_backend {
264                if let Err(e) = backend.store(&task_id, &result).await {
265                    tracing::error!(task_id = %task_id, error = %e, "Failed to store result");
266                }
267
268                // Handle group completion
269                if let Some(ref group_id) = message.group_id {
270                    match backend
271                        .complete_group_member(group_id, &task_id, &result)
272                        .await
273                    {
274                        Ok(completed) => {
275                            let total = message.group_total.unwrap_or(0);
276                            tracing::debug!(
277                                task_id = %task_id,
278                                group_id = %group_id,
279                                completed = completed,
280                                total = total,
281                                "Group member completed"
282                            );
283                            // If all group members are done and there's a chord callback, enqueue it
284                            if completed == total {
285                                if let Some(chord_callback) = message.chord_callback.take() {
286                                    let mut callback_msg = *chord_callback;
287                                    // Inject group results into the callback payload via header
288                                    if let Ok(group_results) =
289                                        backend.get_group_results(group_id).await
290                                    {
291                                        if let Ok(json) = serde_json::to_string(&group_results) {
292                                            callback_msg
293                                                .headers
294                                                .insert("kojin.group_results".to_string(), json);
295                                        }
296                                    }
297                                    if let Err(e) = broker.enqueue(callback_msg).await {
298                                        tracing::error!(
299                                            group_id = %group_id,
300                                            error = %e,
301                                            "Failed to enqueue chord callback"
302                                        );
303                                    } else {
304                                        tracing::info!(
305                                            group_id = %group_id,
306                                            "Chord callback enqueued"
307                                        );
308                                    }
309                                }
310                            }
311                        }
312                        Err(e) => {
313                            tracing::error!(
314                                task_id = %task_id,
315                                group_id = %group_id,
316                                error = %e,
317                                "Failed to complete group member"
318                            );
319                        }
320                    }
321                }
322
323                // Handle chain continuation
324                if let Some(chain_next_json) = message.headers.get("kojin.chain_next") {
325                    match serde_json::from_str::<Vec<Signature>>(chain_next_json) {
326                        Ok(remaining) if !remaining.is_empty() => {
327                            let mut next_msg = remaining[0].clone().into_message();
328                            // Pass current result as input to next task
329                            if let Ok(json) = serde_json::to_string(&result) {
330                                next_msg
331                                    .headers
332                                    .insert("kojin.chain_input".to_string(), json);
333                            }
334                            // Propagate correlation_id
335                            if let Some(ref corr) = message.correlation_id {
336                                next_msg.correlation_id = Some(corr.clone());
337                            }
338                            // Store remaining chain steps (skip first)
339                            if remaining.len() > 1 {
340                                let rest: Vec<Signature> = remaining[1..].to_vec();
341                                if let Ok(json) = serde_json::to_string(&rest) {
342                                    next_msg
343                                        .headers
344                                        .insert("kojin.chain_next".to_string(), json);
345                                }
346                            }
347                            if let Err(e) = broker.enqueue(next_msg).await {
348                                tracing::error!(
349                                    task_id = %task_id,
350                                    error = %e,
351                                    "Failed to enqueue chain continuation"
352                                );
353                            } else {
354                                tracing::info!(
355                                    task_id = %task_id,
356                                    remaining = remaining.len() - 1,
357                                    "Chain continuation enqueued"
358                                );
359                            }
360                        }
361                        Ok(_) => {} // Empty remaining, chain done
362                        Err(e) => {
363                            tracing::error!(
364                                task_id = %task_id,
365                                error = %e,
366                                "Failed to deserialize chain_next"
367                            );
368                        }
369                    }
370                }
371            }
372
373            tracing::info!(task_id = %task_id, task_name = %task_name, "Task completed successfully");
374        }
375        Err(e) => {
376            tracing::error!(task_id = %task_id, task_name = %task_name, error = %e, "Task failed");
377            handle_failure(broker, middlewares, message, e).await;
378        }
379    }
380}
381
382async fn handle_failure<B: Broker>(
383    broker: Arc<B>,
384    middlewares: Arc<Vec<Box<dyn Middleware>>>,
385    mut message: TaskMessage,
386    error: KojinError,
387) {
388    let task_id = message.id;
389
390    // Run on_error middleware
391    for mw in middlewares.iter() {
392        if let Err(e) = mw.on_error(&message, &error).await {
393            tracing::warn!(task_id = %task_id, error = %e, "Middleware on_error() failed");
394        }
395    }
396
397    // Retry or dead-letter
398    if message.retries < message.max_retries {
399        message.retries += 1;
400        message.state = TaskState::Retry;
401        message.updated_at = chrono::Utc::now();
402
403        let backoff_delay =
404            crate::backoff::BackoffStrategy::default().delay_for(message.retries - 1);
405        tracing::info!(
406            task_id = %task_id,
407            retry = message.retries,
408            max_retries = message.max_retries,
409            backoff = ?backoff_delay,
410            "Retrying task"
411        );
412
413        // Simple sleep-based backoff (for MemoryBroker; Redis uses scheduled queue)
414        tokio::time::sleep(backoff_delay).await;
415
416        if let Err(e) = broker.nack(message).await {
417            tracing::error!(task_id = %task_id, error = %e, "Failed to nack/requeue task");
418        }
419    } else {
420        message.state = TaskState::DeadLettered;
421        message.updated_at = chrono::Utc::now();
422        tracing::warn!(task_id = %task_id, "Max retries exceeded, moving to DLQ");
423
424        if let Err(e) = broker.dead_letter(message).await {
425            tracing::error!(task_id = %task_id, error = %e, "Failed to dead-letter task");
426        }
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433    use crate::memory_broker::MemoryBroker;
434    use crate::memory_result_backend::MemoryResultBackend;
435    use crate::task::Task;
436    use async_trait::async_trait;
437    use serde::{Deserialize, Serialize};
438    use std::sync::atomic::{AtomicU32, Ordering};
439
440    #[derive(Debug, Serialize, Deserialize)]
441    struct CountTask;
442
443    static COUNTER: AtomicU32 = AtomicU32::new(0);
444
445    #[async_trait]
446    impl Task for CountTask {
447        const NAME: &'static str = "count";
448        const MAX_RETRIES: u32 = 0;
449        type Output = ();
450
451        async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
452            COUNTER.fetch_add(1, Ordering::SeqCst);
453            Ok(())
454        }
455    }
456
457    #[tokio::test]
458    async fn worker_processes_tasks() {
459        let before = COUNTER.load(Ordering::SeqCst);
460
461        let broker = MemoryBroker::new();
462        let mut registry = TaskRegistry::new();
463        registry.register::<CountTask>();
464
465        // Enqueue 3 tasks
466        for _ in 0..3 {
467            broker
468                .enqueue(TaskMessage::new(
469                    "count",
470                    "default",
471                    serde_json::json!(null),
472                ))
473                .await
474                .unwrap();
475        }
476
477        let config = WorkerConfig {
478            concurrency: 2,
479            queues: vec!["default".to_string()],
480            shutdown_timeout: Duration::from_secs(5),
481            dequeue_timeout: Duration::from_millis(100),
482        };
483
484        let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config);
485        let cancel = worker.cancel_token();
486
487        // Run worker in background
488        let handle = tokio::spawn(async move {
489            worker.run().await;
490        });
491
492        // Wait for tasks to be processed
493        tokio::time::sleep(Duration::from_millis(500)).await;
494        cancel.cancel();
495        handle.await.unwrap();
496
497        let after = COUNTER.load(Ordering::SeqCst);
498        assert_eq!(after - before, 3);
499    }
500
501    #[derive(Debug, Serialize, Deserialize)]
502    struct FailTask;
503
504    #[async_trait]
505    impl Task for FailTask {
506        const NAME: &'static str = "fail_task";
507        const MAX_RETRIES: u32 = 0;
508        type Output = ();
509
510        async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
511            Err(KojinError::TaskFailed("intentional failure".into()))
512        }
513    }
514
515    #[tokio::test]
516    async fn worker_dead_letters_after_max_retries() {
517        let broker = MemoryBroker::new();
518        let mut registry = TaskRegistry::new();
519        registry.register::<FailTask>();
520
521        broker
522            .enqueue(
523                TaskMessage::new("fail_task", "default", serde_json::json!(null))
524                    .with_max_retries(0),
525            )
526            .await
527            .unwrap();
528
529        let config = WorkerConfig {
530            concurrency: 1,
531            queues: vec!["default".to_string()],
532            shutdown_timeout: Duration::from_secs(5),
533            dequeue_timeout: Duration::from_millis(100),
534        };
535
536        let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config);
537        let cancel = worker.cancel_token();
538
539        let handle = tokio::spawn(async move {
540            worker.run().await;
541        });
542
543        tokio::time::sleep(Duration::from_millis(500)).await;
544        cancel.cancel();
545        handle.await.unwrap();
546
547        assert_eq!(broker.dlq_len("default").await, 1);
548    }
549
550    #[tokio::test]
551    async fn worker_graceful_shutdown() {
552        let broker = MemoryBroker::new();
553        let registry = TaskRegistry::new();
554
555        let config = WorkerConfig {
556            concurrency: 1,
557            queues: vec!["default".to_string()],
558            shutdown_timeout: Duration::from_secs(1),
559            dequeue_timeout: Duration::from_millis(100),
560        };
561
562        let worker = Worker::new(broker, registry, TaskContext::new(), config);
563        let cancel = worker.cancel_token();
564
565        let handle = tokio::spawn(async move {
566            worker.run().await;
567        });
568
569        // Cancel immediately
570        cancel.cancel();
571        // Should complete within shutdown timeout
572        tokio::time::timeout(Duration::from_secs(3), handle)
573            .await
574            .expect("Worker should shutdown within timeout")
575            .unwrap();
576    }
577
578    #[derive(Debug, Serialize, Deserialize)]
579    struct AddTask {
580        a: i32,
581        b: i32,
582    }
583
584    #[async_trait]
585    impl Task for AddTask {
586        const NAME: &'static str = "add";
587        const MAX_RETRIES: u32 = 0;
588        type Output = i32;
589
590        async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
591            Ok(self.a + self.b)
592        }
593    }
594
595    #[tokio::test]
596    async fn worker_stores_results() {
597        let broker = MemoryBroker::new();
598        let backend = Arc::new(MemoryResultBackend::new());
599        let mut registry = TaskRegistry::new();
600        registry.register::<AddTask>();
601
602        let msg = TaskMessage::new("add", "default", serde_json::json!({"a": 3, "b": 4}));
603        let task_id = msg.id;
604        broker.enqueue(msg).await.unwrap();
605
606        let config = WorkerConfig {
607            concurrency: 1,
608            queues: vec!["default".to_string()],
609            shutdown_timeout: Duration::from_secs(5),
610            dequeue_timeout: Duration::from_millis(100),
611        };
612
613        let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config)
614            .with_result_backend(backend.clone());
615        let cancel = worker.cancel_token();
616
617        let handle = tokio::spawn(async move {
618            worker.run().await;
619        });
620
621        tokio::time::sleep(Duration::from_millis(500)).await;
622        cancel.cancel();
623        handle.await.unwrap();
624
625        let result = backend.get(&task_id).await.unwrap();
626        assert_eq!(result, Some(serde_json::json!(7)));
627    }
628
629    static CHAIN_COUNTER: AtomicU32 = AtomicU32::new(0);
630
631    #[derive(Debug, Serialize, Deserialize)]
632    struct ChainCountTask;
633
634    #[async_trait]
635    impl Task for ChainCountTask {
636        const NAME: &'static str = "chain_count";
637        const MAX_RETRIES: u32 = 0;
638        type Output = u32;
639
640        async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
641            let val = CHAIN_COUNTER.fetch_add(1, Ordering::SeqCst) + 1;
642            Ok(val)
643        }
644    }
645
646    #[tokio::test]
647    async fn worker_chain_continuation() {
648        let broker = MemoryBroker::new();
649        let backend = Arc::new(MemoryResultBackend::new());
650        let mut registry = TaskRegistry::new();
651        registry.register::<ChainCountTask>();
652
653        let before = CHAIN_COUNTER.load(Ordering::SeqCst);
654
655        // Build a chain: chain_count -> chain_count -> chain_count
656        let remaining = vec![
657            crate::signature::Signature::new("chain_count", "default", serde_json::json!(null)),
658            crate::signature::Signature::new("chain_count", "default", serde_json::json!(null)),
659        ];
660        let mut msg =
661            TaskMessage::new("chain_count", "default", serde_json::json!(null)).with_max_retries(0);
662        msg.headers.insert(
663            "kojin.chain_next".to_string(),
664            serde_json::to_string(&remaining).unwrap(),
665        );
666        broker.enqueue(msg).await.unwrap();
667
668        let config = WorkerConfig {
669            concurrency: 1,
670            queues: vec!["default".to_string()],
671            shutdown_timeout: Duration::from_secs(5),
672            dequeue_timeout: Duration::from_millis(100),
673        };
674
675        let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config)
676            .with_result_backend(backend);
677        let cancel = worker.cancel_token();
678
679        let handle = tokio::spawn(async move {
680            worker.run().await;
681        });
682
683        tokio::time::sleep(Duration::from_millis(1500)).await;
684        cancel.cancel();
685        handle.await.unwrap();
686
687        // All 3 tasks in the chain should have executed
688        let after = CHAIN_COUNTER.load(Ordering::SeqCst);
689        assert_eq!(after - before, 3);
690    }
691}