hashtree_cli/nostrdb_integration/
access.rs1use nostrdb_social::Ndb;
4use std::collections::HashSet;
5use std::sync::Arc;
6
7use super::SocialGraphStats;
8
9#[derive(Clone)]
11pub struct SocialGraphAccessControl {
12 ndb: Arc<Ndb>,
13 max_write_distance: u32,
14 allowed_pubkeys: HashSet<String>,
15}
16
17impl SocialGraphAccessControl {
18 pub fn new(ndb: Arc<Ndb>, max_write_distance: u32, allowed_pubkeys: HashSet<String>) -> Self {
19 Self {
20 ndb,
21 max_write_distance,
22 allowed_pubkeys,
23 }
24 }
25
26 pub fn check_write_access(&self, pubkey_hex: &str) -> bool {
31 if self.allowed_pubkeys.contains(pubkey_hex) {
32 return true;
33 }
34
35 if let Ok(pk_bytes) = hex::decode(pubkey_hex) {
36 if pk_bytes.len() == 32 {
37 let pk: [u8; 32] = pk_bytes.try_into().unwrap();
38 if let Some(distance) = super::get_follow_distance(&self.ndb, &pk) {
39 return distance <= self.max_write_distance;
40 }
41 }
42 }
43
44 false
45 }
46
47 pub fn stats(&self) -> SocialGraphStats {
48 SocialGraphStats {
49 root: None,
50 total_follows: 0,
51 max_depth: self.max_write_distance,
52 enabled: true,
53 }
54 }
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60 use tempfile::TempDir;
61
62 fn setup() -> (TempDir, Arc<Ndb>) {
63 let tmp = TempDir::new().unwrap();
64 let ndb = super::super::init_ndb(tmp.path()).unwrap();
65 (tmp, ndb)
66 }
67
68 #[test]
69 fn test_allowed_pubkey_passes() {
70 let _guard = super::super::test_lock();
71 let (_tmp, ndb) = setup();
72 let pk_hex = "aa".repeat(32);
73 let mut allowed = HashSet::new();
74 allowed.insert(pk_hex.clone());
75
76 let ac = SocialGraphAccessControl::new(ndb, 3, allowed);
77 assert!(ac.check_write_access(&pk_hex));
78 }
79
80 #[test]
81 fn test_unknown_pubkey_denied() {
82 let _guard = super::super::test_lock();
83 let (_tmp, ndb) = setup();
84 let root_pk = [1u8; 32];
85 super::super::set_social_graph_root(&ndb, &root_pk);
86 std::thread::sleep(std::time::Duration::from_millis(100));
87
88 let ac = SocialGraphAccessControl::new(ndb, 3, HashSet::new());
89 let unknown = "bb".repeat(32);
90 assert!(!ac.check_write_access(&unknown));
91 }
92
93 #[test]
94 fn test_root_pubkey_within_distance() {
95 let _guard = super::super::test_lock();
96 let (_tmp, ndb) = setup();
97 let root_pk = [1u8; 32];
98 super::super::set_social_graph_root(&ndb, &root_pk);
99 std::thread::sleep(std::time::Duration::from_millis(100));
100
101 let ac = SocialGraphAccessControl::new(ndb, 3, HashSet::new());
102 let root_hex = hex::encode(root_pk);
103 assert!(ac.check_write_access(&root_hex));
104 }
105
106 #[test]
107 fn test_stats_enabled() {
108 let _guard = super::super::test_lock();
109 let (_tmp, ndb) = setup();
110 let ac = SocialGraphAccessControl::new(ndb, 3, HashSet::new());
111 let stats = ac.stats();
112 assert!(stats.enabled);
113 }
114}