use std::path::{Path, PathBuf};
use std::vec::Vec;
use anyhow::{Context, Result};
use hyper::{Client, Uri};
use sha2::{Digest, Sha256};
use tracing::trace;
use crate::filter::{path, reader, updater};
use crate::{fetcher, hyper_smol};
struct DomainParentIter<'a> {
domain: &'a String,
start_idx: usize,
}
impl<'a> DomainParentIter<'a> {
fn new(domain: &'a String) -> DomainParentIter<'a> {
DomainParentIter {
domain,
start_idx: 0,
}
}
}
impl<'a> Iterator for DomainParentIter<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<&'a str> {
if self.start_idx >= self.domain.len() {
None
} else {
let remainder = &self.domain[self.start_idx..];
match remainder.find('.') {
Some(idx) => {
self.start_idx += idx + 1;
}
None => {
self.start_idx = self.domain.len();
}
}
Some(remainder)
}
}
}
pub struct Filter {
override_files: Vec<reader::FilterFile>,
block_files: Vec<reader::FilterFile>,
}
impl Filter {
pub fn new() -> Filter {
Filter {
override_files: vec![],
block_files: vec![],
}
}
pub fn update_override_entries(self: &mut Filter, files: Vec<reader::FilterFile>) {
for file in files {
upsert_entries(&mut self.override_files, file);
}
}
pub fn update_block_entries(self: &mut Filter, files: Vec<reader::FilterFile>) {
for file in files {
upsert_entries(&mut self.block_files, file);
}
}
pub fn set_hardcoded_block(self: &mut Filter, block_names: &[&str]) -> Result<()> {
let hardcoded_entries = reader::block_hardcoded(block_names)?;
upsert_entries(&mut self.block_files, hardcoded_entries);
Ok(())
}
pub fn check(
self: &Filter,
host: &String,
) -> Option<(&Option<reader::FileInfo>, &reader::FilterEntry)> {
for domain_str in DomainParentIter::new(&host) {
let domain = domain_str.to_string();
for override_file in &self.override_files {
match override_file.content.get(&domain) {
Some(entry) => return Some((&override_file.info, entry)),
None => {}
}
}
for block_file in &self.block_files {
match block_file.content.get(&domain) {
Some(entry) => return Some((&block_file.info, entry)),
None => {}
}
}
}
return None;
}
}
pub async fn update_if_url(
fetch_client: &Client<hyper_smol::SmolConnector>,
filters_dir: &PathBuf,
filter_path_or_url: &String,
timeout_ms: u64,
) -> Result<(reader::FileInfo, bool)> {
if let Ok(filter_uri) = Uri::try_from(filter_path_or_url) {
if filter_uri.scheme() == None {
trace!(
"Assuming that no-schema filter path is local: {}",
filter_path_or_url
);
return Ok((
reader::FileInfo {
source_path: filter_path_or_url.clone(),
local_path: filter_path_or_url.clone(),
},
false,
));
}
let fetcher = fetcher::Fetcher::new(10 * 1024 * 1024, None);
let hosts_path_sha = Sha256::digest(filter_path_or_url.as_bytes());
let download_path = Path::new(filters_dir).join(format!(
"{:x}.sha256.{}",
hosts_path_sha,
path::ZSTD_EXTENSION
));
let downloaded = updater::update_file(
fetch_client,
&fetcher,
filter_path_or_url,
download_path.as_path(),
timeout_ms,
)
.await?;
Ok((
reader::FileInfo {
source_path: filter_path_or_url.clone(),
local_path: download_path
.to_str()
.with_context(|| format!("busted download path: {:?}", download_path))?
.to_string(),
},
downloaded,
))
} else {
trace!(
"Assuming that non-url filter path is local: {}",
filter_path_or_url
);
return Ok((
reader::FileInfo {
source_path: filter_path_or_url.clone(),
local_path: filter_path_or_url.clone(),
},
false,
));
}
}
fn upsert_entries(entries: &mut Vec<reader::FilterFile>, new_file: reader::FilterFile) {
if let Some(new_file_info) = &new_file.info {
for i in 0..entries.len() {
let entry = entries.get(i).expect("incoherent vector size");
if let Some(existing_file_info) = &entry.info {
if existing_file_info.local_path == new_file_info.local_path {
if new_file.content.is_empty() {
entries.remove(i);
} else {
entries.insert(i, new_file);
}
return;
}
}
}
}
if !new_file.content.is_empty() {
entries.push(new_file);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn iter_empty() {
let domain = "".to_string();
let mut iter = DomainParentIter::new(&domain);
assert_eq!(None, iter.next());
}
#[test]
fn iter_com() {
let domain = "com".to_string();
let mut iter = DomainParentIter::new(&domain);
assert_eq!(Some("com"), iter.next());
assert_eq!(None, iter.next());
}
#[test]
fn iter_domaincom() {
let domain = "domain.com".to_string();
let mut iter = DomainParentIter::new(&domain);
assert_eq!(Some("domain.com"), iter.next());
assert_eq!(Some("com"), iter.next());
assert_eq!(None, iter.next());
}
#[test]
fn iter_wwwdomaincom() {
let domain = "www.domain.com".to_string();
let mut iter = DomainParentIter::new(&domain);
assert_eq!(Some("www.domain.com"), iter.next());
assert_eq!(Some("domain.com"), iter.next());
assert_eq!(Some("com"), iter.next());
assert_eq!(None, iter.next());
}
#[test]
fn iter_wwwngeeknz() {
let domain = "www.n.geek.nz".to_string();
let mut iter = DomainParentIter::new(&domain);
assert_eq!(Some("www.n.geek.nz"), iter.next());
assert_eq!(Some("n.geek.nz"), iter.next());
assert_eq!(Some("geek.nz"), iter.next());
assert_eq!(Some("nz"), iter.next());
assert_eq!(None, iter.next());
}
#[test]
fn iter_averylongteststringwithmanysegments() {
let domain = "a.very-long.test.string.with-many.segments".to_string();
let mut iter = DomainParentIter::new(&domain);
assert_eq!(
Some("a.very-long.test.string.with-many.segments"),
iter.next()
);
assert_eq!(
Some("very-long.test.string.with-many.segments"),
iter.next()
);
assert_eq!(Some("test.string.with-many.segments"), iter.next());
assert_eq!(Some("string.with-many.segments"), iter.next());
assert_eq!(Some("with-many.segments"), iter.next());
assert_eq!(Some("segments"), iter.next());
assert_eq!(None, iter.next());
}
}