kapiti 0.0.3

The Kapiti DNS Server
Documentation
use std::path::{Path, PathBuf};
use std::vec::Vec;

use anyhow::{Context, Result};
use hyper::{Client, Uri};
use sha2::{Digest, Sha256};
use tracing::trace;

use crate::filter::{path, reader, updater};
use crate::{fetcher, hyper_smol};

/// An iterator that goes over the parent domains of a provided child domain.
/// For example, www.domain.com => [www.domain.com, domain.com, com]
struct DomainParentIter<'a> {
    domain: &'a String,
    start_idx: usize,
}

impl<'a> DomainParentIter<'a> {
    fn new(domain: &'a String) -> DomainParentIter<'a> {
        DomainParentIter {
            domain,
            start_idx: 0,
        }
    }
}

impl<'a> Iterator for DomainParentIter<'a> {
    type Item = &'a str;

    fn next(&mut self) -> Option<&'a str> {
        if self.start_idx >= self.domain.len() {
            // Seeked past end of domain string, nothing left
            None
        } else {
            // Collect this result: everything from start_idx
            let remainder = &self.domain[self.start_idx..];
            // Update start for next result
            match remainder.find('.') {
                Some(idx) => {
                    // idx is within remainder's address space, which starts at start_idx
                    // add 1 to seek past the '.' itself
                    self.start_idx += idx + 1;
                }
                None => {
                    self.start_idx = self.domain.len();
                }
            }
            Some(remainder)
        }
    }
}

pub struct Filter {
    override_files: Vec<reader::FilterFile>,
    block_files: Vec<reader::FilterFile>,
}

impl Filter {
    pub fn new() -> Filter {
        Filter {
            override_files: vec![],
            block_files: vec![],
        }
    }

    pub fn update_override_entries(self: &mut Filter, files: Vec<reader::FilterFile>) {
        for file in files {
            upsert_entries(&mut self.override_files, file);
        }
    }

    pub fn update_block_entries(self: &mut Filter, files: Vec<reader::FilterFile>) {
        for file in files {
            upsert_entries(&mut self.block_files, file);
        }
    }

    pub fn set_hardcoded_block(self: &mut Filter, block_names: &[&str]) -> Result<()> {
        let hardcoded_entries = reader::block_hardcoded(block_names)?;
        upsert_entries(&mut self.block_files, hardcoded_entries);
        Ok(())
    }

    /// Search all filters for the domain in ancestor order.
    /// For example check all filters for 'www.example.com', then all again for 'example.com'.
    /// This allows file B with 'www.example.com' to take precedence over file A with 'example.com'
    /// Meanwhile if two files mention the exact same name then the first file in the list wins.
    /// So if file A says "127.0.0.1" and file B says "172.16.0.1" then "127.0.0.1" wins.
    pub fn check(
        self: &Filter,
        host: &String,
    ) -> Option<(&Option<reader::FileInfo>, &reader::FilterEntry)> {
        // NOTE: wildcards like 'foo*.example.com', 'foo.*.example.com' are not supported,
        // and it'd probably be too much trouble to deal with trying to support them.
        // Meanwhile due to how we parse the domains, wildcards like '*.example.com' are already supported.
        for domain_str in DomainParentIter::new(&host) {
            let domain = domain_str.to_string();
            for override_file in &self.override_files {
                match override_file.content.get(&domain) {
                    // Found in an override file: Tell upstream to let it through or use provided override value
                    Some(entry) => return Some((&override_file.info, entry)),
                    None => {}
                }
            }
            for block_file in &self.block_files {
                match block_file.content.get(&domain) {
                    // Found in block: Tell upstream to block it or use filter-provided override
                    Some(entry) => return Some((&block_file.info, entry)),
                    None => {}
                }
            }
        }

        return None;
    }
}

