1use std::sync::Arc;
2use std::time::Duration;
3
4use tokio::sync::Semaphore;
5use tokio_util::sync::CancellationToken;
6
7use crate::broker::Broker;
8use crate::context::TaskContext;
9use crate::error::KojinError;
10use crate::message::TaskMessage;
11use crate::middleware::Middleware;
12
13use crate::registry::TaskRegistry;
14use crate::state::TaskState;
15
16#[derive(Debug, Clone)]
18pub struct WorkerConfig {
19 pub concurrency: usize,
21 pub queues: Vec<String>,
23 pub shutdown_timeout: Duration,
25 pub dequeue_timeout: Duration,
27}
28
29impl Default for WorkerConfig {
30 fn default() -> Self {
31 Self {
32 concurrency: 10,
33 queues: vec!["default".to_string()],
34 shutdown_timeout: Duration::from_secs(30),
35 dequeue_timeout: Duration::from_secs(5),
36 }
37 }
38}
39
40pub struct Worker<B: Broker> {
42 broker: Arc<B>,
43 registry: Arc<TaskRegistry>,
44 middlewares: Arc<Vec<Box<dyn Middleware>>>,
45 context: Arc<TaskContext>,
46 config: WorkerConfig,
47 cancel: CancellationToken,
48}
49
50impl<B: Broker> Worker<B> {
51 pub fn new(
52 broker: B,
53 registry: TaskRegistry,
54 context: TaskContext,
55 config: WorkerConfig,
56 ) -> Self {
57 Self {
58 broker: Arc::new(broker),
59 registry: Arc::new(registry),
60 middlewares: Arc::new(Vec::new()),
61 context: Arc::new(context),
62 config,
63 cancel: CancellationToken::new(),
64 }
65 }
66
67 pub fn with_middleware(mut self, middleware: impl Middleware) -> Self {
69 Arc::get_mut(&mut self.middlewares)
70 .expect("middleware can only be added before starting")
71 .push(Box::new(middleware));
72 self
73 }
74
75 pub fn with_middleware_boxed(mut self, middleware: Box<dyn Middleware>) -> Self {
77 Arc::get_mut(&mut self.middlewares)
78 .expect("middleware can only be added before starting")
79 .push(middleware);
80 self
81 }
82
83 pub fn cancel_token(&self) -> CancellationToken {
85 self.cancel.clone()
86 }
87
88 pub async fn run(&self) {
90 let semaphore = Arc::new(Semaphore::new(self.config.concurrency));
91
92 tracing::info!(
93 concurrency = self.config.concurrency,
94 queues = ?self.config.queues,
95 "Worker starting"
96 );
97
98 loop {
99 if self.cancel.is_cancelled() {
100 break;
101 }
102
103 let permit = tokio::select! {
105 permit = semaphore.clone().acquire_owned() => {
106 match permit {
107 Ok(p) => p,
108 Err(_) => break, }
110 }
111 _ = self.cancel.cancelled() => break,
112 };
113
114 let message = tokio::select! {
116 result = self.broker.dequeue(&self.config.queues, self.config.dequeue_timeout) => {
117 match result {
118 Ok(Some(msg)) => msg,
119 Ok(None) => {
120 drop(permit);
121 continue; }
123 Err(e) => {
124 tracing::error!(error = %e, "Failed to dequeue");
125 drop(permit);
126 tokio::time::sleep(Duration::from_secs(1)).await;
127 continue;
128 }
129 }
130 }
131 _ = self.cancel.cancelled() => {
132 drop(permit);
133 break;
134 }
135 };
136
137 let broker = self.broker.clone();
139 let registry = self.registry.clone();
140 let middlewares = self.middlewares.clone();
141 let context = self.context.clone();
142
143 tokio::spawn(async move {
144 let _permit = permit; execute_task(broker, registry, middlewares, context, message).await;
146 });
147 }
148
149 tracing::info!("Worker shutting down, waiting for in-flight tasks...");
151 let drain_deadline = tokio::time::Instant::now() + self.config.shutdown_timeout;
152 loop {
153 if semaphore.available_permits() == self.config.concurrency {
155 break;
156 }
157 if tokio::time::Instant::now() >= drain_deadline {
158 tracing::warn!("Shutdown timeout reached, some tasks may not have completed");
159 break;
160 }
161 tokio::time::sleep(Duration::from_millis(100)).await;
162 }
163
164 tracing::info!("Worker stopped");
165 }
166}
167
168async fn execute_task<B: Broker>(
169 broker: Arc<B>,
170 registry: Arc<TaskRegistry>,
171 middlewares: Arc<Vec<Box<dyn Middleware>>>,
172 context: Arc<TaskContext>,
173 mut message: TaskMessage,
174) {
175 let task_id = message.id;
176 let task_name = message.task_name.clone();
177
178 tracing::info!(task_id = %task_id, task_name = %task_name, "Executing task");
179 message.state = TaskState::Started;
180
181 for mw in middlewares.iter() {
183 if let Err(e) = mw.before(&message).await {
184 tracing::error!(task_id = %task_id, error = %e, "Middleware before() failed");
185 handle_failure(broker, middlewares, message, e).await;
186 return;
187 }
188 }
189
190 match registry
192 .dispatch(&task_name, message.payload.clone(), context)
193 .await
194 {
195 Ok(result) => {
196 for mw in middlewares.iter() {
198 if let Err(e) = mw.after(&message, &result).await {
199 tracing::warn!(task_id = %task_id, error = %e, "Middleware after() failed");
200 }
201 }
202 message.state = TaskState::Success;
203 if let Err(e) = broker.ack(&task_id).await {
204 tracing::error!(task_id = %task_id, error = %e, "Failed to ack task");
205 }
206 tracing::info!(task_id = %task_id, task_name = %task_name, "Task completed successfully");
207 }
208 Err(e) => {
209 tracing::error!(task_id = %task_id, task_name = %task_name, error = %e, "Task failed");
210 handle_failure(broker, middlewares, message, e).await;
211 }
212 }
213}
214
215async fn handle_failure<B: Broker>(
216 broker: Arc<B>,
217 middlewares: Arc<Vec<Box<dyn Middleware>>>,
218 mut message: TaskMessage,
219 error: KojinError,
220) {
221 let task_id = message.id;
222
223 for mw in middlewares.iter() {
225 if let Err(e) = mw.on_error(&message, &error).await {
226 tracing::warn!(task_id = %task_id, error = %e, "Middleware on_error() failed");
227 }
228 }
229
230 if message.retries < message.max_retries {
232 message.retries += 1;
233 message.state = TaskState::Retry;
234 message.updated_at = chrono::Utc::now();
235
236 let backoff_delay =
237 crate::backoff::BackoffStrategy::default().delay_for(message.retries - 1);
238 tracing::info!(
239 task_id = %task_id,
240 retry = message.retries,
241 max_retries = message.max_retries,
242 backoff = ?backoff_delay,
243 "Retrying task"
244 );
245
246 tokio::time::sleep(backoff_delay).await;
248
249 if let Err(e) = broker.nack(message).await {
250 tracing::error!(task_id = %task_id, error = %e, "Failed to nack/requeue task");
251 }
252 } else {
253 message.state = TaskState::DeadLettered;
254 message.updated_at = chrono::Utc::now();
255 tracing::warn!(task_id = %task_id, "Max retries exceeded, moving to DLQ");
256
257 if let Err(e) = broker.dead_letter(message).await {
258 tracing::error!(task_id = %task_id, error = %e, "Failed to dead-letter task");
259 }
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::memory_broker::MemoryBroker;
267 use crate::task::Task;
268 use async_trait::async_trait;
269 use serde::{Deserialize, Serialize};
270 use std::sync::atomic::{AtomicU32, Ordering};
271
272 #[derive(Debug, Serialize, Deserialize)]
273 struct CountTask;
274
275 static COUNTER: AtomicU32 = AtomicU32::new(0);
276
277 #[async_trait]
278 impl Task for CountTask {
279 const NAME: &'static str = "count";
280 const MAX_RETRIES: u32 = 0;
281 type Output = ();
282
283 async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
284 COUNTER.fetch_add(1, Ordering::SeqCst);
285 Ok(())
286 }
287 }
288
289 #[tokio::test]
290 async fn worker_processes_tasks() {
291 COUNTER.store(0, Ordering::SeqCst);
292
293 let broker = MemoryBroker::new();
294 let mut registry = TaskRegistry::new();
295 registry.register::<CountTask>();
296
297 for _ in 0..3 {
299 broker
300 .enqueue(TaskMessage::new(
301 "count",
302 "default",
303 serde_json::json!(null),
304 ))
305 .await
306 .unwrap();
307 }
308
309 let config = WorkerConfig {
310 concurrency: 2,
311 queues: vec!["default".to_string()],
312 shutdown_timeout: Duration::from_secs(5),
313 dequeue_timeout: Duration::from_millis(100),
314 };
315
316 let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config);
317 let cancel = worker.cancel_token();
318
319 let handle = tokio::spawn(async move {
321 worker.run().await;
322 });
323
324 tokio::time::sleep(Duration::from_millis(500)).await;
326 cancel.cancel();
327 handle.await.unwrap();
328
329 assert_eq!(COUNTER.load(Ordering::SeqCst), 3);
330 }
331
332 #[derive(Debug, Serialize, Deserialize)]
333 struct FailTask;
334
335 #[async_trait]
336 impl Task for FailTask {
337 const NAME: &'static str = "fail_task";
338 const MAX_RETRIES: u32 = 0;
339 type Output = ();
340
341 async fn run(&self, _ctx: &TaskContext) -> crate::error::TaskResult<Self::Output> {
342 Err(KojinError::TaskFailed("intentional failure".into()))
343 }
344 }
345
346 #[tokio::test]
347 async fn worker_dead_letters_after_max_retries() {
348 let broker = MemoryBroker::new();
349 let mut registry = TaskRegistry::new();
350 registry.register::<FailTask>();
351
352 broker
353 .enqueue(
354 TaskMessage::new("fail_task", "default", serde_json::json!(null))
355 .with_max_retries(0),
356 )
357 .await
358 .unwrap();
359
360 let config = WorkerConfig {
361 concurrency: 1,
362 queues: vec!["default".to_string()],
363 shutdown_timeout: Duration::from_secs(5),
364 dequeue_timeout: Duration::from_millis(100),
365 };
366
367 let worker = Worker::new(broker.clone(), registry, TaskContext::new(), config);
368 let cancel = worker.cancel_token();
369
370 let handle = tokio::spawn(async move {
371 worker.run().await;
372 });
373
374 tokio::time::sleep(Duration::from_millis(500)).await;
375 cancel.cancel();
376 handle.await.unwrap();
377
378 assert_eq!(broker.dlq_len("default").await, 1);
379 }
380
381 #[tokio::test]
382 async fn worker_graceful_shutdown() {
383 let broker = MemoryBroker::new();
384 let registry = TaskRegistry::new();
385
386 let config = WorkerConfig {
387 concurrency: 1,
388 queues: vec!["default".to_string()],
389 shutdown_timeout: Duration::from_secs(1),
390 dequeue_timeout: Duration::from_millis(100),
391 };
392
393 let worker = Worker::new(broker, registry, TaskContext::new(), config);
394 let cancel = worker.cancel_token();
395
396 let handle = tokio::spawn(async move {
397 worker.run().await;
398 });
399
400 cancel.cancel();
402 tokio::time::timeout(Duration::from_secs(3), handle)
404 .await
405 .expect("Worker should shutdown within timeout")
406 .unwrap();
407 }
408}