Skip to main content

hush_proxy/
policy.rs

1//! Egress policy enforcement
2//!
3//! Provides domain allowlist/blocklist policy evaluation.
4
5use std::sync::OnceLock;
6
7use globset::{GlobBuilder, GlobMatcher};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug)]
11struct DomainPatternError {
12    pattern: String,
13    error: globset::Error,
14}
15
16impl std::fmt::Display for DomainPatternError {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        write!(f, "invalid domain glob {:?}: {}", self.pattern, self.error)
19    }
20}
21
22#[derive(Debug)]
23struct CompiledPattern {
24    original: String,
25    matcher: GlobMatcher,
26}
27
28#[derive(Debug, Default)]
29struct CompiledDomainPolicy {
30    allow: Vec<CompiledPattern>,
31    block: Vec<CompiledPattern>,
32}
33
34impl CompiledDomainPolicy {
35    fn compile(policy: &DomainPolicy) -> Result<Self, DomainPatternError> {
36        Ok(Self {
37            allow: compile_patterns(policy.allow_patterns())?,
38            block: compile_patterns(policy.block_patterns())?,
39        })
40    }
41}
42
43fn compile_patterns(patterns: &[String]) -> Result<Vec<CompiledPattern>, DomainPatternError> {
44    let mut out = Vec::with_capacity(patterns.len());
45    for p in patterns {
46        let matcher = compile_pattern(p)?;
47        out.push(CompiledPattern {
48            original: p.clone(),
49            matcher,
50        });
51    }
52    Ok(out)
53}
54
55fn compile_pattern(pattern: &str) -> Result<GlobMatcher, DomainPatternError> {
56    let glob = GlobBuilder::new(pattern)
57        .case_insensitive(true)
58        .literal_separator(true)
59        .build()
60        .map_err(|e| DomainPatternError {
61            pattern: pattern.to_string(),
62            error: e,
63        })?;
64
65    Ok(glob.compile_matcher())
66}
67
68/// Policy action for a domain
69#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
70#[serde(rename_all = "lowercase")]
71pub enum PolicyAction {
72    /// Allow the connection
73    Allow,
74    /// Block the connection
75    #[serde(alias = "deny")]
76    #[default]
77    Block,
78    /// Log but allow
79    Log,
80}
81
82/// Domain policy configuration
83#[derive(Debug, Serialize, Deserialize)]
84pub struct DomainPolicy {
85    /// Allowed domain patterns (glob syntax)
86    #[serde(default)]
87    allow: Vec<String>,
88    /// Blocked domain patterns
89    #[serde(default)]
90    block: Vec<String>,
91    /// Default action when no pattern matches
92    #[serde(default = "default_action")]
93    default_action: PolicyAction,
94
95    #[serde(skip)]
96    compiled: OnceLock<Result<CompiledDomainPolicy, DomainPatternError>>,
97}
98
99fn default_action() -> PolicyAction {
100    PolicyAction::Block
101}
102
103impl Default for DomainPolicy {
104    fn default() -> Self {
105        Self {
106            allow: Vec::new(),
107            block: Vec::new(),
108            default_action: default_action(),
109            compiled: OnceLock::new(),
110        }
111    }
112}
113
114impl Clone for DomainPolicy {
115    fn clone(&self) -> Self {
116        Self {
117            allow: self.allow.clone(),
118            block: self.block.clone(),
119            default_action: self.default_action.clone(),
120            compiled: OnceLock::new(),
121        }
122    }
123}
124
125impl DomainPolicy {
126    /// Create a new policy with default deny
127    pub fn new() -> Self {
128        Self::default()
129    }
130
131    /// Create a permissive policy (default allow)
132    pub fn permissive() -> Self {
133        Self {
134            default_action: PolicyAction::Allow,
135            ..Self::default()
136        }
137    }
138
139    /// Add an allowed domain pattern
140    pub fn allow(mut self, pattern: impl Into<String>) -> Self {
141        self.allow.push(pattern.into());
142        self.compiled = OnceLock::new();
143        self
144    }
145
146    /// Add a blocked domain pattern
147    pub fn block(mut self, pattern: impl Into<String>) -> Self {
148        self.block.push(pattern.into());
149        self.compiled = OnceLock::new();
150        self
151    }
152
153    /// Evaluate a domain against the policy
154    pub fn evaluate(&self, domain: &str) -> PolicyAction {
155        let compiled = match self.compiled() {
156            Ok(c) => c,
157            Err(_) => return PolicyAction::Block,
158        };
159
160        // Check blocklist first (block takes precedence)
161        for pattern in &compiled.block {
162            if pattern.matcher.is_match(domain) {
163                return PolicyAction::Block;
164            }
165        }
166
167        // Check allowlist
168        for pattern in &compiled.allow {
169            if pattern.matcher.is_match(domain) {
170                return PolicyAction::Allow;
171            }
172        }
173
174        // Default action
175        self.default_action.clone()
176    }
177
178    /// Check if a domain is allowed
179    pub fn is_allowed(&self, domain: &str) -> bool {
180        matches!(self.evaluate(domain), PolicyAction::Allow)
181    }
182}
183
184/// Policy evaluation result with details
185#[derive(Clone, Debug, Serialize, Deserialize)]
186pub struct PolicyResult {
187    /// The evaluated domain
188    pub domain: String,
189    /// The resulting action
190    pub action: PolicyAction,
191    /// The pattern that matched (if any)
192    pub matched_pattern: Option<String>,
193    /// Whether this was a default action
194    pub is_default: bool,
195}
196
197impl DomainPolicy {
198    /// Evaluate with detailed result
199    pub fn evaluate_detailed(&self, domain: &str) -> PolicyResult {
200        let compiled = match self.compiled() {
201            Ok(c) => c,
202            Err(_) => {
203                return PolicyResult {
204                    domain: domain.to_string(),
205                    action: PolicyAction::Block,
206                    matched_pattern: None,
207                    is_default: true,
208                };
209            }
210        };
211
212        // Check blocklist first
213        for pattern in &compiled.block {
214            if pattern.matcher.is_match(domain) {
215                return PolicyResult {
216                    domain: domain.to_string(),
217                    action: PolicyAction::Block,
218                    matched_pattern: Some(pattern.original.clone()),
219                    is_default: false,
220                };
221            }
222        }
223
224        // Check allowlist
225        for pattern in &compiled.allow {
226            if pattern.matcher.is_match(domain) {
227                return PolicyResult {
228                    domain: domain.to_string(),
229                    action: PolicyAction::Allow,
230                    matched_pattern: Some(pattern.original.clone()),
231                    is_default: false,
232                };
233            }
234        }
235
236        // Default action
237        PolicyResult {
238            domain: domain.to_string(),
239            action: self.default_action.clone(),
240            matched_pattern: None,
241            is_default: true,
242        }
243    }
244
245    pub fn allow_patterns(&self) -> &[String] {
246        &self.allow
247    }
248
249    pub fn block_patterns(&self) -> &[String] {
250        &self.block
251    }
252
253    pub fn set_default_action(&mut self, default_action: PolicyAction) {
254        self.default_action = default_action;
255        self.compiled = OnceLock::new();
256    }
257
258    pub fn extend_allow<I>(&mut self, patterns: I)
259    where
260        I: IntoIterator<Item = String>,
261    {
262        self.allow.extend(patterns);
263        self.compiled = OnceLock::new();
264    }
265
266    pub fn extend_block<I>(&mut self, patterns: I)
267    where
268        I: IntoIterator<Item = String>,
269    {
270        self.block.extend(patterns);
271        self.compiled = OnceLock::new();
272    }
273
274    fn compiled(&self) -> std::result::Result<&CompiledDomainPolicy, &DomainPatternError> {
275        match self
276            .compiled
277            .get_or_init(|| CompiledDomainPolicy::compile(self))
278        {
279            Ok(c) => Ok(c),
280            Err(e) => Err(e),
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_default_deny() {
291        let policy = DomainPolicy::new();
292        assert!(!policy.is_allowed("example.com"));
293    }
294
295    #[test]
296    fn test_permissive() {
297        let policy = DomainPolicy::permissive();
298        assert!(policy.is_allowed("example.com"));
299    }
300
301    #[test]
302    fn test_allowlist() {
303        let policy = DomainPolicy::new()
304            .allow("example.com")
305            .allow("*.allowed.org");
306
307        assert!(policy.is_allowed("example.com"));
308        assert!(policy.is_allowed("sub.allowed.org"));
309        assert!(!policy.is_allowed("other.com"));
310    }
311
312    #[test]
313    fn test_blocklist_precedence() {
314        let policy = DomainPolicy::permissive().block("bad.example.com");
315
316        assert!(policy.is_allowed("good.example.com"));
317        assert!(!policy.is_allowed("bad.example.com"));
318    }
319
320    #[test]
321    fn test_wildcard_block() {
322        let policy = DomainPolicy::permissive()
323            .block("*.blocked.com")
324            .block("blocked.com");
325
326        assert!(policy.is_allowed("allowed.com"));
327        assert!(!policy.is_allowed("sub.blocked.com"));
328        assert!(!policy.is_allowed("blocked.com"));
329    }
330
331    #[test]
332    fn test_evaluate_detailed() {
333        let policy = DomainPolicy::new().allow("*.example.com");
334
335        let result = policy.evaluate_detailed("sub.example.com");
336        assert_eq!(result.action, PolicyAction::Allow);
337        assert_eq!(result.matched_pattern, Some("*.example.com".to_string()));
338        assert!(!result.is_default);
339
340        let result = policy.evaluate_detailed("other.com");
341        assert_eq!(result.action, PolicyAction::Block);
342        assert!(result.matched_pattern.is_none());
343        assert!(result.is_default);
344    }
345}