1use std::collections::HashMap;
4use std::fs::File;
5use std::io;
6use std::net::IpAddr;
7use std::path::Path;
8use std::str::FromStr;
9use std::sync::Arc;
10
11use crate::proto::op::Query;
12use crate::proto::rr::rdata::PTR;
13use crate::proto::rr::{Name, RecordType};
14use crate::proto::rr::{RData, Record};
15use tracing::warn;
16
17use crate::cache::MAX_TTL;
18use crate::lookup::Lookup;
19
20#[derive(Debug, Default)]
21struct LookupType {
22 a: Option<Lookup>,
24 aaaa: Option<Lookup>,
26}
27
28#[derive(Debug, Default)]
30pub struct Hosts {
31 by_name: HashMap<Name, LookupType>,
33}
34
35impl Hosts {
36 #[cfg(any(unix, windows))]
40 pub fn from_system() -> io::Result<Self> {
41 Self::from_file(hosts_path())
42 }
43
44 #[cfg(not(any(unix, windows)))]
46 pub fn from_system() -> io::Result<Self> {
47 Ok(Hosts::default())
48 }
49
50 #[cfg(any(unix, windows))]
52 pub(crate) fn from_file(path: impl AsRef<Path>) -> io::Result<Self> {
53 let file = File::open(path)?;
54 let mut hosts = Self::default();
55 hosts.read_hosts_conf(file)?;
56 Ok(hosts)
57 }
58
59 pub fn lookup_static_host(&self, query: &Query) -> Option<Lookup> {
61 if self.by_name.is_empty() {
62 return None;
63 }
64
65 let mut name = query.name().clone();
66 name.set_fqdn(true);
67 match query.query_type() {
68 RecordType::A | RecordType::AAAA => {
69 let val = self.by_name.get(&name)?;
70 return match query.query_type() {
71 RecordType::A => val.a.clone(),
72 RecordType::AAAA => val.aaaa.clone(),
73 _ => None,
74 };
75 }
76 RecordType::PTR => {}
77 _ => return None,
78 }
79
80 let ip = name.parse_arpa_name().ok()?;
81 let ip_addr = ip.addr();
82 let records = self
83 .by_name
84 .iter()
85 .filter(|(_, v)| match ip_addr {
86 IpAddr::V4(ip) => match v.a.as_ref() {
87 Some(lookup) => lookup
88 .answers()
89 .iter()
90 .any(|r| r.data.ip_addr().map(|it| it == ip).unwrap_or_default()),
91 None => false,
92 },
93 IpAddr::V6(ip) => match v.aaaa.as_ref() {
94 Some(lookup) => lookup
95 .answers()
96 .iter()
97 .any(|r| r.data.ip_addr().map(|it| it == ip).unwrap_or_default()),
98 None => false,
99 },
100 })
101 .map(|(n, _)| Record::from_rdata(name.clone(), MAX_TTL, RData::PTR(PTR(n.clone()))))
102 .collect::<Arc<[Record]>>();
103
104 match records.is_empty() {
105 false => Some(Lookup::new_with_max_ttl(
106 query.clone(),
107 records.iter().cloned(),
108 )),
109 true => None,
110 }
111 }
112
113 pub fn insert(&mut self, mut name: Name, record_type: RecordType, lookup: Lookup) {
115 assert!(record_type == RecordType::A || record_type == RecordType::AAAA);
116
117 name.set_fqdn(true);
118 let lookup_type = self.by_name.entry(name.clone()).or_default();
119
120 let new_lookup = {
121 let old_lookup = match record_type {
122 RecordType::A => lookup_type.a.get_or_insert_with(|| {
123 let query = Query::query(name.clone(), record_type);
124 Lookup::new_with_max_ttl(query, [])
125 }),
126 RecordType::AAAA => lookup_type.aaaa.get_or_insert_with(|| {
127 let query = Query::query(name.clone(), record_type);
128 Lookup::new_with_max_ttl(query, [])
129 }),
130 _ => {
131 tracing::warn!("unsupported IP type from Hosts file: {:#?}", record_type);
132 return;
133 }
134 };
135
136 old_lookup.append(lookup)
137 };
138
139 match record_type {
141 RecordType::A => lookup_type.a = Some(new_lookup),
142 RecordType::AAAA => lookup_type.aaaa = Some(new_lookup),
143 _ => tracing::warn!("unsupported IP type from Hosts file"),
144 }
145 }
146
147 pub fn read_hosts_conf(&mut self, src: impl io::Read) -> io::Result<()> {
149 use std::io::{BufRead, BufReader};
150
151 for (line_index, line) in BufReader::new(src).lines().enumerate() {
158 let line = line?;
159
160 let line = if line_index == 0 && line.starts_with('\u{feff}') {
162 &line[3..]
164 } else {
165 &line
166 };
167
168 let line = match line.split_once('#') {
170 Some((line, _)) => line,
171 None => line,
172 }
173 .trim();
174
175 if line.is_empty() {
176 continue;
177 }
178
179 let mut iter = line.split_whitespace();
180 let addr = match iter.next() {
181 Some(addr) => match IpAddr::from_str(addr) {
182 Ok(addr) => RData::from(addr),
183 Err(_) => {
184 warn!("could not parse an IP from hosts file ({addr:?})");
185 continue;
186 }
187 },
188 None => continue,
189 };
190
191 for domain in iter {
192 let domain = domain.to_lowercase();
193 let Ok(mut name) = Name::from_str(&domain) else {
194 continue;
195 };
196
197 name.set_fqdn(true);
198 let record = Record::from_rdata(name.clone(), MAX_TTL, addr.clone());
199 match addr {
200 RData::A(..) => {
201 let query = Query::query(name.clone(), RecordType::A);
202 let lookup = Lookup::new_with_max_ttl(query, [record]);
203 self.insert(name.clone(), RecordType::A, lookup);
204 }
205 RData::AAAA(..) => {
206 let query = Query::query(name.clone(), RecordType::AAAA);
207 let lookup = Lookup::new_with_max_ttl(query, [record]);
208 self.insert(name.clone(), RecordType::AAAA, lookup);
209 }
210 _ => {
211 warn!("unsupported IP type from Hosts file: {:#?}", addr);
212 continue;
213 }
214 };
215
216 }
218 }
219
220 Ok(())
221 }
222}
223
224#[cfg(unix)]
225fn hosts_path() -> &'static str {
226 "/etc/hosts"
227}
228
229#[cfg(windows)]
230fn hosts_path() -> std::path::PathBuf {
231 let system_root =
232 std::env::var_os("SystemRoot").expect("Environment variable SystemRoot not found");
233 let system_root = Path::new(&system_root);
234 system_root.join("System32\\drivers\\etc\\hosts")
235}
236
237#[cfg(any(unix, windows))]
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use std::env;
242 use std::net::{Ipv4Addr, Ipv6Addr};
243
244 fn tests_dir() -> String {
245 let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
246 format! {"{server_path}/crates/resolver/tests"}
247 }
248
249 #[test]
250 fn test_read_hosts_conf() {
251 let path = format!("{}/hosts", tests_dir());
252 let hosts = Hosts::from_file(path).unwrap();
253
254 let name = Name::from_str("localhost.").unwrap();
255 assert_eq!(
256 hosts
257 .lookup_static_host(&Query::query(name.clone(), RecordType::A))
258 .unwrap()
259 .answers(),
260 &[Record::from_rdata(
261 name.clone(),
262 MAX_TTL,
263 RData::A(Ipv4Addr::LOCALHOST.into())
264 )]
265 );
266
267 assert_eq!(
268 hosts
269 .lookup_static_host(&Query::query(name.clone(), RecordType::AAAA))
270 .unwrap()
271 .answers(),
272 &[Record::from_rdata(
273 name,
274 MAX_TTL,
275 RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into())
276 )]
277 );
278
279 let mut name = Name::from_str("broadcasthost").unwrap();
280 name.set_fqdn(true);
281 assert_eq!(
282 hosts
283 .lookup_static_host(&Query::query(name.clone(), RecordType::A))
284 .unwrap()
285 .answers(),
286 &[Record::from_rdata(
287 name,
288 MAX_TTL,
289 RData::A(Ipv4Addr::new(255, 255, 255, 255).into())
290 )]
291 );
292
293 let mut name = Name::from_str("example.com").unwrap();
294 name.set_fqdn(true);
295 assert_eq!(
296 hosts
297 .lookup_static_host(&Query::query(name.clone(), RecordType::A))
298 .unwrap()
299 .answers(),
300 &[Record::from_rdata(
301 name,
302 MAX_TTL,
303 RData::A(Ipv4Addr::new(10, 0, 1, 102).into())
304 )]
305 );
306
307 let mut name = Name::from_str("a.example.com").unwrap();
308 name.set_fqdn(true);
309 assert_eq!(
310 hosts
311 .lookup_static_host(&Query::query(name.clone(), RecordType::A))
312 .unwrap()
313 .answers(),
314 &[Record::from_rdata(
315 name,
316 MAX_TTL,
317 RData::A(Ipv4Addr::new(10, 0, 1, 111).into())
318 )]
319 );
320
321 let mut name = Name::from_str("b.example.com").unwrap();
322 name.set_fqdn(true);
323 assert_eq!(
324 hosts
325 .lookup_static_host(&Query::query(name.clone(), RecordType::A))
326 .unwrap()
327 .answers(),
328 &[Record::from_rdata(
329 name,
330 MAX_TTL,
331 RData::A(Ipv4Addr::new(10, 0, 1, 111).into())
332 )]
333 );
334
335 let name = Name::from_str("111.1.0.10.in-addr.arpa.").unwrap();
336 let mut answers = hosts
337 .lookup_static_host(&Query::query(name.clone(), RecordType::PTR))
338 .unwrap()
339 .answers()
340 .to_vec();
341 answers.sort_by_key(|r| match &r.data {
342 RData::PTR(ptr) => Some(ptr.0.clone()),
343 _ => None,
344 });
345 assert_eq!(
346 answers,
347 vec![
348 Record::from_rdata(
349 name.clone(),
350 MAX_TTL,
351 RData::PTR(PTR("a.example.com.".parse().unwrap()))
352 ),
353 Record::from_rdata(
354 name,
355 MAX_TTL,
356 RData::PTR(PTR("b.example.com.".parse().unwrap()))
357 )
358 ]
359 );
360
361 let name = Name::from_str(
362 "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
363 )
364 .unwrap();
365 assert_eq!(
366 hosts
367 .lookup_static_host(&Query::query(name.clone(), RecordType::PTR))
368 .unwrap()
369 .answers(),
370 &[Record::from_rdata(
371 name,
372 MAX_TTL,
373 RData::PTR(PTR("localhost.".parse().unwrap()))
374 )]
375 );
376 }
377}