1#![allow(clippy::type_complexity)]
2use std::collections::{BTreeMap, BTreeSet};
3use std::marker::PhantomData;
4use std::path::Path;
5use std::{fmt, io, ops::Not as _, str::FromStr, time};
6
7use sqlite as sql;
8use thiserror::Error;
9
10use crate::node::{Alias, AliasStore};
11use crate::prelude::{NodeId, RepoId};
12
13use super::{FollowPolicy, Policy, Scope, SeedPolicy, SeedingPolicy};
14
15const DB_READ_TIMEOUT: time::Duration = time::Duration::from_secs(3);
17const DB_WRITE_TIMEOUT: time::Duration = time::Duration::from_secs(6);
19
20#[derive(Error, Debug)]
21pub enum Error {
22 #[error("i/o error: {0}")]
24 Io(#[from] io::Error),
25 #[error("internal error: {0}")]
27 Internal(#[from] sql::Error),
28}
29
30pub struct Read;
32pub struct Write;
34
35pub type StoreReader = Store<Read>;
37pub type StoreWriter = Store<Write>;
39
40pub struct Store<T> {
42 db: sql::Connection,
43 _marker: PhantomData<T>,
44}
45
46impl<T> fmt::Debug for Store<T> {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 write!(f, "Store(..)")
49 }
50}
51
52impl Store<Read> {
53 const SCHEMA: &'static str = include_str!("schema.sql");
54
55 pub fn reader<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
58 let mut db =
59 sql::Connection::open_with_flags(path, sqlite::OpenFlags::new().with_read_only())?;
60 db.set_busy_timeout(DB_READ_TIMEOUT.as_millis() as usize)?;
61 db.execute(Self::SCHEMA)?;
62
63 Ok(Self {
64 db,
65 _marker: PhantomData,
66 })
67 }
68
69 pub fn memory() -> Result<Self, Error> {
71 let db = sql::Connection::open_with_flags(
72 ":memory:",
73 sqlite::OpenFlags::new().with_read_only(),
74 )?;
75 db.execute(Self::SCHEMA)?;
76
77 Ok(Self {
78 db,
79 _marker: PhantomData,
80 })
81 }
82}
83
84impl Store<Write> {
85 const SCHEMA: &'static str = include_str!("schema.sql");
86
87 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
90 let mut db = sql::Connection::open(path)?;
91 db.set_busy_timeout(DB_WRITE_TIMEOUT.as_millis() as usize)?;
92 db.execute(Self::SCHEMA)?;
93
94 Ok(Self {
95 db,
96 _marker: PhantomData,
97 })
98 }
99
100 pub fn memory() -> Result<Self, Error> {
102 let db = sql::Connection::open(":memory:")?;
103 db.execute(Self::SCHEMA)?;
104
105 Ok(Self {
106 db,
107 _marker: PhantomData,
108 })
109 }
110
111 pub fn read_only(self) -> StoreReader {
113 Store {
114 db: self.db,
115 _marker: PhantomData,
116 }
117 }
118
119 pub fn follow(&mut self, id: &NodeId, alias: Option<&Alias>) -> Result<bool, Error> {
121 let mut stmt = self.db.prepare(
122 "INSERT INTO `following` (id, alias)
123 VALUES (?1, ?2)
124 ON CONFLICT DO UPDATE
125 SET alias = ?2 WHERE alias != ?2",
126 )?;
127
128 stmt.bind((1, id))?;
129 stmt.bind((2, alias.map_or("", |alias| alias.as_str())))?;
130 stmt.next()?;
131
132 Ok(self.db.change_count() > 0)
133 }
134
135 pub fn seed(&mut self, id: &RepoId, scope: Scope) -> Result<bool, Error> {
137 let mut stmt = self.db.prepare(
138 "INSERT INTO `seeding` (id, scope)
139 VALUES (?1, ?2)
140 ON CONFLICT DO UPDATE
141 SET scope = ?2 WHERE scope != ?2",
142 )?;
143
144 stmt.bind((1, id))?;
145 stmt.bind((2, scope))?;
146 stmt.next()?;
147
148 Ok(self.db.change_count() > 0)
149 }
150
151 pub fn set_follow_policy(&mut self, id: &NodeId, policy: Policy) -> Result<bool, Error> {
153 let mut stmt = self.db.prepare(
154 "INSERT INTO `following` (id, policy)
155 VALUES (?1, ?2)
156 ON CONFLICT DO UPDATE
157 SET policy = ?2 WHERE policy != ?2",
158 )?;
159
160 stmt.bind((1, id))?;
161 stmt.bind((2, policy))?;
162 stmt.next()?;
163
164 Ok(self.db.change_count() > 0)
165 }
166
167 pub fn set_seed_policy(&mut self, id: &RepoId, policy: Policy) -> Result<bool, Error> {
169 let mut stmt = self.db.prepare(
170 "INSERT INTO `seeding` (id, policy)
171 VALUES (?1, ?2)
172 ON CONFLICT DO UPDATE
173 SET policy = ?2 WHERE policy != ?2",
174 )?;
175
176 stmt.bind((1, id))?;
177 stmt.bind((2, policy))?;
178 stmt.next()?;
179
180 Ok(self.db.change_count() > 0)
181 }
182
183 pub fn unfollow(&mut self, id: &NodeId) -> Result<bool, Error> {
185 let mut stmt = self.db.prepare("DELETE FROM `following` WHERE id = ?")?;
186
187 stmt.bind((1, id))?;
188 stmt.next()?;
189
190 Ok(self.db.change_count() > 0)
191 }
192
193 pub fn unseed(&mut self, id: &RepoId) -> Result<bool, Error> {
195 let mut stmt = self.db.prepare("DELETE FROM `seeding` WHERE id = ?")?;
196
197 stmt.bind((1, id))?;
198 stmt.next()?;
199
200 Ok(self.db.change_count() > 0)
201 }
202
203 pub fn unblock_rid(&mut self, id: &RepoId) -> Result<bool, Error> {
205 let mut stmt = self
206 .db
207 .prepare("DELETE FROM `seeding` WHERE id = ? AND policy = 'block'")?;
208
209 stmt.bind((1, id))?;
210 stmt.next()?;
211
212 Ok(self.db.change_count() > 0)
213 }
214
215 pub fn unblock_nid(&mut self, id: &NodeId) -> Result<bool, Error> {
217 let mut stmt = self
218 .db
219 .prepare("DELETE FROM `following` WHERE id = ? AND policy = 'block'")?;
220
221 stmt.bind((1, id))?;
222 stmt.next()?;
223
224 Ok(self.db.change_count() > 0)
225 }
226}
227
228impl<T> Store<T> {
231 pub fn is_following(&self, id: &NodeId) -> Result<bool, Error> {
233 Ok(matches!(
234 self.follow_policy(id)?,
235 Some(FollowPolicy {
236 policy: Policy::Allow,
237 ..
238 })
239 ))
240 }
241
242 pub fn is_seeding(&self, id: &RepoId) -> Result<bool, Error> {
244 Ok(matches!(
245 self.seed_policy(id)?,
246 Some(SeedPolicy { policy, .. })
247 if policy.is_allow()
248 ))
249 }
250
251 pub fn follow_policy(&self, id: &NodeId) -> Result<Option<FollowPolicy>, Error> {
253 let mut stmt = self
254 .db
255 .prepare("SELECT alias, policy FROM `following` WHERE id = ?")?;
256
257 stmt.bind((1, id))?;
258
259 if let Some(Ok(row)) = stmt.into_iter().next() {
260 let alias = row.try_read::<&str, _>("alias")?;
261 let alias = alias
262 .is_empty()
263 .not()
264 .then_some(alias.to_owned())
265 .and_then(|s| Alias::from_str(&s).ok());
266 let policy = row.try_read::<Policy, _>("policy")?;
267
268 return Ok(Some(FollowPolicy {
269 nid: *id,
270 alias,
271 policy,
272 }));
273 }
274 Ok(None)
275 }
276
277 pub fn seed_policy(&self, id: &RepoId) -> Result<Option<SeedPolicy>, Error> {
279 let mut stmt = self
280 .db
281 .prepare("SELECT scope, policy FROM `seeding` WHERE id = ?")?;
282
283 stmt.bind((1, id))?;
284
285 if let Some(Ok(row)) = stmt.into_iter().next() {
286 let policy = match row.try_read::<Policy, _>("policy")? {
287 Policy::Allow => SeedingPolicy::Allow {
288 scope: row.try_read::<Scope, _>("scope")?,
289 },
290 Policy::Block => SeedingPolicy::Block,
291 };
292 return Ok(Some(SeedPolicy { rid: *id, policy }));
293 }
294 Ok(None)
295 }
296
297 pub fn follow_policies(&self) -> Result<FollowPolicies<'_>, Error> {
299 let stmt = self
300 .db
301 .prepare("SELECT id, alias, policy FROM `following`")?;
302 Ok(FollowPolicies {
303 inner: stmt.into_iter(),
304 })
305 }
306
307 pub fn seed_policies(&self) -> Result<SeedPolicies<'_>, Error> {
309 let stmt = self.db.prepare("SELECT id, scope, policy FROM `seeding`")?;
310 Ok(SeedPolicies {
311 inner: stmt.into_iter(),
312 })
313 }
314
315 pub fn nodes_by_alias<'a>(&'a self, alias: &Alias) -> Result<NodeAliasIter<'a>, Error> {
316 let mut stmt = self
317 .db
318 .prepare("SELECT id, alias FROM `following` WHERE UPPER(alias) LIKE ?")?;
319 let query = format!("%{}%", alias.as_str().to_uppercase());
320 stmt.bind((1, sql::Value::String(query)))?;
321 Ok(NodeAliasIter {
322 inner: stmt.into_iter(),
323 })
324 }
325}
326
327pub struct FollowPolicies<'a> {
328 inner: sql::CursorWithOwnership<'a>,
329}
330
331impl Iterator for FollowPolicies<'_> {
332 type Item = Result<FollowPolicy, Error>;
333
334 fn next(&mut self) -> Option<Self::Item> {
335 let row = self.inner.next()?;
336 let Ok(row) = row else { return self.next() };
337
338 let id = match row.try_read("id") {
339 Ok(id) => id,
340 Err(err) => return Some(Err(err.into())),
341 };
342
343 let alias = match row.try_read::<&str, _>("alias") {
344 Ok(alias) => alias.to_owned(),
345 Err(err) => return Some(Err(err.into())),
346 };
347
348 let alias = alias
349 .is_empty()
350 .not()
351 .then_some(alias.to_owned())
352 .and_then(|s| Alias::from_str(&s).ok());
353
354 let policy = match row.try_read::<Policy, _>("policy") {
355 Ok(policy) => policy,
356 Err(err) => return Some(Err(err.into())),
357 };
358
359 Some(Ok(FollowPolicy {
360 nid: id,
361 alias,
362 policy,
363 }))
364 }
365}
366
367pub struct SeedPolicies<'a> {
368 inner: sql::CursorWithOwnership<'a>,
369}
370
371impl Iterator for SeedPolicies<'_> {
372 type Item = Result<SeedPolicy, Error>;
373
374 fn next(&mut self) -> Option<Self::Item> {
375 let row = self.inner.next()?;
376 let Ok(row) = row else { return self.next() };
377
378 let id = match row.try_read("id") {
379 Ok(id) => id,
380 Err(err) => return Some(Err(err.into())),
381 };
382
383 let policy = match row.try_read::<Policy, _>("policy") {
384 Ok(policy) => policy,
385 Err(err) => return Some(Err(err.into())),
386 };
387
388 match policy {
389 Policy::Allow => match row.try_read::<Scope, _>("scope") {
390 Ok(scope) => Some(Ok(SeedPolicy {
391 rid: id,
392 policy: SeedingPolicy::Allow { scope },
393 })),
394 Err(err) => Some(Err(err.into())),
395 },
396 Policy::Block => Some(Ok(SeedPolicy {
397 rid: id,
398 policy: SeedingPolicy::Block,
399 })),
400 }
401 }
402}
403
404pub struct NodeAliasIter<'a> {
405 inner: sql::CursorWithOwnership<'a>,
406}
407
408impl NodeAliasIter<'_> {
409 fn parse_row(row: sql::Row) -> Result<(NodeId, Alias), Error> {
410 let nid = row.try_read::<NodeId, _>("id")?;
411 let alias = row.try_read::<Alias, _>("alias")?;
412 Ok((nid, alias))
413 }
414}
415
416impl Iterator for NodeAliasIter<'_> {
417 type Item = Result<(NodeId, Alias), Error>;
418
419 fn next(&mut self) -> Option<Self::Item> {
420 let row = self.inner.next()?;
421 Some(row.map_err(Error::from).and_then(Self::parse_row))
422 }
423}
424
425impl<T> AliasStore for Store<T> {
426 fn alias(&self, nid: &NodeId) -> Option<Alias> {
429 self.follow_policy(nid)
430 .map(|node| node.and_then(|n| n.alias))
431 .unwrap_or(None)
432 }
433
434 fn reverse_lookup(&self, alias: &Alias) -> BTreeMap<Alias, BTreeSet<NodeId>> {
435 let Ok(iter) = self.nodes_by_alias(alias) else {
436 return BTreeMap::new();
437 };
438 iter.flatten()
439 .fold(BTreeMap::new(), |mut result, (node, alias)| {
440 let nodes = result.entry(alias).or_default();
441 nodes.insert(node);
442 result
443 })
444 }
445}
446
447#[cfg(test)]
448#[allow(clippy::unwrap_used)]
449mod test {
450 use crate::{assert_matches, node};
451
452 use super::*;
453 use crate::test::arbitrary;
454
455 #[test]
456 fn test_follow_and_unfollow_node() {
457 let id = arbitrary::gen::<NodeId>(1);
458 let mut db = Store::open(":memory:").unwrap();
459 let eve = Alias::new("eve");
460
461 assert!(db.follow(&id, Some(&eve)).unwrap());
462 assert!(db.is_following(&id).unwrap());
463 assert!(!db.follow(&id, Some(&eve)).unwrap());
464 assert!(db.unfollow(&id).unwrap());
465 assert!(!db.is_following(&id).unwrap());
466 }
467
468 #[test]
469 fn test_seed_and_unseed_repo() {
470 let id = arbitrary::gen::<RepoId>(1);
471 let mut db = Store::open(":memory:").unwrap();
472
473 assert!(db.seed(&id, Scope::All).unwrap());
474 assert!(db.is_seeding(&id).unwrap());
475 assert!(!db.seed(&id, Scope::All).unwrap());
476 assert!(db.unseed(&id).unwrap());
477 assert!(!db.is_seeding(&id).unwrap());
478 }
479
480 #[test]
481 fn test_node_policies() {
482 let ids = arbitrary::vec::<NodeId>(3);
483 let mut db = Store::open(":memory:").unwrap();
484
485 for id in &ids {
486 assert!(db.follow(id, None).unwrap());
487 }
488 let mut entries = db.follow_policies().unwrap();
489 assert_matches!(entries.next(), Some(Ok(FollowPolicy { nid, .. })) if nid == ids[0]);
490 assert_matches!(entries.next(), Some(Ok(FollowPolicy { nid, .. })) if nid == ids[1]);
491 assert_matches!(entries.next(), Some(Ok(FollowPolicy { nid, .. })) if nid == ids[2]);
492 }
493
494 #[test]
495 fn test_repo_policies() {
496 let ids = arbitrary::vec::<RepoId>(3);
497 let mut db = Store::open(":memory:").unwrap();
498
499 for id in &ids {
500 assert!(db.seed(id, Scope::All).unwrap());
501 }
502 let mut entries = db.seed_policies().unwrap();
503 assert_matches!(entries.next(), Some(Ok(SeedPolicy { rid, .. })) if rid == ids[0]);
504 assert_matches!(entries.next(), Some(Ok(SeedPolicy { rid, .. })) if rid == ids[1]);
505 assert_matches!(entries.next(), Some(Ok(SeedPolicy { rid, .. })) if rid == ids[2]);
506 }
507
508 #[test]
509 fn test_update_alias() {
510 let id = arbitrary::gen::<NodeId>(1);
511 let mut db = Store::open(":memory:").unwrap();
512
513 assert!(db.follow(&id, Some(&Alias::new("eve"))).unwrap());
514 assert_eq!(
515 db.follow_policy(&id).unwrap().unwrap().alias,
516 Some(Alias::from_str("eve").unwrap())
517 );
518 assert!(db.follow(&id, None).unwrap());
519 assert_eq!(db.follow_policy(&id).unwrap().unwrap().alias, None);
520 assert!(!db.follow(&id, None).unwrap());
521 assert!(db.follow(&id, Some(&Alias::new("alice"))).unwrap());
522 assert_eq!(
523 db.follow_policy(&id).unwrap().unwrap().alias,
524 Some(Alias::new("alice"))
525 );
526 }
527
528 #[test]
529 fn test_update_scope() {
530 let id = arbitrary::gen::<RepoId>(1);
531 let mut db = Store::open(":memory:").unwrap();
532
533 assert!(db.seed(&id, Scope::All).unwrap());
534 assert_eq!(
535 db.seed_policy(&id).unwrap().unwrap().scope(),
536 Some(Scope::All)
537 );
538 assert!(db.seed(&id, Scope::Followed).unwrap());
539 assert_eq!(
540 db.seed_policy(&id).unwrap().unwrap().scope(),
541 Some(Scope::Followed)
542 );
543 }
544
545 #[test]
546 fn test_repo_policy() {
547 let id = arbitrary::gen::<RepoId>(1);
548 let mut db = Store::open(":memory:").unwrap();
549
550 assert!(db.seed(&id, Scope::All).unwrap());
551 assert!(db.seed_policy(&id).unwrap().unwrap().is_allow());
552 assert!(db.set_seed_policy(&id, Policy::Block).unwrap());
553 assert!(!db.seed_policy(&id).unwrap().unwrap().is_allow());
554 assert_eq!(db.seed_policy(&id).unwrap().unwrap().scope(), None);
555 }
556
557 #[test]
558 fn test_node_policy() {
559 let id = arbitrary::gen::<NodeId>(1);
560 let mut db = Store::open(":memory:").unwrap();
561
562 assert!(db.follow(&id, None).unwrap());
563 assert_eq!(
564 db.follow_policy(&id).unwrap().unwrap().policy,
565 Policy::Allow
566 );
567 assert!(db.set_follow_policy(&id, Policy::Block).unwrap());
568 assert_eq!(
569 db.follow_policy(&id).unwrap().unwrap().policy,
570 Policy::Block
571 );
572 }
573
574 #[test]
575 fn test_node_aliases() {
576 let mut db = Store::open(":memory:").unwrap();
577 let input = node::properties::AliasInput::new();
578 let (short, short_ids) = input.short();
579 let (long, long_ids) = input.long();
580
581 for id in short_ids {
582 db.follow(id, Some(short)).unwrap();
583 }
584
585 for id in long_ids {
586 db.follow(id, Some(long)).unwrap();
587 }
588
589 node::properties::test_reverse_lookup(&db, input)
590 }
591}