1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use parking_lot::Mutex;
9use serde::{Deserialize, Serialize};
10use sqlx::PgPool;
11use uuid::Uuid;
12
13use crate::container::Container;
14use crate::Error;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct QueuePayload {
18 pub id: Uuid,
19 pub job_type: String,
20 pub data: serde_json::Value,
21 pub attempts: i32,
22 pub max_attempts: i32,
23 pub queue: String,
24}
25
26pub type JobRunner = Arc<
27 dyn for<'a> Fn(
28 &'a Container,
29 &'a QueuePayload,
30 ) -> futures::future::BoxFuture<'a, Result<(), Error>>
31 + Send
32 + Sync,
33>;
34
35#[derive(Default, Clone)]
36pub struct JobRegistry {
37 runners: Arc<parking_lot::RwLock<HashMap<String, JobRunner>>>,
38}
39
40impl JobRegistry {
41 pub fn register<F>(&self, name: impl Into<String>, runner: F)
42 where
43 F: for<'a> Fn(
44 &'a Container,
45 &'a QueuePayload,
46 ) -> futures::future::BoxFuture<'a, Result<(), Error>>
47 + Send
48 + Sync
49 + 'static,
50 {
51 self.runners.write().insert(name.into(), Arc::new(runner));
52 }
53
54 pub fn get(&self, name: &str) -> Option<JobRunner> {
55 self.runners.read().get(name).cloned()
56 }
57}
58
59inventory::collect!(JobRegistration);
60
61pub struct JobRegistration {
62 pub name: &'static str,
63 pub runner: fn() -> JobRunner,
64}
65
66pub fn collect_inventory_registry() -> JobRegistry {
67 let registry = JobRegistry::default();
68 for reg in inventory::iter::<JobRegistration> {
69 let runner = (reg.runner)();
70 registry
71 .runners
72 .write()
73 .insert(reg.name.to_string(), runner);
74 }
75 registry
76}
77
78#[async_trait]
79pub trait QueueDriver: Send + Sync {
80 async fn push(&self, payload: QueuePayload) -> Result<(), Error>;
81 async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error>;
82 async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error>;
83}
84
85#[derive(Clone)]
86pub struct QueueHandle {
87 driver: Arc<dyn QueueDriver>,
88 registry: JobRegistry,
89}
90
91impl QueueHandle {
92 pub fn new(driver: Arc<dyn QueueDriver>, registry: JobRegistry) -> Self {
93 Self { driver, registry }
94 }
95
96 pub fn in_memory(_pool: PgPool) -> Self {
100 Self::in_memory_no_pool()
101 }
102
103 pub fn in_memory_no_pool() -> Self {
106 Self {
107 driver: Arc::new(InMemoryDriver::default()),
108 registry: collect_inventory_registry(),
109 }
110 }
111
112 pub fn database(pool: PgPool) -> Self {
115 Self {
116 driver: Arc::new(DatabaseDriver { pool }),
117 registry: collect_inventory_registry(),
118 }
119 }
120
121 pub fn fake() -> (Self, Arc<Mutex<Vec<QueuePayload>>>) {
122 let log = Arc::new(Mutex::new(Vec::new()));
123 let driver = FakeDriver { log: log.clone() };
124 (
125 Self {
126 driver: Arc::new(driver),
127 registry: JobRegistry::default(),
128 },
129 log,
130 )
131 }
132
133 pub fn registry(&self) -> &JobRegistry {
134 &self.registry
135 }
136
137 pub async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
138 self.driver.push(payload).await
139 }
140
141 pub async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error> {
142 self.driver.pop(queue).await
143 }
144
145 pub async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error> {
146 self.driver.fail(payload, error).await
147 }
148}
149
150#[derive(Default)]
151struct InMemoryDriver {
152 queues: Mutex<HashMap<String, Vec<QueuePayload>>>,
153}
154
155#[async_trait]
156impl QueueDriver for InMemoryDriver {
157 async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
158 self.queues
159 .lock()
160 .entry(payload.queue.clone())
161 .or_default()
162 .push(payload);
163 Ok(())
164 }
165 async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error> {
166 Ok(self.queues.lock().get_mut(queue).and_then(|v| v.pop()))
167 }
168 async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error> {
169 tracing::error!(?payload, error, "job failed (in-memory)");
170 Ok(())
171 }
172}
173
174struct FakeDriver {
175 log: Arc<Mutex<Vec<QueuePayload>>>,
176}
177
178#[async_trait]
179impl QueueDriver for FakeDriver {
180 async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
181 self.log.lock().push(payload);
182 Ok(())
183 }
184 async fn pop(&self, _queue: &str) -> Result<Option<QueuePayload>, Error> {
185 Ok(None)
186 }
187 async fn fail(&self, _: QueuePayload, _: String) -> Result<(), Error> {
188 Ok(())
189 }
190}
191
192pub struct DatabaseDriver {
193 pool: PgPool,
194}
195
196#[async_trait]
197impl QueueDriver for DatabaseDriver {
198 async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
199 sqlx::query("INSERT INTO jobs (id, job_type, payload, attempts, max_attempts, queue, available_at) VALUES ($1, $2, $3, $4, $5, $6, NOW())")
200 .bind(payload.id)
201 .bind(&payload.job_type)
202 .bind(&payload.data)
203 .bind(payload.attempts)
204 .bind(payload.max_attempts)
205 .bind(&payload.queue)
206 .execute(&self.pool)
207 .await?;
208 Ok(())
209 }
210
211 async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error> {
212 let row: Option<(Uuid, String, serde_json::Value, i32, i32, String)> = sqlx::query_as(
213 r#"DELETE FROM jobs
214 WHERE id = (
215 SELECT id FROM jobs
216 WHERE queue = $1 AND available_at <= NOW()
217 ORDER BY available_at
218 LIMIT 1
219 FOR UPDATE SKIP LOCKED
220 )
221 RETURNING id, job_type, payload, attempts, max_attempts, queue"#,
222 )
223 .bind(queue)
224 .fetch_optional(&self.pool)
225 .await?;
226 Ok(row.map(
227 |(id, job_type, data, attempts, max_attempts, queue)| QueuePayload {
228 id,
229 job_type,
230 data,
231 attempts,
232 max_attempts,
233 queue,
234 },
235 ))
236 }
237
238 async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error> {
239 sqlx::query("INSERT INTO failed_jobs (id, job_type, payload, error, failed_at) VALUES ($1, $2, $3, $4, NOW())")
240 .bind(payload.id)
241 .bind(&payload.job_type)
242 .bind(&payload.data)
243 .bind(error)
244 .execute(&self.pool)
245 .await?;
246 Ok(())
247 }
248}
249
250pub async fn run_worker(
252 container: Container,
253 queue: String,
254 shutdown: crate::shutdown::ShutdownHandle,
255) -> Result<(), Error> {
256 let handle = container.queue().clone();
257 let registry = handle.registry().clone();
258
259 tracing::info!(queue, "queue worker starting");
260
261 loop {
262 if shutdown.is_shutdown() {
263 tracing::info!("queue worker shutting down");
264 break;
265 }
266
267 let payload = match handle.pop(&queue).await? {
268 Some(p) => p,
269 None => {
270 tokio::select! {
271 _ = tokio::time::sleep(Duration::from_secs(1)) => continue,
272 _ = shutdown.wait() => break,
273 }
274 }
275 };
276
277 let runner = registry.get(&payload.job_type);
278 let Some(runner) = runner else {
279 tracing::error!(
280 job_type = %payload.job_type,
281 "no runner registered for job type"
282 );
283 handle.fail(payload, "no runner registered".into()).await?;
284 continue;
285 };
286
287 let mut payload_mut = payload.clone();
288 payload_mut.attempts += 1;
289
290 match runner(&container, &payload_mut).await {
291 Ok(()) => {
292 tracing::info!(job_type = %payload_mut.job_type, id = %payload_mut.id, "job complete");
293 }
294 Err(e) => {
295 tracing::warn!(error = ?e, attempts = payload_mut.attempts, "job failed");
296 if payload_mut.attempts >= payload_mut.max_attempts {
297 handle.fail(payload_mut, e.to_string()).await?;
298 } else {
299 let backoff =
300 Duration::from_secs(2u64.pow(payload_mut.attempts as u32).min(60));
301 tokio::time::sleep(backoff).await;
302 handle.push(payload_mut).await?;
303 }
304 }
305 }
306 }
307 Ok(())
308}
309
310pub async fn dispatch_payload(
312 container: &Container,
313 job_type: impl Into<String>,
314 data: serde_json::Value,
315) -> Result<(), Error> {
316 let payload = QueuePayload {
317 id: Uuid::new_v4(),
318 job_type: job_type.into(),
319 data,
320 attempts: 0,
321 max_attempts: 3,
322 queue: "default".into(),
323 };
324 container.queue().push(payload).await
325}