radicle/node/notifications/
store.rs

1#![allow(clippy::type_complexity)]
2use std::marker::PhantomData;
3use std::num::TryFromIntError;
4use std::path::Path;
5use std::sync::Arc;
6use std::{fmt, io, str::FromStr, time};
7
8use localtime::LocalTime;
9use sqlite as sql;
10use thiserror::Error;
11
12use crate::git;
13use crate::git::{Oid, RefError, RefString};
14use crate::prelude::RepoId;
15use crate::sql::transaction;
16use crate::storage::RefUpdate;
17
18use super::{
19    Notification, NotificationId, NotificationKind, NotificationKindError, NotificationStatus,
20};
21
22/// How long to wait for the database lock to be released before failing a read.
23const DB_READ_TIMEOUT: time::Duration = time::Duration::from_secs(3);
24/// How long to wait for the database lock to be released before failing a write.
25const DB_WRITE_TIMEOUT: time::Duration = time::Duration::from_secs(6);
26
27#[derive(Error, Debug)]
28pub enum Error {
29    /// I/O error.
30    #[error("i/o error: {0}")]
31    Io(#[from] io::Error),
32    /// An Internal error.
33    #[error("internal error: {0}")]
34    Internal(#[from] sql::Error),
35    /// Timestamp error.
36    #[error("invalid timestamp: {0}")]
37    Timestamp(#[from] TryFromIntError),
38    /// Invalid Git ref name.
39    #[error("invalid ref name: {0}")]
40    RefName(#[from] RefError),
41    /// Invalid Git ref format.
42    #[error("invalid ref format: {0}")]
43    RefFormat(#[from] git_ext::ref_format::Error),
44    /// Invalid notification kind.
45    #[error("invalid notification kind: {0}")]
46    NotificationKind(#[from] NotificationKindError),
47    /// Not found.
48    #[error("notification {0} not found")]
49    NotificationNotFound(NotificationId),
50    /// Internal unit overflow.
51    #[error("the unit overflowed")]
52    UnitOverflow,
53}
54
55/// Read-only type witness.
56#[derive(Clone)]
57pub struct Read;
58/// Read-write type witness.
59#[derive(Clone)]
60pub struct Write;
61
62/// Notifications store.
63#[derive(Clone)]
64pub struct Store<T> {
65    db: Arc<sql::ConnectionThreadSafe>,
66    marker: PhantomData<T>,
67}
68
69impl<T> fmt::Debug for Store<T> {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        write!(f, "Store(..)")
72    }
73}
74
75impl Store<Read> {
76    const SCHEMA: &'static str = include_str!("schema.sql");
77
78    /// Same as [`Self::open`], but in read-only mode. This is useful to have multiple
79    /// open databases, as no locking is required.
80    pub fn reader<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
81        let mut db = sql::Connection::open_thread_safe_with_flags(
82            path,
83            sqlite::OpenFlags::new().with_read_only(),
84        )?;
85        db.set_busy_timeout(DB_READ_TIMEOUT.as_millis() as usize)?;
86        db.execute(Self::SCHEMA)?;
87
88        Ok(Self {
89            db: Arc::new(db),
90            marker: PhantomData,
91        })
92    }
93
94    /// Create a new in-memory address book.
95    pub fn memory() -> Result<Self, Error> {
96        let db = sql::Connection::open_thread_safe_with_flags(
97            ":memory:",
98            sqlite::OpenFlags::new().with_read_only(),
99        )?;
100        db.execute(Self::SCHEMA)?;
101
102        Ok(Self {
103            db: Arc::new(db),
104            marker: PhantomData,
105        })
106    }
107}
108
109impl Store<Write> {
110    const SCHEMA: &'static str = include_str!("schema.sql");
111
112    /// Open a policy store at the given path. Creates a new store if it
113    /// doesn't exist.
114    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
115        let mut db = sql::Connection::open_thread_safe(path)?;
116        db.set_busy_timeout(DB_WRITE_TIMEOUT.as_millis() as usize)?;
117        db.execute(Self::SCHEMA)?;
118
119        Ok(Self {
120            db: Arc::new(db),
121            marker: PhantomData,
122        })
123    }
124
125    /// Create a new in-memory address book.
126    pub fn memory() -> Result<Self, Error> {
127        let db = sql::Connection::open_thread_safe(":memory:")?;
128        db.execute(Self::SCHEMA)?;
129
130        Ok(Self {
131            db: Arc::new(db),
132            marker: PhantomData,
133        })
134    }
135
136    /// Get a read-only version of this store.
137    pub fn read_only(self) -> Store<Read> {
138        Store {
139            db: self.db,
140            marker: PhantomData,
141        }
142    }
143
144    /// Set notification read status for the given notifications.
145    pub fn set_status(
146        &mut self,
147        status: NotificationStatus,
148        ids: &[NotificationId],
149    ) -> Result<bool, Error> {
150        transaction(&self.db, |_| {
151            let mut stmt = self.db.prepare(
152                "UPDATE `repository-notifications`
153                 SET status = ?1
154                 WHERE rowid = ?2",
155            )?;
156            for id in ids {
157                stmt.bind((1, &status))?;
158                stmt.bind((2, *id as i64))?;
159                stmt.next()?;
160                stmt.reset()?;
161            }
162            Ok(self.db.change_count() > 0)
163        })
164    }
165
166    /// Insert a notification. Resets the status to *unread* if it already exists.
167    pub fn insert(
168        &mut self,
169        repo: &RepoId,
170        update: &RefUpdate,
171        timestamp: LocalTime,
172    ) -> Result<bool, Error> {
173        let mut stmt = self.db.prepare(
174            "INSERT INTO `repository-notifications` (repo, ref, old, new, timestamp)
175             VALUES (?1, ?2, ?3, ?4, ?5)
176             ON CONFLICT DO UPDATE
177             SET old = ?3, new = ?4, timestamp = ?5, status = null",
178        )?;
179        let old = update.old().map(|o| o.to_string());
180        let new = update.new().map(|o| o.to_string());
181
182        stmt.bind((1, repo))?;
183        stmt.bind((2, update.name().as_str()))?;
184        stmt.bind((3, old.as_deref()))?;
185        stmt.bind((4, new.as_deref()))?;
186        stmt.bind((5, i64::try_from(timestamp.as_millis())?))?;
187        stmt.next()?;
188
189        Ok(self.db.change_count() > 0)
190    }
191
192    /// Delete the given notifications.
193    pub fn clear(&mut self, ids: &[NotificationId]) -> Result<usize, Error> {
194        transaction(&self.db, |_| {
195            let mut stmt = self
196                .db
197                .prepare("DELETE FROM `repository-notifications` WHERE rowid = ?")?;
198
199            // N.b. we need to keep the count manually since the change count
200            // will always be `1` because of each reset.
201            let mut count = 0;
202            for id in ids {
203                stmt.bind((1, *id as i64))?;
204                stmt.next()?;
205                stmt.reset()?;
206                count += self.db.change_count();
207            }
208            Ok(count)
209        })
210    }
211
212    /// Delete all notifications of a repo.
213    pub fn clear_by_repo(&mut self, repo: &RepoId) -> Result<usize, Error> {
214        let mut stmt = self
215            .db
216            .prepare("DELETE FROM `repository-notifications` WHERE repo = ?")?;
217
218        stmt.bind((1, repo))?;
219        stmt.next()?;
220
221        Ok(self.db.change_count())
222    }
223
224    /// Delete all notifications from all repos.
225    pub fn clear_all(&mut self) -> Result<usize, Error> {
226        self.db
227            .prepare("DELETE FROM `repository-notifications`")?
228            .next()?;
229
230        Ok(self.db.change_count())
231    }
232}
233
234/// `Read` methods for `Store`. This implies that a
235/// `Store<Write>` can access these functions as well.
236impl<T> Store<T> {
237    /// Get a specific notification.
238    pub fn get(&self, id: NotificationId) -> Result<Notification, Error> {
239        let mut stmt = self.db.prepare(
240            "SELECT rowid, repo, ref, old, new, status, timestamp
241             FROM `repository-notifications`
242             WHERE rowid = ?",
243        )?;
244        stmt.bind((1, id as i64))?;
245
246        if let Some(Ok(row)) = stmt.into_iter().next() {
247            return parse::notification(row);
248        }
249        Err(Error::NotificationNotFound(id))
250    }
251
252    /// Get all notifications.
253    pub fn all(&self) -> Result<impl Iterator<Item = Result<Notification, Error>> + '_, Error> {
254        let stmt = self.db.prepare(
255            "SELECT rowid, repo, ref, old, new, status, timestamp
256             FROM `repository-notifications`
257             ORDER BY timestamp DESC",
258        )?;
259
260        Ok(stmt.into_iter().map(move |row| {
261            let row = row?;
262            parse::notification(row)
263        }))
264    }
265
266    // Get notifications that were created between the given times: `since <= t < until`.
267    pub fn by_timestamp(
268        &self,
269        since: LocalTime,
270        until: LocalTime,
271    ) -> Result<impl Iterator<Item = Result<Notification, Error>> + '_, Error> {
272        let mut stmt = self.db.prepare(
273            "SELECT rowid, repo, ref, old, new, status, timestamp
274             FROM `repository-notifications`
275             WHERE timestamp >= ?1 AND timestamp < ?2
276             ORDER BY timestamp",
277        )?;
278        let since = i64::try_from(since.as_millis())?;
279        let until = i64::try_from(until.as_millis())?;
280
281        stmt.bind((1, since))?;
282        stmt.bind((2, until))?;
283
284        Ok(stmt.into_iter().map(move |row| {
285            let row = row?;
286            parse::notification(row)
287        }))
288    }
289
290    /// Get notifications by repo.
291    pub fn by_repo(
292        &self,
293        repo: &RepoId,
294        order_by: &str,
295    ) -> Result<impl Iterator<Item = Result<Notification, Error>> + '_, Error> {
296        let mut stmt = self.db.prepare(format!(
297            "SELECT rowid, repo, ref, old, new, status, timestamp
298             FROM `repository-notifications`
299             WHERE repo = ?
300             ORDER BY {order_by} DESC",
301        ))?;
302        stmt.bind((1, repo))?;
303
304        Ok(stmt.into_iter().map(move |row| {
305            let row = row?;
306            parse::notification(row)
307        }))
308    }
309
310    /// Get the total notification count.
311    pub fn count(&self) -> Result<usize, Error> {
312        let stmt = self
313            .db
314            .prepare("SELECT COUNT(*) FROM `repository-notifications`")?;
315
316        let count: i64 = stmt
317            .into_iter()
318            .next()
319            .expect("COUNT will always return a single row")?
320            .read(0);
321        let count: usize = count.try_into().map_err(|_| Error::UnitOverflow)?;
322
323        Ok(count)
324    }
325
326    /// Get the total notification count by repos.
327    pub fn counts_by_repo(
328        &self,
329    ) -> Result<impl Iterator<Item = Result<(RepoId, usize), Error>> + '_, Error> {
330        let stmt = self.db.prepare(
331            "SELECT repo, COUNT(*) as count
332             FROM `repository-notifications`
333             GROUP BY repo",
334        )?;
335
336        Ok(stmt.into_iter().map(|row| {
337            let row = row?;
338            let count = row.try_read::<i64, _>("count")? as usize;
339            let repo = row.try_read::<RepoId, _>("repo")?;
340
341            Ok((repo, count))
342        }))
343    }
344
345    /// Get the notification count for the given repo.
346    pub fn count_by_repo(&self, repo: &RepoId) -> Result<usize, Error> {
347        let mut stmt = self
348            .db
349            .prepare("SELECT COUNT(*) FROM `repository-notifications` WHERE repo = ?")?;
350
351        stmt.bind((1, repo))?;
352
353        let count: i64 = stmt
354            .into_iter()
355            .next()
356            .expect("COUNT will always return a single row")?
357            .read(0);
358        let count: usize = count.try_into().map_err(|_| Error::UnitOverflow)?;
359
360        Ok(count)
361    }
362}
363
364mod parse {
365    use super::*;
366
367    pub fn notification(row: sql::Row) -> Result<Notification, Error> {
368        let id = row.try_read::<i64, _>("rowid")? as NotificationId;
369        let repo = row.try_read::<RepoId, _>("repo")?;
370        let refstr = row.try_read::<&str, _>("ref")?;
371        let status = row.try_read::<NotificationStatus, _>("status")?;
372        let old = row
373            .try_read::<Option<&str>, _>("old")?
374            .map(|oid| {
375                Oid::from_str(oid).map_err(|e| {
376                    Error::Internal(sql::Error {
377                        code: None,
378                        message: Some(format!("sql: invalid oid in `old` column: {oid:?}: {e}")),
379                    })
380                })
381            })
382            .unwrap_or(Ok(git::raw::Oid::zero().into()))?;
383        let new = row
384            .try_read::<Option<&str>, _>("new")?
385            .map(|oid| {
386                Oid::from_str(oid).map_err(|e| {
387                    Error::Internal(sql::Error {
388                        code: None,
389                        message: Some(format!("sql: invalid oid in `new` column: {oid:?}: {e}")),
390                    })
391                })
392            })
393            .unwrap_or(Ok(git::raw::Oid::zero().into()))?;
394        let update = RefUpdate::from(RefString::try_from(refstr)?, old, new);
395        let (namespace, qualified) = git::parse_ref(refstr)?;
396        let timestamp = row.try_read::<i64, _>("timestamp")?;
397        let timestamp = LocalTime::from_millis(timestamp as u128);
398        let qualified = qualified.to_owned();
399        let kind = NotificationKind::try_from(qualified.clone())?;
400
401        Ok(Notification {
402            id,
403            repo,
404            update,
405            remote: namespace,
406            qualified,
407            status,
408            kind,
409            timestamp,
410        })
411    }
412}
413
414#[cfg(test)]
415#[allow(clippy::unwrap_used)]
416mod test {
417    use radicle_git_ext::ref_format::{qualified, refname};
418
419    use super::*;
420    use crate::{cob, node::NodeId, test::arbitrary};
421
422    #[test]
423    fn test_clear() {
424        let mut db = Store::open(":memory:").unwrap();
425        let repo = arbitrary::gen::<RepoId>(1);
426        let old = arbitrary::oid();
427        let time = LocalTime::from_millis(32188142);
428        let master = arbitrary::oid();
429
430        for i in 0..3 {
431            let update = RefUpdate::Updated {
432                name: format!("refs/heads/feature/{i}").try_into().unwrap(),
433                old,
434                new: master,
435            };
436            assert!(db.insert(&repo, &update, time).unwrap());
437        }
438        assert_eq!(db.count().unwrap(), 3);
439        assert_eq!(db.count_by_repo(&repo).unwrap(), 3);
440        db.clear_by_repo(&repo).unwrap();
441        assert_eq!(db.count().unwrap(), 0);
442        assert_eq!(db.count_by_repo(&repo).unwrap(), 0);
443    }
444
445    #[test]
446    fn test_counts_by_repo() {
447        let mut db = Store::open(":memory:").unwrap();
448        let repo1 = arbitrary::gen::<RepoId>(1);
449        let repo2 = arbitrary::gen::<RepoId>(1);
450        let oid = arbitrary::oid();
451        let time = LocalTime::from_millis(32188142);
452
453        let update1 = RefUpdate::Created {
454            name: refname!("refs/heads/feature/1"),
455            oid,
456        };
457        let update2 = RefUpdate::Created {
458            name: refname!("refs/heads/feature/2"),
459            oid,
460        };
461        let update3 = RefUpdate::Created {
462            name: refname!("refs/heads/feature/3"),
463            oid,
464        };
465        assert!(db.insert(&repo1, &update1, time).unwrap());
466        assert!(db.insert(&repo1, &update2, time).unwrap());
467        assert!(db.insert(&repo2, &update3, time).unwrap());
468
469        let counts = db
470            .counts_by_repo()
471            .unwrap()
472            .collect::<Result<std::collections::HashMap<_, _>, _>>()
473            .unwrap();
474
475        assert_eq!(counts.get(&repo1).unwrap(), &2);
476        assert_eq!(counts.get(&repo2).unwrap(), &1);
477    }
478
479    #[test]
480    fn test_branch_notifications() {
481        let repo = arbitrary::gen::<RepoId>(1);
482        let old = arbitrary::oid();
483        let master = arbitrary::oid();
484        let other = arbitrary::oid();
485        let time1 = LocalTime::from_millis(32188142);
486        let time2 = LocalTime::from_millis(32189874);
487        let time3 = LocalTime::from_millis(32189879);
488        let mut db = Store::open(":memory:").unwrap();
489
490        let update1 = RefUpdate::Updated {
491            name: refname!("refs/heads/master"),
492            old,
493            new: master,
494        };
495        let update2 = RefUpdate::Created {
496            name: refname!("refs/heads/other"),
497            oid: other,
498        };
499        let update3 = RefUpdate::Deleted {
500            name: refname!("refs/heads/dev"),
501            oid: other,
502        };
503        assert!(db.insert(&repo, &update1, time1).unwrap());
504        assert!(db.insert(&repo, &update2, time2).unwrap());
505        assert!(db.insert(&repo, &update3, time3).unwrap());
506
507        let mut notifs = db.by_repo(&repo, "timestamp").unwrap();
508
509        assert_eq!(
510            notifs.next().unwrap().unwrap(),
511            Notification {
512                id: 3,
513                repo,
514                remote: None,
515                qualified: qualified!("refs/heads/dev"),
516                update: update3,
517                kind: NotificationKind::Branch {
518                    name: refname!("dev")
519                },
520                status: NotificationStatus::Unread,
521                timestamp: time3,
522            }
523        );
524        assert_eq!(
525            notifs.next().unwrap().unwrap(),
526            Notification {
527                id: 2,
528                repo,
529                remote: None,
530                qualified: qualified!("refs/heads/other"),
531                update: update2,
532                kind: NotificationKind::Branch {
533                    name: refname!("other")
534                },
535                status: NotificationStatus::Unread,
536                timestamp: time2,
537            }
538        );
539        assert_eq!(
540            notifs.next().unwrap().unwrap(),
541            Notification {
542                id: 1,
543                repo,
544                remote: None,
545                qualified: qualified!("refs/heads/master"),
546                update: update1,
547                kind: NotificationKind::Branch {
548                    name: refname!("master")
549                },
550                status: NotificationStatus::Unread,
551                timestamp: time1,
552            }
553        );
554        assert!(notifs.next().is_none());
555    }
556
557    #[test]
558    fn test_notification_status() {
559        let repo = arbitrary::gen::<RepoId>(1);
560        let oid = arbitrary::oid();
561        let time = LocalTime::from_millis(32188142);
562        let mut db = Store::open(":memory:").unwrap();
563
564        let update1 = RefUpdate::Created {
565            name: refname!("refs/heads/feature/1"),
566            oid,
567        };
568        let update2 = RefUpdate::Created {
569            name: refname!("refs/heads/feature/2"),
570            oid,
571        };
572        let update3 = RefUpdate::Created {
573            name: refname!("refs/heads/feature/3"),
574            oid,
575        };
576        assert!(db.insert(&repo, &update1, time).unwrap());
577        assert!(db.insert(&repo, &update2, time).unwrap());
578        assert!(db.insert(&repo, &update3, time).unwrap());
579        assert!(db
580            .set_status(NotificationStatus::ReadAt(time), &[1, 2, 3])
581            .unwrap());
582
583        let mut notifs = db.by_repo(&repo, "timestamp").unwrap();
584
585        assert_eq!(
586            notifs.next().unwrap().unwrap().status,
587            NotificationStatus::ReadAt(time),
588        );
589        assert_eq!(
590            notifs.next().unwrap().unwrap().status,
591            NotificationStatus::ReadAt(time),
592        );
593        assert_eq!(
594            notifs.next().unwrap().unwrap().status,
595            NotificationStatus::ReadAt(time),
596        );
597    }
598
599    #[test]
600    fn test_duplicate_notifications() {
601        let repo = arbitrary::gen::<RepoId>(1);
602        let old = arbitrary::oid();
603        let master1 = arbitrary::oid();
604        let master2 = arbitrary::oid();
605        let time1 = LocalTime::from_millis(32188142);
606        let time2 = LocalTime::from_millis(32189874);
607        let mut db = Store::open(":memory:").unwrap();
608
609        let update1 = RefUpdate::Updated {
610            name: refname!("refs/heads/master"),
611            old,
612            new: master1,
613        };
614        let update2 = RefUpdate::Updated {
615            name: refname!("refs/heads/master"),
616            old: master1,
617            new: master2,
618        };
619        assert!(db.insert(&repo, &update1, time1).unwrap());
620        assert!(db
621            .set_status(NotificationStatus::ReadAt(time1), &[1])
622            .unwrap());
623        assert!(db.insert(&repo, &update2, time2).unwrap());
624
625        let mut notifs = db.by_repo(&repo, "timestamp").unwrap();
626
627        assert_eq!(
628            notifs.next().unwrap().unwrap(),
629            Notification {
630                id: 1,
631                repo,
632                remote: None,
633                qualified: qualified!("refs/heads/master"),
634                update: update2,
635                kind: NotificationKind::Branch {
636                    name: refname!("master")
637                },
638                // Status is reset to "unread".
639                status: NotificationStatus::Unread,
640                timestamp: time2,
641            }
642        );
643        assert!(notifs.next().is_none());
644    }
645
646    #[test]
647    fn test_cob_notifications() {
648        let repo = arbitrary::gen::<RepoId>(1);
649        let old = arbitrary::oid();
650        let new = arbitrary::oid();
651        let timestamp = LocalTime::from_millis(32189874);
652        let nid: NodeId = "z6MknSLrJoTcukLrE435hVNQT4JUhbvWLX4kUzqkEStBU8Vi"
653            .parse()
654            .unwrap();
655        let mut db = Store::open(":memory:").unwrap();
656        let qualified =
657            qualified!("refs/cobs/xyz.radicle.issue/d87dcfe8c2b3200e78b128d9b959cfdf7063fefe");
658        let namespaced = qualified.with_namespace((&nid).into());
659        let update = RefUpdate::Updated {
660            name: namespaced.to_ref_string(),
661            old,
662            new,
663        };
664
665        assert!(db.insert(&repo, &update, timestamp).unwrap());
666
667        let mut notifs = db.by_repo(&repo, "timestamp").unwrap();
668
669        assert_eq!(
670            notifs.next().unwrap().unwrap(),
671            Notification {
672                id: 1,
673                repo,
674                remote: Some(
675                    "z6MknSLrJoTcukLrE435hVNQT4JUhbvWLX4kUzqkEStBU8Vi"
676                        .parse()
677                        .unwrap()
678                ),
679                qualified,
680                update,
681                kind: NotificationKind::Cob {
682                    typed_id: cob::TypedId {
683                        type_name: cob::issue::TYPENAME.clone(),
684                        id: "d87dcfe8c2b3200e78b128d9b959cfdf7063fefe".parse().unwrap(),
685                    },
686                },
687                status: NotificationStatus::Unread,
688                timestamp,
689            }
690        );
691        assert!(notifs.next().is_none());
692    }
693}