1use crate::{Domain, DomainSetSharded};
2
3use std::collections::HashSet;
4use std::str::FromStr;
5
6use parking_lot::Mutex;
7
8struct Filter<T> {
9 allow: T,
10 disallow: T,
11}
12
13impl<T: Default> Default for Filter<T> {
14 fn default() -> Self {
15 Self {
16 allow: T::default(),
17 disallow: T::default(),
18 }
19 }
20}
21
22impl<T> Filter<T> {
23 fn new(allow: T, disallow: T) -> Self {
24 Self { allow, disallow }
25 }
26}
27
28#[derive(Debug, Default)]
29struct AdblockParseError {}
30
31impl std::error::Error for AdblockParseError {
32 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
33 None
34 }
35}
36
37impl std::fmt::Display for AdblockParseError {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 write!(f, "{:?}", self)
40 }
41}
42
43#[derive(Clone, Eq, PartialEq, Hash, Debug)]
44struct AdblockFilter {
45 is_exception: bool,
46 match_start_domain: bool,
47 match_start_address: bool,
48 match_end_domain: bool,
49 filter: String,
50 is_badfilter: bool,
51}
52
53impl FromStr for AdblockFilter {
54 type Err = AdblockParseError;
55 fn from_str(original_rule: &str) -> Result<Self, Self::Err> {
56 let mut rule = original_rule;
57 if rule.starts_with('!') {
58 return Err(AdblockParseError::default());
60 }
61 if rule.contains('#')
62 {
64 return Err(AdblockParseError::default());
65 }
66 let mut is_badfilter = false;
67 if let Some(position) = rule.find('$') {
68 let (start, tags) = rule.split_at(position);
69 let tags = tags
70 .trim_start_matches('$')
71 .split(',')
72 .filter(|tag| !matches!(*tag, "3p" | "third-party"))
73 .filter(|tag| {
74 if *tag == "badfilter" {
75 is_badfilter = true;
76 false
77 } else {
78 true
79 }
80 })
81 .collect::<Vec<&str>>();
82
83 if !(tags.is_empty() || tags.contains(&"document") || tags.contains(&"all"))
84 || tags.iter().any(|tag| tag.starts_with("domain="))
85 {
86 return Err(AdblockParseError::default());
87 }
88 rule = start;
89 }
90 let (rule, is_exception) = rule
91 .strip_prefix("@@")
92 .map(|rule| (rule, true))
93 .unwrap_or((rule, false));
94
95 let (rule, match_start_domain, match_start_address) =
96 if let Some(rule) = rule.strip_prefix("||") {
97 (rule, true, false)
98 } else {
99 let (rule, match_start_address) = rule
100 .strip_prefix('|')
101 .map(|rule| (rule, true))
102 .unwrap_or((rule, false));
103 let (rule, match_start_address) = rule
104 .strip_prefix('*')
105 .or_else(|| rule.strip_prefix("https"))
106 .or_else(|| rule.strip_prefix("http"))
107 .unwrap_or(rule)
108 .strip_prefix("://")
109 .map(|rule| (rule, true))
110 .unwrap_or((rule, match_start_address));
111 (rule, false, match_start_address)
112 };
113 let (rule, match_start_domain, match_start_address) = rule
114 .strip_prefix('*')
115 .unwrap_or(rule)
116 .strip_prefix('.')
117 .map(|rule| (rule, true, false))
118 .unwrap_or((rule, match_start_domain, match_start_address));
119
120 let (rule, match_end_domain) = rule
121 .strip_suffix('|')
122 .map(|rule| (rule, true))
123 .unwrap_or((rule, false));
124 let (rule, match_end_domain) = rule
125 .strip_suffix(".php")
126 .or_else(|| rule.strip_suffix(".htm"))
127 .or_else(|| rule.strip_suffix(".html"))
128 .or_else(|| rule.strip_suffix(".xhtml"))
129 .unwrap_or(rule)
130 .strip_suffix('*')
131 .and_then(|rule| rule.strip_suffix('/').or(Some(rule)))
132 .map(|rule| (rule, true))
133 .unwrap_or((rule, match_end_domain));
134 let (rule, match_end_domain) = rule
135 .strip_suffix('^')
136 .map(|rule| (rule, true))
137 .unwrap_or((rule, match_end_domain));
138 let (rule, match_end_domain) = rule
139 .strip_suffix('.')
140 .map(|rule| (rule, false))
141 .unwrap_or((rule, match_end_domain));
142 if rule.is_empty()
143 || rule == "*"
144 || !rule
145 .chars()
146 .all(|c| c.is_ascii_alphanumeric() || matches!(c, '_' | '-' | '.' | '*'))
147 {
148 return Err(AdblockParseError::default());
149 }
150 Ok(Self {
151 is_exception,
152 match_start_domain,
153 match_start_address,
154 match_end_domain,
155 filter: rule.to_string(),
156 is_badfilter,
157 })
158 }
159}
160
161#[derive(Default)]
162pub struct DomainFilterBuilder<H: std::hash::BuildHasher + Default> {
163 domains: Filter<DomainSetSharded<H>>,
164 subdomains: Filter<DomainSetSharded<H>>,
165 ips: Mutex<Filter<HashSet<std::net::IpAddr, H>>>,
166 ip_nets: Mutex<Filter<HashSet<ipnet::IpNet, H>>>,
167 regexes: Mutex<Filter<HashSet<String, H>>>,
168 adblock: Mutex<HashSet<AdblockFilter, H>>,
169}
170
171type DefaultHasher = std::collections::hash_map::RandomState;
172
173pub type DefaultDomainFilterBuilder = DomainFilterBuilder<DefaultHasher>;
174
175impl<H: std::hash::BuildHasher + Default> DomainFilterBuilder<H> {
176 pub fn new() -> Self {
177 Default::default()
178 }
179 pub fn add_allow_domain(&self, domain: Domain) {
180 if let Some(without_www) = domain
181 .strip_prefix("www.")
182 .and_then(|domain| domain.parse::<Domain>().ok())
183 {
184 self.domains.disallow.remove_str(&without_www);
185 self.domains.allow.insert_str(&without_www);
186 } else if let Ok(with_www) = format!("www.{}", &domain).parse::<Domain>() {
187 self.domains.disallow.remove_str(&with_www);
188 self.domains.allow.insert_str(&with_www);
189 }
190 self.domains.disallow.remove_str(&domain);
191 self.domains.allow.insert_str(&domain);
192 }
193 pub fn add_disallow_domain(&self, domain: Domain) {
194 if !self.domains.allow.contains_str(&domain)
195 && !is_subdomain_of_list(&domain, &self.subdomains.allow)
196 {
197 self.domains.disallow.insert_str(&domain);
198 }
199 }
200 pub fn add_allow_subdomain(&self, domain: Domain) {
201 self.subdomains.disallow.remove_str(&domain);
202 self.subdomains.allow.insert_str(&domain);
203 }
204 pub fn add_disallow_subdomain(&self, domain: Domain) {
205 if !self.subdomains.allow.contains_str(&domain) {
206 self.subdomains.disallow.insert_str(&domain);
207 }
208 }
209
210 pub fn add_allow_ip_addr(&self, ip: std::net::IpAddr) {
211 let mut ips = self.ips.lock();
212 let _ = ips.disallow.remove(&ip);
213 ips.allow.insert(ip);
214 }
215 pub fn add_disallow_ip_addr(&self, ip: std::net::IpAddr) {
216 let mut ips = self.ips.lock();
217 if !ips.allow.contains(&ip) {
218 ips.disallow.insert(ip);
219 }
220 }
221
222 pub fn add_allow_ip_subnet(&self, net: ipnet::IpNet) {
223 let mut ip_nets = self.ip_nets.lock();
224 let _ = ip_nets.disallow.remove(&net);
225 ip_nets.allow.insert(net);
226 }
227
228 pub fn add_disallow_ip_subnet(&self, ip: ipnet::IpNet) {
229 let mut ip_nets = self.ip_nets.lock();
230 if !ip_nets.allow.contains(&ip) {
231 ip_nets.disallow.insert(ip);
232 }
233 }
234
235 pub fn add_adblock_rule(&self, rule: &str) {
236 if let Ok(filter) = rule.parse::<AdblockFilter>() {
237 self.adblock.lock().insert(filter);
238 }
239 }
240
241 pub fn add_allow_regex(&self, re: &str) {
242 if !re.is_empty() && regex::Regex::new(re).is_ok() {
243 let mut regexes = self.regexes.lock();
244 regexes.allow.insert(re.to_string());
245 }
246 }
247 pub fn add_disallow_regex(&self, re: &str) {
248 if !re.is_empty() && regex::Regex::new(re).is_ok() {
249 let mut regexes = self.regexes.lock();
250 regexes.disallow.insert(re.to_string());
251 }
252 }
253
254 pub fn to_domain_filter(self) -> DomainFilter<H> {
255 let adblock = std::mem::take(&mut *self.adblock.lock());
256
257 for filter in adblock.iter() {
258 let mut bad_filter = filter.clone();
259 bad_filter.is_badfilter = true;
260 if adblock.contains(&bad_filter) {
261 continue;
262 }
263 if (filter.match_start_domain || filter.match_start_address) && filter.match_end_domain
264 {
265 if let Ok(domain) = filter.filter.parse::<Domain>() {
266 if filter.is_exception {
267 self.add_allow_domain(domain.clone());
268 if filter.match_start_domain {
269 self.add_allow_subdomain(domain);
270 }
271 } else {
272 self.add_disallow_domain(domain.clone());
273 if filter.match_start_domain {
274 self.add_disallow_subdomain(domain);
275 }
276 }
277 }
278 if let Ok(ip) = filter.filter.parse::<std::net::IpAddr>() {
279 if !filter.is_exception {
280 self.add_disallow_ip_addr(ip);
282 }
283 }
284 }
285 }
286
287 let domains = self.domains;
288 let subdomains = self.subdomains;
289 let mut ips = std::mem::take(&mut *self.ips.lock());
290 let ip_nets = std::mem::take(&mut *self.ip_nets.lock());
291 let regexes = std::mem::take(&mut *self.regexes.lock());
292
293 domains.allow.shrink_to_fit();
294 domains.disallow.shrink_to_fit();
295 subdomains.allow.shrink_to_fit();
296 subdomains.disallow.shrink_to_fit();
297 ips.allow.shrink_to_fit();
298 ips.disallow.shrink_to_fit();
299 let ip_nets = Filter {
300 allow: ip_nets.allow.into_iter().collect(),
301 disallow: ip_nets.disallow.into_iter().collect(),
302 };
303 DomainFilter {
304 domains,
305 subdomains,
306 ips,
307 ip_nets: ip_nets,
308 allow_regex: regex::RegexSet::new(®exes.allow).unwrap(),
309 disallow_regex: regex::RegexSet::new(®exes.disallow).unwrap(),
310 }
311 }
312}
313
314fn is_subdomain_of_list<H: std::hash::BuildHasher>(
315 domain: &Domain,
316 filter_list: &DomainSetSharded<H>,
317) -> bool {
318 Domain::str_iter_parent_domains(domain).any(|part| filter_list.contains_str(part))
319}
320
321pub struct DomainFilter<H: std::hash::BuildHasher + Default> {
322 domains: Filter<DomainSetSharded<H>>,
323 subdomains: Filter<DomainSetSharded<H>>,
324 ips: Filter<HashSet<std::net::IpAddr, H>>,
325 ip_nets: Filter<Vec<ipnet::IpNet>>,
326 allow_regex: regex::RegexSet,
327 disallow_regex: regex::RegexSet,
328}
329
330impl<H: std::hash::BuildHasher + Default> Default for DomainFilter<H> {
331 fn default() -> Self {
332 Self {
333 domains: Filter::new(
334 DomainSetSharded::<H>::with_shards(0),
335 DomainSetSharded::<H>::with_shards(0),
336 ),
337 subdomains: Filter::new(
338 DomainSetSharded::<H>::with_shards(0),
339 DomainSetSharded::<H>::with_shards(0),
340 ),
341 ips: Default::default(),
342 ip_nets: Default::default(),
343 allow_regex: regex::RegexSet::empty(),
344 disallow_regex: regex::RegexSet::empty(),
345 }
346 }
347}
348
349impl<H: std::hash::BuildHasher + Default> DomainFilter<H> {
350 fn is_allowed_by_adblock(&self, _location: &str) -> Option<bool> {
351 None
352 }
353
354 pub fn allowed(
355 &self,
356 domain: &Domain,
357 cnames: &[Domain],
358 ips: &[std::net::IpAddr],
359 ) -> Option<bool> {
360 if let Some(result) = self.domain_is_allowed(domain) {
361 Some(result)
362 } else if cnames
363 .iter()
364 .any(|cname| self.domain_is_allowed(cname) == Some(false))
365 || ips.iter().any(|ip| self.ip_is_allowed(ip) == Some(false))
366 {
367 Some(false)
368 } else {
369 None
370 }
371 }
372
373 fn domain_is_allowed(&self, domain: &Domain) -> Option<bool> {
374 if self.domains.allow.contains_str(&domain)
375 || is_subdomain_of_list(&*domain, &self.subdomains.allow)
376 || self.allow_regex.is_match(domain)
377 {
378 Some(true)
379 } else if let Some(blocker_result) = self.is_allowed_by_adblock(&domain) {
380 Some(blocker_result)
381 } else if self.domains.disallow.contains_str(&domain)
382 || is_subdomain_of_list(&*domain, &self.subdomains.disallow)
383 || self.disallow_regex.is_match(domain)
384 {
385 Some(false)
386 } else {
387 None
388 }
389 }
390
391 pub fn ip_is_allowed(&self, ip: &std::net::IpAddr) -> Option<bool> {
392 if self.ips.allow.contains(ip) || self.ip_nets.allow.iter().any(|net| net.contains(ip)) {
393 Some(true)
394 } else if let Some(blocker_result) = self.is_allowed_by_adblock(&ip.to_string()) {
395 Some(blocker_result)
396 } else if self.ips.disallow.contains(ip)
397 || self.ip_nets.disallow.iter().any(|net| net.contains(ip))
398 {
399 Some(false)
400 } else {
401 None
402 }
403 }
404}
405
406#[test]
407fn default_unblocked() {
408 assert_eq!(
409 DefaultDomainFilterBuilder::new()
410 .to_domain_filter()
411 .domain_is_allowed(&"example.org".parse().unwrap()),
412 None
413 )
414}
415
416#[test]
417fn regex_disallow_all_blocks_domain() {
418 let filter = DefaultDomainFilterBuilder::new();
419 filter.add_disallow_regex(".");
420 let filter = filter.to_domain_filter();
421 assert_eq!(
422 filter.domain_is_allowed(&"example.org".parse().unwrap()),
423 Some(false)
424 )
425}
426#[test]
427fn regex_allow_overrules_regex_disallow() {
428 let filter = DefaultDomainFilterBuilder::new();
429 filter.add_disallow_regex(".");
430 filter.add_allow_regex(".");
431 let filter = filter.to_domain_filter();
432 assert_eq!(
433 filter.domain_is_allowed(&"example.org".parse().unwrap()),
434 Some(true)
435 )
436}
437
438#[test]
439fn adblock_can_block_domain_and_subdomain() {
440 let filter = DefaultDomainFilterBuilder::new();
441 filter.add_adblock_rule("||example.com^");
442 let filter = filter.to_domain_filter();
443 assert_eq!(
444 filter.domain_is_allowed(&"example.com".parse().unwrap()),
445 Some(false)
446 );
447 assert_eq!(
448 filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
449 Some(false)
450 )
451}
452#[test]
453fn adblock_does_not_block_subdomain_for_exact() {
454 let filter = DefaultDomainFilterBuilder::new();
455 filter.add_adblock_rule("|example.com^");
456 let filter = filter.to_domain_filter();
457 assert_eq!(
458 filter.domain_is_allowed(&"example.com".parse().unwrap()),
459 Some(false)
460 );
461 assert_eq!(
462 filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
463 None
464 )
465}
466
467#[test]
468fn adblock_does_not_block_filter_that_has_badfilter() {
469 let filter = DefaultDomainFilterBuilder::new();
470 filter.add_adblock_rule("||cedexis.net^$third-party");
471 filter.add_adblock_rule("||cedexis.net^$third-party,badfilter");
472 let filter = filter.to_domain_filter();
473 assert_eq!(
474 filter.domain_is_allowed(&"cedexis.net".parse().unwrap()),
475 None
476 );
477}
478
479#[test]
480fn adblock_can_block_ip() {
481 let filter = DefaultDomainFilterBuilder::new();
482 filter.add_adblock_rule("||177.33.90.14^");
483 let filter = filter.to_domain_filter();
484 assert_eq!(
485 filter.ip_is_allowed(&"177.33.90.14".parse().unwrap()),
486 Some(false)
487 )
488}
489
490#[test]
491fn adblock_can_block_domain_document() {
492 let filter = DefaultDomainFilterBuilder::new();
493 filter.add_adblock_rule("||ditwrite.com^$document");
494 let filter = filter.to_domain_filter();
495 assert_eq!(
496 filter.domain_is_allowed(&"ditwrite.com".parse().unwrap()),
497 Some(false)
498 )
499}
500
501#[test]
502fn adblock_can_block_with_partial_domains() {
503 let filter = DefaultDomainFilterBuilder::new();
504 filter.add_adblock_rule("-ad.example.com");
505 let filter = filter.to_domain_filter();
506 assert_eq!(
507 filter.domain_is_allowed(&"2-ad.example.com".parse().unwrap()),
508 Some(false)
509 )
510}
511
512#[test]
513fn adblock_can_whitelist_domain() {
514 let filter = DefaultDomainFilterBuilder::new();
515 filter.add_disallow_regex(".");
516 filter.add_adblock_rule("@@||example.com^");
517 let filter = filter.to_domain_filter();
518 assert_eq!(
519 filter.domain_is_allowed(&"example.com".parse().unwrap()),
520 Some(true)
521 );
522 assert_eq!(
523 filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
524 Some(true)
525 )
526}
527
528#[test]
529fn adblock_does_not_whitelist_domain_for_exact() {
530 let filter = DefaultDomainFilterBuilder::new();
531 filter.add_adblock_rule("@@|example.com^");
532 let filter = filter.to_domain_filter();
533 assert_eq!(
534 filter.domain_is_allowed(&"example.com".parse().unwrap()),
535 Some(true)
536 );
537 assert_eq!(
538 filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
539 None
540 )
541}
542
543#[test]
544fn adblock_third_party_does_block_domain() {
545 let filter = DefaultDomainFilterBuilder::new();
546 filter.add_adblock_rule("||example.com^$third-party");
547 let filter = filter.to_domain_filter();
548 assert_eq!(
549 filter.domain_is_allowed(&"example.com".parse().unwrap()),
550 Some(false)
551 );
552 assert_eq!(
553 filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
554 Some(false)
555 )
556}
557
558#[test]
559fn adblock_https_does_block_domain() {
560 let filter = DefaultDomainFilterBuilder::new();
561 filter.add_adblock_rule("https://r.i.ua^");
562 let filter = filter.to_domain_filter();
563 assert_eq!(
564 filter.domain_is_allowed(&"r.i.ua".parse().unwrap()),
565 Some(false)
566 );
567}
568
569#[test]
570fn subdomain_disallow_blocks() {
571 let filter = DefaultDomainFilterBuilder::new();
572 filter.add_disallow_subdomain("example.com".parse().unwrap());
573 let filter = filter.to_domain_filter();
574 assert_eq!(
575 filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
576 Some(false)
577 )
578}
579
580#[test]
581fn subdomain_allow_whitelists_domains() {
582 let filter = DefaultDomainFilterBuilder::new();
583 filter.add_disallow_regex(".");
584 filter.add_allow_subdomain("example.com".parse().unwrap());
585 let filter = filter.to_domain_filter();
586 assert_eq!(
587 filter.domain_is_allowed(&"example_subdomain.example.com".parse().unwrap()),
588 Some(true)
589 )
590}
591
592#[test]
593fn subdomain_disallow_does_not_block_domain() {
594 let filter = DefaultDomainFilterBuilder::new();
595 filter.add_disallow_subdomain("example.com".parse().unwrap());
596 let filter = filter.to_domain_filter();
597 assert_eq!(
598 filter.domain_is_allowed(&"example.com".parse().unwrap()),
599 None
600 )
601}
602
603#[test]
604fn blocked_cname_blocks_base() {
605 let filter = DefaultDomainFilterBuilder::new();
606 filter.add_disallow_domain("tracker.com".parse().unwrap());
607 let filter = filter.to_domain_filter();
608 assert_eq!(
609 filter.allowed(
610 &"example.com".parse().unwrap(),
611 &["tracker.com".parse().unwrap()],
612 &[]
613 ),
614 Some(false)
615 )
616}
617
618#[test]
619fn blocked_ip_blocks_base() {
620 let filter = DefaultDomainFilterBuilder::new();
621 filter.add_disallow_ip_addr("8.8.8.8".parse().unwrap());
622 let filter = filter.to_domain_filter();
623 assert_eq!(
624 filter.allowed(
625 &"example.com".parse().unwrap(),
626 &[],
627 &["8.8.8.8".parse().unwrap()]
628 ),
629 Some(false)
630 )
631}
632
633#[test]
634fn blocked_ip_net_blocks_base() {
635 let filter = DefaultDomainFilterBuilder::new();
636 filter.add_disallow_ip_subnet("8.8.8.0/24".parse().unwrap());
637 let filter = filter.to_domain_filter();
638 assert_eq!(
639 filter.allowed(
640 &"example.com".parse().unwrap(),
641 &[],
642 &["8.8.8.8".parse().unwrap()]
643 ),
644 Some(false)
645 )
646}
647
648#[test]
649fn ignores_allowed_ips() {
650 let filter = DefaultDomainFilterBuilder::new();
651 filter.add_disallow_domain("example.com".parse().unwrap());
652 filter.add_allow_ip_addr("8.8.8.8".parse().unwrap());
653 let filter = filter.to_domain_filter();
654 assert_eq!(
655 filter.allowed(
656 &"example.com".parse().unwrap(),
657 &[],
658 &["8.8.8.8".parse().unwrap()]
659 ),
660 Some(false)
661 )
662}
663
664#[test]
665fn unblocked_ips_do_not_allow() {
666 let filter = DefaultDomainFilterBuilder::new();
667 filter.add_allow_ip_addr("8.8.8.8".parse().unwrap());
668 let filter = filter.to_domain_filter();
669 assert_eq!(
670 filter.allowed(
671 &"example.com".parse().unwrap(),
672 &[],
673 &["8.8.8.8".parse().unwrap()]
674 ),
675 None
676 )
677}