1use std::cell::Cell;
4use std::io;
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
6use std::time::{Duration, Instant};
7use std::vec::IntoIter;
8
9use log::info;
10
11use crate::address::address_name;
12use crate::config::DnsConfig;
13use crate::message::{Message, Qr, Question, MESSAGE_LIMIT};
14use crate::record::{Class, Ptr, Record, RecordType, A, AAAA};
15use crate::socket::{DnsSocket, Error};
16
17pub struct DnsResolver {
19 sock: DnsSocket,
20 config: DnsConfig,
21 next_ns: Cell<usize>,
24}
25
26impl DnsResolver {
27 pub fn new(config: DnsConfig) -> io::Result<DnsResolver> {
29 let bind = bind_addr(&config.name_servers);
30 let sock = DnsSocket::bind((bind, 0))?;
31 DnsResolver::with_sock(sock, config)
32 }
33
34 pub fn bind<A: ToSocketAddrs>(addr: A, config: DnsConfig) -> io::Result<DnsResolver> {
37 let sock = DnsSocket::bind(addr)?;
38 DnsResolver::with_sock(sock, config)
39 }
40
41 fn with_sock(sock: DnsSocket, config: DnsConfig) -> io::Result<DnsResolver> {
42 Ok(DnsResolver {
43 sock,
44 config,
45 next_ns: Cell::new(0),
46 })
47 }
48
49 pub fn resolve_addr(&self, addr: &IpAddr) -> io::Result<String> {
51 convert_error("failed to resolve address", || {
52 let mut out_msg = self.basic_message();
53
54 out_msg.question.push(Question::new(
55 address_name(addr),
56 RecordType::Ptr,
57 Class::Internet,
58 ));
59
60 let mut buf = [0; MESSAGE_LIMIT];
61 let msg = self.send_message(&out_msg, &mut buf)?;
62
63 for rr in msg.answer.into_iter() {
64 if rr.r_type == RecordType::Ptr {
65 let ptr = rr.read_rdata::<Ptr>()?;
66 let mut name = ptr.name;
67 if name.ends_with('.') {
68 name.pop();
69 }
70 return Ok(name);
71 }
72 }
73
74 Err(Error::IoError(io::Error::new(
75 io::ErrorKind::Other,
76 "failed to resolve address: name not found",
77 )))
78 })
79 }
80
81 pub fn resolve_host(&self, host: &str) -> io::Result<ResolveHost> {
83 convert_error("failed to resolve host", || {
84 query_names(host, &self.config, |name| {
85 let mut err;
86 let mut res = Vec::new();
87
88 info!("attempting lookup of name \"{}\"", name);
89
90 if self.config.use_inet6 {
91 err = self
92 .resolve_host_v6(&name, |ip| res.push(IpAddr::V6(ip)))
93 .err();
94
95 if res.is_empty() {
96 err = err.or_else(|| {
97 self.resolve_host_v4(&name, |ip| {
98 res.push(IpAddr::V6(ip.to_ipv6_mapped()))
99 })
100 .err()
101 });
102 }
103 } else {
104 err = self
105 .resolve_host_v4(&name, |ip| res.push(IpAddr::V4(ip)))
106 .err();
107 err = err.or_else(|| {
108 self.resolve_host_v6(&name, |ip| res.push(IpAddr::V6(ip)))
109 .err()
110 });
111 }
112
113 if !res.is_empty() {
114 return Ok(ResolveHost(res.into_iter()));
115 }
116
117 if let Some(e) = err {
118 Err(e)
119 } else {
120 Err(Error::IoError(io::Error::new(
121 io::ErrorKind::Other,
122 "failed to resolve host: name not found",
123 )))
124 }
125 })
126 })
127 }
128
129 pub fn resolve_record<Rec: Record>(&self, name: &str) -> io::Result<Vec<Rec>> {
131 convert_error("failed to resolve record", || {
132 let r_ty = Rec::record_type();
133 let mut msg = self.basic_message();
134
135 msg.question
136 .push(Question::new(name.to_owned(), r_ty, Class::Internet));
137
138 let mut buf = [0; MESSAGE_LIMIT];
139 let reply = self.send_message(&msg, &mut buf)?;
140
141 let mut rec = Vec::new();
142
143 for rr in reply.answer.into_iter() {
144 if rr.r_type == r_ty {
145 rec.push(rr.read_rdata::<Rec>()?);
146 }
147 }
148
149 Ok(rec)
150 })
151 }
152
153 fn resolve_host_v4<F>(&self, host: &str, mut f: F) -> Result<(), Error>
154 where
155 F: FnMut(Ipv4Addr),
156 {
157 let mut out_msg = self.basic_message();
158
159 out_msg.question.push(Question::new(
160 host.to_owned(),
161 RecordType::A,
162 Class::Internet,
163 ));
164
165 let mut buf = [0; MESSAGE_LIMIT];
166 let msg = self.send_message(&out_msg, &mut buf)?;
167
168 for rr in msg.answer.into_iter() {
169 if rr.r_type == RecordType::A {
170 let a = rr.read_rdata::<A>()?;
171 f(a.address);
172 }
173 }
174
175 Ok(())
176 }
177
178 fn resolve_host_v6<F>(&self, host: &str, mut f: F) -> Result<(), Error>
179 where
180 F: FnMut(Ipv6Addr),
181 {
182 let mut out_msg = self.basic_message();
183
184 out_msg.question.push(Question::new(
185 host.to_owned(),
186 RecordType::AAAA,
187 Class::Internet,
188 ));
189
190 let mut buf = [0; MESSAGE_LIMIT];
191 let msg = self.send_message(&out_msg, &mut buf)?;
192
193 for rr in msg.answer.into_iter() {
194 if rr.r_type == RecordType::AAAA {
195 let aaaa = rr.read_rdata::<AAAA>()?;
196 f(aaaa.address);
197 }
198 }
199
200 Ok(())
201 }
202
203 fn basic_message(&self) -> Message {
204 let mut msg = Message::new();
205
206 msg.header.recursion_desired = true;
207 msg
208 }
209
210 #[allow(dropping_references)] pub fn send_message<'buf>(
213 &self,
214 out_msg: &Message,
215 buf: &'buf mut [u8],
216 ) -> Result<Message<'buf>, Error> {
217 let mut last_err = None;
218
219 let buf_ptr = buf as *mut _;
223 drop(buf);
224
225 'retry: for retries in 0..self.config.attempts {
226 let ns_addr = if self.config.rotate {
227 self.next_nameserver()
228 } else {
229 let n = self.config.name_servers.len();
230 self.config.name_servers[retries as usize % n]
231 };
232
233 let mut timeout = self.config.timeout;
234
235 info!("resolver sending message to {}", ns_addr);
236
237 self.sock.send_message(out_msg, ns_addr)?;
238
239 loop {
240 self.sock.get().set_read_timeout(Some(timeout))?;
241
242 let buf = unsafe { &mut *buf_ptr };
244
245 let start = Instant::now();
246
247 match self.sock.recv_message(&ns_addr, buf) {
248 Ok(None) => {
249 let passed = start.elapsed();
250
251 if timeout < passed {
254 timeout = Duration::from_secs(0);
255 } else {
256 timeout -= passed;
257 }
258 }
259 Ok(Some(msg)) => {
260 if msg.header.id == out_msg.header.id && msg.header.qr == Qr::Response {
262 msg.get_error()?;
263 return Ok(msg);
264 }
265 }
266 Err(e) => {
267 if e.is_timeout() {
269 last_err = Some(e);
270 continue 'retry;
271 }
272 return Err(e);
274 }
275 }
276 }
277 }
278
279 Err(last_err.unwrap())
280 }
281
282 fn next_nameserver(&self) -> SocketAddr {
283 let n = self.next_ns.get();
284 self.next_ns.set((n + 1) % self.config.name_servers.len());
285 self.config.name_servers[n]
286 }
287}
288
289fn bind_addr(name_servers: &[SocketAddr]) -> IpAddr {
290 match name_servers.first() {
291 Some(&SocketAddr::V6(_)) => IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
292 _ => IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
293 }
294}
295
296fn convert_error<T, F>(desc: &str, f: F) -> io::Result<T>
297where
298 F: FnOnce() -> Result<T, Error>,
299{
300 match f() {
301 Ok(t) => Ok(t),
302 Err(Error::IoError(e)) => Err(e),
303 Err(e) => Err(io::Error::new(
304 io::ErrorKind::Other,
305 format!("{}: {}", desc, e),
306 )),
307 }
308}
309
310fn query_names<F, T>(name: &str, config: &DnsConfig, mut f: F) -> Result<T, Error>
311where
312 F: FnMut(String) -> Result<T, Error>,
313{
314 let use_search =
315 !name.ends_with('.') && name.chars().filter(|&c| c == '.').count() as u32 >= config.n_dots;
316
317 if use_search {
318 let mut err = None;
319
320 for name in with_suffixes(name, &config.search) {
321 match f(name) {
322 Ok(t) => return Ok(t),
323 Err(e) => err = Some(e),
324 }
325 }
326
327 if let Some(e) = err {
328 Err(e)
329 } else {
330 Err(Error::IoError(io::Error::new(
331 io::ErrorKind::Other,
332 "failed to resolve host: name not found",
333 )))
334 }
335 } else {
336 f(name.to_owned())
337 }
338}
339
340fn with_suffixes(host: &str, suffixes: &[String]) -> Vec<String> {
341 let mut v = suffixes
342 .iter()
343 .map(|s| format!("{}.{}", host, s))
344 .collect::<Vec<_>>();
345 v.push(host.to_owned());
346 v
347}
348
349pub fn resolve_addr(addr: &IpAddr) -> io::Result<String> {
351 let r = DnsResolver::new(DnsConfig::load_default()?)?;
352 r.resolve_addr(addr)
353}
354
355pub fn resolve_host(host: &str) -> io::Result<ResolveHost> {
371 let r = DnsResolver::new(DnsConfig::load_default()?)?;
372 r.resolve_host(host)
373}
374
375pub struct ResolveHost(IntoIter<IpAddr>);
377
378impl Iterator for ResolveHost {
379 type Item = IpAddr;
380
381 fn next(&mut self) -> Option<IpAddr> {
382 self.0.next()
383 }
384}