1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use parking_lot::Mutex;
9use serde::{Deserialize, Serialize};
10use sqlx::PgPool;
11use uuid::Uuid;
12
13use crate::container::Container;
14use crate::Error;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct QueuePayload {
18 pub id: Uuid,
19 pub job_type: String,
20 pub data: serde_json::Value,
21 pub attempts: i32,
22 pub max_attempts: i32,
23 pub queue: String,
24}
25
26pub type JobRunner = Arc<
27 dyn for<'a> Fn(
28 &'a Container,
29 &'a QueuePayload,
30 )
31 -> futures::future::BoxFuture<'a, Result<(), Error>>
32 + Send
33 + Sync,
34>;
35
36#[derive(Default, Clone)]
37pub struct JobRegistry {
38 runners: Arc<parking_lot::RwLock<HashMap<String, JobRunner>>>,
39}
40
41impl JobRegistry {
42 pub fn register<F>(&self, name: impl Into<String>, runner: F)
43 where
44 F: for<'a> Fn(
45 &'a Container,
46 &'a QueuePayload,
47 )
48 -> futures::future::BoxFuture<'a, Result<(), Error>>
49 + Send
50 + Sync
51 + 'static,
52 {
53 self.runners.write().insert(name.into(), Arc::new(runner));
54 }
55
56 pub fn get(&self, name: &str) -> Option<JobRunner> {
57 self.runners.read().get(name).cloned()
58 }
59}
60
61inventory::collect!(JobRegistration);
62
63pub struct JobRegistration {
64 pub name: &'static str,
65 pub runner: fn() -> JobRunner,
66}
67
68pub fn collect_inventory_registry() -> JobRegistry {
69 let registry = JobRegistry::default();
70 for reg in inventory::iter::<JobRegistration> {
71 let runner = (reg.runner)();
72 registry.runners.write().insert(reg.name.to_string(), runner);
73 }
74 registry
75}
76
77#[async_trait]
78pub trait QueueDriver: Send + Sync {
79 async fn push(&self, payload: QueuePayload) -> Result<(), Error>;
80 async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error>;
81 async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error>;
82}
83
84#[derive(Clone)]
85pub struct QueueHandle {
86 driver: Arc<dyn QueueDriver>,
87 registry: JobRegistry,
88}
89
90impl QueueHandle {
91 pub fn new(driver: Arc<dyn QueueDriver>, registry: JobRegistry) -> Self {
92 Self { driver, registry }
93 }
94
95 pub fn in_memory(_pool: PgPool) -> Self {
96 Self {
97 driver: Arc::new(InMemoryDriver::default()),
98 registry: collect_inventory_registry(),
99 }
100 }
101
102 pub fn database(pool: PgPool) -> Self {
103 Self {
104 driver: Arc::new(DatabaseDriver { pool }),
105 registry: collect_inventory_registry(),
106 }
107 }
108
109 pub fn fake() -> (Self, Arc<Mutex<Vec<QueuePayload>>>) {
110 let log = Arc::new(Mutex::new(Vec::new()));
111 let driver = FakeDriver { log: log.clone() };
112 (
113 Self {
114 driver: Arc::new(driver),
115 registry: JobRegistry::default(),
116 },
117 log,
118 )
119 }
120
121 pub fn registry(&self) -> &JobRegistry {
122 &self.registry
123 }
124
125 pub async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
126 self.driver.push(payload).await
127 }
128
129 pub async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error> {
130 self.driver.pop(queue).await
131 }
132
133 pub async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error> {
134 self.driver.fail(payload, error).await
135 }
136}
137
138#[derive(Default)]
139struct InMemoryDriver {
140 queues: Mutex<HashMap<String, Vec<QueuePayload>>>,
141}
142
143#[async_trait]
144impl QueueDriver for InMemoryDriver {
145 async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
146 self.queues
147 .lock()
148 .entry(payload.queue.clone())
149 .or_default()
150 .push(payload);
151 Ok(())
152 }
153 async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error> {
154 Ok(self.queues.lock().get_mut(queue).and_then(|v| v.pop()))
155 }
156 async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error> {
157 tracing::error!(?payload, error, "job failed (in-memory)");
158 Ok(())
159 }
160}
161
162struct FakeDriver {
163 log: Arc<Mutex<Vec<QueuePayload>>>,
164}
165
166#[async_trait]
167impl QueueDriver for FakeDriver {
168 async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
169 self.log.lock().push(payload);
170 Ok(())
171 }
172 async fn pop(&self, _queue: &str) -> Result<Option<QueuePayload>, Error> {
173 Ok(None)
174 }
175 async fn fail(&self, _: QueuePayload, _: String) -> Result<(), Error> {
176 Ok(())
177 }
178}
179
180pub struct DatabaseDriver {
181 pool: PgPool,
182}
183
184#[async_trait]
185impl QueueDriver for DatabaseDriver {
186 async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
187 sqlx::query("INSERT INTO jobs (id, job_type, payload, attempts, max_attempts, queue, available_at) VALUES ($1, $2, $3, $4, $5, $6, NOW())")
188 .bind(payload.id)
189 .bind(&payload.job_type)
190 .bind(&payload.data)
191 .bind(payload.attempts)
192 .bind(payload.max_attempts)
193 .bind(&payload.queue)
194 .execute(&self.pool)
195 .await?;
196 Ok(())
197 }
198
199 async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error> {
200 let row: Option<(Uuid, String, serde_json::Value, i32, i32, String)> = sqlx::query_as(
201 r#"DELETE FROM jobs
202 WHERE id = (
203 SELECT id FROM jobs
204 WHERE queue = $1 AND available_at <= NOW()
205 ORDER BY available_at
206 LIMIT 1
207 FOR UPDATE SKIP LOCKED
208 )
209 RETURNING id, job_type, payload, attempts, max_attempts, queue"#,
210 )
211 .bind(queue)
212 .fetch_optional(&self.pool)
213 .await?;
214 Ok(row.map(|(id, job_type, data, attempts, max_attempts, queue)| QueuePayload {
215 id,
216 job_type,
217 data,
218 attempts,
219 max_attempts,
220 queue,
221 }))
222 }
223
224 async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error> {
225 sqlx::query("INSERT INTO failed_jobs (id, job_type, payload, error, failed_at) VALUES ($1, $2, $3, $4, NOW())")
226 .bind(payload.id)
227 .bind(&payload.job_type)
228 .bind(&payload.data)
229 .bind(error)
230 .execute(&self.pool)
231 .await?;
232 Ok(())
233 }
234}
235
236pub async fn run_worker(
238 container: Container,
239 queue: String,
240 shutdown: crate::shutdown::ShutdownHandle,
241) -> Result<(), Error> {
242 let handle = container.queue().clone();
243 let registry = handle.registry().clone();
244
245 tracing::info!(queue, "queue worker starting");
246
247 loop {
248 if shutdown.is_shutdown() {
249 tracing::info!("queue worker shutting down");
250 break;
251 }
252
253 let payload = match handle.pop(&queue).await? {
254 Some(p) => p,
255 None => {
256 tokio::select! {
257 _ = tokio::time::sleep(Duration::from_secs(1)) => continue,
258 _ = shutdown.wait() => break,
259 }
260 }
261 };
262
263 let runner = registry.get(&payload.job_type);
264 let Some(runner) = runner else {
265 tracing::error!(
266 job_type = %payload.job_type,
267 "no runner registered for job type"
268 );
269 handle.fail(payload, "no runner registered".into()).await?;
270 continue;
271 };
272
273 let mut payload_mut = payload.clone();
274 payload_mut.attempts += 1;
275
276 match runner(&container, &payload_mut).await {
277 Ok(()) => {
278 tracing::info!(job_type = %payload_mut.job_type, id = %payload_mut.id, "job complete");
279 }
280 Err(e) => {
281 tracing::warn!(error = ?e, attempts = payload_mut.attempts, "job failed");
282 if payload_mut.attempts >= payload_mut.max_attempts {
283 handle.fail(payload_mut, e.to_string()).await?;
284 } else {
285 let backoff = Duration::from_secs(2u64.pow(payload_mut.attempts as u32).min(60));
286 tokio::time::sleep(backoff).await;
287 handle.push(payload_mut).await?;
288 }
289 }
290 }
291 }
292 Ok(())
293}
294
295pub async fn dispatch_payload(
297 container: &Container,
298 job_type: impl Into<String>,
299 data: serde_json::Value,
300) -> Result<(), Error> {
301 let payload = QueuePayload {
302 id: Uuid::new_v4(),
303 job_type: job_type.into(),
304 data,
305 attempts: 0,
306 max_attempts: 3,
307 queue: "default".into(),
308 };
309 container.queue().push(payload).await
310}