halldyll_core/security/
allowlist.rs1use regex::Regex;
4use std::collections::HashSet;
5use url::Url;
6
7pub struct DomainAllowlist {
9 allowed: HashSet<String>,
11 blocked: HashSet<String>,
13 allowed_patterns: Vec<Regex>,
15 blocked_patterns: Vec<Regex>,
17}
18
19impl Default for DomainAllowlist {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl DomainAllowlist {
26 pub fn new() -> Self {
28 Self {
29 allowed: HashSet::new(),
30 blocked: HashSet::new(),
31 allowed_patterns: Vec::new(),
32 blocked_patterns: Vec::new(),
33 }
34 }
35
36 pub fn allow_domain(&mut self, domain: &str) {
38 self.allowed.insert(domain.to_lowercase());
39 }
40
41 pub fn allow_domains(&mut self, domains: &[&str]) {
43 for domain in domains {
44 self.allow_domain(domain);
45 }
46 }
47
48 pub fn block_domain(&mut self, domain: &str) {
50 self.blocked.insert(domain.to_lowercase());
51 }
52
53 pub fn block_domains(&mut self, domains: &[&str]) {
55 for domain in domains {
56 self.block_domain(domain);
57 }
58 }
59
60 pub fn allow_pattern(&mut self, pattern: &str) -> Result<(), regex::Error> {
62 let regex = Regex::new(pattern)?;
63 self.allowed_patterns.push(regex);
64 Ok(())
65 }
66
67 pub fn block_pattern(&mut self, pattern: &str) -> Result<(), regex::Error> {
69 let regex = Regex::new(pattern)?;
70 self.blocked_patterns.push(regex);
71 Ok(())
72 }
73
74 pub fn is_allowed(&self, url: &Url) -> bool {
76 let domain = match url.host_str() {
77 Some(d) => d.to_lowercase(),
78 None => return false,
79 };
80
81 if self.blocked.contains(&domain) {
83 return false;
84 }
85
86 for pattern in &self.blocked_patterns {
88 if pattern.is_match(&domain) {
89 return false;
90 }
91 }
92
93 if self.allowed.is_empty() && self.allowed_patterns.is_empty() {
95 return true;
96 }
97
98 if self.allowed.contains(&domain) {
100 return true;
101 }
102
103 for allowed in &self.allowed {
105 if domain.ends_with(&format!(".{}", allowed)) {
106 return true;
107 }
108 }
109
110 for pattern in &self.allowed_patterns {
112 if pattern.is_match(&domain) {
113 return true;
114 }
115 }
116
117 false
118 }
119
120 pub fn is_domain_allowed(&self, domain: &str) -> bool {
122 let domain = domain.to_lowercase();
123
124 if self.blocked.contains(&domain) {
126 return false;
127 }
128 for pattern in &self.blocked_patterns {
129 if pattern.is_match(&domain) {
130 return false;
131 }
132 }
133
134 if self.allowed.is_empty() && self.allowed_patterns.is_empty() {
136 return true;
137 }
138
139 if self.allowed.contains(&domain) {
141 return true;
142 }
143 for allowed in &self.allowed {
144 if domain.ends_with(&format!(".{}", allowed)) {
145 return true;
146 }
147 }
148 for pattern in &self.allowed_patterns {
149 if pattern.is_match(&domain) {
150 return true;
151 }
152 }
153
154 false
155 }
156
157 pub fn allowed_domains(&self) -> &HashSet<String> {
159 &self.allowed
160 }
161
162 pub fn blocked_domains(&self) -> &HashSet<String> {
164 &self.blocked
165 }
166}