use super::error::Result;
use crate::snapshot::PUBLIC_SUFFIX_LIST;
#[cfg(feature = "reqwest")]
use crate::TLDExtractError;
use crate::TLDTrieTree;
#[cfg(feature = "reqwest")]
use reqwest::IntoUrl;
use std::collections::HashSet;
use std::io;
use std::io::BufRead;
use std::path::PathBuf;
use std::time::SystemTime;
const PUBLIC_PRIVATE_SUFFIX_SEPARATOR: &str = "// ===BEGIN PRIVATE DOMAINS===";
#[cfg(feature = "reqwest")]
const PUBLIC_SUFFIX_LIST_URLS: &[&str] = &[
"https://publicsuffix.org/list/public_suffix_list.dat",
"https://raw.githubusercontent.com/publicsuffix/list/master/public_suffix_list.dat",
];
#[derive(Debug, Clone)]
#[derive(Default)]
pub enum Source {
Text(String),
#[default]
Snapshot,
Local(PathBuf),
#[cfg(feature = "reqwest")]
Remote(Option<reqwest::Url>),
}
#[derive(Debug, Default, Clone)]
pub struct SuffixList {
pub source: Source,
pub extra: Option<Source>,
pub public_suffixes: HashSet<String>,
pub private_suffixes: HashSet<String>,
pub disable_private_domains: bool,
pub expire: Option<std::time::Duration>,
pub last_update: std::time::Duration,
}
impl SuffixList {
#[inline]
pub fn new(
source: Source,
disable_private_domains: bool,
expire: Option<std::time::Duration>,
) -> Self {
SuffixList {
source,
extra: None,
public_suffixes: Default::default(),
private_suffixes: Default::default(),
disable_private_domains,
expire,
last_update: now(),
}
}
#[inline]
pub fn private_domains(mut self, disable_private_domains: bool) -> Self {
self.disable_private_domains = disable_private_domains;
self
}
#[inline]
pub fn expire(mut self, expire: std::time::Duration) -> Self {
self.expire = Some(expire);
self
}
#[inline]
pub fn source(mut self, source: Source) -> Self {
self.source = source;
self
}
#[inline]
pub fn extra(mut self, extra: Source) -> Self {
self.extra = Some(extra);
self
}
#[inline]
pub fn is_expired(&self) -> bool {
match self.expire {
Some(s) => {
now().as_secs() - s.as_secs() > self.last_update.as_secs()
}
None => false,
}
}
fn reset(&mut self) {
self.private_suffixes = HashSet::new();
self.public_suffixes = HashSet::new();
}
fn parse_source(&mut self, source: Source) -> Result<()> {
let mut is_private_suffix = false;
let mut tld_lines = Vec::new();
match source {
Source::Local(path) => {
let file = std::fs::File::open(path).unwrap();
let lines = io::BufReader::new(file)
.lines()
.map(|l| l.unwrap_or_default());
tld_lines = lines.collect();
}
#[cfg(feature = "reqwest")]
Source::Remote(u) => match u {
Some(u) => {
tld_lines = get_source_from_url(u)?;
}
None => {
let mut tld_err = TLDExtractError::SuffixListError(String::new());
for u in PUBLIC_SUFFIX_LIST_URLS {
match get_source_from_url(u.trim()) {
Ok(lines) => {
tld_lines = lines;
break;
}
Err(err) => {
tld_err = err;
}
}
}
if tld_lines.is_empty() {
return Err(tld_err);
}
}
},
Source::Snapshot => {
let lines = PUBLIC_SUFFIX_LIST.lines().map(|s| s.to_string());
for line in lines {
is_private_suffix = self.process_line(line, is_private_suffix);
}
}
Source::Text(text) => {
let lines = text.lines().map(|s| s.to_string());
for line in lines {
is_private_suffix = self.process_line(line, is_private_suffix);
}
}
}
for line in tld_lines {
is_private_suffix = self.process_line(line, is_private_suffix);
}
Ok(())
}
#[inline]
pub fn build(&mut self) -> Result<TLDTrieTree> {
self.reset();
self.parse_source(self.source.clone())?;
if let Some(extra) = self.extra.clone() {
self.parse_source(extra)?;
}
let ttt = self.construct_tree();
self.last_update = now();
Ok(ttt)
}
fn process_line(&mut self, raw_line: String, mut is_private_suffix: bool) -> bool {
if is_private_suffix && self.disable_private_domains {
return is_private_suffix;
}
let line = raw_line.trim_end();
if !is_private_suffix && PUBLIC_PRIVATE_SUFFIX_SEPARATOR == line {
is_private_suffix = true;
}
if line.is_empty() || line.starts_with("//") {
return is_private_suffix;
}
if let Ok(suffix) = idna::domain_to_ascii(line) {
if is_private_suffix {
self.private_suffixes.insert(suffix.clone());
if suffix != line {
self.private_suffixes.insert(line.to_string());
}
} else {
self.public_suffixes.insert(suffix.clone());
if suffix != line {
self.public_suffixes.insert(suffix);
}
}
}
is_private_suffix
}
fn construct_tree(&self) -> TLDTrieTree {
let mut trie_tree = TLDTrieTree {
node: Default::default(),
end: false,
};
let mut suffix_list = self.public_suffixes.clone();
if !self.disable_private_domains {
suffix_list.extend(self.private_suffixes.clone());
}
for suffix in suffix_list {
let sp: Vec<&str> = suffix.rsplit('.').collect();
trie_tree.insert(sp);
}
trie_tree
}
}
fn now() -> std::time::Duration {
SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
}
#[cfg(feature = "reqwest")]
fn get_source_from_url<T>(u: T) -> Result<Vec<String>>
where
T: IntoUrl,
{
let response = reqwest::blocking::get(u)?;
let bytes = response.bytes()?;
let lines = bytes.lines().map(|l| l.unwrap_or_default());
Ok(lines.collect())
}