Skip to main content

anvil_core/
queue.rs

1//! Queue subsystem. Jobs dispatched as serialized payloads; workers deserialize and run.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use parking_lot::Mutex;
9use serde::{Deserialize, Serialize};
10use sqlx::PgPool;
11use uuid::Uuid;
12
13use crate::container::Container;
14use crate::Error;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct QueuePayload {
18    pub id: Uuid,
19    pub job_type: String,
20    pub data: serde_json::Value,
21    pub attempts: i32,
22    pub max_attempts: i32,
23    pub queue: String,
24}
25
26pub type JobRunner = Arc<
27    dyn for<'a> Fn(
28            &'a Container,
29            &'a QueuePayload,
30        )
31            -> futures::future::BoxFuture<'a, Result<(), Error>>
32        + Send
33        + Sync,
34>;
35
36#[derive(Default, Clone)]
37pub struct JobRegistry {
38    runners: Arc<parking_lot::RwLock<HashMap<String, JobRunner>>>,
39}
40
41impl JobRegistry {
42    pub fn register<F>(&self, name: impl Into<String>, runner: F)
43    where
44        F: for<'a> Fn(
45                &'a Container,
46                &'a QueuePayload,
47            )
48                -> futures::future::BoxFuture<'a, Result<(), Error>>
49            + Send
50            + Sync
51            + 'static,
52    {
53        self.runners.write().insert(name.into(), Arc::new(runner));
54    }
55
56    pub fn get(&self, name: &str) -> Option<JobRunner> {
57        self.runners.read().get(name).cloned()
58    }
59}
60
61inventory::collect!(JobRegistration);
62
63pub struct JobRegistration {
64    pub name: &'static str,
65    pub runner: fn() -> JobRunner,
66}
67
68pub fn collect_inventory_registry() -> JobRegistry {
69    let registry = JobRegistry::default();
70    for reg in inventory::iter::<JobRegistration> {
71        let runner = (reg.runner)();
72        registry.runners.write().insert(reg.name.to_string(), runner);
73    }
74    registry
75}
76
77#[async_trait]
78pub trait QueueDriver: Send + Sync {
79    async fn push(&self, payload: QueuePayload) -> Result<(), Error>;
80    async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error>;
81    async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error>;
82}
83
84#[derive(Clone)]
85pub struct QueueHandle {
86    driver: Arc<dyn QueueDriver>,
87    registry: JobRegistry,
88}
89
90impl QueueHandle {
91    pub fn new(driver: Arc<dyn QueueDriver>, registry: JobRegistry) -> Self {
92        Self { driver, registry }
93    }
94
95    pub fn in_memory(_pool: PgPool) -> Self {
96        Self {
97            driver: Arc::new(InMemoryDriver::default()),
98            registry: collect_inventory_registry(),
99        }
100    }
101
102    pub fn database(pool: PgPool) -> Self {
103        Self {
104            driver: Arc::new(DatabaseDriver { pool }),
105            registry: collect_inventory_registry(),
106        }
107    }
108
109    pub fn fake() -> (Self, Arc<Mutex<Vec<QueuePayload>>>) {
110        let log = Arc::new(Mutex::new(Vec::new()));
111        let driver = FakeDriver { log: log.clone() };
112        (
113            Self {
114                driver: Arc::new(driver),
115                registry: JobRegistry::default(),
116            },
117            log,
118        )
119    }
120
121    pub fn registry(&self) -> &JobRegistry {
122        &self.registry
123    }
124
125    pub async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
126        self.driver.push(payload).await
127    }
128
129    pub async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error> {
130        self.driver.pop(queue).await
131    }
132
133    pub async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error> {
134        self.driver.fail(payload, error).await
135    }
136}
137
138#[derive(Default)]
139struct InMemoryDriver {
140    queues: Mutex<HashMap<String, Vec<QueuePayload>>>,
141}
142
143#[async_trait]
144impl QueueDriver for InMemoryDriver {
145    async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
146        self.queues
147            .lock()
148            .entry(payload.queue.clone())
149            .or_default()
150            .push(payload);
151        Ok(())
152    }
153    async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error> {
154        Ok(self.queues.lock().get_mut(queue).and_then(|v| v.pop()))
155    }
156    async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error> {
157        tracing::error!(?payload, error, "job failed (in-memory)");
158        Ok(())
159    }
160}
161
162struct FakeDriver {
163    log: Arc<Mutex<Vec<QueuePayload>>>,
164}
165
166#[async_trait]
167impl QueueDriver for FakeDriver {
168    async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
169        self.log.lock().push(payload);
170        Ok(())
171    }
172    async fn pop(&self, _queue: &str) -> Result<Option<QueuePayload>, Error> {
173        Ok(None)
174    }
175    async fn fail(&self, _: QueuePayload, _: String) -> Result<(), Error> {
176        Ok(())
177    }
178}
179
180pub struct DatabaseDriver {
181    pool: PgPool,
182}
183
184#[async_trait]
185impl QueueDriver for DatabaseDriver {
186    async fn push(&self, payload: QueuePayload) -> Result<(), Error> {
187        sqlx::query("INSERT INTO jobs (id, job_type, payload, attempts, max_attempts, queue, available_at) VALUES ($1, $2, $3, $4, $5, $6, NOW())")
188            .bind(payload.id)
189            .bind(&payload.job_type)
190            .bind(&payload.data)
191            .bind(payload.attempts)
192            .bind(payload.max_attempts)
193            .bind(&payload.queue)
194            .execute(&self.pool)
195            .await?;
196        Ok(())
197    }
198
199    async fn pop(&self, queue: &str) -> Result<Option<QueuePayload>, Error> {
200        let row: Option<(Uuid, String, serde_json::Value, i32, i32, String)> = sqlx::query_as(
201            r#"DELETE FROM jobs
202               WHERE id = (
203                   SELECT id FROM jobs
204                   WHERE queue = $1 AND available_at <= NOW()
205                   ORDER BY available_at
206                   LIMIT 1
207                   FOR UPDATE SKIP LOCKED
208               )
209               RETURNING id, job_type, payload, attempts, max_attempts, queue"#,
210        )
211        .bind(queue)
212        .fetch_optional(&self.pool)
213        .await?;
214        Ok(row.map(|(id, job_type, data, attempts, max_attempts, queue)| QueuePayload {
215            id,
216            job_type,
217            data,
218            attempts,
219            max_attempts,
220            queue,
221        }))
222    }
223
224    async fn fail(&self, payload: QueuePayload, error: String) -> Result<(), Error> {
225        sqlx::query("INSERT INTO failed_jobs (id, job_type, payload, error, failed_at) VALUES ($1, $2, $3, $4, NOW())")
226            .bind(payload.id)
227            .bind(&payload.job_type)
228            .bind(&payload.data)
229            .bind(error)
230            .execute(&self.pool)
231            .await?;
232        Ok(())
233    }
234}
235
236/// Run the queue worker loop: pop a job, look up its runner, run it, retry on failure.
237pub async fn run_worker(
238    container: Container,
239    queue: String,
240    shutdown: crate::shutdown::ShutdownHandle,
241) -> Result<(), Error> {
242    let handle = container.queue().clone();
243    let registry = handle.registry().clone();
244
245    tracing::info!(queue, "queue worker starting");
246
247    loop {
248        if shutdown.is_shutdown() {
249            tracing::info!("queue worker shutting down");
250            break;
251        }
252
253        let payload = match handle.pop(&queue).await? {
254            Some(p) => p,
255            None => {
256                tokio::select! {
257                    _ = tokio::time::sleep(Duration::from_secs(1)) => continue,
258                    _ = shutdown.wait() => break,
259                }
260            }
261        };
262
263        let runner = registry.get(&payload.job_type);
264        let Some(runner) = runner else {
265            tracing::error!(
266                job_type = %payload.job_type,
267                "no runner registered for job type"
268            );
269            handle.fail(payload, "no runner registered".into()).await?;
270            continue;
271        };
272
273        let mut payload_mut = payload.clone();
274        payload_mut.attempts += 1;
275
276        match runner(&container, &payload_mut).await {
277            Ok(()) => {
278                tracing::info!(job_type = %payload_mut.job_type, id = %payload_mut.id, "job complete");
279            }
280            Err(e) => {
281                tracing::warn!(error = ?e, attempts = payload_mut.attempts, "job failed");
282                if payload_mut.attempts >= payload_mut.max_attempts {
283                    handle.fail(payload_mut, e.to_string()).await?;
284                } else {
285                    let backoff = Duration::from_secs(2u64.pow(payload_mut.attempts as u32).min(60));
286                    tokio::time::sleep(backoff).await;
287                    handle.push(payload_mut).await?;
288                }
289            }
290        }
291    }
292    Ok(())
293}
294
295/// Push a job onto the configured queue (helper for the `Job::dispatch().await?` form).
296pub async fn dispatch_payload(
297    container: &Container,
298    job_type: impl Into<String>,
299    data: serde_json::Value,
300) -> Result<(), Error> {
301    let payload = QueuePayload {
302        id: Uuid::new_v4(),
303        job_type: job_type.into(),
304        data,
305        attempts: 0,
306        max_attempts: 3,
307        queue: "default".into(),
308    };
309    container.queue().push(payload).await
310}