Skip to main content

hashtree_cli/nostrdb_integration/
access.rs

1//! Social graph-based write access control using nostrdb.
2
3use nostrdb_social::Ndb;
4use std::collections::HashSet;
5use std::sync::Arc;
6
7use super::SocialGraphStats;
8
9/// Access control that combines allowed_pubkeys with social graph follow distance.
10#[derive(Clone)]
11pub struct SocialGraphAccessControl {
12    ndb: Arc<Ndb>,
13    max_write_distance: u32,
14    allowed_pubkeys: HashSet<String>,
15}
16
17impl SocialGraphAccessControl {
18    pub fn new(ndb: Arc<Ndb>, max_write_distance: u32, allowed_pubkeys: HashSet<String>) -> Self {
19        Self {
20            ndb,
21            max_write_distance,
22            allowed_pubkeys,
23        }
24    }
25
26    /// Check if a pubkey (hex) has write access.
27    /// Returns true if:
28    /// 1. The pubkey is in the allowed_pubkeys set, OR
29    /// 2. The pubkey's follow distance from the root is <= max_write_distance
30    pub fn check_write_access(&self, pubkey_hex: &str) -> bool {
31        if self.allowed_pubkeys.contains(pubkey_hex) {
32            return true;
33        }
34
35        if let Ok(pk_bytes) = hex::decode(pubkey_hex) {
36            if pk_bytes.len() == 32 {
37                let pk: [u8; 32] = pk_bytes.try_into().unwrap();
38                if let Some(distance) = super::get_follow_distance(&self.ndb, &pk) {
39                    return distance <= self.max_write_distance;
40                }
41            }
42        }
43
44        false
45    }
46
47    pub fn stats(&self) -> SocialGraphStats {
48        SocialGraphStats {
49            root: None,
50            total_follows: 0,
51            max_depth: self.max_write_distance,
52            enabled: true,
53        }
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use tempfile::TempDir;
61
62    fn setup() -> (TempDir, Arc<Ndb>) {
63        let tmp = TempDir::new().unwrap();
64        let ndb = super::super::init_ndb(tmp.path()).unwrap();
65        (tmp, ndb)
66    }
67
68    #[test]
69    fn test_allowed_pubkey_passes() {
70        let _guard = super::super::test_lock();
71        let (_tmp, ndb) = setup();
72        let pk_hex = "aa".repeat(32);
73        let mut allowed = HashSet::new();
74        allowed.insert(pk_hex.clone());
75
76        let ac = SocialGraphAccessControl::new(ndb, 3, allowed);
77        assert!(ac.check_write_access(&pk_hex));
78    }
79
80    #[test]
81    fn test_unknown_pubkey_denied() {
82        let _guard = super::super::test_lock();
83        let (_tmp, ndb) = setup();
84        let root_pk = [1u8; 32];
85        super::super::set_social_graph_root(&ndb, &root_pk);
86        std::thread::sleep(std::time::Duration::from_millis(100));
87
88        let ac = SocialGraphAccessControl::new(ndb, 3, HashSet::new());
89        let unknown = "bb".repeat(32);
90        assert!(!ac.check_write_access(&unknown));
91    }
92
93    #[test]
94    fn test_root_pubkey_within_distance() {
95        let _guard = super::super::test_lock();
96        let (_tmp, ndb) = setup();
97        let root_pk = [1u8; 32];
98        super::super::set_social_graph_root(&ndb, &root_pk);
99        std::thread::sleep(std::time::Duration::from_millis(100));
100
101        let ac = SocialGraphAccessControl::new(ndb, 3, HashSet::new());
102        let root_hex = hex::encode(root_pk);
103        assert!(ac.check_write_access(&root_hex));
104    }
105
106    #[test]
107    fn test_stats_enabled() {
108        let _guard = super::super::test_lock();
109        let (_tmp, ndb) = setup();
110        let ac = SocialGraphAccessControl::new(ndb, 3, HashSet::new());
111        let stats = ac.stats();
112        assert!(stats.enabled);
113    }
114}