1use crate::{Error, Job, JobPayload, QueueConnection};
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::Semaphore;
10use tracing::{debug, error, info, warn};
11
12#[derive(Debug, Clone)]
14pub struct WorkerConfig {
15 pub queues: Vec<String>,
17 pub max_jobs: usize,
19 pub sleep_duration: Duration,
21 pub stop_on_error: bool,
23}
24
25impl Default for WorkerConfig {
26 fn default() -> Self {
27 Self {
28 queues: vec!["default".to_string()],
29 max_jobs: 10,
30 sleep_duration: Duration::from_secs(1),
31 stop_on_error: false,
32 }
33 }
34}
35
36impl WorkerConfig {
37 pub fn new(queues: Vec<String>) -> Self {
39 Self {
40 queues,
41 ..Default::default()
42 }
43 }
44
45 pub fn max_jobs(mut self, max: usize) -> Self {
47 self.max_jobs = max;
48 self
49 }
50}
51
52type JobHandler =
54 Arc<dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send>> + Send + Sync>;
55
56pub struct Worker {
58 connection: QueueConnection,
60 config: WorkerConfig,
62 handlers: HashMap<String, JobHandler>,
64 semaphore: Arc<Semaphore>,
66 shutdown: Arc<tokio::sync::Notify>,
68}
69
70impl Worker {
71 pub fn new(connection: QueueConnection, config: WorkerConfig) -> Self {
73 let semaphore = Arc::new(Semaphore::new(config.max_jobs));
74 Self {
75 connection,
76 config,
77 handlers: HashMap::new(),
78 semaphore,
79 shutdown: Arc::new(tokio::sync::Notify::new()),
80 }
81 }
82
83 pub fn register<J>(&mut self)
91 where
92 J: Job + serde::de::DeserializeOwned + 'static,
93 {
94 let type_name = std::any::type_name::<J>().to_string();
95
96 let handler: JobHandler = Arc::new(move |data: String| {
97 Box::pin(async move {
98 let job: J = serde_json::from_str(&data)
99 .map_err(|e| Error::DeserializationFailed(e.to_string()))?;
100 job.handle().await
101 })
102 });
103
104 self.handlers.insert(type_name, handler);
105 }
106
107 pub async fn run(&self) -> Result<(), Error> {
109 info!(
110 queues = ?self.config.queues,
111 max_jobs = self.config.max_jobs,
112 "Starting queue worker"
113 );
114
115 let conn = self.connection.clone();
117 let queues = self.config.queues.clone();
118 let shutdown = self.shutdown.clone();
119
120 tokio::spawn(async move {
121 loop {
122 tokio::select! {
123 _ = shutdown.notified() => break,
124 _ = tokio::time::sleep(Duration::from_secs(1)) => {
125 for queue in &queues {
126 if let Err(e) = conn.migrate_delayed(queue).await {
127 error!(queue = queue, error = %e, "Failed to migrate delayed jobs");
128 }
129 }
130 }
131 }
132 }
133 });
134
135 loop {
137 tokio::select! {
138 _ = self.shutdown.notified() => {
139 info!("Worker shutting down");
140 info!("Waiting for in-flight jobs to complete");
142 let _ = self.semaphore.acquire_many(self.config.max_jobs as u32).await;
143 return Ok(());
144 }
145 result = self.process_next() => {
146 if let Err(e) = result {
147 error!(error = %e, "Error processing job");
148 if self.config.stop_on_error {
149 return Err(e);
150 }
151 }
152 }
153 }
154 }
155 }
156
157 async fn process_next(&self) -> Result<(), Error> {
159 for queue in &self.config.queues {
161 if let Some(payload) = self.connection.pop_nowait(queue).await? {
162 self.process_job(payload).await?;
163 return Ok(());
164 }
165 }
166
167 tokio::time::sleep(self.config.sleep_duration).await;
169 Ok(())
170 }
171
172 async fn process_job(&self, payload: JobPayload) -> Result<(), Error> {
174 let permit = self.semaphore.clone().acquire_owned().await.unwrap();
175 let connection = self.connection.clone();
176 let handlers = self.handlers.clone();
177 let job_type = payload.job_type.clone();
178 let job_id = payload.id;
179
180 tokio::spawn(async move {
181 let _permit = permit; debug!(job_id = %job_id, job_type = &job_type, "Processing job");
184
185 let handler = match handlers.get(&job_type) {
186 Some(h) => h,
187 None => {
188 warn!(job_type = &job_type, "No handler registered for job type");
189 return;
190 }
191 };
192
193 match handler(payload.data.clone()).await {
194 Ok(()) => {
195 info!(job_id = %job_id, job_type = &job_type, "Job completed successfully");
196 }
197 Err(e) => {
198 error!(job_id = %job_id, job_type = &job_type, error = %e, "Job failed");
199
200 if payload.has_exceeded_retries() {
201 warn!(job_id = %job_id, "Job exceeded max retries, moving to failed queue");
202 if let Err(e) = connection.fail(payload, &e).await {
203 error!(error = %e, "Failed to move job to failed queue");
204 }
205 } else {
206 let delay = Duration::from_secs(2u64.pow(payload.attempts));
208 if let Err(e) = connection.release(payload, delay).await {
209 error!(error = %e, "Failed to release job for retry");
210 }
211 }
212 }
213 }
214 });
215
216 Ok(())
217 }
218
219 pub fn shutdown(&self) {
221 self.shutdown.notify_waiters();
222 }
223}
224
225impl Clone for Worker {
227 fn clone(&self) -> Self {
228 Self {
229 connection: self.connection.clone(),
230 config: self.config.clone(),
231 handlers: HashMap::new(), semaphore: self.semaphore.clone(),
233 shutdown: self.shutdown.clone(),
234 }
235 }
236}