rustvello_postgres/
broker.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use rustvello_core::broker::Broker;
8use rustvello_core::error::RustvelloResult;
9use rustvello_proto::identifiers::{InvocationId, TaskId};
10
11use crate::db::{pg_err, Database};
12
13pub struct PostgresBroker {
17 db: Arc<Database>,
18}
19
20impl PostgresBroker {
21 pub fn new(db: Arc<Database>) -> Self {
22 Self { db }
23 }
24}
25
26#[async_trait]
27impl Broker for PostgresBroker {
28 async fn route_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<()> {
29 let client = self.db.conn().await?;
30 client
31 .execute(
32 "INSERT INTO broker_queue (invocation_id) VALUES ($1)",
33 &[&invocation_id.as_str()],
34 )
35 .await
36 .map_err(pg_err)?;
37 Ok(())
38 }
39
40 async fn retrieve_invocation(
41 &self,
42 task_id: Option<&TaskId>,
43 ) -> RustvelloResult<Option<InvocationId>> {
44 let client = self.db.conn().await?;
45
46 let row = if let Some(tid) = task_id {
48 client
49 .query_opt(
50 "DELETE FROM broker_queue
51 WHERE id = (
52 SELECT bq.id FROM broker_queue bq
53 JOIN invocations inv ON bq.invocation_id = inv.invocation_id
54 WHERE inv.task_id = $1
55 ORDER BY bq.id ASC LIMIT 1
56 FOR UPDATE OF bq SKIP LOCKED
57 )
58 RETURNING invocation_id",
59 &[&tid.to_string()],
60 )
61 .await
62 .map_err(pg_err)?
63 } else {
64 client
65 .query_opt(
66 "DELETE FROM broker_queue
67 WHERE id = (SELECT id FROM broker_queue ORDER BY id ASC LIMIT 1 FOR UPDATE SKIP LOCKED)
68 RETURNING invocation_id",
69 &[],
70 )
71 .await
72 .map_err(pg_err)?
73 };
74
75 Ok(row.map(|r| InvocationId::from_string(r.get::<_, String>(0))))
76 }
77
78 async fn count_invocations(&self, task_id: Option<&TaskId>) -> RustvelloResult<usize> {
79 let client = self.db.conn().await?;
80 let row = if let Some(tid) = task_id {
81 client
82 .query_one(
83 "SELECT COUNT(*) FROM broker_queue bq \
84 JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
85 WHERE inv.task_id = $1",
86 &[&tid.to_string()],
87 )
88 .await
89 .map_err(pg_err)?
90 } else {
91 client
92 .query_one("SELECT COUNT(*) FROM broker_queue", &[])
93 .await
94 .map_err(pg_err)?
95 };
96 let count: i64 = row.get(0);
97 Ok(usize::try_from(count).unwrap_or(usize::MAX))
98 }
99
100 async fn purge(&self, task_id: Option<&TaskId>) -> RustvelloResult<()> {
101 let client = self.db.conn().await?;
102 if let Some(tid) = task_id {
103 client
104 .execute(
105 "DELETE FROM broker_queue WHERE invocation_id IN (\
106 SELECT bq.invocation_id FROM broker_queue bq \
107 JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
108 WHERE inv.task_id = $1)",
109 &[&tid.to_string()],
110 )
111 .await
112 .map_err(pg_err)?;
113 } else {
114 client
115 .execute("DELETE FROM broker_queue", &[])
116 .await
117 .map_err(pg_err)?;
118 }
119 Ok(())
120 }
121
122 async fn retrieve_invocation_for_language(
123 &self,
124 language: &str,
125 ) -> RustvelloResult<Option<InvocationId>> {
126 let client = self.db.conn().await?;
127 let prefix = format!("{language}::");
128 let row = client
129 .query_opt(
130 "DELETE FROM broker_queue
131 WHERE id = (
132 SELECT bq.id FROM broker_queue bq
133 JOIN invocations inv ON bq.invocation_id = inv.invocation_id
134 WHERE inv.task_id LIKE $1 || '%'
135 ORDER BY bq.id ASC LIMIT 1
136 FOR UPDATE OF bq SKIP LOCKED
137 )
138 RETURNING invocation_id",
139 &[&prefix],
140 )
141 .await
142 .map_err(pg_err)?;
143 Ok(row.map(|r| InvocationId::from_string(r.get::<_, String>(0))))
144 }
145}