use std::collections::HashMap;
use std::fs::File;
use std::io;
use std::net::IpAddr;
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use crate::proto::op::Query;
use crate::proto::rr::rdata::PTR;
use crate::proto::rr::{Name, RecordType};
use crate::proto::rr::{RData, Record};
use tracing::warn;
use crate::cache::MAX_TTL;
use crate::lookup::Lookup;
#[derive(Debug, Default)]
struct LookupType {
a: Option<Lookup>,
aaaa: Option<Lookup>,
}
#[derive(Debug, Default)]
pub struct Hosts {
by_name: HashMap<Name, LookupType>,
}
impl Hosts {
#[cfg(any(unix, windows))]
pub fn from_system() -> io::Result<Self> {
Self::from_file(hosts_path())
}
#[cfg(not(any(unix, windows)))]
pub fn from_system() -> io::Result<Self> {
Ok(Hosts::default())
}
#[cfg(any(unix, windows))]
pub(crate) fn from_file(path: impl AsRef<Path>) -> io::Result<Self> {
let file = File::open(path)?;
let mut hosts = Self::default();
hosts.read_hosts_conf(file)?;
Ok(hosts)
}
pub fn lookup_static_host(&self, query: &Query) -> Option<Lookup> {
if self.by_name.is_empty() {
return None;
}
let mut name = query.name().clone();
name.set_fqdn(true);
match query.query_type() {
RecordType::A | RecordType::AAAA => {
let val = self.by_name.get(&name)?;
return match query.query_type() {
RecordType::A => val.a.clone(),
RecordType::AAAA => val.aaaa.clone(),
_ => None,
};
}
RecordType::PTR => {}
_ => return None,
}
let ip = name.parse_arpa_name().ok()?;
let ip_addr = ip.addr();
let records = self
.by_name
.iter()
.filter(|(_, v)| match ip_addr {
IpAddr::V4(ip) => match v.a.as_ref() {
Some(lookup) => lookup
.answers()
.iter()
.any(|r| r.data().ip_addr().map(|it| it == ip).unwrap_or_default()),
None => false,
},
IpAddr::V6(ip) => match v.aaaa.as_ref() {
Some(lookup) => lookup
.answers()
.iter()
.any(|r| r.data().ip_addr().map(|it| it == ip).unwrap_or_default()),
None => false,
},
})
.map(|(n, _)| Record::from_rdata(name.clone(), MAX_TTL, RData::PTR(PTR(n.clone()))))
.collect::<Arc<[Record]>>();
match records.is_empty() {
false => Some(Lookup::new_with_max_ttl(
query.clone(),
records.iter().cloned(),
)),
true => None,
}
}
pub fn insert(&mut self, mut name: Name, record_type: RecordType, lookup: Lookup) {
assert!(record_type == RecordType::A || record_type == RecordType::AAAA);
name.set_fqdn(true);
let lookup_type = self.by_name.entry(name.clone()).or_default();
let new_lookup = {
let old_lookup = match record_type {
RecordType::A => lookup_type.a.get_or_insert_with(|| {
let query = Query::query(name.clone(), record_type);
Lookup::new_with_max_ttl(query, [])
}),
RecordType::AAAA => lookup_type.aaaa.get_or_insert_with(|| {
let query = Query::query(name.clone(), record_type);
Lookup::new_with_max_ttl(query, [])
}),
_ => {
tracing::warn!("unsupported IP type from Hosts file: {:#?}", record_type);
return;
}
};
old_lookup.append(lookup)
};
match record_type {
RecordType::A => lookup_type.a = Some(new_lookup),
RecordType::AAAA => lookup_type.aaaa = Some(new_lookup),
_ => tracing::warn!("unsupported IP type from Hosts file"),
}
}
pub fn read_hosts_conf(&mut self, src: impl io::Read) -> io::Result<()> {
use std::io::{BufRead, BufReader};
for (line_index, line) in BufReader::new(src).lines().enumerate() {
let line = line?;
let line = if line_index == 0 && line.starts_with('\u{feff}') {
&line[3..]
} else {
&line
};
let line = match line.split_once('#') {
Some((line, _)) => line,
None => line,
}
.trim();
if line.is_empty() {
continue;
}
let mut iter = line.split_whitespace();
let addr = match iter.next() {
Some(addr) => match IpAddr::from_str(addr) {
Ok(addr) => RData::from(addr),
Err(_) => {
warn!("could not parse an IP from hosts file ({addr:?})");
continue;
}
},
None => continue,
};
for domain in iter {
let domain = domain.to_lowercase();
let Ok(mut name) = Name::from_str(&domain) else {
continue;
};
name.set_fqdn(true);
let record = Record::from_rdata(name.clone(), MAX_TTL, addr.clone());
match addr {
RData::A(..) => {
let query = Query::query(name.clone(), RecordType::A);
let lookup = Lookup::new_with_max_ttl(query, [record]);
self.insert(name.clone(), RecordType::A, lookup);
}
RData::AAAA(..) => {
let query = Query::query(name.clone(), RecordType::AAAA);
let lookup = Lookup::new_with_max_ttl(query, [record]);
self.insert(name.clone(), RecordType::AAAA, lookup);
}
_ => {
warn!("unsupported IP type from Hosts file: {:#?}", addr);
continue;
}
};
}
}
Ok(())
}
}
#[cfg(unix)]
fn hosts_path() -> &'static str {
"/etc/hosts"
}
#[cfg(windows)]
fn hosts_path() -> std::path::PathBuf {
let system_root =
std::env::var_os("SystemRoot").expect("Environment variable SystemRoot not found");
let system_root = Path::new(&system_root);
system_root.join("System32\\drivers\\etc\\hosts")
}
#[cfg(any(unix, windows))]
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use std::net::{Ipv4Addr, Ipv6Addr};
fn tests_dir() -> String {
let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
format! {"{server_path}/crates/resolver/tests"}
}
#[test]
fn test_read_hosts_conf() {
let path = format!("{}/hosts", tests_dir());
let hosts = Hosts::from_file(path).unwrap();
let name = Name::from_str("localhost.").unwrap();
assert_eq!(
hosts
.lookup_static_host(&Query::query(name.clone(), RecordType::A))
.unwrap()
.answers(),
&[Record::from_rdata(
name.clone(),
MAX_TTL,
RData::A(Ipv4Addr::LOCALHOST.into())
)]
);
assert_eq!(
hosts
.lookup_static_host(&Query::query(name.clone(), RecordType::AAAA))
.unwrap()
.answers(),
&[Record::from_rdata(
name,
MAX_TTL,
RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into())
)]
);
let mut name = Name::from_str("broadcasthost").unwrap();
name.set_fqdn(true);
assert_eq!(
hosts
.lookup_static_host(&Query::query(name.clone(), RecordType::A))
.unwrap()
.answers(),
&[Record::from_rdata(
name,
MAX_TTL,
RData::A(Ipv4Addr::new(255, 255, 255, 255).into())
)]
);
let mut name = Name::from_str("example.com").unwrap();
name.set_fqdn(true);
assert_eq!(
hosts
.lookup_static_host(&Query::query(name.clone(), RecordType::A))
.unwrap()
.answers(),
&[Record::from_rdata(
name,
MAX_TTL,
RData::A(Ipv4Addr::new(10, 0, 1, 102).into())
)]
);
let mut name = Name::from_str("a.example.com").unwrap();
name.set_fqdn(true);
assert_eq!(
hosts
.lookup_static_host(&Query::query(name.clone(), RecordType::A))
.unwrap()
.answers(),
&[Record::from_rdata(
name,
MAX_TTL,
RData::A(Ipv4Addr::new(10, 0, 1, 111).into())
)]
);
let mut name = Name::from_str("b.example.com").unwrap();
name.set_fqdn(true);
assert_eq!(
hosts
.lookup_static_host(&Query::query(name.clone(), RecordType::A))
.unwrap()
.answers(),
&[Record::from_rdata(
name,
MAX_TTL,
RData::A(Ipv4Addr::new(10, 0, 1, 111).into())
)]
);
let name = Name::from_str("111.1.0.10.in-addr.arpa.").unwrap();
let mut answers = hosts
.lookup_static_host(&Query::query(name.clone(), RecordType::PTR))
.unwrap()
.answers()
.to_vec();
answers.sort_by_key(|r| match r.data() {
RData::PTR(ptr) => Some(ptr.0.clone()),
_ => None,
});
assert_eq!(
answers,
vec![
Record::from_rdata(
name.clone(),
MAX_TTL,
RData::PTR(PTR("a.example.com.".parse().unwrap()))
),
Record::from_rdata(
name,
MAX_TTL,
RData::PTR(PTR("b.example.com.".parse().unwrap()))
)
]
);
let name = Name::from_str(
"1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
)
.unwrap();
assert_eq!(
hosts
.lookup_static_host(&Query::query(name.clone(), RecordType::PTR))
.unwrap()
.answers(),
&[Record::from_rdata(
name,
MAX_TTL,
RData::PTR(PTR("localhost.".parse().unwrap()))
)]
);
}
}