1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use parking_lot::Mutex;
9use serde::{Deserialize, Serialize};
10use sqlx::PgPool;
11use uuid::Uuid;
12
13use crate::container::Container;
14use crate::Error;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct QueuePayload {
18 pub id: Uuid,
19 pub job_type: String,
20 pub data: serde_json::Value,
21 pub attempts: i32,
22 pub max_attempts: i32,
23 pub queue: String,
24}
25
26pub type JobRunner = Arc<
27 dyn for<'a> Fn(
28 &'a Container,
29 &'a QueuePayload,
30 )
31 -> futures::future::BoxFuture<'a, Result<(), Error>>
32 + Send
33 + Sync,
34>;
35
36#[derive(Default, Clone)]
37pub struct JobRegistry {
38 runners: Arc<parking_lot::RwLock<HashMap<String, JobRunner>>>,
39}
40
41impl JobRegistry {
42 pub fn register<F>(&self, name: impl Into<String>, runner: F)
43 where
44 F: for<'a> Fn(
45 &'a Container,
46 &'a QueuePayload,
47 )
48 -> futures::future::BoxFuture<'a, Result<(), Error>>
49 + Send
50 + Sync
51 + 'static,
52 {
53 self.runners.write().insert(name.into(), Arc::new(runner));
54 }
55
56 pub fn get(&self, name: &str) -> Option<JobRunner> {
57 self.runners.read().get(name).cloned()
58 }
59}
60
61inventory::collect!(JobRegistration);
62
63pub struct JobRegistration {
64 pub name: &'static str,
65 pub runner: fn() -> JobRunner,
66}
67
68pub fn collect_inventory_registry() -> JobRegistry {
69 let registry = JobRegistry::default();
70 for reg in inventory::iter::<JobRegistration> {
71 let runner = (reg.runner)();
72 registry.runners.write().insert(reg.name.to_string(), runner);
73 }
74 registry
75}
76
77#[async_trait]
78pub trait QueueDriver: Send + Sync {
79 async fn push(&self, payload: QueuePayload) -> Result<(), Error>;
80 async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error>;
81 async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error>;
82}
83
84#[derive(Clone)]
85pub struct QueueHandle {
86 driver: Arc<dyn QueueDriver>,
87 registry: JobRegistry,
88}
89
90impl QueueHandle {
91 pub fn new(driver: Arc<dyn QueueDriver>, registry: JobRegistry) -> Self {
92 Self { driver, registry }
93 }
94
95 pub fn in_memory(_pool: PgPool) -> Self {
99 Self::in_memory_no_pool()
100 }
101
102 pub fn in_memory_no_pool() -> Self {
105 Self {
106 driver: Arc::new(InMemoryDriver::default()),
107 registry: collect_inventory_registry(),
108 }
109 }
110
111 pub fn database(pool: PgPool) -> Self {
114 Self {
115 driver: Arc::new(DatabaseDriver { pool }),
116 registry: collect_inventory_registry(),
117 }
118 }
119
120 pub fn fake() -> (Self, Arc<Mutex<Vec<QueuePayload>>>) {
121 let log = Arc::new(Mutex::new(Vec::new()));
122 let driver = FakeDriver { log: log.clone() };
123 (
124 Self {
125 driver: Arc::new(driver),
126 registry: JobRegistry::default(),
127 },
128 log,
129 )
130 }
131
132 pub fn registry(&self) -> &JobRegistry {
133 &self.registry
134 }
135
136 pub async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
137 self.driver.push(payload).await
138 }
139
140 pub async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error> {
141 self.driver.pop(queue).await
142 }
143
144 pub async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error> {
145 self.driver.fail(payload, error).await
146 }
147}
148
149#[derive(Default)]
150struct InMemoryDriver {
151 queues: Mutex<HashMap<String, Vec<QueuePayload>>>,
152}
153
154#[async_trait]
155impl QueueDriver for InMemoryDriver {
156 async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
157 self.queues
158 .lock()
159 .entry(payload.queue.clone())
160 .or_default()
161 .push(payload);
162 Ok(())
163 }
164 async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error> {
165 Ok(self.queues.lock().get_mut(queue).and_then(|v| v.pop()))
166 }
167 async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error> {
168 tracing::error!(?payload, error, "job failed (in-memory)");
169 Ok(())
170 }
171}
172
173struct FakeDriver {
174 log: Arc<Mutex<Vec<QueuePayload>>>,
175}
176
177#[async_trait]
178impl QueueDriver for FakeDriver {
179 async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
180 self.log.lock().push(payload);
181 Ok(())
182 }
183 async fn pop(&self, _queue: &str) -> Result<Option<QueuePayload>, Error> {
184 Ok(None)
185 }
186 async fn fail(&self, _: QueuePayload, _: String) -> Result<(), Error> {
187 Ok(())
188 }
189}
190
191pub struct DatabaseDriver {
192 pool: PgPool,
193}
194
195#[async_trait]
196impl QueueDriver for DatabaseDriver {
197 async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
198 sqlx::query("INSERT INTO jobs (id, job_type, payload, attempts, max_attempts, queue, available_at) VALUES ($1, $2, $3, $4, $5, $6, NOW())")
199 .bind(payload.id)
200 .bind(&payload.job_type)
201 .bind(&payload.data)
202 .bind(payload.attempts)
203 .bind(payload.max_attempts)
204 .bind(&payload.queue)
205 .execute(&self.pool)
206 .await?;
207 Ok(())
208 }
209
210 async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error> {
211 let row: Option<(Uuid, String, serde_json::Value, i32, i32, String)> = sqlx::query_as(
212 r#"DELETE FROM jobs
213 WHERE id = (
214 SELECT id FROM jobs
215 WHERE queue = $1 AND available_at <= NOW()
216 ORDER BY available_at
217 LIMIT 1
218 FOR UPDATE SKIP LOCKED
219 )
220 RETURNING id, job_type, payload, attempts, max_attempts, queue"#,
221 )
222 .bind(queue)
223 .fetch_optional(&self.pool)
224 .await?;
225 Ok(row.map(|(id, job_type, data, attempts, max_attempts, queue)| QueuePayload {
226 id,
227 job_type,
228 data,
229 attempts,
230 max_attempts,
231 queue,
232 }))
233 }
234
235 async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error> {
236 sqlx::query("INSERT INTO failed_jobs (id, job_type, payload, error, failed_at) VALUES ($1, $2, $3, $4, NOW())")
237 .bind(payload.id)
238 .bind(&payload.job_type)
239 .bind(&payload.data)
240 .bind(error)
241 .execute(&self.pool)
242 .await?;
243 Ok(())
244 }
245}
246
247pub async fn run_worker(
249 container: Container,
250 queue: String,
251 shutdown: crate::shutdown::ShutdownHandle,
252) -> Result<(), Error> {
253 let handle = container.queue().clone();
254 let registry = handle.registry().clone();
255
256 tracing::info!(queue, "queue worker starting");
257
258 loop {
259 if shutdown.is_shutdown() {
260 tracing::info!("queue worker shutting down");
261 break;
262 }
263
264 let payload = match handle.pop(&queue).await? {
265 Some(p) => p,
266 None => {
267 tokio::select! {
268 _ = tokio::time::sleep(Duration::from_secs(1)) => continue,
269 _ = shutdown.wait() => break,
270 }
271 }
272 };
273
274 let runner = registry.get(&payload.job_type);
275 let Some(runner) = runner else {
276 tracing::error!(
277 job_type = %payload.job_type,
278 "no runner registered for job type"
279 );
280 handle.fail(payload, "no runner registered".into()).await?;
281 continue;
282 };
283
284 let mut payload_mut = payload.clone();
285 payload_mut.attempts += 1;
286
287 match runner(&container, &payload_mut).await {
288 Ok(()) => {
289 tracing::info!(job_type = %payload_mut.job_type, id = %payload_mut.id, "job complete");
290 }
291 Err(e) => {
292 tracing::warn!(error = ?e, attempts = payload_mut.attempts, "job failed");
293 if payload_mut.attempts >= payload_mut.max_attempts {
294 handle.fail(payload_mut, e.to_string()).await?;
295 } else {
296 let backoff = Duration::from_secs(2u64.pow(payload_mut.attempts as u32).min(60));
297 tokio::time::sleep(backoff).await;
298 handle.push(payload_mut).await?;
299 }
300 }
301 }
302 }
303 Ok(())
304}
305
306pub async fn dispatch_payload(
308 container: &Container,
309 job_type: impl Into<String>,
310 data: serde_json::Value,
311) -> Result<(), Error> {
312 let payload = QueuePayload {
313 id: Uuid::new_v4(),
314 job_type: job_type.into(),
315 data,
316 attempts: 0,
317 max_attempts: 3,
318 queue: "default".into(),
319 };
320 container.queue().push(payload).await
321}