Skip to main content

p2panda_store/address_book/
sqlite.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3use std::collections::HashSet;
4use std::fmt::Display;
5use std::str::FromStr;
6use std::time::Duration;
7
8use p2panda_core::cbor::{decode_cbor, encode_cbor};
9use p2panda_core::{Topic, VerifyingKey};
10use serde::{Deserialize, Serialize};
11use sqlx::{query, query_as, query_scalar};
12
13use crate::address_book::{AddressBookStore, NodeInfo};
14use crate::sqlite::{SqliteError, SqliteStore};
15
16impl<N> AddressBookStore<VerifyingKey, N> for SqliteStore
17where
18    N: NodeInfo<VerifyingKey> + Serialize + for<'de> Deserialize<'de>,
19{
20    type Error = SqliteError;
21
22    async fn insert_node_info(&self, info: N) -> Result<bool, Self::Error> {
23        let is_upsert = {
24            let row = self
25                .tx(async |tx| {
26                    query_as::<_, (i32,)>("SELECT COUNT(*) FROM node_infos_v1 WHERE node_id = ?")
27                        .bind(info.id().to_hex())
28                        .fetch_one(&mut **tx)
29                        .await
30                        .map_err(SqliteError::Sqlite)
31                })
32                .await?;
33
34            row.0 == 1
35        };
36
37        self.tx(async |tx| {
38            query(
39                "
40                INSERT
41                INTO
42                    node_infos_v1 (
43                        node_id,
44                        node_info,
45                        bootstrap
46                    )
47                VALUES
48                    (?, ?, ?)
49                ON CONFLICT(node_id)
50                DO UPDATE
51                    SET
52                        node_info = EXCLUDED.node_info,
53                        bootstrap = EXCLUDED.bootstrap
54                ",
55            )
56            .bind(info.id().to_hex())
57            .bind(
58                encode_cbor(&info)
59                    .map_err(|err| SqliteError::Encode("node_info".to_string(), err))?,
60            )
61            .bind(info.is_bootstrap())
62            .execute(&mut **tx)
63            .await
64            .map_err(SqliteError::Sqlite)
65        })
66        .await?;
67
68        Ok(!is_upsert)
69    }
70
71    async fn remove_node_info(&self, id: &VerifyingKey) -> Result<bool, Self::Error> {
72        // Remove node's info.
73        let result = self
74            .tx(async |tx| {
75                query(
76                    "
77                    DELETE FROM
78                        node_infos_v1
79                    WHERE
80                        node_id = ?
81                    ",
82                )
83                .bind(id.to_hex())
84                .execute(&mut **tx)
85                .await
86                .map_err(SqliteError::Sqlite)
87            })
88            .await?;
89
90        // Remove associated topics for this node.
91        self.tx(async |tx| {
92            query(
93                "
94                DELETE FROM
95                    topics2node_infos_v1
96                WHERE
97                    node_id = ?
98                ",
99            )
100            .bind(id.to_hex())
101            .execute(&mut **tx)
102            .await
103            .map_err(SqliteError::Sqlite)
104        })
105        .await?;
106
107        Ok(result.rows_affected() > 0)
108    }
109
110    async fn remove_older_than(&self, duration: Duration) -> Result<usize, Self::Error> {
111        let result = self
112            .tx(async |tx| {
113                query_as::<_, (String,)>(
114                    "
115                    DELETE FROM
116                        node_infos_v1
117                    WHERE
118                        updated_at < UNIXEPOCH() - ?
119                    RETURNING
120                        node_id
121                    ",
122                )
123                .bind(duration.as_secs() as i64)
124                .fetch_all(&mut **tx)
125                .await
126                .map_err(SqliteError::Sqlite)
127            })
128            .await?;
129
130        let node_ids: Vec<&String> = result.iter().map(|item| &item.0).collect();
131
132        // Remove associated topics for removed nodes.
133        self.tx(async |tx| {
134            query(&format!(
135                "
136                DELETE FROM
137                    topics2node_infos_v1
138                WHERE
139                    node_id IN ({})
140                ",
141                in_op_str(&node_ids)
142            ))
143            .execute(&mut **tx)
144            .await
145            .map_err(SqliteError::Sqlite)
146        })
147        .await?;
148
149        Ok(node_ids.len())
150    }
151
152    async fn node_info(&self, id: &VerifyingKey) -> Result<Option<N>, Self::Error> {
153        let result = self
154            .execute(async |pool| {
155                query_as::<_, (Vec<u8>,)>(
156                    "
157                    SELECT
158                        node_info
159                    FROM
160                        node_infos_v1
161                    WHERE
162                        node_id = ?
163                    ",
164                )
165                .bind(id.to_hex())
166                .fetch_optional(pool)
167                .await
168                .map_err(SqliteError::Sqlite)
169            })
170            .await?;
171
172        decode_node_info(result)
173    }
174
175    async fn node_topics(&self, id: &VerifyingKey) -> Result<HashSet<Topic>, Self::Error> {
176        let result = self
177            .execute(async |pool| {
178                query_as::<_, (String,)>(
179                    "
180                    SELECT
181                        topic_id
182                    FROM
183                        topics2node_infos_v1
184                    WHERE
185                        node_id = ?
186                    ",
187                )
188                .bind(id.to_hex())
189                .fetch_all(pool)
190                .await
191                .map_err(SqliteError::Sqlite)
192            })
193            .await?;
194
195        result
196            .iter()
197            .map(|item| {
198                Topic::from_str(&item.0)
199                    .map_err(|err| SqliteError::Decode("topic_id".to_string(), err.into()))
200            })
201            .collect()
202    }
203
204    async fn all_node_infos(&self) -> Result<Vec<N>, Self::Error> {
205        let result = self
206            .execute(async |pool| {
207                query_as::<_, (Vec<u8>,)>(
208                    "
209                    SELECT
210                        node_info
211                    FROM
212                        node_infos_v1
213                    ",
214                )
215                .fetch_all(pool)
216                .await
217                .map_err(SqliteError::Sqlite)
218            })
219            .await?;
220
221        decode_node_infos(result)
222    }
223
224    async fn all_nodes_len(&self) -> Result<usize, Self::Error> {
225        let count: i64 = self
226            .execute(async |pool| {
227                query_scalar(
228                    "
229                    SELECT
230                        COUNT(node_id)
231                    FROM
232                        node_infos_v1
233                    ",
234                )
235                .fetch_one(pool)
236                .await
237                .map_err(SqliteError::Sqlite)
238            })
239            .await?;
240
241        Ok(count as usize)
242    }
243
244    async fn all_bootstrap_nodes_len(&self) -> Result<usize, Self::Error> {
245        let count: i64 = self
246            .execute(async |pool| {
247                query_scalar(
248                    "
249                    SELECT
250                        COUNT(node_id)
251                    FROM
252                        node_infos_v1
253                    WHERE
254                        bootstrap = TRUE
255                    ",
256                )
257                .fetch_one(pool)
258                .await
259                .map_err(SqliteError::Sqlite)
260            })
261            .await?;
262
263        Ok(count as usize)
264    }
265
266    async fn selected_node_infos(&self, ids: &[VerifyingKey]) -> Result<Vec<N>, Self::Error> {
267        let result = self
268            .execute(async |pool| {
269                query_as::<_, (Vec<u8>,)>(&format!(
270                    "
271                    SELECT
272                        node_info
273                    FROM
274                        node_infos_v1
275                    WHERE
276                        node_id IN ({})
277                    ",
278                    in_op_str(ids)
279                ))
280                .fetch_all(pool)
281                .await
282                .map_err(SqliteError::Sqlite)
283            })
284            .await?;
285
286        decode_node_infos(result)
287    }
288
289    async fn set_topics(
290        &self,
291        id: VerifyingKey,
292        topics: HashSet<Topic>,
293    ) -> Result<(), Self::Error> {
294        // Remove all previous topics set for this node id and replace it with new values. Both
295        // updates will be executed inside the same atomic transaction.
296        self.tx(async |tx| {
297            query(
298                "
299                DELETE FROM
300                    topics2node_infos_v1
301                WHERE
302                    node_id = ?
303                ",
304            )
305            .bind(id.to_hex())
306            .execute(&mut **tx)
307            .await
308            .map_err(SqliteError::Sqlite)
309        })
310        .await?;
311
312        for topic in topics {
313            self.tx(async |tx| {
314                query(
315                    "
316                    INSERT OR IGNORE
317                    INTO
318                        topics2node_infos_v1 (
319                            node_id,
320                            topic_id
321                        )
322                    VALUES
323                        (?, ?)
324                    ",
325                )
326                .bind(id.to_hex())
327                .bind(topic.to_string())
328                .execute(&mut **tx)
329                .await
330                .map_err(SqliteError::Sqlite)
331            })
332            .await?;
333        }
334
335        Ok(())
336    }
337
338    async fn node_infos_by_topics(&self, topics: &[Topic]) -> Result<Vec<N>, Self::Error> {
339        let result = self
340            .execute(async |pool| {
341                query_as::<_, (Vec<u8>,)>(&format!(
342                    "
343                    SELECT
344                        node_infos_v1.node_info
345                    FROM
346                        node_infos_v1
347                    LEFT JOIN topics2node_infos_v1
348                        ON node_infos_v1.node_id = topics2node_infos_v1.node_id
349                    WHERE
350                        topics2node_infos_v1.topic_id IN ({})
351                    GROUP BY
352                        node_infos_v1.node_id
353                    ",
354                    in_op_str(topics)
355                ))
356                .fetch_all(pool)
357                .await
358                .map_err(SqliteError::Sqlite)
359            })
360            .await?;
361
362        decode_node_infos(result)
363    }
364
365    async fn random_node(&self) -> Result<Option<N>, Self::Error> {
366        let result = self
367            .execute(async |pool| {
368                query_as::<_, (Vec<u8>,)>(
369                    "
370                    SELECT
371                        node_info
372                    FROM
373                        node_infos_v1
374                    ORDER BY RANDOM()
375                    LIMIT 1
376                    ",
377                )
378                .fetch_optional(pool)
379                .await
380                .map_err(SqliteError::Sqlite)
381            })
382            .await?;
383
384        decode_node_info(result)
385    }
386
387    async fn random_bootstrap_node(&self) -> Result<Option<N>, Self::Error> {
388        let result = self
389            .execute(async |pool| {
390                query_as::<_, (Vec<u8>,)>(
391                    "
392                    SELECT
393                        node_info
394                    FROM
395                        node_infos_v1
396                    WHERE
397                        bootstrap = TRUE
398                    ORDER BY RANDOM()
399                    LIMIT 1
400                    ",
401                )
402                .fetch_optional(pool)
403                .await
404                .map_err(SqliteError::Sqlite)
405            })
406            .await?;
407
408        decode_node_info(result)
409    }
410}
411
412#[cfg(any(test, feature = "test_utils"))]
413#[doc(hidden)]
414impl SqliteStore {
415    pub async fn set_last_changed(
416        &self,
417        id: &VerifyingKey,
418        timestamp: u64,
419    ) -> Result<(), SqliteError> {
420        self.tx(async |tx| {
421            query(
422                "
423                UPDATE
424                    node_infos_v1
425                SET
426                    updated_at = ?
427                WHERE
428                    node_id = ?
429                ",
430            )
431            .bind(timestamp as i64)
432            .bind(id.to_hex())
433            .execute(&mut **tx)
434            .await
435            .map_err(SqliteError::Sqlite)
436        })
437        .await?;
438
439        Ok(())
440    }
441}
442
443/// Takes a list of items implementing `Display` to turn it into an SQL "IN" operator where each
444/// item is represented as a string.
445///
446/// ```text
447/// SELECT * FROM users
448/// WHERE
449///     id IN ('1a', '2b', '3c');
450/// ```
451fn in_op_str<T: Display>(list: &[T]) -> String {
452    list.iter()
453        .map(|item| format!("'{item}'"))
454        .collect::<Vec<String>>()
455        .join(",")
456}
457
458/// Deserialize multiple rows containing encoded node info.
459fn decode_node_infos<N>(result: Vec<(Vec<u8>,)>) -> Result<Vec<N>, SqliteError>
460where
461    N: NodeInfo<VerifyingKey> + Serialize + for<'a> Deserialize<'a>,
462{
463    result
464        .iter()
465        .map(|item| {
466            decode_cbor(&item.0[..])
467                .map_err(|err| SqliteError::Decode("node_info".to_string(), err.into()))
468        })
469        .collect()
470}
471
472/// Deserialize single row maybe containing encoded node info.
473fn decode_node_info<N>(result: Option<(Vec<u8>,)>) -> Result<Option<N>, SqliteError>
474where
475    N: NodeInfo<VerifyingKey> + Serialize + for<'a> Deserialize<'a>,
476{
477    match result {
478        Some((bytes,)) => {
479            Ok(Some(decode_cbor(&bytes[..]).map_err(|err| {
480                SqliteError::Decode("node_info".to_string(), err.into())
481            })?))
482        }
483        None => Ok(None),
484    }
485}