blockconvert/
domain_filter.rs

1use crate::{Domain, DomainSetSharded};
2
3use std::collections::HashSet;
4use std::str::FromStr;
5
6use parking_lot::Mutex;
7
8struct Filter<T> {
9    allow: T,
10    disallow: T,
11}
12
13impl<T: Default> Default for Filter<T> {
14    fn default() -> Self {
15        Self {
16            allow: T::default(),
17            disallow: T::default(),
18        }
19    }
20}
21
22impl<T> Filter<T> {
23    fn new(allow: T, disallow: T) -> Self {
24        Self { allow, disallow }
25    }
26}
27
28#[derive(Debug, Default)]
29struct AdblockParseError {}
30
31impl std::error::Error for AdblockParseError {
32    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
33        None
34    }
35}
36
37impl std::fmt::Display for AdblockParseError {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "{:?}", self)
40    }
41}
42
43#[derive(Clone, Eq, PartialEq, Hash, Debug)]
44struct AdblockFilter {
45    is_exception: bool,
46    match_start_domain: bool,
47    match_start_address: bool,
48    match_end_domain: bool,
49    filter: String,
50    is_badfilter: bool,
51}
52
53impl FromStr for AdblockFilter {
54    type Err = AdblockParseError;
55    fn from_str(original_rule: &str) -> Result<Self, Self::Err> {
56        let mut rule = original_rule;
57        if rule.starts_with('!') {
58            // Remove comments
59            return Err(AdblockParseError::default());
60        }
61        if rule.contains('#')
62        // Element Hiding
63        {
64            return Err(AdblockParseError::default());
65        }
66        let mut is_badfilter = false;
67        if let Some(position) = rule.find('$') {
68            let (start, tags) = rule.split_at(position);
69            let tags = tags
70                .trim_start_matches('$')
71                .split(',')
72                .filter(|tag| !matches!(*tag, "3p" | "third-party"))
73                .filter(|tag| {
74                    if *tag == "badfilter" {
75                        is_badfilter = true;
76                        false
77                    } else {
78                        true
79                    }
80                })
81                .collect::<Vec<&str>>();
82
83            if !(tags.is_empty() || tags.contains(&"document") || tags.contains(&"all"))
84                || tags.iter().any(|tag| tag.starts_with("domain="))
85            {
86                return Err(AdblockParseError::default());
87            }
88            rule = start;
89        }
90        let (rule, is_exception) = rule
91            .strip_prefix("@@")
92            .map(|rule| (rule, true))
93            .unwrap_or((rule, false));
94
95        let (rule, match_start_domain, match_start_address) =
96            if let Some(rule) = rule.strip_prefix("||") {
97                (rule, true, false)
98            } else {
99                let (rule, match_start_address) = rule
100                    .strip_prefix('|')
101                    .map(|rule| (rule, true))
102                    .unwrap_or((rule, false));
103                let (rule, match_start_address) = rule
104                    .strip_prefix('*')
105                    .or_else(|| rule.strip_prefix("https"))
106                    .or_else(|| rule.strip_prefix("http"))
107                    .unwrap_or(rule)
108                    .strip_prefix("://")
109                    .map(|rule| (rule, true))
110                    .unwrap_or((rule, match_start_address));
111                (rule, false, match_start_address)
112            };
113        let (rule, match_start_domain, match_start_address) = rule
114            .strip_prefix('*')
115            .unwrap_or(rule)
116            .strip_prefix('.')
117            .map(|rule| (rule, true, false))
118            .unwrap_or((rule, match_start_domain, match_start_address));
119
120        let (rule, match_end_domain) = rule
121            .strip_suffix('|')
122            .map(|rule| (rule, true))
123            .unwrap_or((rule, false));
124        let (rule, match_end_domain) = rule
125            .strip_suffix(".php")
126            .or_else(|| rule.strip_suffix(".htm"))
127            .or_else(|| rule.strip_suffix(".html"))
128            .or_else(|| rule.strip_suffix(".xhtml"))
129            .unwrap_or(rule)
130            .strip_suffix('*')
131            .and_then(|rule| rule.strip_suffix('/').or(Some(rule)))
132            .map(|rule| (rule, true))
133            .unwrap_or((rule, match_end_domain));
134        let (rule, match_end_domain) = rule
135            .strip_suffix('^')
136            .map(|rule| (rule, true))
137            .unwrap_or((rule, match_end_domain));
138        let (rule, match_end_domain) = rule
139            .strip_suffix('.')
140            .map(|rule| (rule, false))
141            .unwrap_or((rule, match_end_domain));
142        if rule.is_empty()
143            || rule == "*"
144            || !rule
145                .chars()
146                .all(|c| c.is_ascii_alphanumeric() || matches!(c, '_' | '-' | '.' | '*'))
147        {
148            return Err(AdblockParseError::default());
149        }
150        Ok(Self {
151            is_exception,
152            match_start_domain,
153            match_start_address,
154            match_end_domain,
155            filter: rule.to_string(),
156            is_badfilter,
157        })
158    }
159}
160
161#[derive(Default)]
162pub struct DomainFilterBuilder<H: std::hash::BuildHasher + Default> {
163    domains: Filter<DomainSetSharded<H>>,
164    subdomains: Filter<DomainSetSharded<H>>,
165    ips: Mutex<Filter<HashSet<std::net::IpAddr, H>>>,
166    ip_nets: Mutex<Filter<HashSet<ipnet::IpNet, H>>>,
167    regexes: Mutex<Filter<HashSet<String, H>>>,
168    adblock: Mutex<HashSet<AdblockFilter, H>>,
169}
170
171type DefaultHasher = std::collections::hash_map::RandomState;
172
173pub type DefaultDomainFilterBuilder = DomainFilterBuilder<DefaultHasher>;
174
175impl<H: std::hash::BuildHasher + Default> DomainFilterBuilder<H> {
176    pub fn new() -> Self {
177        Default::default()
178    }
179    pub fn add_allow_domain(&self, domain: Domain) {
180        if let Some(without_www) = domain
181            .strip_prefix("www.")
182            .and_then(|domain| domain.parse::<Domain>().ok())
183        {
184            self.domains.disallow.remove_str(&without_www);
185            self.domains.allow.insert_str(&without_www);
186        } else if let Ok(with_www) = format!("www.{}", &domain).parse::<Domain>() {
187            self.domains.disallow.remove_str(&with_www);
188            self.domains.allow.insert_str(&with_www);
189        }
190        self.domains.disallow.remove_str(&domain);
191        self.domains.allow.insert_str(&domain);
192    }
193    pub fn add_disallow_domain(&self, domain: Domain) {
194        if !self.domains.allow.contains_str(&domain)
195            && !is_subdomain_of_list(&domain, &self.subdomains.allow)
196        {
197            self.domains.disallow.insert_str(&domain);
198        }
199    }
200    pub fn add_allow_subdomain(&self, domain: Domain) {
201        self.subdomains.disallow.remove_str(&domain);
202        self.subdomains.allow.insert_str(&domain);
203    }
204    pub fn add_disallow_subdomain(&self, domain: Domain) {
205        if !self.subdomains.allow.contains_str(&domain) {
206            self.subdomains.disallow.insert_str(&domain);
207        }
208    }
209
210    pub fn add_allow_ip_addr(&self, ip: std::net::IpAddr) {
211        let mut ips = self.ips.lock();
212        let _ = ips.disallow.remove(&ip);
213        ips.allow.insert(ip);
214    }
215    pub fn add_disallow_ip_addr(&self, ip: std::net::IpAddr) {
216        let mut ips = self.ips.lock();
217        if !ips.allow.contains(&ip) {
218            ips.disallow.insert(ip);
219        }
220    }
221
222    pub fn add_allow_ip_subnet(&self, net: ipnet::IpNet) {
223        let mut ip_nets = self.ip_nets.lock();
224        let _ = ip_nets.disallow.remove(&net);
225        ip_nets.allow.insert(net);
226    }
227
228    pub fn add_disallow_ip_subnet(&self, ip: ipnet::IpNet) {
229        let mut ip_nets = self.ip_nets.lock();
230        if !ip_nets.allow.contains(&ip) {
231            ip_nets.disallow.insert(ip);
232        }
233    }
234
235    pub fn add_adblock_rule(&self, rule: &str) {
236        if let Ok(filter) = rule.parse::<AdblockFilter>() {
237            self.adblock.lock().insert(filter);
238        }
239    }
240
241    pub fn add_allow_regex(&self, re: &str) {
242        if !re.is_empty() && regex::Regex::new(re).is_ok() {
243            let mut regexes = self.regexes.lock();
244            regexes.allow.insert(re.to_string());
245        }
246    }
247    pub fn add_disallow_regex(&self, re: &str) {
248        if !re.is_empty() && regex::Regex::new(re).is_ok() {
249            let mut regexes = self.regexes.lock();
250            regexes.disallow.insert(re.to_string());
251        }
252    }
253
254    pub fn to_domain_filter(self) -> DomainFilter<H> {
255        let adblock = std::mem::take(&mut *self.adblock.lock());
256
257        for filter in adblock.iter() {
258            let mut bad_filter = filter.clone();
259            bad_filter.is_badfilter = true;
260            if adblock.contains(&bad_filter) {
261                continue;
262            }
263            if (filter.match_start_domain || filter.match_start_address) && filter.match_end_domain
264            {
265                if let Ok(domain) = filter.filter.parse::<Domain>() {
266                    if filter.is_exception {
267                        self.add_allow_domain(domain.clone());
268                        if filter.match_start_domain {
269                            self.add_allow_subdomain(domain);
270                        }
271                    } else {
272                        self.add_disallow_domain(domain.clone());
273                        if filter.match_start_domain {
274                            self.add_disallow_subdomain(domain);
275                        }
276                    }
277                }
278                if let Ok(ip) = filter.filter.parse::<std::net::IpAddr>() {
279                    if !filter.is_exception {
280                        // Currently no IP exception filters, so no benefit from implementing
281                        self.add_disallow_ip_addr(ip);
282                    }
283                }
284            }
285        }
286
287        let domains = self.domains;
288        let subdomains = self.subdomains;
289        let mut ips = std::mem::take(&mut *self.ips.lock());
290        let ip_nets = std::mem::take(&mut *self.ip_nets.lock());
291        let regexes = std::mem::take(&mut *self.regexes.lock());
292
293        domains.allow.shrink_to_fit();
294        domains.disallow.shrink_to_fit();
295        subdomains.allow.shrink_to_fit();
296        subdomains.disallow.shrink_to_fit();
297        ips.allow.shrink_to_fit();
298        ips.disallow.shrink_to_fit();
299        let ip_nets = Filter {
300            allow: ip_nets.allow.into_iter().collect(),
301            disallow: ip_nets.disallow.into_iter().collect(),
302        };
303        DomainFilter {
304            domains,
305            subdomains,
306            ips,
307            ip_nets: ip_nets,
308            allow_regex: regex::RegexSet::new(&regexes.allow).unwrap(),
309            disallow_regex: regex::RegexSet::new(&regexes.disallow).unwrap(),
310        }
311    }
312}
313
314fn is_subdomain_of_list<H: std::hash::BuildHasher>(
315    domain: &Domain,
316    filter_list: &DomainSetSharded<H>,
317) -> bool {
318    Domain::str_iter_parent_domains(domain).any(|part| filter_list.contains_str(part))
319}
320
321pub struct DomainFilter<H: std::hash::BuildHasher + Default> {
322    domains: Filter<DomainSetSharded<H>>,
323    subdomains: Filter<DomainSetSharded<H>>,
324    ips: Filter<HashSet<std::net::IpAddr, H>>,
325    ip_nets: Filter<Vec<ipnet::IpNet>>,
326    allow_regex: regex::RegexSet,
327    disallow_regex: regex::RegexSet,
328}
329
330impl<H: std::hash::BuildHasher + Default> Default for DomainFilter<H> {
331    fn default() -> Self {
332        Self {
333            domains: Filter::new(
334                DomainSetSharded::<H>::with_shards(0),
335                DomainSetSharded::<H>::with_shards(0),
336            ),
337            subdomains: Filter::new(
338                DomainSetSharded::<H>::with_shards(0),
339                DomainSetSharded::<H>::with_shards(0),
340            ),
341            ips: Default::default(),
342            ip_nets: Default::default(),
343            allow_regex: regex::RegexSet::empty(),
344            disallow_regex: regex::RegexSet::empty(),
345        }
346    }
347}
348
349impl<H: std::hash::BuildHasher + Default> DomainFilter<H> {
350    fn is_allowed_by_adblock(&self, _location: &str) -> Option<bool> {
351        None
352    }
353
354    pub fn allowed(
355        &self,
356        domain: &Domain,
357        cnames: &[Domain],
358        ips: &[std::net::IpAddr],
359    ) -> Option<bool> {
360        if let Some(result) = self.domain_is_allowed(domain) {
361            Some(result)
362        } else if cnames
363            .iter()
364            .any(|cname| self.domain_is_allowed(cname) == Some(false))
365            || ips.iter().any(|ip| self.ip_is_allowed(ip) == Some(false))
366        {
367            Some(false)
368        } else {
369            None
370        }
371    }
372
373    fn domain_is_allowed(&self, domain: &Domain) -> Option<bool> {
374        if self.domains.allow.contains_str(&domain)
375            || is_subdomain_of_list(&*domain, &self.subdomains.allow)
376            || self.allow_regex.is_match(domain)
377        {
378            Some(true)
379        } else if let Some(blocker_result) = self.is_allowed_by_adblock(&domain) {
380            Some(blocker_result)
381        } else if self.domains.disallow.contains_str(&domain)
382            || is_subdomain_of_list(&*domain, &self.subdomains.disallow)
383            || self.disallow_regex.is_match(domain)
384        {
385            Some(false)
386        } else {
387            None
388        }
389    }
390
391    pub fn ip_is_allowed(&self, ip: &std::net::IpAddr) -> Option<bool> {
392        if self.ips.allow.contains(ip) || self.ip_nets.allow.iter().any(|net| net.contains(ip)) {
393            Some(true)
394        } else if let Some(blocker_result) = self.is_allowed_by_adblock(&ip.to_string()) {
395            Some(blocker_result)
396        } else if self.ips.disallow.contains(ip)
397            || self.ip_nets.disallow.iter().any(|net| net.contains(ip))
398        {
399            Some(false)
400        } else {
401            None
402        }
403    }
404}
405
406#[test]
407fn default_unblocked() {
408    assert_eq!(
409        DefaultDomainFilterBuilder::new()
410            .to_domain_filter()
411            .domain_is_allowed(&"example.org".parse().unwrap()),
412        None
413    )
414}
415
416#[test]
417fn regex_disallow_all_blocks_domain() {
418    let filter = DefaultDomainFilterBuilder::new();
419    filter.add_disallow_regex(".");
420    let filter = filter.to_domain_filter();
421    assert_eq!(
422        filter.domain_is_allowed(&"example.org".parse().unwrap()),
423        Some(false)
424    )
425}
426#[test]
427fn regex_allow_overrules_regex_disallow() {
428    let filter = DefaultDomainFilterBuilder::new();
429    filter.add_disallow_regex(".");
430    filter.add_allow_regex(".");
431    let filter = filter.to_domain_filter();
432    assert_eq!(
433        filter.domain_is_allowed(&"example.org".parse().unwrap()),
434        Some(true)
435    )
436}
437
438#[test]
439fn adblock_can_block_domain_and_subdomain() {
440    let filter = DefaultDomainFilterBuilder::new();
441    filter.add_adblock_rule("||example.com^");
442    let filter = filter.to_domain_filter();
443    assert_eq!(
444        filter.domain_is_allowed(&"example.com".parse().unwrap()),
445        Some(false)
446    );
447    assert_eq!(
448        filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
449        Some(false)
450    )
451}
452#[test]
453fn adblock_does_not_block_subdomain_for_exact() {
454    let filter = DefaultDomainFilterBuilder::new();
455    filter.add_adblock_rule("|example.com^");
456    let filter = filter.to_domain_filter();
457    assert_eq!(
458        filter.domain_is_allowed(&"example.com".parse().unwrap()),
459        Some(false)
460    );
461    assert_eq!(
462        filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
463        None
464    )
465}
466
467#[test]
468fn adblock_does_not_block_filter_that_has_badfilter() {
469    let filter = DefaultDomainFilterBuilder::new();
470    filter.add_adblock_rule("||cedexis.net^$third-party");
471    filter.add_adblock_rule("||cedexis.net^$third-party,badfilter");
472    let filter = filter.to_domain_filter();
473    assert_eq!(
474        filter.domain_is_allowed(&"cedexis.net".parse().unwrap()),
475        None
476    );
477}
478
479#[test]
480fn adblock_can_block_ip() {
481    let filter = DefaultDomainFilterBuilder::new();
482    filter.add_adblock_rule("||177.33.90.14^");
483    let filter = filter.to_domain_filter();
484    assert_eq!(
485        filter.ip_is_allowed(&"177.33.90.14".parse().unwrap()),
486        Some(false)
487    )
488}
489
490#[test]
491fn adblock_can_block_domain_document() {
492    let filter = DefaultDomainFilterBuilder::new();
493    filter.add_adblock_rule("||ditwrite.com^$document");
494    let filter = filter.to_domain_filter();
495    assert_eq!(
496        filter.domain_is_allowed(&"ditwrite.com".parse().unwrap()),
497        Some(false)
498    )
499}
500
501#[test]
502fn adblock_can_block_with_partial_domains() {
503    let filter = DefaultDomainFilterBuilder::new();
504    filter.add_adblock_rule("-ad.example.com");
505    let filter = filter.to_domain_filter();
506    assert_eq!(
507        filter.domain_is_allowed(&"2-ad.example.com".parse().unwrap()),
508        Some(false)
509    )
510}
511
512#[test]
513fn adblock_can_whitelist_domain() {
514    let filter = DefaultDomainFilterBuilder::new();
515    filter.add_disallow_regex(".");
516    filter.add_adblock_rule("@@||example.com^");
517    let filter = filter.to_domain_filter();
518    assert_eq!(
519        filter.domain_is_allowed(&"example.com".parse().unwrap()),
520        Some(true)
521    );
522    assert_eq!(
523        filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
524        Some(true)
525    )
526}
527
528#[test]
529fn adblock_does_not_whitelist_domain_for_exact() {
530    let filter = DefaultDomainFilterBuilder::new();
531    filter.add_adblock_rule("@@|example.com^");
532    let filter = filter.to_domain_filter();
533    assert_eq!(
534        filter.domain_is_allowed(&"example.com".parse().unwrap()),
535        Some(true)
536    );
537    assert_eq!(
538        filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
539        None
540    )
541}
542
543#[test]
544fn adblock_third_party_does_block_domain() {
545    let filter = DefaultDomainFilterBuilder::new();
546    filter.add_adblock_rule("||example.com^$third-party");
547    let filter = filter.to_domain_filter();
548    assert_eq!(
549        filter.domain_is_allowed(&"example.com".parse().unwrap()),
550        Some(false)
551    );
552    assert_eq!(
553        filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
554        Some(false)
555    )
556}
557
558#[test]
559fn adblock_https_does_block_domain() {
560    let filter = DefaultDomainFilterBuilder::new();
561    filter.add_adblock_rule("https://r.i.ua^");
562    let filter = filter.to_domain_filter();
563    assert_eq!(
564        filter.domain_is_allowed(&"r.i.ua".parse().unwrap()),
565        Some(false)
566    );
567}
568
569#[test]
570fn subdomain_disallow_blocks() {
571    let filter = DefaultDomainFilterBuilder::new();
572    filter.add_disallow_subdomain("example.com".parse().unwrap());
573    let filter = filter.to_domain_filter();
574    assert_eq!(
575        filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
576        Some(false)
577    )
578}
579
580#[test]
581fn subdomain_allow_whitelists_domains() {
582    let filter = DefaultDomainFilterBuilder::new();
583    filter.add_disallow_regex(".");
584    filter.add_allow_subdomain("example.com".parse().unwrap());
585    let filter = filter.to_domain_filter();
586    assert_eq!(
587        filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
588        Some(true)
589    )
590}
591
592#[test]
593fn subdomain_disallow_does_not_block_domain() {
594    let filter = DefaultDomainFilterBuilder::new();
595    filter.add_disallow_subdomain("example.com".parse().unwrap());
596    let filter = filter.to_domain_filter();
597    assert_eq!(
598        filter.domain_is_allowed(&"example.com".parse().unwrap()),
599        None
600    )
601}
602
603#[test]
604fn blocked_cname_blocks_base() {
605    let filter = DefaultDomainFilterBuilder::new();
606    filter.add_disallow_domain("tracker.com".parse().unwrap());
607    let filter = filter.to_domain_filter();
608    assert_eq!(
609        filter.allowed(
610            &"example.com".parse().unwrap(),
611            &["tracker.com".parse().unwrap()],
612            &[]
613        ),
614        Some(false)
615    )
616}
617
618#[test]
619fn blocked_ip_blocks_base() {
620    let filter = DefaultDomainFilterBuilder::new();
621    filter.add_disallow_ip_addr("8.8.8.8".parse().unwrap());
622    let filter = filter.to_domain_filter();
623    assert_eq!(
624        filter.allowed(
625            &"example.com".parse().unwrap(),
626            &[],
627            &["8.8.8.8".parse().unwrap()]
628        ),
629        Some(false)
630    )
631}
632
633#[test]
634fn blocked_ip_net_blocks_base() {
635    let filter = DefaultDomainFilterBuilder::new();
636    filter.add_disallow_ip_subnet("8.8.8.0/24".parse().unwrap());
637    let filter = filter.to_domain_filter();
638    assert_eq!(
639        filter.allowed(
640            &"example.com".parse().unwrap(),
641            &[],
642            &["8.8.8.8".parse().unwrap()]
643        ),
644        Some(false)
645    )
646}
647
648#[test]
649fn ignores_allowed_ips() {
650    let filter = DefaultDomainFilterBuilder::new();
651    filter.add_disallow_domain("example.com".parse().unwrap());
652    filter.add_allow_ip_addr("8.8.8.8".parse().unwrap());
653    let filter = filter.to_domain_filter();
654    assert_eq!(
655        filter.allowed(
656            &"example.com".parse().unwrap(),
657            &[],
658            &["8.8.8.8".parse().unwrap()]
659        ),
660        Some(false)
661    )
662}
663
664#[test]
665fn unblocked_ips_do_not_allow() {
666    let filter = DefaultDomainFilterBuilder::new();
667    filter.add_allow_ip_addr("8.8.8.8".parse().unwrap());
668    let filter = filter.to_domain_filter();
669    assert_eq!(
670        filter.allowed(
671            &"example.com".parse().unwrap(),
672            &[],
673            &["8.8.8.8".parse().unwrap()]
674        ),
675        None
676    )
677}