use regex::Regex;
use std::collections::HashSet;
use url::Url;
pub struct DomainAllowlist {
allowed: HashSet<String>,
blocked: HashSet<String>,
allowed_patterns: Vec<Regex>,
blocked_patterns: Vec<Regex>,
}
impl Default for DomainAllowlist {
fn default() -> Self {
Self::new()
}
}
impl DomainAllowlist {
pub fn new() -> Self {
Self {
allowed: HashSet::new(),
blocked: HashSet::new(),
allowed_patterns: Vec::new(),
blocked_patterns: Vec::new(),
}
}
pub fn allow_domain(&mut self, domain: &str) {
self.allowed.insert(domain.to_lowercase());
}
pub fn allow_domains(&mut self, domains: &[&str]) {
for domain in domains {
self.allow_domain(domain);
}
}
pub fn block_domain(&mut self, domain: &str) {
self.blocked.insert(domain.to_lowercase());
}
pub fn block_domains(&mut self, domains: &[&str]) {
for domain in domains {
self.block_domain(domain);
}
}
pub fn allow_pattern(&mut self, pattern: &str) -> Result<(), regex::Error> {
let regex = Regex::new(pattern)?;
self.allowed_patterns.push(regex);
Ok(())
}
pub fn block_pattern(&mut self, pattern: &str) -> Result<(), regex::Error> {
let regex = Regex::new(pattern)?;
self.blocked_patterns.push(regex);
Ok(())
}
pub fn is_allowed(&self, url: &Url) -> bool {
let domain = match url.host_str() {
Some(d) => d.to_lowercase(),
None => return false,
};
if self.blocked.contains(&domain) {
return false;
}
for pattern in &self.blocked_patterns {
if pattern.is_match(&domain) {
return false;
}
}
if self.allowed.is_empty() && self.allowed_patterns.is_empty() {
return true;
}
if self.allowed.contains(&domain) {
return true;
}
for allowed in &self.allowed {
if domain.ends_with(&format!(".{}", allowed)) {
return true;
}
}
for pattern in &self.allowed_patterns {
if pattern.is_match(&domain) {
return true;
}
}
false
}
pub fn is_domain_allowed(&self, domain: &str) -> bool {
let domain = domain.to_lowercase();
if self.blocked.contains(&domain) {
return false;
}
for pattern in &self.blocked_patterns {
if pattern.is_match(&domain) {
return false;
}
}
if self.allowed.is_empty() && self.allowed_patterns.is_empty() {
return true;
}
if self.allowed.contains(&domain) {
return true;
}
for allowed in &self.allowed {
if domain.ends_with(&format!(".{}", allowed)) {
return true;
}
}
for pattern in &self.allowed_patterns {
if pattern.is_match(&domain) {
return true;
}
}
false
}
pub fn allowed_domains(&self) -> &HashSet<String> {
&self.allowed
}
pub fn blocked_domains(&self) -> &HashSet<String> {
&self.blocked
}
}