p2panda_store/address_book/
sqlite.rs1use std::collections::HashSet;
4use std::fmt::Display;
5use std::str::FromStr;
6use std::time::Duration;
7
8use p2panda_core::cbor::{decode_cbor, encode_cbor};
9use p2panda_core::{Topic, VerifyingKey};
10use serde::{Deserialize, Serialize};
11use sqlx::{query, query_as, query_scalar};
12
13use crate::address_book::{AddressBookStore, NodeInfo};
14use crate::sqlite::{SqliteError, SqliteStore};
15
16impl<N> AddressBookStore<VerifyingKey, N> for SqliteStore
17where
18 N: NodeInfo<VerifyingKey> + Serialize + for<'de> Deserialize<'de>,
19{
20 type Error = SqliteError;
21
22 async fn insert_node_info(&self, info: N) -> Result<bool, Self::Error> {
23 let is_upsert = {
24 let row = self
25 .tx(async |tx| {
26 query_as::<_, (i32,)>("SELECT COUNT(*) FROM node_infos_v1 WHERE node_id = ?")
27 .bind(info.id().to_hex())
28 .fetch_one(&mut **tx)
29 .await
30 .map_err(SqliteError::Sqlite)
31 })
32 .await?;
33
34 row.0 == 1
35 };
36
37 self.tx(async |tx| {
38 query(
39 "
40 INSERT
41 INTO
42 node_infos_v1 (
43 node_id,
44 node_info,
45 bootstrap
46 )
47 VALUES
48 (?, ?, ?)
49 ON CONFLICT(node_id)
50 DO UPDATE
51 SET
52 node_info = EXCLUDED.node_info,
53 bootstrap = EXCLUDED.bootstrap
54 ",
55 )
56 .bind(info.id().to_hex())
57 .bind(
58 encode_cbor(&info)
59 .map_err(|err| SqliteError::Encode("node_info".to_string(), err))?,
60 )
61 .bind(info.is_bootstrap())
62 .execute(&mut **tx)
63 .await
64 .map_err(SqliteError::Sqlite)
65 })
66 .await?;
67
68 Ok(!is_upsert)
69 }
70
71 async fn remove_node_info(&self, id: &VerifyingKey) -> Result<bool, Self::Error> {
72 let result = self
74 .tx(async |tx| {
75 query(
76 "
77 DELETE FROM
78 node_infos_v1
79 WHERE
80 node_id = ?
81 ",
82 )
83 .bind(id.to_hex())
84 .execute(&mut **tx)
85 .await
86 .map_err(SqliteError::Sqlite)
87 })
88 .await?;
89
90 self.tx(async |tx| {
92 query(
93 "
94 DELETE FROM
95 topics2node_infos_v1
96 WHERE
97 node_id = ?
98 ",
99 )
100 .bind(id.to_hex())
101 .execute(&mut **tx)
102 .await
103 .map_err(SqliteError::Sqlite)
104 })
105 .await?;
106
107 Ok(result.rows_affected() > 0)
108 }
109
110 async fn remove_older_than(&self, duration: Duration) -> Result<usize, Self::Error> {
111 let result = self
112 .tx(async |tx| {
113 query_as::<_, (String,)>(
114 "
115 DELETE FROM
116 node_infos_v1
117 WHERE
118 updated_at < UNIXEPOCH() - ?
119 RETURNING
120 node_id
121 ",
122 )
123 .bind(duration.as_secs() as i64)
124 .fetch_all(&mut **tx)
125 .await
126 .map_err(SqliteError::Sqlite)
127 })
128 .await?;
129
130 let node_ids: Vec<&String> = result.iter().map(|item| &item.0).collect();
131
132 self.tx(async |tx| {
134 query(&format!(
135 "
136 DELETE FROM
137 topics2node_infos_v1
138 WHERE
139 node_id IN ({})
140 ",
141 in_op_str(&node_ids)
142 ))
143 .execute(&mut **tx)
144 .await
145 .map_err(SqliteError::Sqlite)
146 })
147 .await?;
148
149 Ok(node_ids.len())
150 }
151
152 async fn node_info(&self, id: &VerifyingKey) -> Result<Option<N>, Self::Error> {
153 let result = self
154 .execute(async |pool| {
155 query_as::<_, (Vec<u8>,)>(
156 "
157 SELECT
158 node_info
159 FROM
160 node_infos_v1
161 WHERE
162 node_id = ?
163 ",
164 )
165 .bind(id.to_hex())
166 .fetch_optional(pool)
167 .await
168 .map_err(SqliteError::Sqlite)
169 })
170 .await?;
171
172 decode_node_info(result)
173 }
174
175 async fn node_topics(&self, id: &VerifyingKey) -> Result<HashSet<Topic>, Self::Error> {
176 let result = self
177 .execute(async |pool| {
178 query_as::<_, (String,)>(
179 "
180 SELECT
181 topic_id
182 FROM
183 topics2node_infos_v1
184 WHERE
185 node_id = ?
186 ",
187 )
188 .bind(id.to_hex())
189 .fetch_all(pool)
190 .await
191 .map_err(SqliteError::Sqlite)
192 })
193 .await?;
194
195 result
196 .iter()
197 .map(|item| {
198 Topic::from_str(&item.0)
199 .map_err(|err| SqliteError::Decode("topic_id".to_string(), err.into()))
200 })
201 .collect()
202 }
203
204 async fn all_node_infos(&self) -> Result<Vec<N>, Self::Error> {
205 let result = self
206 .execute(async |pool| {
207 query_as::<_, (Vec<u8>,)>(
208 "
209 SELECT
210 node_info
211 FROM
212 node_infos_v1
213 ",
214 )
215 .fetch_all(pool)
216 .await
217 .map_err(SqliteError::Sqlite)
218 })
219 .await?;
220
221 decode_node_infos(result)
222 }
223
224 async fn all_nodes_len(&self) -> Result<usize, Self::Error> {
225 let count: i64 = self
226 .execute(async |pool| {
227 query_scalar(
228 "
229 SELECT
230 COUNT(node_id)
231 FROM
232 node_infos_v1
233 ",
234 )
235 .fetch_one(pool)
236 .await
237 .map_err(SqliteError::Sqlite)
238 })
239 .await?;
240
241 Ok(count as usize)
242 }
243
244 async fn all_bootstrap_nodes_len(&self) -> Result<usize, Self::Error> {
245 let count: i64 = self
246 .execute(async |pool| {
247 query_scalar(
248 "
249 SELECT
250 COUNT(node_id)
251 FROM
252 node_infos_v1
253 WHERE
254 bootstrap = TRUE
255 ",
256 )
257 .fetch_one(pool)
258 .await
259 .map_err(SqliteError::Sqlite)
260 })
261 .await?;
262
263 Ok(count as usize)
264 }
265
266 async fn selected_node_infos(&self, ids: &[VerifyingKey]) -> Result<Vec<N>, Self::Error> {
267 let result = self
268 .execute(async |pool| {
269 query_as::<_, (Vec<u8>,)>(&format!(
270 "
271 SELECT
272 node_info
273 FROM
274 node_infos_v1
275 WHERE
276 node_id IN ({})
277 ",
278 in_op_str(ids)
279 ))
280 .fetch_all(pool)
281 .await
282 .map_err(SqliteError::Sqlite)
283 })
284 .await?;
285
286 decode_node_infos(result)
287 }
288
289 async fn set_topics(
290 &self,
291 id: VerifyingKey,
292 topics: HashSet<Topic>,
293 ) -> Result<(), Self::Error> {
294 self.tx(async |tx| {
297 query(
298 "
299 DELETE FROM
300 topics2node_infos_v1
301 WHERE
302 node_id = ?
303 ",
304 )
305 .bind(id.to_hex())
306 .execute(&mut **tx)
307 .await
308 .map_err(SqliteError::Sqlite)
309 })
310 .await?;
311
312 for topic in topics {
313 self.tx(async |tx| {
314 query(
315 "
316 INSERT OR IGNORE
317 INTO
318 topics2node_infos_v1 (
319 node_id,
320 topic_id
321 )
322 VALUES
323 (?, ?)
324 ",
325 )
326 .bind(id.to_hex())
327 .bind(topic.to_string())
328 .execute(&mut **tx)
329 .await
330 .map_err(SqliteError::Sqlite)
331 })
332 .await?;
333 }
334
335 Ok(())
336 }
337
338 async fn node_infos_by_topics(&self, topics: &[Topic]) -> Result<Vec<N>, Self::Error> {
339 let result = self
340 .execute(async |pool| {
341 query_as::<_, (Vec<u8>,)>(&format!(
342 "
343 SELECT
344 node_infos_v1.node_info
345 FROM
346 node_infos_v1
347 LEFT JOIN topics2node_infos_v1
348 ON node_infos_v1.node_id = topics2node_infos_v1.node_id
349 WHERE
350 topics2node_infos_v1.topic_id IN ({})
351 GROUP BY
352 node_infos_v1.node_id
353 ",
354 in_op_str(topics)
355 ))
356 .fetch_all(pool)
357 .await
358 .map_err(SqliteError::Sqlite)
359 })
360 .await?;
361
362 decode_node_infos(result)
363 }
364
365 async fn random_node(&self) -> Result<Option<N>, Self::Error> {
366 let result = self
367 .execute(async |pool| {
368 query_as::<_, (Vec<u8>,)>(
369 "
370 SELECT
371 node_info
372 FROM
373 node_infos_v1
374 ORDER BY RANDOM()
375 LIMIT 1
376 ",
377 )
378 .fetch_optional(pool)
379 .await
380 .map_err(SqliteError::Sqlite)
381 })
382 .await?;
383
384 decode_node_info(result)
385 }
386
387 async fn random_bootstrap_node(&self) -> Result<Option<N>, Self::Error> {
388 let result = self
389 .execute(async |pool| {
390 query_as::<_, (Vec<u8>,)>(
391 "
392 SELECT
393 node_info
394 FROM
395 node_infos_v1
396 WHERE
397 bootstrap = TRUE
398 ORDER BY RANDOM()
399 LIMIT 1
400 ",
401 )
402 .fetch_optional(pool)
403 .await
404 .map_err(SqliteError::Sqlite)
405 })
406 .await?;
407
408 decode_node_info(result)
409 }
410}
411
412#[cfg(any(test, feature = "test_utils"))]
413#[doc(hidden)]
414impl SqliteStore {
415 pub async fn set_last_changed(
416 &self,
417 id: &VerifyingKey,
418 timestamp: u64,
419 ) -> Result<(), SqliteError> {
420 self.tx(async |tx| {
421 query(
422 "
423 UPDATE
424 node_infos_v1
425 SET
426 updated_at = ?
427 WHERE
428 node_id = ?
429 ",
430 )
431 .bind(timestamp as i64)
432 .bind(id.to_hex())
433 .execute(&mut **tx)
434 .await
435 .map_err(SqliteError::Sqlite)
436 })
437 .await?;
438
439 Ok(())
440 }
441}
442
443fn in_op_str<T: Display>(list: &[T]) -> String {
452 list.iter()
453 .map(|item| format!("'{item}'"))
454 .collect::<Vec<String>>()
455 .join(",")
456}
457
458fn decode_node_infos<N>(result: Vec<(Vec<u8>,)>) -> Result<Vec<N>, SqliteError>
460where
461 N: NodeInfo<VerifyingKey> + Serialize + for<'a> Deserialize<'a>,
462{
463 result
464 .iter()
465 .map(|item| {
466 decode_cbor(&item.0[..])
467 .map_err(|err| SqliteError::Decode("node_info".to_string(), err.into()))
468 })
469 .collect()
470}
471
472fn decode_node_info<N>(result: Option<(Vec<u8>,)>) -> Result<Option<N>, SqliteError>
474where
475 N: NodeInfo<VerifyingKey> + Serialize + for<'a> Deserialize<'a>,
476{
477 match result {
478 Some((bytes,)) => {
479 Ok(Some(decode_cbor(&bytes[..]).map_err(|err| {
480 SqliteError::Decode("node_info".to_string(), err.into())
481 })?))
482 }
483 None => Ok(None),
484 }
485}