encrypted-dns 0.9.20

A modern encrypted DNS server (DNSCrypt v2, Anonymized DNSCrypt, DoH)
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::sync::Arc;

use rustc_hash::FxHashMap;

use crate::errors::*;

#[derive(Debug)]
struct BlackListInner {
    map: FxHashMap<Vec<u8>, ()>,
}

#[derive(Clone, Debug)]
pub struct BlackList {
    inner: Arc<BlackListInner>,
    max_labels: usize,
}

fn label_count(qname: &[u8]) -> usize {
    if qname.is_empty() {
        0
    } else {
        qname.iter().filter(|&&b| b == b'.').count() + 1
    }
}

impl BlackList {
    pub fn new(map: FxHashMap<Vec<u8>, ()>) -> Self {
        let max_labels = map.keys().map(|k| label_count(k)).max().unwrap_or(0);
        let inner = Arc::new(BlackListInner { map });
        BlackList { inner, max_labels }
    }

    pub fn load(path: impl AsRef<Path>) -> Result<Self, Error> {
        let mut map = FxHashMap::default();
        let fp = BufReader::new(File::open(path)?);
        for (line_nb, line) in fp.lines().enumerate() {
            let line = line?;
            let mut line = line.trim();
            if line.is_empty() || line.starts_with('#') {
                continue;
            }
            while line.starts_with("*.") {
                line = &line[2..];
            }
            while line.ends_with('.') {
                line = &line[..line.len() - 1];
            }
            let qname = line.as_bytes().to_ascii_lowercase();
            if qname.is_empty() {
                bail!("Unexpected blacklist rule at line {}", line_nb)
            }
            map.insert(qname, ());
        }
        Ok(BlackList::new(map))
    }

    pub fn find(&self, qname: &[u8]) -> bool {
        if self.max_labels == 0 {
            return false;
        }
        let qname = qname.to_ascii_lowercase();
        let mut qname = qname.as_slice();
        let map = &self.inner.map;
        let mut start = 0;
        let mut dots = 0;
        for i in (0..qname.len()).rev() {
            if qname[i] == b'.' {
                dots += 1;
                if dots == self.max_labels {
                    start = i + 1;
                    break;
                }
            }
        }
        qname = &qname[start..];
        while !qname.is_empty() {
            if map.contains_key(qname) {
                return true;
            }
            match qname.iter().position(|&b| b == b'.') {
                Some(pos) => qname = &qname[pos + 1..],
                None => break,
            }
        }
        false
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make(entries: &[&str]) -> BlackList {
        let mut map = FxHashMap::default();
        for entry in entries {
            map.insert(entry.as_bytes().to_vec(), ());
        }
        BlackList::new(map)
    }

    #[test]
    fn exact_match() {
        let bl = make(&["evil.example"]);
        assert!(bl.find(b"evil.example"));
    }

    #[test]
    fn suffix_match_single_label() {
        let bl = make(&["evil.example"]);
        assert!(bl.find(b"sub.evil.example"));
    }

    #[test]
    fn deep_subdomain_matches_short_suffix() {
        let bl = make(&["blocked.example"]);
        assert!(bl.find(b"a.b.c.d.e.f.g.h.blocked.example"));
    }

    #[test]
    fn unrelated_name_does_not_match() {
        let bl = make(&["blocked.example"]);
        assert!(!bl.find(b"good.example"));
        assert!(!bl.find(b"a.b.c.good.example"));
    }

    #[test]
    fn parent_of_blocked_is_not_blocked() {
        let bl = make(&["sub.example.com"]);
        assert!(!bl.find(b"example.com"));
        assert!(!bl.find(b"com"));
    }

    #[test]
    fn empty_blacklist() {
        let bl = make(&[]);
        assert!(!bl.find(b"anything.example"));
    }

    #[test]
    fn case_insensitive() {
        let bl = make(&["evil.example"]);
        assert!(bl.find(b"SUB.Evil.EXAMPLE"));
    }
}