Skip to main content

kojin_core/
worker.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use tokio::sync::Semaphore;
5use tokio_util::sync::CancellationToken;
6
7use crate::broker::Broker;
8use crate::context::TaskContext;
9use crate::error::KojinError;
10use crate::message::TaskMessage;
11use crate::middleware::Middleware;
12
13use crate::registry::TaskRegistry;
14use crate::state::TaskState;
15
16/// Worker configuration.
17#[derive(Debug, Clone)]
18pub struct WorkerConfig {
19    /// Max concurrent tasks.
20    pub concurrency: usize,
21    /// Queue names to consume from.
22    pub queues: Vec<String>,
23    /// How long to wait for in-flight tasks during shutdown.
24    pub shutdown_timeout: Duration,
25    /// Dequeue poll timeout.
26    pub dequeue_timeout: Duration,
27}
28
29impl Default for WorkerConfig {
30    fn default() -> Self {
31        Self {
32            concurrency: 10,
33            queues: vec!["default".to_string()],
34            shutdown_timeout: Duration::from_secs(30),
35            dequeue_timeout: Duration::from_secs(5),
36        }
37    }
38}
39
40/// The worker loop that dequeues and executes tasks.
41pub struct Worker<B: Broker> {
42    broker: Arc<B>,
43    registry: Arc<TaskRegistry>,
44    middlewares: Arc<Vec<Box<dyn Middleware>>>,
45    context: Arc<TaskContext>,
46    config: WorkerConfig,
47    cancel: CancellationToken,
48}
49
50impl<B: Broker> Worker<B> {
51    pub fn new(
52        broker: B,
53        registry: TaskRegistry,
54        context: TaskContext,
55        config: WorkerConfig,
56    ) -> Self {
57        Self {
58            broker: Arc::new(broker),
59            registry: Arc::new(registry),
60            middlewares: Arc::new(Vec::new()),
61            context: Arc::new(context),
62            config,
63            cancel: CancellationToken::new(),
64        }
65    }
66
67    /// Add middleware to the worker pipeline.
68    pub fn with_middleware(mut self, middleware: impl Middleware) -> Self {
69        Arc::get_mut(&mut self.middlewares)
70            .expect("middleware can only be added before starting")
71            .push(Box::new(middleware));
72        self
73    }
74
75    /// Add a boxed middleware to the worker pipeline.
76    pub fn with_middleware_boxed(mut self, middleware: Box<dyn Middleware>) -> Self {
77        Arc::get_mut(&mut self.middlewares)
78            .expect("middleware can only be added before starting")
79            .push(middleware);
80        self
81    }
82
83    /// Get the cancellation token for external shutdown triggering.
84    pub fn cancel_token(&self) -> CancellationToken {
85        self.cancel.clone()
86    }
87
88    /// Run the worker loop until shutdown.
89    pub async fn run(&self) {
90        let semaphore = Arc::new(Semaphore::new(self.config.concurrency));
91
92        tracing::info!(
93            concurrency = self.config.concurrency,
94            queues = ?self.config.queues,
95            "Worker starting"
96        );
97
98        loop {
99            if self.cancel.is_cancelled() {
100                break;
101            }
102
103            // Acquire a concurrency permit
104            let permit = tokio::select! {
105                permit = semaphore.clone().acquire_owned() => {
106                    match permit {
107                        Ok(p) => p,
108                        Err(_) => break, // Semaphore closed
109                    }
110                }
111                _ = self.cancel.cancelled() => break,
112            };
113
114            // Dequeue a message
115            let message = tokio::select! {
116                result = self.broker.dequeue(&self.config.queues, self.config.dequeue_timeout) => {
117                    match result {
118                        Ok(Some(msg)) => msg,
119                        Ok(None) => {
120                            drop(permit);
121                            continue; // Timeout, try again
122                        }
123                        Err(e) => {
124                            tracing::error!(error = %e, "Failed to dequeue");
125                            drop(permit);
126                            tokio::time::sleep(Duration::from_secs(1)).await;
127                            continue;
128                        }
129                    }
130                }
131                _ = self.cancel.cancelled() => {
132                    drop(permit);
133                    break;
134                }
135            };
136
137            // Spawn task execution
138            let broker = self.broker.clone();
139            let registry = self.registry.clone();
140            let middlewares = self.middlewares.clone();
141            let context = self.context.clone();
142
143            tokio::spawn(async move {
144                let _permit = permit; // Hold permit until done
145                execute_task(broker, registry, middlewares, context, message).await;
146            });
147        }
148
149        // Graceful shutdown: wait for in-flight tasks to complete
150        tracing::info!("Worker shutting down, waiting for in-flight tasks...");
151        let drain_deadline = tokio::time::Instant::now() + self.config.shutdown_timeout;
152        loop {
153            // When all permits are available, no tasks are in-flight
154            if semaphore.available_permits() == self.config.concurrency {
155                break;
156            }
157            if tokio::time::Instant::now() >= drain_deadline {
158                tracing::warn!("Shutdown timeout reached, some tasks may not have completed");
159                break;
160            }
161            tokio::time::sleep(Duration::from_millis(100)).await;
162        }
163
164        tracing::info!("Worker stopped");
165    }
166}
167
168async fn execute_task<B: Broker>(
169    broker: Arc<B>,
170    registry: Arc<TaskRegistry>,
171    middlewares: Arc<Vec<Box<dyn Middleware>>>,
172    context: Arc<TaskContext>,
173    mut message: TaskMessage,
174) {
175    let task_id = message.id;
176    let task_name = message.task_name.clone();
177
178    tracing::info!(task_id = %task_id, task_name = %task_name, "Executing task");
179    message.state = TaskState::Started;
180
181    // Run before middleware
182    for mw in middlewares.iter() {
183        if let Err(e) = mw.before(&message).await {
184            tracing::error!(task_id = %task_id, error = %e, "Middleware before() failed");
185            handle_failure(broker, middlewares, message, e).await;
186            return;
187        }
188    }
189
190    // Dispatch to handler
191    match registry
192        .dispatch(&task_name, message.payload.clone(), context)
193        .await
194    {
195        Ok(result) => {
196            // Run after middleware
197            for mw in middlewares.iter() {
198                if let Err(e) = mw.after(&message, &result).await {
199                    tracing::warn!(task_id = %task_id, error = %e, "Middleware after() failed");
200                }
201            }
202            message.state = TaskState::Success;
203            if let Err(e) = broker.ack(&task_id).await {
204                tracing::error!(task_id = %task_id, error = %e, "Failed to ack task");
205            }
206            tracing::info!(task_id = %task_id, task_name = %task_name, "Task completed successfully");
207        }
208        Err(e) => {
209            tracing::error!(task_id = %task_id, task_name = %task_name, error = %e, "Task failed");
210            handle_failure(broker, middlewares, message, e).await;
211        }
212    }
213}
214
215async fn handle_failure<B: Broker>(
216    broker: Arc<B>,
217    middlewares: Arc<Vec<Box<dyn Middleware>>>,
218    mut message: TaskMessage,
219    error: KojinError,
220) {
221    let task_id = message.id;
222
223    // Run on_error middleware
224    for mw in middlewares.iter() {
225        if let Err(e) = mw.on_error(&message, &error).await {
226            tracing::warn!(task_id = %task_id, error = %e, "Middleware on_error() failed");
227        }
228    }
229
230    // Retry or dead-letter
231    if message.retries < message.max_retries {
232        message.retries += 1;
233        message.state = TaskState::Retry;
234        message.updated_at = chrono::Utc::now();
235
236        let backoff_delay =
237            crate::backoff::BackoffStrategy::default().delay_for(message.retries - 1);
238        tracing::info!(
239            task_id = %task_id,
240            retry = message.retries,
241            max_retries = message.max_retries,
242            backoff = ?backoff_delay,
243            "Retrying task"
244        );
245
246        // Simple sleep-based backoff (for MemoryBroker; Redis uses scheduled queue)
247        tokio::time::sleep(backoff_delay).await;
248
249        if let Err(e) = broker.nack(message).await {
250            tracing::error!(task_id = %task_id, error = %e, "Failed to nack/requeue task");
251        }
252    } else {
253        message.state = TaskState::DeadLettered;
254        message.updated_at = chrono::Utc::now();
255        tracing::warn!(task_id = %task_id, "Max retries exceeded, moving to DLQ");
256
257        if let Err(e) = broker.dead_letter(message).await {
258            tracing::error!(task_id = %task_id, error = %e, "Failed to dead-letter task");
259        }
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::memory_broker::MemoryBroker;
267    use crate::task::Task;
268    use async_trait::async_trait;
269    use serde::{Deserialize, Serialize};
270    use std::sync::atomic::{AtomicU32, Ordering};
271
272    #[derive(Debug, Serialize, Deserialize)]
273    struct CountTask;
274
275    static COUNTER: AtomicU32 = AtomicU32::new(0);
276
277    #[async_trait]
278    impl Task for CountTask {
279        const NAME: &'static str = "count";
280        const MAX_RETRIES: u32 = 0;
281        type Output = ();
282
283        async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
284            COUNTER.fetch_add(1, Ordering::SeqCst);
285            Ok(())
286        }
287    }
288
289    #[tokio::test]
290    async fn worker_processes_tasks() {
291        COUNTER.store(0, Ordering::SeqCst);
292
293        let broker = MemoryBroker::new();
294        let mut registry = TaskRegistry::new();
295        registry.register::<CountTask>();
296
297        // Enqueue 3 tasks
298        for _ in 0..3 {
299            broker
300                .enqueue(TaskMessage::new(
301                    "count",
302                    "default",
303                    serde_json::json!(null),
304                ))
305                .await
306                .unwrap();
307        }
308
309        let config = WorkerConfig {
310            concurrency: 2,
311            queues: vec!["default".to_string()],
312            shutdown_timeout: Duration::from_secs(5),
313            dequeue_timeout: Duration::from_millis(100),
314        };
315
316        let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config);
317        let cancel = worker.cancel_token();
318
319        // Run worker in background
320        let handle = tokio::spawn(async move {
321            worker.run().await;
322        });
323
324        // Wait for tasks to be processed
325        tokio::time::sleep(Duration::from_millis(500)).await;
326        cancel.cancel();
327        handle.await.unwrap();
328
329        assert_eq!(COUNTER.load(Ordering::SeqCst), 3);
330    }
331
332    #[derive(Debug, Serialize, Deserialize)]
333    struct FailTask;
334
335    #[async_trait]
336    impl Task for FailTask {
337        const NAME: &'static str = "fail_task";
338        const MAX_RETRIES: u32 = 0;
339        type Output = ();
340
341        async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
342            Err(KojinError::TaskFailed("intentional failure".into()))
343        }
344    }
345
346    #[tokio::test]
347    async fn worker_dead_letters_after_max_retries() {
348        let broker = MemoryBroker::new();
349        let mut registry = TaskRegistry::new();
350        registry.register::<FailTask>();
351
352        broker
353            .enqueue(
354                TaskMessage::new("fail_task", "default", serde_json::json!(null))
355                    .with_max_retries(0),
356            )
357            .await
358            .unwrap();
359
360        let config = WorkerConfig {
361            concurrency: 1,
362            queues: vec!["default".to_string()],
363            shutdown_timeout: Duration::from_secs(5),
364            dequeue_timeout: Duration::from_millis(100),
365        };
366
367        let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config);
368        let cancel = worker.cancel_token();
369
370        let handle = tokio::spawn(async move {
371            worker.run().await;
372        });
373
374        tokio::time::sleep(Duration::from_millis(500)).await;
375        cancel.cancel();
376        handle.await.unwrap();
377
378        assert_eq!(broker.dlq_len("default").await, 1);
379    }
380
381    #[tokio::test]
382    async fn worker_graceful_shutdown() {
383        let broker = MemoryBroker::new();
384        let registry = TaskRegistry::new();
385
386        let config = WorkerConfig {
387            concurrency: 1,
388            queues: vec!["default".to_string()],
389            shutdown_timeout: Duration::from_secs(1),
390            dequeue_timeout: Duration::from_millis(100),
391        };
392
393        let worker = Worker::new(broker, registry, TaskContext::new(), config);
394        let cancel = worker.cancel_token();
395
396        let handle = tokio::spawn(async move {
397            worker.run().await;
398        });
399
400        // Cancel immediately
401        cancel.cancel();
402        // Should complete within shutdown timeout
403        tokio::time::timeout(Duration::from_secs(3), handle)
404            .await
405            .expect("Worker should shutdown within timeout")
406            .unwrap();
407    }
408}