Skip to main content

hashtree_cli/nostrdb_integration/
access.rs

1//! Social graph-based write access control using nostrdb.
2
3use nostrdb_social::Ndb;
4use std::collections::HashSet;
5use std::sync::Arc;
6
7use super::SocialGraphStats;
8
9/// Access control that combines allowed_pubkeys with social graph follow distance.
10#[derive(Clone)]
11pub struct SocialGraphAccessControl {
12    ndb: Arc<Ndb>,
13    max_write_distance: u32,
14    allowed_pubkeys: HashSet<String>,
15}
16
17impl SocialGraphAccessControl {
18    pub fn new(
19        ndb: Arc<Ndb>,
20        max_write_distance: u32,
21        allowed_pubkeys: HashSet<String>,
22    ) -> Self {
23        Self {
24            ndb,
25            max_write_distance,
26            allowed_pubkeys,
27        }
28    }
29
30    /// Check if a pubkey (hex) has write access.
31    /// Returns true if:
32    /// 1. The pubkey is in the allowed_pubkeys set, OR
33    /// 2. The pubkey's follow distance from the root is <= max_write_distance
34    pub fn check_write_access(&self, pubkey_hex: &str) -> bool {
35        if self.allowed_pubkeys.contains(pubkey_hex) {
36            return true;
37        }
38
39        if let Ok(pk_bytes) = hex::decode(pubkey_hex) {
40            if pk_bytes.len() == 32 {
41                let pk: [u8; 32] = pk_bytes.try_into().unwrap();
42                if let Some(distance) = super::get_follow_distance(&self.ndb, &pk) {
43                    return distance <= self.max_write_distance;
44                }
45            }
46        }
47
48        false
49    }
50
51    pub fn stats(&self) -> SocialGraphStats {
52        SocialGraphStats {
53            root: None,
54            total_follows: 0,
55            max_depth: self.max_write_distance,
56            enabled: true,
57        }
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use tempfile::TempDir;
65
66    fn setup() -> (TempDir, Arc<Ndb>) {
67        let tmp = TempDir::new().unwrap();
68        let ndb = super::super::init_ndb(tmp.path()).unwrap();
69        (tmp, ndb)
70    }
71
72    #[test]
73    fn test_allowed_pubkey_passes() {
74        let _guard = super::super::test_lock();
75        let (_tmp, ndb) = setup();
76        let pk_hex = "aa".repeat(32);
77        let mut allowed = HashSet::new();
78        allowed.insert(pk_hex.clone());
79
80        let ac = SocialGraphAccessControl::new(ndb, 3, allowed);
81        assert!(ac.check_write_access(&pk_hex));
82    }
83
84    #[test]
85    fn test_unknown_pubkey_denied() {
86        let _guard = super::super::test_lock();
87        let (_tmp, ndb) = setup();
88        let root_pk = [1u8; 32];
89        super::super::set_social_graph_root(&ndb, &root_pk);
90        std::thread::sleep(std::time::Duration::from_millis(100));
91
92        let ac = SocialGraphAccessControl::new(ndb, 3, HashSet::new());
93        let unknown = "bb".repeat(32);
94        assert!(!ac.check_write_access(&unknown));
95    }
96
97    #[test]
98    fn test_root_pubkey_within_distance() {
99        let _guard = super::super::test_lock();
100        let (_tmp, ndb) = setup();
101        let root_pk = [1u8; 32];
102        super::super::set_social_graph_root(&ndb, &root_pk);
103        std::thread::sleep(std::time::Duration::from_millis(100));
104
105        let ac = SocialGraphAccessControl::new(ndb, 3, HashSet::new());
106        let root_hex = hex::encode(root_pk);
107        assert!(ac.check_write_access(&root_hex));
108    }
109
110    #[test]
111    fn test_stats_enabled() {
112        let _guard = super::super::test_lock();
113        let (_tmp, ndb) = setup();
114        let ac = SocialGraphAccessControl::new(ndb, 3, HashSet::new());
115        let stats = ac.stats();
116        assert!(stats.enabled);
117    }
118}