1use std::sync::OnceLock;
6
7use globset::{GlobBuilder, GlobMatcher};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug)]
11struct DomainPatternError {
12 pattern: String,
13 error: globset::Error,
14}
15
16impl std::fmt::Display for DomainPatternError {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 write!(f, "invalid domain glob {:?}: {}", self.pattern, self.error)
19 }
20}
21
22#[derive(Debug)]
23struct CompiledPattern {
24 original: String,
25 matcher: GlobMatcher,
26}
27
28#[derive(Debug, Default)]
29struct CompiledDomainPolicy {
30 allow: Vec<CompiledPattern>,
31 block: Vec<CompiledPattern>,
32}
33
34impl CompiledDomainPolicy {
35 fn compile(policy: &DomainPolicy) -> Result<Self, DomainPatternError> {
36 Ok(Self {
37 allow: compile_patterns(policy.allow_patterns())?,
38 block: compile_patterns(policy.block_patterns())?,
39 })
40 }
41}
42
43fn compile_patterns(patterns: &[String]) -> Result<Vec<CompiledPattern>, DomainPatternError> {
44 let mut out = Vec::with_capacity(patterns.len());
45 for p in patterns {
46 let matcher = compile_pattern(p)?;
47 out.push(CompiledPattern {
48 original: p.clone(),
49 matcher,
50 });
51 }
52 Ok(out)
53}
54
55fn compile_pattern(pattern: &str) -> Result<GlobMatcher, DomainPatternError> {
56 let glob = GlobBuilder::new(pattern)
57 .case_insensitive(true)
58 .literal_separator(true)
59 .build()
60 .map_err(|e| DomainPatternError {
61 pattern: pattern.to_string(),
62 error: e,
63 })?;
64
65 Ok(glob.compile_matcher())
66}
67
68#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
70#[serde(rename_all = "lowercase")]
71pub enum PolicyAction {
72 Allow,
74 #[serde(alias = "deny")]
76 #[default]
77 Block,
78 Log,
80}
81
82#[derive(Debug, Serialize, Deserialize)]
84pub struct DomainPolicy {
85 #[serde(default)]
87 allow: Vec<String>,
88 #[serde(default)]
90 block: Vec<String>,
91 #[serde(default = "default_action")]
93 default_action: PolicyAction,
94
95 #[serde(skip)]
96 compiled: OnceLock<Result<CompiledDomainPolicy, DomainPatternError>>,
97}
98
99fn default_action() -> PolicyAction {
100 PolicyAction::Block
101}
102
103impl Default for DomainPolicy {
104 fn default() -> Self {
105 Self {
106 allow: Vec::new(),
107 block: Vec::new(),
108 default_action: default_action(),
109 compiled: OnceLock::new(),
110 }
111 }
112}
113
114impl Clone for DomainPolicy {
115 fn clone(&self) -> Self {
116 Self {
117 allow: self.allow.clone(),
118 block: self.block.clone(),
119 default_action: self.default_action.clone(),
120 compiled: OnceLock::new(),
121 }
122 }
123}
124
125impl DomainPolicy {
126 pub fn new() -> Self {
128 Self::default()
129 }
130
131 pub fn permissive() -> Self {
133 Self {
134 default_action: PolicyAction::Allow,
135 ..Self::default()
136 }
137 }
138
139 pub fn allow(mut self, pattern: impl Into<String>) -> Self {
141 self.allow.push(pattern.into());
142 self.compiled = OnceLock::new();
143 self
144 }
145
146 pub fn block(mut self, pattern: impl Into<String>) -> Self {
148 self.block.push(pattern.into());
149 self.compiled = OnceLock::new();
150 self
151 }
152
153 pub fn evaluate(&self, domain: &str) -> PolicyAction {
155 let compiled = match self.compiled() {
156 Ok(c) => c,
157 Err(_) => return PolicyAction::Block,
158 };
159
160 for pattern in &compiled.block {
162 if pattern.matcher.is_match(domain) {
163 return PolicyAction::Block;
164 }
165 }
166
167 for pattern in &compiled.allow {
169 if pattern.matcher.is_match(domain) {
170 return PolicyAction::Allow;
171 }
172 }
173
174 self.default_action.clone()
176 }
177
178 pub fn is_allowed(&self, domain: &str) -> bool {
180 matches!(self.evaluate(domain), PolicyAction::Allow)
181 }
182}
183
184#[derive(Clone, Debug, Serialize, Deserialize)]
186pub struct PolicyResult {
187 pub domain: String,
189 pub action: PolicyAction,
191 pub matched_pattern: Option<String>,
193 pub is_default: bool,
195}
196
197impl DomainPolicy {
198 pub fn evaluate_detailed(&self, domain: &str) -> PolicyResult {
200 let compiled = match self.compiled() {
201 Ok(c) => c,
202 Err(_) => {
203 return PolicyResult {
204 domain: domain.to_string(),
205 action: PolicyAction::Block,
206 matched_pattern: None,
207 is_default: true,
208 };
209 }
210 };
211
212 for pattern in &compiled.block {
214 if pattern.matcher.is_match(domain) {
215 return PolicyResult {
216 domain: domain.to_string(),
217 action: PolicyAction::Block,
218 matched_pattern: Some(pattern.original.clone()),
219 is_default: false,
220 };
221 }
222 }
223
224 for pattern in &compiled.allow {
226 if pattern.matcher.is_match(domain) {
227 return PolicyResult {
228 domain: domain.to_string(),
229 action: PolicyAction::Allow,
230 matched_pattern: Some(pattern.original.clone()),
231 is_default: false,
232 };
233 }
234 }
235
236 PolicyResult {
238 domain: domain.to_string(),
239 action: self.default_action.clone(),
240 matched_pattern: None,
241 is_default: true,
242 }
243 }
244
245 pub fn allow_patterns(&self) -> &[String] {
246 &self.allow
247 }
248
249 pub fn block_patterns(&self) -> &[String] {
250 &self.block
251 }
252
253 pub fn set_default_action(&mut self, default_action: PolicyAction) {
254 self.default_action = default_action;
255 self.compiled = OnceLock::new();
256 }
257
258 pub fn extend_allow<I>(&mut self, patterns: I)
259 where
260 I: IntoIterator<Item = String>,
261 {
262 self.allow.extend(patterns);
263 self.compiled = OnceLock::new();
264 }
265
266 pub fn extend_block<I>(&mut self, patterns: I)
267 where
268 I: IntoIterator<Item = String>,
269 {
270 self.block.extend(patterns);
271 self.compiled = OnceLock::new();
272 }
273
274 fn compiled(&self) -> std::result::Result<&CompiledDomainPolicy, &DomainPatternError> {
275 match self
276 .compiled
277 .get_or_init(|| CompiledDomainPolicy::compile(self))
278 {
279 Ok(c) => Ok(c),
280 Err(e) => Err(e),
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_default_deny() {
291 let policy = DomainPolicy::new();
292 assert!(!policy.is_allowed("example.com"));
293 }
294
295 #[test]
296 fn test_permissive() {
297 let policy = DomainPolicy::permissive();
298 assert!(policy.is_allowed("example.com"));
299 }
300
301 #[test]
302 fn test_allowlist() {
303 let policy = DomainPolicy::new()
304 .allow("example.com")
305 .allow("*.allowed.org");
306
307 assert!(policy.is_allowed("example.com"));
308 assert!(policy.is_allowed("sub.allowed.org"));
309 assert!(!policy.is_allowed("other.com"));
310 }
311
312 #[test]
313 fn test_blocklist_precedence() {
314 let policy = DomainPolicy::permissive().block("bad.example.com");
315
316 assert!(policy.is_allowed("good.example.com"));
317 assert!(!policy.is_allowed("bad.example.com"));
318 }
319
320 #[test]
321 fn test_wildcard_block() {
322 let policy = DomainPolicy::permissive()
323 .block("*.blocked.com")
324 .block("blocked.com");
325
326 assert!(policy.is_allowed("allowed.com"));
327 assert!(!policy.is_allowed("sub.blocked.com"));
328 assert!(!policy.is_allowed("blocked.com"));
329 }
330
331 #[test]
332 fn test_evaluate_detailed() {
333 let policy = DomainPolicy::new().allow("*.example.com");
334
335 let result = policy.evaluate_detailed("sub.example.com");
336 assert_eq!(result.action, PolicyAction::Allow);
337 assert_eq!(result.matched_pattern, Some("*.example.com".to_string()));
338 assert!(!result.is_default);
339
340 let result = policy.evaluate_detailed("other.com");
341 assert_eq!(result.action, PolicyAction::Block);
342 assert!(result.matched_pattern.is_none());
343 assert!(result.is_default);
344 }
345}