radicle/node/policy/
store.rs

1#![allow(clippy::type_complexity)]
2use std::collections::{BTreeMap, BTreeSet};
3use std::marker::PhantomData;
4use std::path::Path;
5use std::{fmt, io, ops::Not as _, str::FromStr, time};
6
7use sqlite as sql;
8use thiserror::Error;
9
10use crate::node::{Alias, AliasStore};
11use crate::prelude::{NodeId, RepoId};
12
13use super::{FollowPolicy, Policy, Scope, SeedPolicy, SeedingPolicy};
14
15/// How long to wait for the database lock to be released before failing a read.
16const DB_READ_TIMEOUT: time::Duration = time::Duration::from_secs(3);
17/// How long to wait for the database lock to be released before failing a write.
18const DB_WRITE_TIMEOUT: time::Duration = time::Duration::from_secs(6);
19
20#[derive(Error, Debug)]
21pub enum Error {
22    /// I/O error.
23    #[error("i/o error: {0}")]
24    Io(#[from] io::Error),
25    /// An Internal error.
26    #[error("internal error: {0}")]
27    Internal(#[from] sql::Error),
28}
29
30/// Read-only type witness.
31pub struct Read;
32/// Read-write type witness.
33pub struct Write;
34
35/// Read only config.
36pub type StoreReader = Store<Read>;
37/// Read-write config.
38pub type StoreWriter = Store<Write>;
39
40/// Policy configuration.
41pub struct Store<T> {
42    db: sql::Connection,
43    _marker: PhantomData<T>,
44}
45
46impl<T> fmt::Debug for Store<T> {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        write!(f, "Store(..)")
49    }
50}
51
52impl Store<Read> {
53    const SCHEMA: &'static str = include_str!("schema.sql");
54
55    /// Same as [`Self::open`], but in read-only mode. This is useful to have multiple
56    /// open databases, as no locking is required.
57    pub fn reader<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
58        let mut db =
59            sql::Connection::open_with_flags(path, sqlite::OpenFlags::new().with_read_only())?;
60        db.set_busy_timeout(DB_READ_TIMEOUT.as_millis() as usize)?;
61        db.execute(Self::SCHEMA)?;
62
63        Ok(Self {
64            db,
65            _marker: PhantomData,
66        })
67    }
68
69    /// Create a new in-memory address book.
70    pub fn memory() -> Result<Self, Error> {
71        let db = sql::Connection::open_with_flags(
72            ":memory:",
73            sqlite::OpenFlags::new().with_read_only(),
74        )?;
75        db.execute(Self::SCHEMA)?;
76
77        Ok(Self {
78            db,
79            _marker: PhantomData,
80        })
81    }
82}
83
84impl Store<Write> {
85    const SCHEMA: &'static str = include_str!("schema.sql");
86
87    /// Open a policy store at the given path. Creates a new store if it
88    /// doesn't exist.
89    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
90        let mut db = sql::Connection::open(path)?;
91        db.set_busy_timeout(DB_WRITE_TIMEOUT.as_millis() as usize)?;
92        db.execute(Self::SCHEMA)?;
93
94        Ok(Self {
95            db,
96            _marker: PhantomData,
97        })
98    }
99
100    /// Create a new in-memory address book.
101    pub fn memory() -> Result<Self, Error> {
102        let db = sql::Connection::open(":memory:")?;
103        db.execute(Self::SCHEMA)?;
104
105        Ok(Self {
106            db,
107            _marker: PhantomData,
108        })
109    }
110
111    /// Get a read-only version of this store.
112    pub fn read_only(self) -> StoreReader {
113        Store {
114            db: self.db,
115            _marker: PhantomData,
116        }
117    }
118
119    /// Follow a node.
120    pub fn follow(&mut self, id: &NodeId, alias: Option<&Alias>) -> Result<bool, Error> {
121        let mut stmt = self.db.prepare(
122            "INSERT INTO `following` (id, alias)
123             VALUES (?1, ?2)
124             ON CONFLICT DO UPDATE
125             SET alias = ?2 WHERE alias != ?2",
126        )?;
127
128        stmt.bind((1, id))?;
129        stmt.bind((2, alias.map_or("", |alias| alias.as_str())))?;
130        stmt.next()?;
131
132        Ok(self.db.change_count() > 0)
133    }
134
135    /// Seed a repository.
136    pub fn seed(&mut self, id: &RepoId, scope: Scope) -> Result<bool, Error> {
137        let mut stmt = self.db.prepare(
138            "INSERT INTO `seeding` (id, scope)
139             VALUES (?1, ?2)
140             ON CONFLICT DO UPDATE
141             SET scope = ?2 WHERE scope != ?2",
142        )?;
143
144        stmt.bind((1, id))?;
145        stmt.bind((2, scope))?;
146        stmt.next()?;
147
148        Ok(self.db.change_count() > 0)
149    }
150
151    /// Set a node's follow policy.
152    pub fn set_follow_policy(&mut self, id: &NodeId, policy: Policy) -> Result<bool, Error> {
153        let mut stmt = self.db.prepare(
154            "INSERT INTO `following` (id, policy)
155             VALUES (?1, ?2)
156             ON CONFLICT DO UPDATE
157             SET policy = ?2 WHERE policy != ?2",
158        )?;
159
160        stmt.bind((1, id))?;
161        stmt.bind((2, policy))?;
162        stmt.next()?;
163
164        Ok(self.db.change_count() > 0)
165    }
166
167    /// Set a repository's seeding policy.
168    pub fn set_seed_policy(&mut self, id: &RepoId, policy: Policy) -> Result<bool, Error> {
169        let mut stmt = self.db.prepare(
170            "INSERT INTO `seeding` (id, policy)
171             VALUES (?1, ?2)
172             ON CONFLICT DO UPDATE
173             SET policy = ?2 WHERE policy != ?2",
174        )?;
175
176        stmt.bind((1, id))?;
177        stmt.bind((2, policy))?;
178        stmt.next()?;
179
180        Ok(self.db.change_count() > 0)
181    }
182
183    /// Unfollow a node.
184    pub fn unfollow(&mut self, id: &NodeId) -> Result<bool, Error> {
185        let mut stmt = self.db.prepare("DELETE FROM `following` WHERE id = ?")?;
186
187        stmt.bind((1, id))?;
188        stmt.next()?;
189
190        Ok(self.db.change_count() > 0)
191    }
192
193    /// Unseed a repository.
194    pub fn unseed(&mut self, id: &RepoId) -> Result<bool, Error> {
195        let mut stmt = self.db.prepare("DELETE FROM `seeding` WHERE id = ?")?;
196
197        stmt.bind((1, id))?;
198        stmt.next()?;
199
200        Ok(self.db.change_count() > 0)
201    }
202
203    /// Unblock a repository.
204    pub fn unblock_rid(&mut self, id: &RepoId) -> Result<bool, Error> {
205        let mut stmt = self
206            .db
207            .prepare("DELETE FROM `seeding` WHERE id = ? AND policy = 'block'")?;
208
209        stmt.bind((1, id))?;
210        stmt.next()?;
211
212        Ok(self.db.change_count() > 0)
213    }
214
215    /// Unblock a remote.
216    pub fn unblock_nid(&mut self, id: &NodeId) -> Result<bool, Error> {
217        let mut stmt = self
218            .db
219            .prepare("DELETE FROM `following` WHERE id = ? AND policy = 'block'")?;
220
221        stmt.bind((1, id))?;
222        stmt.next()?;
223
224        Ok(self.db.change_count() > 0)
225    }
226}
227
228/// `Read` methods for `Config`. This implies that a
229/// `Config<Write>` can access these functions as well.
230impl<T> Store<T> {
231    /// Check if a node is followed.
232    pub fn is_following(&self, id: &NodeId) -> Result<bool, Error> {
233        Ok(matches!(
234            self.follow_policy(id)?,
235            Some(FollowPolicy {
236                policy: Policy::Allow,
237                ..
238            })
239        ))
240    }
241
242    /// Check if a repository is seeded.
243    pub fn is_seeding(&self, id: &RepoId) -> Result<bool, Error> {
244        Ok(matches!(
245            self.seed_policy(id)?,
246            Some(SeedPolicy { policy, .. })
247            if policy.is_allow()
248        ))
249    }
250
251    /// Get a node's follow policy.
252    pub fn follow_policy(&self, id: &NodeId) -> Result<Option<FollowPolicy>, Error> {
253        let mut stmt = self
254            .db
255            .prepare("SELECT alias, policy FROM `following` WHERE id = ?")?;
256
257        stmt.bind((1, id))?;
258
259        if let Some(Ok(row)) = stmt.into_iter().next() {
260            let alias = row.try_read::<&str, _>("alias")?;
261            let alias = alias
262                .is_empty()
263                .not()
264                .then_some(alias.to_owned())
265                .and_then(|s| Alias::from_str(&s).ok());
266            let policy = row.try_read::<Policy, _>("policy")?;
267
268            return Ok(Some(FollowPolicy {
269                nid: *id,
270                alias,
271                policy,
272            }));
273        }
274        Ok(None)
275    }
276
277    /// Get a repository's seeding policy.
278    pub fn seed_policy(&self, id: &RepoId) -> Result<Option<SeedPolicy>, Error> {
279        let mut stmt = self
280            .db
281            .prepare("SELECT scope, policy FROM `seeding` WHERE id = ?")?;
282
283        stmt.bind((1, id))?;
284
285        if let Some(Ok(row)) = stmt.into_iter().next() {
286            let policy = match row.try_read::<Policy, _>("policy")? {
287                Policy::Allow => SeedingPolicy::Allow {
288                    scope: row.try_read::<Scope, _>("scope")?,
289                },
290                Policy::Block => SeedingPolicy::Block,
291            };
292            return Ok(Some(SeedPolicy { rid: *id, policy }));
293        }
294        Ok(None)
295    }
296
297    /// Get node follow policies.
298    pub fn follow_policies(&self) -> Result<FollowPolicies<'_>, Error> {
299        let stmt = self
300            .db
301            .prepare("SELECT id, alias, policy FROM `following`")?;
302        Ok(FollowPolicies {
303            inner: stmt.into_iter(),
304        })
305    }
306
307    /// Get repository seed policies.
308    pub fn seed_policies(&self) -> Result<SeedPolicies<'_>, Error> {
309        let stmt = self.db.prepare("SELECT id, scope, policy FROM `seeding`")?;
310        Ok(SeedPolicies {
311            inner: stmt.into_iter(),
312        })
313    }
314
315    pub fn nodes_by_alias<'a>(&'a self, alias: &Alias) -> Result<NodeAliasIter<'a>, Error> {
316        let mut stmt = self
317            .db
318            .prepare("SELECT id, alias FROM `following` WHERE UPPER(alias) LIKE ?")?;
319        let query = format!("%{}%", alias.as_str().to_uppercase());
320        stmt.bind((1, sql::Value::String(query)))?;
321        Ok(NodeAliasIter {
322            inner: stmt.into_iter(),
323        })
324    }
325}
326
327pub struct FollowPolicies<'a> {
328    inner: sql::CursorWithOwnership<'a>,
329}
330
331impl Iterator for FollowPolicies<'_> {
332    type Item = Result<FollowPolicy, Error>;
333
334    fn next(&mut self) -> Option<Self::Item> {
335        let row = self.inner.next()?;
336        let Ok(row) = row else { return self.next() };
337
338        let id = match row.try_read("id") {
339            Ok(id) => id,
340            Err(err) => return Some(Err(err.into())),
341        };
342
343        let alias = match row.try_read::<&str, _>("alias") {
344            Ok(alias) => alias.to_owned(),
345            Err(err) => return Some(Err(err.into())),
346        };
347
348        let alias = alias
349            .is_empty()
350            .not()
351            .then_some(alias.to_owned())
352            .and_then(|s| Alias::from_str(&s).ok());
353
354        let policy = match row.try_read::<Policy, _>("policy") {
355            Ok(policy) => policy,
356            Err(err) => return Some(Err(err.into())),
357        };
358
359        Some(Ok(FollowPolicy {
360            nid: id,
361            alias,
362            policy,
363        }))
364    }
365}
366
367pub struct SeedPolicies<'a> {
368    inner: sql::CursorWithOwnership<'a>,
369}
370
371impl Iterator for SeedPolicies<'_> {
372    type Item = Result<SeedPolicy, Error>;
373
374    fn next(&mut self) -> Option<Self::Item> {
375        let row = self.inner.next()?;
376        let Ok(row) = row else { return self.next() };
377
378        let id = match row.try_read("id") {
379            Ok(id) => id,
380            Err(err) => return Some(Err(err.into())),
381        };
382
383        let policy = match row.try_read::<Policy, _>("policy") {
384            Ok(policy) => policy,
385            Err(err) => return Some(Err(err.into())),
386        };
387
388        match policy {
389            Policy::Allow => match row.try_read::<Scope, _>("scope") {
390                Ok(scope) => Some(Ok(SeedPolicy {
391                    rid: id,
392                    policy: SeedingPolicy::Allow { scope },
393                })),
394                Err(err) => Some(Err(err.into())),
395            },
396            Policy::Block => Some(Ok(SeedPolicy {
397                rid: id,
398                policy: SeedingPolicy::Block,
399            })),
400        }
401    }
402}
403
404pub struct NodeAliasIter<'a> {
405    inner: sql::CursorWithOwnership<'a>,
406}
407
408impl NodeAliasIter<'_> {
409    fn parse_row(row: sql::Row) -> Result<(NodeId, Alias), Error> {
410        let nid = row.try_read::<NodeId, _>("id")?;
411        let alias = row.try_read::<Alias, _>("alias")?;
412        Ok((nid, alias))
413    }
414}
415
416impl Iterator for NodeAliasIter<'_> {
417    type Item = Result<(NodeId, Alias), Error>;
418
419    fn next(&mut self) -> Option<Self::Item> {
420        let row = self.inner.next()?;
421        Some(row.map_err(Error::from).and_then(Self::parse_row))
422    }
423}
424
425impl<T> AliasStore for Store<T> {
426    /// Retrieve `alias` of given node.
427    /// Calls `Self::node_policy` under the hood.
428    fn alias(&self, nid: &NodeId) -> Option<Alias> {
429        self.follow_policy(nid)
430            .map(|node| node.and_then(|n| n.alias))
431            .unwrap_or(None)
432    }
433
434    fn reverse_lookup(&self, alias: &Alias) -> BTreeMap<Alias, BTreeSet<NodeId>> {
435        let Ok(iter) = self.nodes_by_alias(alias) else {
436            return BTreeMap::new();
437        };
438        iter.flatten()
439            .fold(BTreeMap::new(), |mut result, (node, alias)| {
440                let nodes = result.entry(alias).or_default();
441                nodes.insert(node);
442                result
443            })
444    }
445}
446
447#[cfg(test)]
448#[allow(clippy::unwrap_used)]
449mod test {
450    use crate::{assert_matches, node};
451
452    use super::*;
453    use crate::test::arbitrary;
454
455    #[test]
456    fn test_follow_and_unfollow_node() {
457        let id = arbitrary::gen::<NodeId>(1);
458        let mut db = Store::open(":memory:").unwrap();
459        let eve = Alias::new("eve");
460
461        assert!(db.follow(&id, Some(&eve)).unwrap());
462        assert!(db.is_following(&id).unwrap());
463        assert!(!db.follow(&id, Some(&eve)).unwrap());
464        assert!(db.unfollow(&id).unwrap());
465        assert!(!db.is_following(&id).unwrap());
466    }
467
468    #[test]
469    fn test_seed_and_unseed_repo() {
470        let id = arbitrary::gen::<RepoId>(1);
471        let mut db = Store::open(":memory:").unwrap();
472
473        assert!(db.seed(&id, Scope::All).unwrap());
474        assert!(db.is_seeding(&id).unwrap());
475        assert!(!db.seed(&id, Scope::All).unwrap());
476        assert!(db.unseed(&id).unwrap());
477        assert!(!db.is_seeding(&id).unwrap());
478    }
479
480    #[test]
481    fn test_node_policies() {
482        let ids = arbitrary::vec::<NodeId>(3);
483        let mut db = Store::open(":memory:").unwrap();
484
485        for id in &ids {
486            assert!(db.follow(id, None).unwrap());
487        }
488        let mut entries = db.follow_policies().unwrap();
489        assert_matches!(entries.next(), Some(Ok(FollowPolicy { nid, .. })) if nid == ids[0]);
490        assert_matches!(entries.next(), Some(Ok(FollowPolicy { nid, .. })) if nid == ids[1]);
491        assert_matches!(entries.next(), Some(Ok(FollowPolicy { nid, .. })) if nid == ids[2]);
492    }
493
494    #[test]
495    fn test_repo_policies() {
496        let ids = arbitrary::vec::<RepoId>(3);
497        let mut db = Store::open(":memory:").unwrap();
498
499        for id in &ids {
500            assert!(db.seed(id, Scope::All).unwrap());
501        }
502        let mut entries = db.seed_policies().unwrap();
503        assert_matches!(entries.next(), Some(Ok(SeedPolicy { rid, .. })) if rid == ids[0]);
504        assert_matches!(entries.next(), Some(Ok(SeedPolicy { rid, .. })) if rid == ids[1]);
505        assert_matches!(entries.next(), Some(Ok(SeedPolicy { rid, .. })) if rid == ids[2]);
506    }
507
508    #[test]
509    fn test_update_alias() {
510        let id = arbitrary::gen::<NodeId>(1);
511        let mut db = Store::open(":memory:").unwrap();
512
513        assert!(db.follow(&id, Some(&Alias::new("eve"))).unwrap());
514        assert_eq!(
515            db.follow_policy(&id).unwrap().unwrap().alias,
516            Some(Alias::from_str("eve").unwrap())
517        );
518        assert!(db.follow(&id, None).unwrap());
519        assert_eq!(db.follow_policy(&id).unwrap().unwrap().alias, None);
520        assert!(!db.follow(&id, None).unwrap());
521        assert!(db.follow(&id, Some(&Alias::new("alice"))).unwrap());
522        assert_eq!(
523            db.follow_policy(&id).unwrap().unwrap().alias,
524            Some(Alias::new("alice"))
525        );
526    }
527
528    #[test]
529    fn test_update_scope() {
530        let id = arbitrary::gen::<RepoId>(1);
531        let mut db = Store::open(":memory:").unwrap();
532
533        assert!(db.seed(&id, Scope::All).unwrap());
534        assert_eq!(
535            db.seed_policy(&id).unwrap().unwrap().scope(),
536            Some(Scope::All)
537        );
538        assert!(db.seed(&id, Scope::Followed).unwrap());
539        assert_eq!(
540            db.seed_policy(&id).unwrap().unwrap().scope(),
541            Some(Scope::Followed)
542        );
543    }
544
545    #[test]
546    fn test_repo_policy() {
547        let id = arbitrary::gen::<RepoId>(1);
548        let mut db = Store::open(":memory:").unwrap();
549
550        assert!(db.seed(&id, Scope::All).unwrap());
551        assert!(db.seed_policy(&id).unwrap().unwrap().is_allow());
552        assert!(db.set_seed_policy(&id, Policy::Block).unwrap());
553        assert!(!db.seed_policy(&id).unwrap().unwrap().is_allow());
554        assert_eq!(db.seed_policy(&id).unwrap().unwrap().scope(), None);
555    }
556
557    #[test]
558    fn test_node_policy() {
559        let id = arbitrary::gen::<NodeId>(1);
560        let mut db = Store::open(":memory:").unwrap();
561
562        assert!(db.follow(&id, None).unwrap());
563        assert_eq!(
564            db.follow_policy(&id).unwrap().unwrap().policy,
565            Policy::Allow
566        );
567        assert!(db.set_follow_policy(&id, Policy::Block).unwrap());
568        assert_eq!(
569            db.follow_policy(&id).unwrap().unwrap().policy,
570            Policy::Block
571        );
572    }
573
574    #[test]
575    fn test_node_aliases() {
576        let mut db = Store::open(":memory:").unwrap();
577        let input = node::properties::AliasInput::new();
578        let (short, short_ids) = input.short();
579        let (long, long_ids) = input.long();
580
581        for id in short_ids {
582            db.follow(id, Some(short)).unwrap();
583        }
584
585        for id in long_ids {
586            db.follow(id, Some(long)).unwrap();
587        }
588
589        node::properties::test_reverse_lookup(&db, input)
590    }
591}