1use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::Semaphore;
6use tracing::{error, info, warn};
7
8use crate::backend::QueueBackend;
9use crate::job::{Job, JobResult, JobStatus};
10
11#[derive(Clone, Copy)]
12pub struct WorkerConfig {
13 pub max_concurrency: usize,
14 pub poll_interval: Duration,
15}
16
17impl Default for WorkerConfig {
18 fn default() -> Self {
19 Self {
20 max_concurrency: 5,
21 poll_interval: Duration::from_millis(100),
22 }
23 }
24}
25
26use serde::de::DeserializeOwned;
27
28use std::sync::RwLock;
29
30pub struct WorkerPool<B: QueueBackend + ?Sized> {
31 pub backend: Arc<B>,
32 config: WorkerConfig,
33 registry: Arc<JobRegistry>,
34}
35
36type JobFactory = Box<dyn Fn(serde_json::Value) -> Box<dyn Job> + Send + Sync>;
37
38struct JobRegistry {
39 factories: RwLock<std::collections::HashMap<String, JobFactory>>,
40}
41
42impl<B: QueueBackend + 'static> WorkerPool<B> {
43 pub fn new(backend: B, config: WorkerConfig) -> Self {
44 Self::new_with_arc(Arc::new(backend), config)
45 }
46}
47
48impl<B: QueueBackend + ?Sized + 'static> WorkerPool<B> {
49 pub fn new_with_arc(backend: Arc<B>, config: WorkerConfig) -> Self {
51 Self {
52 backend,
53 config,
54 registry: Arc::new(JobRegistry {
55 factories: RwLock::new(std::collections::HashMap::new()),
56 }),
57 }
58 }
59
60 pub fn register_job_type<J: Job + DeserializeOwned + 'static>(&self, name: &str) {
62 let factory = Box::new(|payload: serde_json::Value| {
63 let job: J =
64 serde_json::from_value(payload).expect("Job payload deserialization failed");
65 Box::new(job) as Box<dyn Job>
66 });
67
68 self.registry
69 .factories
70 .write()
71 .expect("Job registry RwLock poisoned")
72 .insert(name.to_string(), factory);
73 }
74
75 pub fn register_job_factory<F>(&self, name: &str, factory: F)
77 where
78 F: Fn(serde_json::Value) -> Box<dyn Job> + Send + Sync + 'static,
79 {
80 self.registry
81 .factories
82 .write()
83 .expect("Job registry RwLock poisoned")
84 .insert(name.to_string(), Box::new(factory));
85 }
86
87 pub async fn start(&self) {
88 let semaphore = Arc::new(Semaphore::new(self.config.max_concurrency));
89
90 info!(
91 "Worker pool started with concurrency {}",
92 self.config.max_concurrency
93 );
94
95 loop {
96 if semaphore.available_permits() > 0 {
97 match self.backend.dequeue().await {
98 Ok(Some(entry)) => {
99 let permit = semaphore
100 .clone()
101 .acquire_owned()
102 .await
103 .expect("Worker semaphore closed unexpectedly");
104 let backend = self.backend.clone();
105 let registry = self.registry.clone();
106
107 tokio::spawn(async move {
108 let job_opt = {
109 let factories = registry
110 .factories
111 .read()
112 .expect("Job registry RwLock poisoned");
113 factories
114 .get(&entry.job_type)
115 .map(|f| f(entry.payload.clone()))
116 };
117
118 if let Some(mut job) = job_opt {
119 info!("Processing job {} ({})", entry.id, entry.job_type);
120
121 let result = job.execute().await;
122
123 match result {
124 JobResult::Success => {
125 let _ = backend
126 .update_status(
127 entry.id,
128 JobStatus::Completed,
129 None,
130 None,
131 )
132 .await;
133 }
134 JobResult::Retry(e) => {
135 let delay = job.backoff_strategy().delay(entry.attempts);
137 let delay_secs = delay.as_secs();
138
139 info!(
140 job_id = %entry.id,
141 attempt = entry.attempts + 1,
142 delay_secs = delay_secs,
143 "Job failed, scheduling retry with backoff"
144 );
145
146 let _ = backend
147 .update_status(
148 entry.id,
149 JobStatus::Failed(entry.attempts + 1),
150 Some(e),
151 Some(delay_secs),
152 )
153 .await;
154 }
155 JobResult::Fatal(e) => {
156 let _ = backend
157 .update_status(
158 entry.id,
159 JobStatus::DeadLetter,
160 Some(e),
161 None,
162 )
163 .await;
164 }
165 }
166 } else {
167 warn!("No handler registered for job type: {}", entry.job_type);
168 let _ = backend
169 .update_status(
170 entry.id,
171 JobStatus::DeadLetter,
172 Some(format!("No handler for {}", entry.job_type)),
173 None,
174 )
175 .await;
176 }
177
178 drop(permit);
179 });
180 }
181 Ok(None) => {
182 tokio::time::sleep(self.config.poll_interval).await;
184 }
185 Err(e) => {
186 error!("Queue error: {}", e);
187 tokio::time::sleep(Duration::from_secs(1)).await;
188 }
189 }
190 } else {
191 tokio::time::sleep(Duration::from_millis(50)).await;
193 }
194 }
195 }
196}