halldyll_core/security/
allowlist.rs

1//! Allowlist - List of allowed/blocked domains
2
3use regex::Regex;
4use std::collections::HashSet;
5use url::Url;
6
7/// List of allowed domains
8pub struct DomainAllowlist {
9    /// Allowed domains (empty = all)
10    allowed: HashSet<String>,
11    /// Blocked domains
12    blocked: HashSet<String>,
13    /// Allowed regex patterns
14    allowed_patterns: Vec<Regex>,
15    /// Blocked regex patterns
16    blocked_patterns: Vec<Regex>,
17}
18
19impl Default for DomainAllowlist {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl DomainAllowlist {
26    /// New empty list (all allowed by default)
27    pub fn new() -> Self {
28        Self {
29            allowed: HashSet::new(),
30            blocked: HashSet::new(),
31            allowed_patterns: Vec::new(),
32            blocked_patterns: Vec::new(),
33        }
34    }
35
36    /// Add an allowed domain
37    pub fn allow_domain(&mut self, domain: &str) {
38        self.allowed.insert(domain.to_lowercase());
39    }
40
41    /// Add multiple allowed domains
42    pub fn allow_domains(&mut self, domains: &[&str]) {
43        for domain in domains {
44            self.allow_domain(domain);
45        }
46    }
47
48    /// Add a blocked domain
49    pub fn block_domain(&mut self, domain: &str) {
50        self.blocked.insert(domain.to_lowercase());
51    }
52
53    /// Add multiple blocked domains
54    pub fn block_domains(&mut self, domains: &[&str]) {
55        for domain in domains {
56            self.block_domain(domain);
57        }
58    }
59
60    /// Add an allowed regex pattern
61    pub fn allow_pattern(&mut self, pattern: &str) -> Result<(), regex::Error> {
62        let regex = Regex::new(pattern)?;
63        self.allowed_patterns.push(regex);
64        Ok(())
65    }
66
67    /// Add a blocked regex pattern
68    pub fn block_pattern(&mut self, pattern: &str) -> Result<(), regex::Error> {
69        let regex = Regex::new(pattern)?;
70        self.blocked_patterns.push(regex);
71        Ok(())
72    }
73
74    /// Check if a URL is allowed
75    pub fn is_allowed(&self, url: &Url) -> bool {
76        let domain = match url.host_str() {
77            Some(d) => d.to_lowercase(),
78            None => return false,
79        };
80
81        // 1. Check blocked domains (priority)
82        if self.blocked.contains(&domain) {
83            return false;
84        }
85
86        // 2. Check blocked patterns
87        for pattern in &self.blocked_patterns {
88            if pattern.is_match(&domain) {
89                return false;
90            }
91        }
92
93        // 3. If no allowed list, everything is allowed
94        if self.allowed.is_empty() && self.allowed_patterns.is_empty() {
95            return true;
96        }
97
98        // 4. Check allowed domains
99        if self.allowed.contains(&domain) {
100            return true;
101        }
102
103        // 5. Check allowed subdomains
104        for allowed in &self.allowed {
105            if domain.ends_with(&format!(".{}", allowed)) {
106                return true;
107            }
108        }
109
110        // 6. Check allowed patterns
111        for pattern in &self.allowed_patterns {
112            if pattern.is_match(&domain) {
113                return true;
114            }
115        }
116
117        false
118    }
119
120    /// Check if a domain is allowed
121    pub fn is_domain_allowed(&self, domain: &str) -> bool {
122        let domain = domain.to_lowercase();
123
124        // Blocked?
125        if self.blocked.contains(&domain) {
126            return false;
127        }
128        for pattern in &self.blocked_patterns {
129            if pattern.is_match(&domain) {
130                return false;
131            }
132        }
133
134        // No list = allowed
135        if self.allowed.is_empty() && self.allowed_patterns.is_empty() {
136            return true;
137        }
138
139        // Allowed?
140        if self.allowed.contains(&domain) {
141            return true;
142        }
143        for allowed in &self.allowed {
144            if domain.ends_with(&format!(".{}", allowed)) {
145                return true;
146            }
147        }
148        for pattern in &self.allowed_patterns {
149            if pattern.is_match(&domain) {
150                return true;
151            }
152        }
153
154        false
155    }
156
157    /// Return the allowed domains
158    pub fn allowed_domains(&self) -> &HashSet<String> {
159        &self.allowed
160    }
161
162    /// Return the blocked domains
163    pub fn blocked_domains(&self) -> &HashSet<String> {
164        &self.blocked
165    }
166}