use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct HostMatcher {
blocked: HashSet<String>,
}
impl HostMatcher {
pub fn new(domains: impl IntoIterator<Item = String>) -> Self {
let blocked = domains
.into_iter()
.map(|s| s.trim().trim_end_matches('.').to_ascii_lowercase())
.filter(|s| !s.is_empty() && !s.starts_with('#'))
.collect();
Self { blocked }
}
pub fn len(&self) -> usize {
self.blocked.len()
}
pub fn is_empty(&self) -> bool {
self.blocked.is_empty()
}
pub fn is_blocked(&self, host: &str) -> bool {
let host = host.trim().trim_end_matches('.').to_ascii_lowercase();
let mut cursor: &str = &host;
loop {
if self.blocked.contains(cursor) {
return true;
}
match cursor.find('.') {
Some(pos) => cursor = &cursor[pos + 1..],
None => return false,
}
}
}
}
pub fn host_of(url: &str) -> Option<&str> {
let after_scheme = url.split_once("://")?.1;
let authority = match after_scheme.find(['/', '?', '#']) {
Some(pos) => &after_scheme[..pos],
None => after_scheme,
};
let host_and_port = match authority.rfind('@') {
Some(pos) => &authority[pos + 1..],
None => authority,
};
let host = if host_and_port.starts_with('[') {
match host_and_port.find(']') {
Some(pos) => &host_and_port[..=pos],
None => host_and_port,
}
} else {
match host_and_port.rfind(':') {
Some(pos) => &host_and_port[..pos],
None => host_and_port,
}
};
Some(host)
}
#[cfg(test)]
#[allow(clippy::panic, clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn exact_match_blocked() {
let m = HostMatcher::new(["evil.com".to_string()]);
assert!(m.is_blocked("evil.com"));
}
#[test]
fn subdomain_of_listed_domain_is_blocked() {
let m = HostMatcher::new(["evil.com".to_string()]);
assert!(m.is_blocked("tracker.evil.com"));
assert!(m.is_blocked("a.b.tracker.evil.com"));
}
#[test]
fn unrelated_host_not_blocked() {
let m = HostMatcher::new(["evil.com".to_string()]);
assert!(!m.is_blocked("good.com"));
assert!(!m.is_blocked("notevil.com"));
assert!(!m.is_blocked("totallyevil.com"));
}
#[test]
fn bare_root_in_set_blocks_all_subdomains() {
let m = HostMatcher::new(["example.com".to_string()]);
assert!(m.is_blocked("example.com"));
assert!(m.is_blocked("sub.example.com"));
}
#[test]
fn case_insensitive_match() {
let m = HostMatcher::new(["Evil.Com".to_string()]);
assert!(m.is_blocked("EVIL.COM"));
assert!(m.is_blocked("Tracker.Evil.Com"));
}
#[test]
fn empty_matcher_blocks_nothing() {
let m = HostMatcher::new(std::iter::empty());
assert!(!m.is_blocked("evil.com"));
}
#[test]
fn comment_and_blank_lines_ignored() {
let lines = vec![
"# this is a comment".to_string(),
"".to_string(),
" ".to_string(),
"tracker.example.com".to_string(),
"# another comment".to_string(),
];
let m = HostMatcher::new(lines);
assert!(m.is_blocked("tracker.example.com"));
assert!(!m.is_blocked("example.com")); }
#[test]
fn single_label_host_no_infinite_loop() {
let m = HostMatcher::new(["localhost".to_string()]);
assert!(m.is_blocked("localhost"));
assert!(!m.is_blocked("otherhost"));
}
#[test]
fn host_of_simple_url() {
assert_eq!(host_of("https://example.com/path"), Some("example.com"));
}
#[test]
fn host_of_with_port() {
assert_eq!(host_of("http://example.com:8080/path"), Some("example.com"));
}
#[test]
fn host_of_no_path() {
assert_eq!(host_of("https://example.com"), Some("example.com"));
}
#[test]
fn host_of_with_query() {
assert_eq!(host_of("https://example.com?foo=bar"), Some("example.com"));
}
#[test]
fn host_of_with_fragment() {
assert_eq!(host_of("https://example.com#section"), Some("example.com"));
}
#[test]
fn host_of_missing_scheme_separator() {
assert_eq!(host_of("not-a-url"), None);
}
#[test]
fn host_of_ipv6() {
assert_eq!(host_of("https://[::1]:443/path"), Some("[::1]"));
}
#[test]
fn host_of_with_userinfo() {
assert_eq!(
host_of("https://user:pass@example.com/path"),
Some("example.com")
);
}
#[test]
fn is_blocked_using_host_of() {
let m = HostMatcher::new(["fingerprinter.io".to_string()]);
let url = "https://cdn.fingerprinter.io/track.js?v=1";
let host = host_of(url).expect("host_of returned None");
assert!(m.is_blocked(host));
}
}