dns_resolver/
lib.rs

1use std::io;
2use std::net::{IpAddr, SocketAddr};
3use std::ops::Deref;
4use std::str::FromStr;
5use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
6use std::sync::{Arc, Mutex};
7use std::time::Duration;
8
9use domain::base::iana::{Rcode, Rtype};
10use domain::base::message::Message;
11use domain::base::message_builder::{AdditionalBuilder, MessageBuilder, StreamTarget};
12use domain::base::name::{Dname, ToDname};
13use domain::base::question::Question;
14use domain::rdata::A;
15use lru_time_cache::LruCache;
16use octseq::array::Array;
17
18const DEFAULT_CACHE_EXPIRE: Duration = Duration::from_secs(10 * 60);
19
20#[cfg(not(feature = "tokio-runtime"))]
21use futures_util::{AsyncReadExt, AsyncWriteExt};
22
23#[cfg(feature = "slings-runtime")]
24use slings::{
25    net::{TcpStream, UdpSocket},
26    time::timeout,
27};
28
29#[cfg(feature = "awak-runtime")]
30use awak::{
31    net::{TcpStream, UdpSocket},
32    time::timeout,
33};
34
35#[cfg(feature = "tokio-runtime")]
36use tokio::{
37    io::{AsyncReadExt, AsyncWriteExt},
38    net::{TcpStream, UdpSocket},
39    time::timeout,
40};
41
42mod conf;
43
44pub use conf::{ResolvConf, ResolvOptions};
45use conf::{ServerConf, Transport};
46
47const RETRY_RANDOM_PORT: usize = 10;
48
49pub struct Resolver {
50    preferred: ServerList,
51    stream: ServerList,
52    options: ResolvOptions,
53    lru_cache: Mutex<LruCache<String, Vec<IpAddr>>>,
54}
55
56impl Resolver {
57    pub fn new() -> Self {
58        Self::from_conf(ResolvConf::default())
59    }
60
61    pub fn from_conf(conf: ResolvConf) -> Self {
62        Resolver {
63            preferred: ServerList::from_conf(&conf, |s| s.transport.is_preferred()),
64            stream: ServerList::from_conf(&conf, |s| s.transport.is_stream()),
65            options: conf.options,
66            lru_cache: Mutex::new(LruCache::with_expiry_duration(DEFAULT_CACHE_EXPIRE)),
67        }
68    }
69
70    fn options(&self) -> &ResolvOptions {
71        &self.options
72    }
73
74    pub async fn query<N: ToDname, Q: Into<Question<N>>>(&self, question: Q) -> io::Result<Answer> {
75        Query::new(self)?
76            .run(Query::create_message(question.into()))
77            .await
78    }
79
80    fn try_resolve_from_cache(&self, key: &str) -> Option<Vec<IpAddr>> {
81        self.lru_cache.lock().unwrap().get(key).cloned()
82    }
83
84    fn insert_into_cache(&self, key: &str, val: Vec<IpAddr>) {
85        self.lru_cache.lock().unwrap().insert(key.to_string(), val);
86    }
87
88    pub async fn lookup_host<T: AsRef<str>>(&self, host: T) -> io::Result<Vec<IpAddr>> {
89        let host = &host.as_ref();
90        if let Some(v) = self.try_resolve_from_cache(host) {
91            return Ok(v);
92        }
93
94        let qname = &Dname::<Vec<u8>>::from_str(host)
95            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
96        let answer = self.query((&qname, Rtype::A)).await?;
97        let name = answer.canonical_name();
98        let records = answer
99            .answer()
100            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
101            .limit_to::<A>();
102
103        let mut ips = vec![];
104        for record in records.flatten() {
105            if Some(*record.owner()) == name {
106                ips.push(record.data().addr().into());
107            }
108        }
109        self.insert_into_cache(host, ips.clone());
110        Ok(ips)
111    }
112
113    pub async fn query_message(&self, message: QueryMessage) -> io::Result<Answer> {
114        Query::new(self)?.run(message).await
115    }
116}
117
118impl Default for Resolver {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124pub struct Query<'a> {
125    resolver: &'a Resolver,
126    preferred: bool,
127    attempt: usize,
128    counter: ServerListCounter,
129    error: io::Result<Answer>,
130}
131
132impl<'a> Query<'a> {
133    pub fn new(resolver: &'a Resolver) -> io::Result<Self> {
134        let (preferred, counter) = if resolver.options().use_vc || resolver.preferred.is_empty() {
135            if resolver.stream.is_empty() {
136                return Err(io::Error::new(
137                    io::ErrorKind::NotFound,
138                    "no servers available",
139                ));
140            }
141            (false, resolver.stream.counter(resolver.options().rotate))
142        } else {
143            (true, resolver.preferred.counter(resolver.options().rotate))
144        };
145        Ok(Query {
146            resolver,
147            preferred,
148            attempt: 0,
149            counter,
150            error: Err(io::Error::new(io::ErrorKind::TimedOut, "all timed out")),
151        })
152    }
153
154    pub async fn run(mut self, mut message: QueryMessage) -> io::Result<Answer> {
155        loop {
156            match self.run_query(&mut message).await {
157                Ok(answer) => {
158                    if answer.header().rcode() == Rcode::FormErr
159                        && self.current_server().does_edns()
160                    {
161                        self.current_server().disable_edns();
162                        continue;
163                    } else if answer.header().rcode() == Rcode::ServFail {
164                        self.update_error_servfail(answer);
165                    } else if answer.header().tc()
166                        && self.preferred
167                        && !self.resolver.options().ign_tc
168                    {
169                        if self.switch_to_stream() {
170                            continue;
171                        } else {
172                            return Ok(answer);
173                        }
174                    } else {
175                        return Ok(answer);
176                    }
177                }
178                Err(err) => self.update_error(err),
179            }
180            if !self.next_server() {
181                return self.error;
182            }
183        }
184    }
185
186    fn create_message(question: Question<impl ToDname>) -> QueryMessage {
187        let mut message =
188            MessageBuilder::from_target(StreamTarget::new(Default::default()).unwrap()).unwrap();
189        message.header_mut().set_rd(true);
190        let mut message = message.question();
191        message.push(question).unwrap();
192        message.additional()
193    }
194
195    async fn run_query(&mut self, message: &mut QueryMessage) -> io::Result<Answer> {
196        let server = self.current_server();
197        server.prepare_message(message);
198        server.query(message).await
199    }
200
201    fn current_server(&self) -> &ServerInfo {
202        let list = if self.preferred {
203            &self.resolver.preferred
204        } else {
205            &self.resolver.stream
206        };
207        self.counter.info(list)
208    }
209
210    fn update_error(&mut self, err: io::Error) {
211        if err.kind() != io::ErrorKind::TimedOut && self.error.is_err() {
212            self.error = Err(err)
213        }
214    }
215
216    fn update_error_servfail(&mut self, answer: Answer) {
217        self.error = Ok(answer)
218    }
219
220    fn switch_to_stream(&mut self) -> bool {
221        if !self.preferred {
222            return false;
223        }
224        self.preferred = false;
225        self.attempt = 0;
226        self.counter = self.resolver.stream.counter(self.resolver.options().rotate);
227        true
228    }
229
230    fn next_server(&mut self) -> bool {
231        if self.counter.next() {
232            return true;
233        }
234        self.attempt += 1;
235        if self.attempt >= self.resolver.options().attempts {
236            return false;
237        }
238        self.counter = if self.preferred {
239            self.resolver
240                .preferred
241                .counter(self.resolver.options().rotate)
242        } else {
243            self.resolver.stream.counter(self.resolver.options().rotate)
244        };
245        true
246    }
247}
248
249pub type QueryMessage = AdditionalBuilder<StreamTarget<Array<512>>>;
250
251#[derive(Clone)]
252pub struct Answer {
253    message: Message<Vec<u8>>,
254}
255
256impl Answer {
257    pub fn is_final(&self) -> bool {
258        (self.message.header().rcode() == Rcode::NoError
259            || self.message.header().rcode() == Rcode::NXDomain)
260            && !self.message.header().tc()
261    }
262
263    pub fn is_truncated(&self) -> bool {
264        self.message.header().tc()
265    }
266
267    pub fn into_message(self) -> Message<Vec<u8>> {
268        self.message
269    }
270}
271
272impl From<Message<Vec<u8>>> for Answer {
273    fn from(message: Message<Vec<u8>>) -> Self {
274        Answer { message }
275    }
276}
277
278#[derive(Clone, Debug)]
279struct ServerInfo {
280    conf: ServerConf,
281    edns: Arc<AtomicBool>,
282}
283
284impl ServerInfo {
285    pub fn does_edns(&self) -> bool {
286        self.edns.load(Ordering::Relaxed)
287    }
288
289    pub fn disable_edns(&self) {
290        self.edns.store(false, Ordering::Relaxed);
291    }
292
293    pub fn prepare_message(&self, query: &mut QueryMessage) {
294        query.rewind();
295        if self.does_edns() {
296            query
297                .opt(|opt| {
298                    opt.set_udp_payload_size(self.conf.udp_payload_size);
299                    Ok(())
300                })
301                .unwrap();
302        }
303    }
304
305    pub async fn query(&self, query: &QueryMessage) -> io::Result<Answer> {
306        let res = match self.conf.transport {
307            Transport::Udp => {
308                timeout(
309                    self.conf.request_timeout,
310                    Self::udp_query(query, self.conf.addr, self.conf.recv_size),
311                )
312                .await
313            }
314            Transport::Tcp => {
315                timeout(
316                    self.conf.request_timeout,
317                    Self::tcp_query(query, self.conf.addr),
318                )
319                .await
320            }
321        };
322        match res {
323            Ok(Ok(answer)) => Ok(answer),
324            Ok(Err(err)) => Err(err),
325            Err(_) => Err(io::Error::new(io::ErrorKind::TimedOut, "request timed out")),
326        }
327    }
328
329    pub async fn tcp_query(query: &QueryMessage, addr: SocketAddr) -> io::Result<Answer> {
330        let sock = &mut TcpStream::connect(&addr).await?;
331        sock.write_all(query.as_target().as_stream_slice()).await?;
332
333        loop {
334            let mut len_buf = [0u8; 2];
335            sock.read_exact(&mut len_buf).await?;
336            let len = u16::from_be_bytes(len_buf) as u64;
337            let mut buf = Vec::new();
338            sock.take(len).read_to_end(&mut buf).await?;
339            if let Ok(answer) = Message::from_octets(buf) {
340                if answer.is_answer(&query.as_message()) {
341                    return Ok(answer.into());
342                }
343            } else {
344                return Err(io::Error::new(io::ErrorKind::Other, "short buf"));
345            }
346        }
347    }
348
349    pub async fn udp_query(
350        query: &QueryMessage,
351        addr: SocketAddr,
352        recv_size: usize,
353    ) -> io::Result<Answer> {
354        let sock = Self::udp_bind(addr.is_ipv4()).await?;
355        #[cfg(not(feature = "awak-runtime"))]
356        sock.connect(addr).await?;
357        #[cfg(feature = "awak-runtime")]
358        sock.connect(addr)?;
359        let sent = sock.send(query.as_target().as_dgram_slice()).await?;
360        if sent != query.as_target().as_dgram_slice().len() {
361            return Err(io::Error::new(io::ErrorKind::Other, "short UDP send"));
362        }
363        loop {
364            let mut buf = vec![0; recv_size];
365            let len = sock.recv(&mut buf).await?;
366            buf.truncate(len);
367            let answer = match Message::from_octets(buf) {
368                Ok(answer) => answer,
369                Err(_) => continue,
370            };
371            if !answer.is_answer(&query.as_message()) {
372                continue;
373            }
374            return Ok(answer.into());
375        }
376    }
377
378    async fn udp_bind(v4: bool) -> io::Result<UdpSocket> {
379        let mut i = 0;
380        loop {
381            let local: SocketAddr = if v4 {
382                ([0u8; 4], 0).into()
383            } else {
384                ([0u16; 8], 0).into()
385            };
386            #[cfg(feature = "tokio-runtime")]
387            let binder = UdpSocket::bind(&local).await;
388            #[cfg(not(feature = "tokio-runtime"))]
389            let binder = UdpSocket::bind(local);
390            match binder {
391                Ok(sock) => return Ok(sock),
392                Err(err) => {
393                    if i == RETRY_RANDOM_PORT {
394                        return Err(err);
395                    } else {
396                        i += 1
397                    }
398                }
399            }
400        }
401    }
402}
403
404impl From<ServerConf> for ServerInfo {
405    fn from(conf: ServerConf) -> Self {
406        ServerInfo {
407            conf,
408            edns: Arc::new(AtomicBool::new(true)),
409        }
410    }
411}
412
413impl<'a> From<&'a ServerConf> for ServerInfo {
414    fn from(conf: &'a ServerConf) -> Self {
415        conf.clone().into()
416    }
417}
418
419#[derive(Clone, Debug)]
420struct ServerList {
421    servers: Vec<ServerInfo>,
422    start: Arc<AtomicUsize>,
423}
424
425impl ServerList {
426    pub fn from_conf<F>(conf: &ResolvConf, filter: F) -> Self
427    where
428        F: Fn(&ServerConf) -> bool,
429    {
430        ServerList {
431            servers: {
432                conf.servers
433                    .iter()
434                    .filter(|f| filter(f))
435                    .map(Into::into)
436                    .collect()
437            },
438            start: Arc::new(AtomicUsize::new(0)),
439        }
440    }
441
442    pub fn is_empty(&self) -> bool {
443        self.servers.is_empty()
444    }
445
446    pub fn counter(&self, rotate: bool) -> ServerListCounter {
447        let res = ServerListCounter::new(self);
448        if rotate {
449            self.rotate()
450        }
451        res
452    }
453
454    pub fn iter(&self) -> ServerListIter {
455        ServerListIter::new(self)
456    }
457
458    pub fn rotate(&self) {
459        self.start.fetch_add(1, Ordering::SeqCst);
460    }
461}
462
463impl<'a> IntoIterator for &'a ServerList {
464    type Item = &'a ServerInfo;
465    type IntoIter = ServerListIter<'a>;
466
467    fn into_iter(self) -> Self::IntoIter {
468        self.iter()
469    }
470}
471
472impl Deref for ServerList {
473    type Target = [ServerInfo];
474
475    fn deref(&self) -> &Self::Target {
476        self.servers.as_ref()
477    }
478}
479
480#[derive(Clone, Debug)]
481struct ServerListCounter {
482    cur: usize,
483    end: usize,
484}
485
486impl ServerListCounter {
487    fn new(list: &ServerList) -> Self {
488        if list.servers.is_empty() {
489            return ServerListCounter { cur: 0, end: 0 };
490        }
491
492        let start = list.start.load(Ordering::Relaxed) % list.servers.len();
493        ServerListCounter {
494            cur: start,
495            end: start + list.servers.len(),
496        }
497    }
498
499    #[allow(clippy::should_implement_trait)]
500    pub fn next(&mut self) -> bool {
501        let next = self.cur + 1;
502        if next < self.end {
503            self.cur = next;
504            true
505        } else {
506            false
507        }
508    }
509
510    pub fn info<'a>(&self, list: &'a ServerList) -> &'a ServerInfo {
511        &list[self.cur % list.servers.len()]
512    }
513}
514
515#[derive(Clone, Debug)]
516struct ServerListIter<'a> {
517    servers: &'a ServerList,
518    counter: ServerListCounter,
519}
520
521impl<'a> ServerListIter<'a> {
522    fn new(list: &'a ServerList) -> Self {
523        ServerListIter {
524            servers: list,
525            counter: ServerListCounter::new(list),
526        }
527    }
528}
529
530impl<'a> Iterator for ServerListIter<'a> {
531    type Item = &'a ServerInfo;
532
533    fn next(&mut self) -> Option<Self::Item> {
534        if self.counter.next() {
535            Some(self.counter.info(self.servers))
536        } else {
537            None
538        }
539    }
540}
541
542impl Deref for Answer {
543    type Target = Message<Vec<u8>>;
544
545    fn deref(&self) -> &Self::Target {
546        &self.message
547    }
548}
549
550impl AsRef<Message<Vec<u8>>> for Answer {
551    fn as_ref(&self) -> &Message<Vec<u8>> {
552        &self.message
553    }
554}