Skip to main content

rustvello_postgres/
broker.rs

1//! PostgreSQL-backed [`Broker`] implementation.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use rustvello_core::broker::Broker;
8use rustvello_core::error::RustvelloResult;
9use rustvello_proto::identifiers::{InvocationId, TaskId};
10
11use crate::db::{pg_err, Database};
12
13/// PostgreSQL-backed broker implementation.
14///
15/// Persists the queue to a PostgreSQL database, suitable for multi-node deployments.
16pub struct PostgresBroker {
17    db: Arc<Database>,
18}
19
20impl PostgresBroker {
21    pub fn new(db: Arc<Database>) -> Self {
22        Self { db }
23    }
24}
25
26#[async_trait]
27impl Broker for PostgresBroker {
28    async fn route_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<()> {
29        let client = self.db.conn().await?;
30        client
31            .execute(
32                "INSERT INTO broker_queue (invocation_id) VALUES ($1)",
33                &[&invocation_id.as_str()],
34            )
35            .await
36            .map_err(pg_err)?;
37        Ok(())
38    }
39
40    async fn retrieve_invocation(
41        &self,
42        task_id: Option<&TaskId>,
43    ) -> RustvelloResult<Option<InvocationId>> {
44        let client = self.db.conn().await?;
45
46        // Atomically select and delete using a CTE for crash safety.
47        let row = if let Some(tid) = task_id {
48            client
49                .query_opt(
50                    "DELETE FROM broker_queue
51                     WHERE id = (
52                         SELECT bq.id FROM broker_queue bq
53                         JOIN invocations inv ON bq.invocation_id = inv.invocation_id
54                         WHERE inv.task_id = $1
55                         ORDER BY bq.id ASC LIMIT 1
56                         FOR UPDATE OF bq SKIP LOCKED
57                     )
58                     RETURNING invocation_id",
59                    &[&tid.to_string()],
60                )
61                .await
62                .map_err(pg_err)?
63        } else {
64            client
65                .query_opt(
66                    "DELETE FROM broker_queue
67                     WHERE id = (SELECT id FROM broker_queue ORDER BY id ASC LIMIT 1 FOR UPDATE SKIP LOCKED)
68                     RETURNING invocation_id",
69                    &[],
70                )
71                .await
72                .map_err(pg_err)?
73        };
74
75        Ok(row.map(|r| InvocationId::from_string(r.get::<_, String>(0))))
76    }
77
78    async fn count_invocations(&self, task_id: Option<&TaskId>) -> RustvelloResult<usize> {
79        let client = self.db.conn().await?;
80        let row = if let Some(tid) = task_id {
81            client
82                .query_one(
83                    "SELECT COUNT(*) FROM broker_queue bq \
84                     JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
85                     WHERE inv.task_id = $1",
86                    &[&tid.to_string()],
87                )
88                .await
89                .map_err(pg_err)?
90        } else {
91            client
92                .query_one("SELECT COUNT(*) FROM broker_queue", &[])
93                .await
94                .map_err(pg_err)?
95        };
96        let count: i64 = row.get(0);
97        Ok(usize::try_from(count).unwrap_or(usize::MAX))
98    }
99
100    async fn purge(&self, task_id: Option<&TaskId>) -> RustvelloResult<()> {
101        let client = self.db.conn().await?;
102        if let Some(tid) = task_id {
103            client
104                .execute(
105                    "DELETE FROM broker_queue WHERE invocation_id IN (\
106                     SELECT bq.invocation_id FROM broker_queue bq \
107                     JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
108                     WHERE inv.task_id = $1)",
109                    &[&tid.to_string()],
110                )
111                .await
112                .map_err(pg_err)?;
113        } else {
114            client
115                .execute("DELETE FROM broker_queue", &[])
116                .await
117                .map_err(pg_err)?;
118        }
119        Ok(())
120    }
121
122    async fn retrieve_invocation_for_language(
123        &self,
124        language: &str,
125    ) -> RustvelloResult<Option<InvocationId>> {
126        let client = self.db.conn().await?;
127        let prefix = format!("{language}::");
128        let row = client
129            .query_opt(
130                "DELETE FROM broker_queue
131                 WHERE id = (
132                     SELECT bq.id FROM broker_queue bq
133                     JOIN invocations inv ON bq.invocation_id = inv.invocation_id
134                     WHERE inv.task_id LIKE $1 || '%'
135                     ORDER BY bq.id ASC LIMIT 1
136                     FOR UPDATE OF bq SKIP LOCKED
137                 )
138                 RETURNING invocation_id",
139                &[&prefix],
140            )
141            .await
142            .map_err(pg_err)?;
143        Ok(row.map(|r| InvocationId::from_string(r.get::<_, String>(0))))
144    }
145}