rustvello_sqlite/
broker.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use rustvello_core::broker::Broker;
6use rustvello_core::error::RustvelloResult;
7use rustvello_proto::identifiers::{InvocationId, TaskId};
8
9use crate::db::{blocking, lock_err, sql_err, Database};
10
11pub struct SqliteBroker {
15 db: Arc<Database>,
16}
17
18impl SqliteBroker {
19 pub fn new(db: Arc<Database>) -> Self {
20 Self { db }
21 }
22}
23
24#[async_trait]
25impl Broker for SqliteBroker {
26 async fn route_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<()> {
27 let db = Arc::clone(&self.db);
28 let id = invocation_id.clone();
29 blocking(move || {
30 let conn = db.conn.lock().map_err(lock_err)?;
31 conn.execute(
32 "INSERT INTO broker_queue (invocation_id) VALUES (?1)",
33 [id.as_str()],
34 )
35 .map_err(sql_err)?;
36 Ok(())
37 })
38 .await
39 }
40
41 async fn retrieve_invocation(
42 &self,
43 task_id: Option<&TaskId>,
44 ) -> RustvelloResult<Option<InvocationId>> {
45 let db = Arc::clone(&self.db);
46 let task_id = task_id.cloned();
47 blocking(move || {
48 let conn = db.conn.lock().map_err(lock_err)?;
49
50 let tx = conn.unchecked_transaction().map_err(sql_err)?;
51
52 let result: Option<(i64, String)> = if let Some(ref tid) = task_id {
53 tx.query_row(
54 "SELECT bq.id, bq.invocation_id FROM broker_queue bq \
55 JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
56 WHERE inv.task_id = ?1 \
57 ORDER BY bq.id ASC LIMIT 1",
58 [&tid.to_string()],
59 |row| Ok((row.get(0)?, row.get(1)?)),
60 )
61 .ok()
62 } else {
63 tx.query_row(
64 "SELECT id, invocation_id FROM broker_queue ORDER BY id ASC LIMIT 1",
65 [],
66 |row| Ok((row.get(0)?, row.get(1)?)),
67 )
68 .ok()
69 };
70
71 if let Some((row_id, inv_id)) = result {
72 tx.execute("DELETE FROM broker_queue WHERE id = ?1", [row_id])
73 .map_err(sql_err)?;
74 tx.commit().map_err(sql_err)?;
75 Ok(Some(InvocationId::from_string(inv_id)))
76 } else {
77 Ok(None)
78 }
79 })
80 .await
81 }
82
83 async fn count_invocations(&self, task_id: Option<&TaskId>) -> RustvelloResult<usize> {
84 let db = Arc::clone(&self.db);
85 let task_id = task_id.cloned();
86 blocking(move || {
87 let conn = db.conn.lock().map_err(lock_err)?;
88 let count: i64 = if let Some(ref tid) = task_id {
89 conn.query_row(
90 "SELECT COUNT(*) FROM broker_queue bq \
91 JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
92 WHERE inv.task_id = ?1",
93 [&tid.to_string()],
94 |row| row.get(0),
95 )
96 .map_err(sql_err)?
97 } else {
98 conn.query_row("SELECT COUNT(*) FROM broker_queue", [], |row| row.get(0))
99 .map_err(sql_err)?
100 };
101 Ok(count as usize)
102 })
103 .await
104 }
105
106 async fn purge(&self, task_id: Option<&TaskId>) -> RustvelloResult<()> {
107 let db = Arc::clone(&self.db);
108 let task_id = task_id.cloned();
109 blocking(move || {
110 let conn = db.conn.lock().map_err(lock_err)?;
111 if let Some(ref tid) = task_id {
112 conn.execute(
113 "DELETE FROM broker_queue WHERE invocation_id IN (\
114 SELECT bq.invocation_id FROM broker_queue bq \
115 JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
116 WHERE inv.task_id = ?1)",
117 [&tid.to_string()],
118 )
119 .map_err(sql_err)?;
120 } else {
121 conn.execute("DELETE FROM broker_queue", [])
122 .map_err(sql_err)?;
123 }
124 Ok(())
125 })
126 .await
127 }
128
129 async fn retrieve_invocation_for_language(
130 &self,
131 language: &str,
132 ) -> RustvelloResult<Option<InvocationId>> {
133 let db = Arc::clone(&self.db);
134 let language = language.to_owned();
135 blocking(move || {
136 let conn = db.conn.lock().map_err(lock_err)?;
137 let tx = conn.unchecked_transaction().map_err(sql_err)?;
138
139 let global: Option<(i64, String)> = tx
142 .query_row(
143 "SELECT bq.id, bq.invocation_id FROM broker_queue bq \
144 LEFT JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
145 WHERE inv.invocation_id IS NULL \
146 ORDER BY bq.id ASC LIMIT 1",
147 [],
148 |row| Ok((row.get(0)?, row.get(1)?)),
149 )
150 .ok();
151
152 let result = if global.is_some() {
153 global
154 } else {
155 let prefix = format!("{language}::");
157 tx.query_row(
158 "SELECT bq.id, bq.invocation_id FROM broker_queue bq \
159 JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
160 WHERE inv.task_id LIKE ?1 || '%' \
161 ORDER BY bq.id ASC LIMIT 1",
162 [&prefix],
163 |row| Ok((row.get(0)?, row.get(1)?)),
164 )
165 .ok()
166 };
167
168 if let Some((row_id, inv_id)) = result {
169 tx.execute("DELETE FROM broker_queue WHERE id = ?1", [row_id])
170 .map_err(sql_err)?;
171 tx.commit().map_err(sql_err)?;
172 Ok(Some(InvocationId::from_string(inv_id)))
173 } else {
174 Ok(None)
175 }
176 })
177 .await
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 fn make_broker() -> SqliteBroker {
186 let db = Arc::new(Database::in_memory().unwrap());
187 SqliteBroker::new(db)
188 }
189
190 #[tokio::test]
191 async fn test_route_and_retrieve() {
192 let broker = make_broker();
193 let id1 = InvocationId::new();
194 let id2 = InvocationId::new();
195
196 broker.route_invocation(&id1).await.unwrap();
197 broker.route_invocation(&id2).await.unwrap();
198
199 assert_eq!(broker.count_invocations(None).await.unwrap(), 2);
200
201 let r1 = broker.retrieve_invocation(None).await.unwrap();
202 assert_eq!(r1.unwrap().as_str(), id1.as_str());
203
204 let r2 = broker.retrieve_invocation(None).await.unwrap();
205 assert_eq!(r2.unwrap().as_str(), id2.as_str());
206
207 assert!(broker.retrieve_invocation(None).await.unwrap().is_none());
208 }
209
210 #[tokio::test]
211 async fn test_purge() {
212 let broker = make_broker();
213 broker.route_invocation(&InvocationId::new()).await.unwrap();
214 broker.route_invocation(&InvocationId::new()).await.unwrap();
215
216 broker.purge(None).await.unwrap();
217 assert_eq!(broker.count_invocations(None).await.unwrap(), 0);
218 }
219}