1use crate::{Queue, QueueBackend, QueueResult, QueueError, JobEntry, JobResult, QueueConfig};
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::{RwLock, Semaphore};
8use tokio::time::{interval, Duration, timeout};
9use tracing::{info, warn, error, debug};
10use futures::future::BoxFuture;
11
12pub type JobHandler = Arc<dyn Fn(JobEntry) -> BoxFuture<'static, JobResult<()>> + Send + Sync>;
14
15pub struct WorkerRegistry {
17 handlers: RwLock<HashMap<String, JobHandler>>,
18}
19
20impl WorkerRegistry {
21 pub fn new() -> Self {
23 Self {
24 handlers: RwLock::new(HashMap::new()),
25 }
26 }
27
28 pub async fn register<T: crate::Job + JobTypeProvider + 'static>(&self, handler: impl JobProcessor<T> + Send + Sync + 'static) {
30 let job_type = T::default_job_type();
31 let handler: JobHandler = Arc::new(move |entry: JobEntry| {
32 let handler = handler.clone();
33 Box::pin(async move {
34 handler.process(entry).await
35 })
36 });
37
38 self.handlers.write().await.insert(job_type.to_string(), handler);
39 info!("Registered job handler for type: {}", job_type);
40 }
41
42 pub async fn get_handler(&self, job_type: &str) -> Option<JobHandler> {
44 self.handlers.read().await.get(job_type).cloned()
45 }
46
47 pub async fn list_job_types(&self) -> Vec<String> {
49 self.handlers.read().await.keys().cloned().collect()
50 }
51}
52
53impl Default for WorkerRegistry {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59#[async_trait]
61pub trait JobProcessor<T: crate::Job>: Clone + Send + Sync {
62 async fn process(&self, entry: JobEntry) -> JobResult<()>;
64}
65
66#[async_trait]
68impl<T, F, Fut> JobProcessor<T> for F
69where
70 T: crate::Job + 'static,
71 F: Fn(T) -> Fut + Clone + Send + Sync + 'static,
72 Fut: std::future::Future<Output = JobResult<()>> + Send + 'static,
73{
74 async fn process(&self, entry: JobEntry) -> JobResult<()> {
75 let job: T = serde_json::from_value(entry.payload.clone())?;
76 self(job).await
77 }
78}
79
80pub struct Worker<B: QueueBackend> {
82 queue: Arc<Queue<B>>,
83 registry: Arc<WorkerRegistry>,
84 config: QueueConfig,
85 concurrency_limiter: Arc<Semaphore>,
86}
87
88impl<B: QueueBackend + 'static> Worker<B> {
89 pub fn new(queue: Queue<B>, registry: WorkerRegistry, config: QueueConfig) -> Self {
91 let concurrency_limiter = Arc::new(Semaphore::new(*config.get_max_workers()));
92
93 Self {
94 queue: Arc::new(queue),
95 registry: Arc::new(registry),
96 config,
97 concurrency_limiter,
98 }
99 }
100
101 pub async fn start(&self) -> QueueResult<()> {
103 info!("Starting worker with {} max concurrent jobs", self.config.get_max_workers());
104
105 let mut poll_interval = interval(*self.config.get_poll_interval());
106
107 loop {
108 poll_interval.tick().await;
109
110 match self.queue.dequeue().await {
112 Ok(Some(job_entry)) => {
113 let permit = match self.concurrency_limiter.clone().try_acquire_owned() {
114 Ok(permit) => permit,
115 Err(_) => {
116 debug!("No available worker slots, skipping job processing");
118 continue;
119 }
120 };
121
122 let queue = self.queue.clone();
123 let registry = self.registry.clone();
124 let job_timeout = *self.config.get_default_timeout();
125
126 tokio::spawn(async move {
128 let _permit = permit; let job_id = job_entry.id();
131 let job_type = job_entry.job_type().to_string();
132
133 debug!("Processing job {} of type {}", job_id, job_type);
134
135 let result = if let Some(handler) = registry.get_handler(&job_type).await {
136 match timeout(job_timeout, handler(job_entry)).await {
138 Ok(result) => result,
139 Err(_) => {
140 error!("Job {} timed out after {:?}", job_id, job_timeout);
141 Err(Box::new(QueueError::Timeout) as Box<dyn std::error::Error + Send + Sync>)
142 }
143 }
144 } else {
145 error!("No handler registered for job type: {}", job_type);
146 Err(Box::new(QueueError::Configuration(
147 format!("No handler for job type: {}", job_type)
148 )) as Box<dyn std::error::Error + Send + Sync>)
149 };
150
151 match &result {
153 Ok(_) => {
154 info!("Job {} completed successfully", job_id);
155 if let Err(e) = queue.complete(job_id, result).await {
156 error!("Failed to mark job {} as completed: {}", job_id, e);
157 }
158 }
159 Err(e) => {
160 warn!("Job {} failed: {}", job_id, e);
161 if let Err(e2) = queue.complete(job_id, result).await {
162 error!("Failed to mark job {} as completed: {}", job_id, e2);
163 }
164 }
165 }
166 });
167 }
168 Ok(None) => {
169 continue;
171 }
172 Err(e) => {
173 error!("Failed to dequeue job: {}", e);
174 tokio::time::sleep(Duration::from_secs(1)).await;
176 }
177 }
178 }
179 }
180
181 pub async fn start_with_shutdown(&self, mut shutdown: tokio::sync::mpsc::Receiver<()>) -> QueueResult<()> {
183 info!("Starting worker with graceful shutdown support");
184
185 let mut poll_interval = interval(*self.config.get_poll_interval());
186 let mut shutting_down = false;
187
188 loop {
189 tokio::select! {
190 _ = poll_interval.tick(), if !shutting_down => {
191 match self.queue.dequeue().await {
193 Ok(Some(job_entry)) => {
194 let permit = match self.concurrency_limiter.clone().try_acquire_owned() {
195 Ok(permit) => permit,
196 Err(_) => {
197 debug!("No available worker slots, skipping job processing");
198 continue;
199 }
200 };
201
202 let queue = self.queue.clone();
203 let registry = self.registry.clone();
204 let job_timeout = *self.config.get_default_timeout();
205
206 tokio::spawn(async move {
208 let _permit = permit;
209
210 let job_id = job_entry.id();
211 let job_type = job_entry.job_type().to_string();
212
213 debug!("Processing job {} of type {}", job_id, job_type);
214
215 let result = if let Some(handler) = registry.get_handler(&job_type).await {
216 match timeout(job_timeout, handler(job_entry)).await {
217 Ok(result) => result,
218 Err(_) => {
219 error!("Job {} timed out after {:?}", job_id, job_timeout);
220 Err(Box::new(QueueError::Timeout) as Box<dyn std::error::Error + Send + Sync>)
221 }
222 }
223 } else {
224 error!("No handler registered for job type: {}", job_type);
225 Err(Box::new(QueueError::Configuration(
226 format!("No handler for job type: {}", job_type)
227 )) as Box<dyn std::error::Error + Send + Sync>)
228 };
229
230 match &result {
232 Ok(_) => {
233 info!("Job {} completed successfully", job_id);
234 if let Err(e) = queue.complete(job_id, result).await {
235 error!("Failed to mark job {} as completed: {}", job_id, e);
236 }
237 }
238 Err(e) => {
239 warn!("Job {} failed: {}", job_id, e);
240 if let Err(e2) = queue.complete(job_id, result).await {
241 error!("Failed to mark job {} as completed: {}", job_id, e2);
242 }
243 }
244 }
245 });
246 }
247 Ok(None) => {
248 continue;
250 }
251 Err(e) => {
252 error!("Failed to dequeue job: {}", e);
253 tokio::time::sleep(Duration::from_secs(1)).await;
254 }
255 }
256 }
257
258 _ = shutdown.recv() => {
259 info!("Shutdown signal received, stopping new job processing");
260 shutting_down = true;
261
262 let active_jobs = *self.config.get_max_workers() - self.concurrency_limiter.available_permits();
264 info!("Waiting for {} active jobs to complete", active_jobs);
265 while self.concurrency_limiter.available_permits() < *self.config.get_max_workers() {
266 tokio::time::sleep(Duration::from_millis(100)).await;
267 }
268
269 info!("Worker shutdown complete");
270 break;
271 }
272 }
273 }
274
275 Ok(())
276 }
277
278 pub async fn stats(&self) -> QueueResult<WorkerStats> {
280 let queue_stats = self.queue.stats().await?;
281 let available_slots = self.concurrency_limiter.available_permits();
282 let active_jobs = *self.config.get_max_workers() - available_slots;
283 let job_types = self.registry.list_job_types().await;
284
285 Ok(WorkerStats {
286 queue_stats,
287 max_workers: *self.config.get_max_workers(),
288 active_workers: active_jobs,
289 available_workers: available_slots,
290 registered_job_types: job_types,
291 })
292 }
293}
294
295#[derive(Debug, Clone)]
297pub struct WorkerStats {
298 pub queue_stats: crate::QueueStats,
299 pub max_workers: usize,
300 pub active_workers: usize,
301 pub available_workers: usize,
302 pub registered_job_types: Vec<String>,
303}
304
305pub trait JobTypeProvider {
307 fn default_job_type() -> &'static str;
308}
309
310impl<T> JobTypeProvider for T
312where
313 T: crate::Job + Default,
314{
315 fn default_job_type() -> &'static str {
316 T::default().job_type()
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::{backends::MemoryBackend, Priority};
324 use crate::config::QueueConfigBuilder;
325 use serde::{Deserialize, Serialize};
326 use std::sync::atomic::{AtomicU32, Ordering};
327
328 #[derive(Debug, Clone, Serialize, Deserialize, Default)]
329 struct TestJob {
330 id: u32,
331 message: String,
332 }
333
334 #[async_trait]
335 impl crate::Job for TestJob {
336 async fn execute(&self) -> JobResult<()> {
337 Ok(())
338 }
339
340 fn job_type(&self) -> &'static str {
341 "test"
342 }
343 }
344
345 #[derive(Clone)]
346 struct TestJobProcessor {
347 counter: Arc<AtomicU32>,
348 }
349
350 impl TestJobProcessor {
351 fn new() -> Self {
352 Self {
353 counter: Arc::new(AtomicU32::new(0)),
354 }
355 }
356
357 fn get_count(&self) -> u32 {
358 self.counter.load(Ordering::Relaxed)
359 }
360 }
361
362 #[async_trait]
363 impl JobProcessor<TestJob> for TestJobProcessor {
364 async fn process(&self, _entry: JobEntry) -> JobResult<()> {
365 self.counter.fetch_add(1, Ordering::Relaxed);
366 Ok(())
367 }
368 }
369
370 #[tokio::test]
371 async fn test_worker_registry() {
372 let registry = WorkerRegistry::new();
373 let processor = TestJobProcessor::new();
374
375 registry.register::<TestJob>(processor.clone()).await;
376
377 let job_types = registry.list_job_types().await;
378 assert_eq!(job_types, vec!["test"]);
379
380 let handler = registry.get_handler("test").await;
381 assert!(handler.is_some());
382
383 let no_handler = registry.get_handler("nonexistent").await;
384 assert!(no_handler.is_none());
385 }
386
387 #[tokio::test]
388 async fn test_job_processing() {
389 let backend = MemoryBackend::new(QueueConfigBuilder::testing().build().expect("Failed to build config"));
390 let queue = Queue::new(backend);
391 let registry = WorkerRegistry::new();
392 let processor = TestJobProcessor::new();
393
394 registry.register::<TestJob>(processor.clone()).await;
395
396 let job = TestJob {
398 id: 1,
399 message: "test message".to_string(),
400 };
401 let job_id = queue.enqueue(job, Some(Priority::Normal)).await.unwrap();
402
403 let job_entry = queue.dequeue().await.unwrap().unwrap();
405 let handler = registry.get_handler("test").await.unwrap();
406 let result = handler(job_entry).await;
407
408 assert!(result.is_ok());
409 assert_eq!(processor.get_count(), 1);
410
411 queue.complete(job_id, result).await.unwrap();
413
414 let stats = queue.stats().await.unwrap();
416 assert_eq!(stats.completed_jobs, 1);
417 }
418
419 #[tokio::test]
420 async fn test_worker_stats() {
421 let backend = MemoryBackend::new(QueueConfigBuilder::testing().build().expect("Failed to build config"));
422 let queue = Queue::new(backend);
423 let registry = WorkerRegistry::new();
424 let processor = TestJobProcessor::new();
425 let config = QueueConfigBuilder::testing().build().expect("Failed to build config");
426
427 registry.register::<TestJob>(processor).await;
428 let worker = Worker::new(queue, registry, config);
429
430 let stats = worker.stats().await.unwrap();
431 assert_eq!(stats.max_workers, 1); assert_eq!(stats.active_workers, 0);
433 assert_eq!(stats.available_workers, 1);
434 assert_eq!(stats.registered_job_types, vec!["test"]);
435 }
436}