Skip to main content

hashtree_cli/socialgraph/
access.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3
4use super::{SocialGraphBackend, SocialGraphStats};
5
6#[derive(Clone)]
7pub struct SocialGraphAccessControl {
8    store: Arc<dyn SocialGraphBackend>,
9    max_write_distance: u32,
10    allowed_pubkeys: HashSet<String>,
11}
12
13impl SocialGraphAccessControl {
14    pub fn new(
15        store: Arc<dyn SocialGraphBackend>,
16        max_write_distance: u32,
17        allowed_pubkeys: HashSet<String>,
18    ) -> Self {
19        Self {
20            store,
21            max_write_distance,
22            allowed_pubkeys,
23        }
24    }
25
26    pub fn check_write_access(&self, pubkey_hex: &str) -> bool {
27        if self.allowed_pubkeys.contains(pubkey_hex) {
28            return true;
29        }
30
31        let Ok(pk_bytes) = hex::decode(pubkey_hex) else {
32            return false;
33        };
34        let Ok(pk) = <[u8; 32]>::try_from(pk_bytes.as_slice()) else {
35            return false;
36        };
37
38        super::get_follow_distance(self.store.as_ref(), &pk)
39            .map(|distance| distance <= self.max_write_distance)
40            .unwrap_or(false)
41    }
42
43    pub fn stats(&self) -> SocialGraphStats {
44        self.store.stats().unwrap_or_else(|_| SocialGraphStats {
45            enabled: true,
46            max_depth: self.max_write_distance,
47            ..Default::default()
48        })
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55    use tempfile::TempDir;
56
57    #[test]
58    fn test_allowed_pubkey_passes() {
59        let _guard = crate::socialgraph::test_lock();
60        let tmp = TempDir::new().unwrap();
61        let graph_store = crate::socialgraph::open_social_graph_store(tmp.path()).unwrap();
62        let pk_hex = "aa".repeat(32);
63        let mut allowed = HashSet::new();
64        allowed.insert(pk_hex.clone());
65
66        let access = SocialGraphAccessControl::new(graph_store, 1, allowed);
67        assert!(access.check_write_access(&pk_hex));
68    }
69}