hashtree_cli/nostrdb_integration/
access.rs1use nostrdb_social::Ndb;
4use std::collections::HashSet;
5use std::sync::Arc;
6
7use super::SocialGraphStats;
8
9#[derive(Clone)]
11pub struct SocialGraphAccessControl {
12 ndb: Arc<Ndb>,
13 max_write_distance: u32,
14 allowed_pubkeys: HashSet<String>,
15}
16
17impl SocialGraphAccessControl {
18 pub fn new(
19 ndb: Arc<Ndb>,
20 max_write_distance: u32,
21 allowed_pubkeys: HashSet<String>,
22 ) -> Self {
23 Self {
24 ndb,
25 max_write_distance,
26 allowed_pubkeys,
27 }
28 }
29
30 pub fn check_write_access(&self, pubkey_hex: &str) -> bool {
35 if self.allowed_pubkeys.contains(pubkey_hex) {
36 return true;
37 }
38
39 if let Ok(pk_bytes) = hex::decode(pubkey_hex) {
40 if pk_bytes.len() == 32 {
41 let pk: [u8; 32] = pk_bytes.try_into().unwrap();
42 if let Some(distance) = super::get_follow_distance(&self.ndb, &pk) {
43 return distance <= self.max_write_distance;
44 }
45 }
46 }
47
48 false
49 }
50
51 pub fn stats(&self) -> SocialGraphStats {
52 SocialGraphStats {
53 root: None,
54 total_follows: 0,
55 max_depth: self.max_write_distance,
56 enabled: true,
57 }
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64 use tempfile::TempDir;
65
66 fn setup() -> (TempDir, Arc<Ndb>) {
67 let tmp = TempDir::new().unwrap();
68 let ndb = super::super::init_ndb(tmp.path()).unwrap();
69 (tmp, ndb)
70 }
71
72 #[test]
73 fn test_allowed_pubkey_passes() {
74 let _guard = super::super::test_lock();
75 let (_tmp, ndb) = setup();
76 let pk_hex = "aa".repeat(32);
77 let mut allowed = HashSet::new();
78 allowed.insert(pk_hex.clone());
79
80 let ac = SocialGraphAccessControl::new(ndb, 3, allowed);
81 assert!(ac.check_write_access(&pk_hex));
82 }
83
84 #[test]
85 fn test_unknown_pubkey_denied() {
86 let _guard = super::super::test_lock();
87 let (_tmp, ndb) = setup();
88 let root_pk = [1u8; 32];
89 super::super::set_social_graph_root(&ndb, &root_pk);
90 std::thread::sleep(std::time::Duration::from_millis(100));
91
92 let ac = SocialGraphAccessControl::new(ndb, 3, HashSet::new());
93 let unknown = "bb".repeat(32);
94 assert!(!ac.check_write_access(&unknown));
95 }
96
97 #[test]
98 fn test_root_pubkey_within_distance() {
99 let _guard = super::super::test_lock();
100 let (_tmp, ndb) = setup();
101 let root_pk = [1u8; 32];
102 super::super::set_social_graph_root(&ndb, &root_pk);
103 std::thread::sleep(std::time::Duration::from_millis(100));
104
105 let ac = SocialGraphAccessControl::new(ndb, 3, HashSet::new());
106 let root_hex = hex::encode(root_pk);
107 assert!(ac.check_write_access(&root_hex));
108 }
109
110 #[test]
111 fn test_stats_enabled() {
112 let _guard = super::super::test_lock();
113 let (_tmp, ndb) = setup();
114 let ac = SocialGraphAccessControl::new(ndb, 3, HashSet::new());
115 let stats = ac.stats();
116 assert!(stats.enabled);
117 }
118}