Skip to main content

hickory_resolver/
hosts.rs

1//! Hosts result from a configuration of the system hosts file
2
3use std::collections::HashMap;
4use std::fs::File;
5use std::io;
6use std::net::IpAddr;
7use std::path::Path;
8use std::str::FromStr;
9use std::sync::Arc;
10
11use crate::proto::op::Query;
12use crate::proto::rr::rdata::PTR;
13use crate::proto::rr::{Name, RecordType};
14use crate::proto::rr::{RData, Record};
15use tracing::warn;
16
17use crate::cache::MAX_TTL;
18use crate::lookup::Lookup;
19
20#[derive(Debug, Default)]
21struct LookupType {
22    /// represents the A record type
23    a: Option<Lookup>,
24    /// represents the AAAA record type
25    aaaa: Option<Lookup>,
26}
27
28/// Configuration for the local hosts file
29#[derive(Debug, Default)]
30pub struct Hosts {
31    /// Name -> RDatas map
32    by_name: HashMap<Name, LookupType>,
33}
34
35impl Hosts {
36    /// Creates a new configuration from the system hosts file,
37    /// only works for Windows and Unix-like OSes,
38    /// will return empty configuration on others
39    #[cfg(any(unix, windows))]
40    pub fn from_system() -> io::Result<Self> {
41        Self::from_file(hosts_path())
42    }
43
44    /// Creates a default configuration for non Windows or Unix-like OSes
45    #[cfg(not(any(unix, windows)))]
46    pub fn from_system() -> io::Result<Self> {
47        Ok(Hosts::default())
48    }
49
50    /// parse configuration from `path`
51    #[cfg(any(unix, windows))]
52    pub(crate) fn from_file(path: impl AsRef<Path>) -> io::Result<Self> {
53        let file = File::open(path)?;
54        let mut hosts = Self::default();
55        hosts.read_hosts_conf(file)?;
56        Ok(hosts)
57    }
58
59    /// Look up the addresses for the given host from the system hosts file.
60    pub fn lookup_static_host(&self, query: &Query) -> Option<Lookup> {
61        if self.by_name.is_empty() {
62            return None;
63        }
64
65        let mut name = query.name().clone();
66        name.set_fqdn(true);
67        match query.query_type() {
68            RecordType::A | RecordType::AAAA => {
69                let val = self.by_name.get(&name)?;
70                return match query.query_type() {
71                    RecordType::A => val.a.clone(),
72                    RecordType::AAAA => val.aaaa.clone(),
73                    _ => None,
74                };
75            }
76            RecordType::PTR => {}
77            _ => return None,
78        }
79
80        let ip = name.parse_arpa_name().ok()?;
81        let ip_addr = ip.addr();
82        let records = self
83            .by_name
84            .iter()
85            .filter(|(_, v)| match ip_addr {
86                IpAddr::V4(ip) => match v.a.as_ref() {
87                    Some(lookup) => lookup
88                        .answers()
89                        .iter()
90                        .any(|r| r.data.ip_addr().map(|it| it == ip).unwrap_or_default()),
91                    None => false,
92                },
93                IpAddr::V6(ip) => match v.aaaa.as_ref() {
94                    Some(lookup) => lookup
95                        .answers()
96                        .iter()
97                        .any(|r| r.data.ip_addr().map(|it| it == ip).unwrap_or_default()),
98                    None => false,
99                },
100            })
101            .map(|(n, _)| Record::from_rdata(name.clone(), MAX_TTL, RData::PTR(PTR(n.clone()))))
102            .collect::<Arc<[Record]>>();
103
104        match records.is_empty() {
105            false => Some(Lookup::new_with_max_ttl(
106                query.clone(),
107                records.iter().cloned(),
108            )),
109            true => None,
110        }
111    }
112
113    /// Insert a new Lookup for the associated `Name` and `RecordType`
114    pub fn insert(&mut self, mut name: Name, record_type: RecordType, lookup: Lookup) {
115        assert!(record_type == RecordType::A || record_type == RecordType::AAAA);
116
117        name.set_fqdn(true);
118        let lookup_type = self.by_name.entry(name.clone()).or_default();
119
120        let new_lookup = {
121            let old_lookup = match record_type {
122                RecordType::A => lookup_type.a.get_or_insert_with(|| {
123                    let query = Query::query(name.clone(), record_type);
124                    Lookup::new_with_max_ttl(query, [])
125                }),
126                RecordType::AAAA => lookup_type.aaaa.get_or_insert_with(|| {
127                    let query = Query::query(name.clone(), record_type);
128                    Lookup::new_with_max_ttl(query, [])
129                }),
130                _ => {
131                    tracing::warn!("unsupported IP type from Hosts file: {:#?}", record_type);
132                    return;
133                }
134            };
135
136            old_lookup.append(lookup)
137        };
138
139        // replace the appended version
140        match record_type {
141            RecordType::A => lookup_type.a = Some(new_lookup),
142            RecordType::AAAA => lookup_type.aaaa = Some(new_lookup),
143            _ => tracing::warn!("unsupported IP type from Hosts file"),
144        }
145    }
146
147    /// parse configuration from `src`
148    pub fn read_hosts_conf(&mut self, src: impl io::Read) -> io::Result<()> {
149        use std::io::{BufRead, BufReader};
150
151        // lines in the src should have the form `addr host1 host2 host3 ...`
152        // line starts with `#` will be regarded with comments and ignored,
153        // also empty line also will be ignored,
154        // if line only include `addr` without `host` will be ignored,
155        // the src will be parsed to map in the form `Name -> LookUp`.
156
157        for (line_index, line) in BufReader::new(src).lines().enumerate() {
158            let line = line?;
159
160            // Remove byte-order mark if present
161            let line = if line_index == 0 && line.starts_with('\u{feff}') {
162                // BOM is 3 bytes
163                &line[3..]
164            } else {
165                &line
166            };
167
168            // Remove comments from the line
169            let line = match line.split_once('#') {
170                Some((line, _)) => line,
171                None => line,
172            }
173            .trim();
174
175            if line.is_empty() {
176                continue;
177            }
178
179            let mut iter = line.split_whitespace();
180            let addr = match iter.next() {
181                Some(addr) => match IpAddr::from_str(addr) {
182                    Ok(addr) => RData::from(addr),
183                    Err(_) => {
184                        warn!("could not parse an IP from hosts file ({addr:?})");
185                        continue;
186                    }
187                },
188                None => continue,
189            };
190
191            for domain in iter {
192                let domain = domain.to_lowercase();
193                let Ok(mut name) = Name::from_str(&domain) else {
194                    continue;
195                };
196
197                name.set_fqdn(true);
198                let record = Record::from_rdata(name.clone(), MAX_TTL, addr.clone());
199                match addr {
200                    RData::A(..) => {
201                        let query = Query::query(name.clone(), RecordType::A);
202                        let lookup = Lookup::new_with_max_ttl(query, [record]);
203                        self.insert(name.clone(), RecordType::A, lookup);
204                    }
205                    RData::AAAA(..) => {
206                        let query = Query::query(name.clone(), RecordType::AAAA);
207                        let lookup = Lookup::new_with_max_ttl(query, [record]);
208                        self.insert(name.clone(), RecordType::AAAA, lookup);
209                    }
210                    _ => {
211                        warn!("unsupported IP type from Hosts file: {:#?}", addr);
212                        continue;
213                    }
214                };
215
216                // TODO: insert reverse lookup as well.
217            }
218        }
219
220        Ok(())
221    }
222}
223
224#[cfg(unix)]
225fn hosts_path() -> &'static str {
226    "/etc/hosts"
227}
228
229#[cfg(windows)]
230fn hosts_path() -> std::path::PathBuf {
231    let system_root =
232        std::env::var_os("SystemRoot").expect("Environment variable SystemRoot not found");
233    let system_root = Path::new(&system_root);
234    system_root.join("System32\\drivers\\etc\\hosts")
235}
236
237#[cfg(any(unix, windows))]
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use std::env;
242    use std::net::{Ipv4Addr, Ipv6Addr};
243
244    fn tests_dir() -> String {
245        let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
246        format! {"{server_path}/crates/resolver/tests"}
247    }
248
249    #[test]
250    fn test_read_hosts_conf() {
251        let path = format!("{}/hosts", tests_dir());
252        let hosts = Hosts::from_file(path).unwrap();
253
254        let name = Name::from_str("localhost.").unwrap();
255        assert_eq!(
256            hosts
257                .lookup_static_host(&Query::query(name.clone(), RecordType::A))
258                .unwrap()
259                .answers(),
260            &[Record::from_rdata(
261                name.clone(),
262                MAX_TTL,
263                RData::A(Ipv4Addr::LOCALHOST.into())
264            )]
265        );
266
267        assert_eq!(
268            hosts
269                .lookup_static_host(&Query::query(name.clone(), RecordType::AAAA))
270                .unwrap()
271                .answers(),
272            &[Record::from_rdata(
273                name,
274                MAX_TTL,
275                RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into())
276            )]
277        );
278
279        let mut name = Name::from_str("broadcasthost").unwrap();
280        name.set_fqdn(true);
281        assert_eq!(
282            hosts
283                .lookup_static_host(&Query::query(name.clone(), RecordType::A))
284                .unwrap()
285                .answers(),
286            &[Record::from_rdata(
287                name,
288                MAX_TTL,
289                RData::A(Ipv4Addr::new(255, 255, 255, 255).into())
290            )]
291        );
292
293        let mut name = Name::from_str("example.com").unwrap();
294        name.set_fqdn(true);
295        assert_eq!(
296            hosts
297                .lookup_static_host(&Query::query(name.clone(), RecordType::A))
298                .unwrap()
299                .answers(),
300            &[Record::from_rdata(
301                name,
302                MAX_TTL,
303                RData::A(Ipv4Addr::new(10, 0, 1, 102).into())
304            )]
305        );
306
307        let mut name = Name::from_str("a.example.com").unwrap();
308        name.set_fqdn(true);
309        assert_eq!(
310            hosts
311                .lookup_static_host(&Query::query(name.clone(), RecordType::A))
312                .unwrap()
313                .answers(),
314            &[Record::from_rdata(
315                name,
316                MAX_TTL,
317                RData::A(Ipv4Addr::new(10, 0, 1, 111).into())
318            )]
319        );
320
321        let mut name = Name::from_str("b.example.com").unwrap();
322        name.set_fqdn(true);
323        assert_eq!(
324            hosts
325                .lookup_static_host(&Query::query(name.clone(), RecordType::A))
326                .unwrap()
327                .answers(),
328            &[Record::from_rdata(
329                name,
330                MAX_TTL,
331                RData::A(Ipv4Addr::new(10, 0, 1, 111).into())
332            )]
333        );
334
335        let name = Name::from_str("111.1.0.10.in-addr.arpa.").unwrap();
336        let mut answers = hosts
337            .lookup_static_host(&Query::query(name.clone(), RecordType::PTR))
338            .unwrap()
339            .answers()
340            .to_vec();
341        answers.sort_by_key(|r| match &r.data {
342            RData::PTR(ptr) => Some(ptr.0.clone()),
343            _ => None,
344        });
345        assert_eq!(
346            answers,
347            vec![
348                Record::from_rdata(
349                    name.clone(),
350                    MAX_TTL,
351                    RData::PTR(PTR("a.example.com.".parse().unwrap()))
352                ),
353                Record::from_rdata(
354                    name,
355                    MAX_TTL,
356                    RData::PTR(PTR("b.example.com.".parse().unwrap()))
357                )
358            ]
359        );
360
361        let name = Name::from_str(
362            "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
363        )
364        .unwrap();
365        assert_eq!(
366            hosts
367                .lookup_static_host(&Query::query(name.clone(), RecordType::PTR))
368                .unwrap()
369                .answers(),
370            &[Record::from_rdata(
371                name,
372                MAX_TTL,
373                RData::PTR(PTR("localhost.".parse().unwrap()))
374            )]
375        );
376    }
377}