/// Returns the local path where the file was downloaded,
/// and whether the file was updated (true) or the update was skipped (false)
pub async fn update_if_url(
    fetch_client: &Client<hyper_smol::SmolConnector>,
    filters_dir: &PathBuf,
    filter_path_or_url: &String,
    timeout_ms: u64,
) -> Result<(reader::FileInfo, bool)> {
    // Check if this is a file or URL
    if let Ok(filter_uri) = Uri::try_from(filter_path_or_url) {
        // Filesystem paths can get parsed as URLs with no scheme
        if filter_uri.scheme() == None {
            trace!(
                "Assuming that no-schema filter path is local: {}",
                filter_path_or_url
            );
            return Ok((
                reader::FileInfo {
                    source_path: filter_path_or_url.clone(),
                    local_path: filter_path_or_url.clone(),
                },
                false,
            ));
        }

        let fetcher = fetcher::Fetcher::new(10 * 1024 * 1024, None);
        // We download files to the exact SHA of the URL string we were provided.
        // This is an easy way to avoid filename collisions in URLs: example1.com/hosts vs example2.com/hosts
        // If the user changes the URL string then that changes the SHA, perfect for "cache invalidation" purposes.
        let hosts_path_sha = Sha256::digest(filter_path_or_url.as_bytes());
        let download_path = Path::new(filters_dir).join(format!(
            "{:x}.sha256.{}",
            hosts_path_sha,
            path::ZSTD_EXTENSION
        ));
        let downloaded = updater::update_file(
            fetch_client,
            &fetcher,
            filter_path_or_url,
            download_path.as_path(),
            timeout_ms,
        )
        .await?;
        Ok((
            reader::FileInfo {
                source_path: filter_path_or_url.clone(),
                local_path: download_path
                    .to_str()
                    .with_context(|| format!("busted download path: {:?}", download_path))?
                    .to_string(),
            },
            downloaded,
        ))
    } else {
        trace!(
            "Assuming that non-url filter path is local: {}",
            filter_path_or_url
        );
        return Ok((
            reader::FileInfo {
                source_path: filter_path_or_url.clone(),
                local_path: filter_path_or_url.clone(),
            },
            false,
        ));
    }
}

fn upsert_entries(entries: &mut Vec<reader::FilterFile>, new_file: reader::FilterFile) {
    if let Some(new_file_info) = &new_file.info {
        // Before adding a new file entry, check for an existing file entry to be replaced/updated.
        for i in 0..entries.len() {
            let entry = entries.get(i).expect("incoherent vector size");
            if let Some(existing_file_info) = &entry.info {
                if existing_file_info.local_path == new_file_info.local_path {
                    // Delete or replace existing version
                    if new_file.content.is_empty() {
                        entries.remove(i);
                    } else {
                        entries.insert(i, new_file);
                    }
                    return;
                }
            }
        }
    }
    // Add new entry
    if !new_file.content.is_empty() {
        entries.push(new_file);
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn iter_empty() {
        let domain = "".to_string();
        let mut iter = DomainParentIter::new(&domain);
        assert_eq!(None, iter.next());
    }

    #[test]
    fn iter_com() {
        let domain = "com".to_string();
        let mut iter = DomainParentIter::new(&domain);
        assert_eq!(Some("com"), iter.next());
        assert_eq!(None, iter.next());
    }

    #[test]
    fn iter_domaincom() {
        let domain = "domain.com".to_string();
        let mut iter = DomainParentIter::new(&domain);
        assert_eq!(Some("domain.com"), iter.next());
        assert_eq!(Some("com"), iter.next());
        assert_eq!(None, iter.next());
    }

    #[test]
    fn iter_wwwdomaincom() {
        let domain = "www.domain.com".to_string();
        let mut iter = DomainParentIter::new(&domain);
        assert_eq!(Some("www.domain.com"), iter.next());
        assert_eq!(Some("domain.com"), iter.next());
        assert_eq!(Some("com"), iter.next());
        assert_eq!(None, iter.next());
    }

    #[test]
    fn iter_wwwngeeknz() {
        let domain = "www.n.geek.nz".to_string();
        let mut iter = DomainParentIter::new(&domain);
        assert_eq!(Some("www.n.geek.nz"), iter.next());
        assert_eq!(Some("n.geek.nz"), iter.next());
        assert_eq!(Some("geek.nz"), iter.next());
        assert_eq!(Some("nz"), iter.next());
        assert_eq!(None, iter.next());
    }

    #[test]
    fn iter_averylongteststringwithmanysegments() {
        let domain = "a.very-long.test.string.with-many.segments".to_string();
        let mut iter = DomainParentIter::new(&domain);
        assert_eq!(
            Some("a.very-long.test.string.with-many.segments"),
            iter.next()
        );
        assert_eq!(
            Some("very-long.test.string.with-many.segments"),
            iter.next()
        );
        assert_eq!(Some("test.string.with-many.segments"), iter.next());
        assert_eq!(Some("string.with-many.segments"), iter.next());
        assert_eq!(Some("with-many.segments"), iter.next());
        assert_eq!(Some("segments"), iter.next());
        assert_eq!(None, iter.next());
    }
}