dns_resolver/
lib.rs

1use std::io;
2use std::net::{IpAddr, SocketAddr};
3use std::ops::Deref;
4use std::str::FromStr;
5use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
6use std::sync::{Arc, Mutex};
7use std::time::Duration;
8
9use domain::base::iana::{Rcode, Rtype};
10use domain::base::message::Message;
11use domain::base::message_builder::{AdditionalBuilder, MessageBuilder, StreamTarget};
12use domain::base::name::{Name, ToName};
13use domain::base::question::Question;
14use domain::rdata::A;
15use lru_time_cache::LruCache;
16
17const DEFAULT_CACHE_EXPIRE: Duration = Duration::from_secs(10 * 60);
18
19cfg_if::cfg_if! {
20    if #[cfg(feature = "slings-runtime")] {
21        use slings::{
22            net::{TcpStream, UdpSocket},
23            time::timeout,
24        };
25        use futures_util::{AsyncReadExt, AsyncWriteExt};
26    }
27    else if #[cfg(feature = "awak-runtime")] {
28        use awak::{
29            net::{TcpStream, UdpSocket},
30            time::timeout,
31        };
32        use futures_util::{AsyncReadExt, AsyncWriteExt};
33    }
34    else if #[cfg(feature = "tokio-runtime")] {
35        use tokio::{
36            io::{AsyncReadExt, AsyncWriteExt},
37            net::{TcpStream, UdpSocket},
38            time::timeout,
39        };
40    }
41}
42
43mod conf;
44
45pub use conf::{ResolvConf, ResolvOptions};
46use conf::{ServerConf, Transport};
47
48const RETRY_RANDOM_PORT: usize = 10;
49
50pub struct Resolver {
51    preferred: ServerList,
52    stream: ServerList,
53    options: ResolvOptions,
54    lru_cache: Mutex<LruCache<String, Vec<IpAddr>>>,
55}
56
57impl Resolver {
58    pub fn new() -> Self {
59        Self::from_conf(ResolvConf::default())
60    }
61
62    pub fn from_conf(conf: ResolvConf) -> Self {
63        Resolver {
64            preferred: ServerList::from_conf(&conf, |s| s.transport.is_preferred()),
65            stream: ServerList::from_conf(&conf, |s| s.transport.is_stream()),
66            options: conf.options,
67            lru_cache: Mutex::new(LruCache::with_expiry_duration(DEFAULT_CACHE_EXPIRE)),
68        }
69    }
70
71    fn options(&self) -> &ResolvOptions {
72        &self.options
73    }
74
75    pub async fn query<N: ToName, Q: Into<Question<N>>>(&self, question: Q) -> io::Result<Answer> {
76        Query::new(self)?
77            .run(Query::create_message(question.into()))
78            .await
79    }
80
81    fn try_resolve_from_cache(&self, key: &str) -> Option<Vec<IpAddr>> {
82        self.lru_cache.lock().unwrap().get(key).cloned()
83    }
84
85    fn insert_into_cache(&self, key: &str, val: Vec<IpAddr>) {
86        self.lru_cache.lock().unwrap().insert(key.to_string(), val);
87    }
88
89    pub async fn lookup_host<T: AsRef<str>>(&self, host: T) -> io::Result<Vec<IpAddr>> {
90        let host = &host.as_ref();
91        if let Some(v) = self.try_resolve_from_cache(host) {
92            return Ok(v);
93        }
94
95        let qname = &Name::<Vec<u8>>::from_str(host)
96            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
97        let answer = self.query((&qname, Rtype::A)).await?;
98        let name = answer.canonical_name();
99        let records = answer
100            .answer()
101            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
102            .limit_to::<A>();
103
104        let mut ips = vec![];
105        for record in records.flatten() {
106            if Some(*record.owner()) == name {
107                ips.push(record.data().addr().into());
108            }
109        }
110        self.insert_into_cache(host, ips.clone());
111        Ok(ips)
112    }
113
114    pub async fn query_message(&self, message: QueryMessage) -> io::Result<Answer> {
115        Query::new(self)?.run(message).await
116    }
117}
118
119impl Default for Resolver {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125pub struct Query<'a> {
126    resolver: &'a Resolver,
127    preferred: bool,
128    attempt: usize,
129    counter: ServerListCounter,
130    error: io::Result<Answer>,
131}
132
133impl<'a> Query<'a> {
134    pub fn new(resolver: &'a Resolver) -> io::Result<Self> {
135        let (preferred, counter) = if resolver.options().use_vc || resolver.preferred.is_empty() {
136            if resolver.stream.is_empty() {
137                return Err(io::Error::new(
138                    io::ErrorKind::NotFound,
139                    "no servers available",
140                ));
141            }
142            (false, resolver.stream.counter(resolver.options().rotate))
143        } else {
144            (true, resolver.preferred.counter(resolver.options().rotate))
145        };
146        Ok(Query {
147            resolver,
148            preferred,
149            attempt: 0,
150            counter,
151            error: Err(io::Error::new(io::ErrorKind::TimedOut, "all timed out")),
152        })
153    }
154
155    pub async fn run(mut self, mut message: QueryMessage) -> io::Result<Answer> {
156        loop {
157            match self.run_query(&mut message).await {
158                Ok(answer) => {
159                    if answer.header().rcode() == Rcode::FORMERR
160                        && self.current_server().does_edns()
161                    {
162                        self.current_server().disable_edns();
163                        continue;
164                    } else if answer.header().rcode() == Rcode::SERVFAIL {
165                        self.update_error_servfail(answer);
166                    } else if answer.header().tc()
167                        && self.preferred
168                        && !self.resolver.options().ign_tc
169                    {
170                        if self.switch_to_stream() {
171                            continue;
172                        } else {
173                            return Ok(answer);
174                        }
175                    } else {
176                        return Ok(answer);
177                    }
178                }
179                Err(err) => self.update_error(err),
180            }
181            if !self.next_server() {
182                return self.error;
183            }
184        }
185    }
186
187    fn create_message(question: Question<impl ToName>) -> QueryMessage {
188        let mut message =
189            MessageBuilder::from_target(StreamTarget::new_vec()).unwrap();
190        message.header_mut().set_rd(true);
191        let mut message = message.question();
192        message.push(question).unwrap();
193        message.additional()
194    }
195
196    async fn run_query(&mut self, message: &mut QueryMessage) -> io::Result<Answer> {
197        let server = self.current_server();
198        server.prepare_message(message);
199        server.query(message).await
200    }
201
202    fn current_server(&self) -> &ServerInfo {
203        let list = if self.preferred {
204            &self.resolver.preferred
205        } else {
206            &self.resolver.stream
207        };
208        self.counter.info(list)
209    }
210
211    fn update_error(&mut self, err: io::Error) {
212        if err.kind() != io::ErrorKind::TimedOut && self.error.is_err() {
213            self.error = Err(err)
214        }
215    }
216
217    fn update_error_servfail(&mut self, answer: Answer) {
218        self.error = Ok(answer)
219    }
220
221    fn switch_to_stream(&mut self) -> bool {
222        if !self.preferred {
223            return false;
224        }
225        self.preferred = false;
226        self.attempt = 0;
227        self.counter = self.resolver.stream.counter(self.resolver.options().rotate);
228        true
229    }
230
231    fn next_server(&mut self) -> bool {
232        if self.counter.next() {
233            return true;
234        }
235        self.attempt += 1;
236        if self.attempt >= self.resolver.options().attempts {
237            return false;
238        }
239        self.counter = if self.preferred {
240            self.resolver
241                .preferred
242                .counter(self.resolver.options().rotate)
243        } else {
244            self.resolver.stream.counter(self.resolver.options().rotate)
245        };
246        true
247    }
248}
249
250pub type QueryMessage = AdditionalBuilder<StreamTarget<Vec<u8>>>;
251
252#[derive(Clone)]
253pub struct Answer {
254    message: Message<Vec<u8>>,
255}
256
257impl Answer {
258    pub fn is_final(&self) -> bool {
259        (self.message.header().rcode() == Rcode::NOERROR
260            || self.message.header().rcode() == Rcode::NXDOMAIN)
261            && !self.message.header().tc()
262    }
263
264    pub fn is_truncated(&self) -> bool {
265        self.message.header().tc()
266    }
267
268    pub fn into_message(self) -> Message<Vec<u8>> {
269        self.message
270    }
271}
272
273impl From<Message<Vec<u8>>> for Answer {
274    fn from(message: Message<Vec<u8>>) -> Self {
275        Answer { message }
276    }
277}
278
279#[derive(Clone, Debug)]
280struct ServerInfo {
281    conf: ServerConf,
282    edns: Arc<AtomicBool>,
283}
284
285impl ServerInfo {
286    pub fn does_edns(&self) -> bool {
287        self.edns.load(Ordering::Relaxed)
288    }
289
290    pub fn disable_edns(&self) {
291        self.edns.store(false, Ordering::Relaxed);
292    }
293
294    pub fn prepare_message(&self, query: &mut QueryMessage) {
295        query.rewind();
296        if self.does_edns() {
297            query
298                .opt(|opt| {
299                    opt.set_udp_payload_size(self.conf.udp_payload_size);
300                    Ok(())
301                })
302                .unwrap();
303        }
304    }
305
306    pub async fn query(&self, query: &QueryMessage) -> io::Result<Answer> {
307        let res = match self.conf.transport {
308            Transport::Udp => {
309                timeout(
310                    self.conf.request_timeout,
311                    Self::udp_query(query, self.conf.addr, self.conf.recv_size),
312                )
313                .await
314            }
315            Transport::Tcp => {
316                timeout(
317                    self.conf.request_timeout,
318                    Self::tcp_query(query, self.conf.addr),
319                )
320                .await
321            }
322        };
323        match res {
324            Ok(Ok(answer)) => Ok(answer),
325            Ok(Err(err)) => Err(err),
326            Err(_) => Err(io::Error::new(io::ErrorKind::TimedOut, "request timed out")),
327        }
328    }
329
330    pub async fn tcp_query(query: &QueryMessage, addr: SocketAddr) -> io::Result<Answer> {
331        let sock = &mut TcpStream::connect(&addr).await?;
332        sock.write_all(query.as_target().as_stream_slice()).await?;
333
334        loop {
335            let mut len_buf = [0u8; 2];
336            sock.read_exact(&mut len_buf).await?;
337            let len = u16::from_be_bytes(len_buf) as u64;
338            let mut buf = Vec::new();
339            sock.take(len).read_to_end(&mut buf).await?;
340            if let Ok(answer) = Message::from_octets(buf) {
341                if answer.is_answer(&query.as_message()) {
342                    return Ok(answer.into());
343                }
344            } else {
345                return Err(io::Error::new(io::ErrorKind::Other, "short buf"));
346            }
347        }
348    }
349
350    pub async fn udp_query(
351        query: &QueryMessage,
352        addr: SocketAddr,
353        recv_size: usize,
354    ) -> io::Result<Answer> {
355        let sock = Self::udp_bind(addr.is_ipv4()).await?;
356        #[cfg(not(feature = "awak-runtime"))]
357        sock.connect(addr).await?;
358        #[cfg(feature = "awak-runtime")]
359        sock.connect(addr)?;
360        let sent = sock.send(query.as_target().as_dgram_slice()).await?;
361        if sent != query.as_target().as_dgram_slice().len() {
362            return Err(io::Error::new(io::ErrorKind::Other, "short UDP send"));
363        }
364        loop {
365            let mut buf = vec![0; recv_size];
366            let len = sock.recv(&mut buf).await?;
367            buf.truncate(len);
368            let answer = match Message::from_octets(buf) {
369                Ok(answer) => answer,
370                Err(_) => continue,
371            };
372            if !answer.is_answer(&query.as_message()) {
373                continue;
374            }
375            return Ok(answer.into());
376        }
377    }
378
379    async fn udp_bind(v4: bool) -> io::Result<UdpSocket> {
380        let mut i = 0;
381        loop {
382            let local: SocketAddr = if v4 {
383                ([0u8; 4], 0).into()
384            } else {
385                ([0u16; 8], 0).into()
386            };
387            #[cfg(feature = "tokio-runtime")]
388            let binder = UdpSocket::bind(&local).await;
389            #[cfg(not(feature = "tokio-runtime"))]
390            let binder = UdpSocket::bind(local);
391            match binder {
392                Ok(sock) => return Ok(sock),
393                Err(err) => {
394                    if i == RETRY_RANDOM_PORT {
395                        return Err(err);
396                    } else {
397                        i += 1
398                    }
399                }
400            }
401        }
402    }
403}
404
405impl From<ServerConf> for ServerInfo {
406    fn from(conf: ServerConf) -> Self {
407        ServerInfo {
408            conf,
409            edns: Arc::new(AtomicBool::new(true)),
410        }
411    }
412}
413
414impl<'a> From<&'a ServerConf> for ServerInfo {
415    fn from(conf: &'a ServerConf) -> Self {
416        conf.clone().into()
417    }
418}
419
420#[derive(Clone, Debug)]
421struct ServerList {
422    servers: Vec<ServerInfo>,
423    start: Arc<AtomicUsize>,
424}
425
426impl ServerList {
427    pub fn from_conf<F>(conf: &ResolvConf, filter: F) -> Self
428    where
429        F: Fn(&ServerConf) -> bool,
430    {
431        ServerList {
432            servers: {
433                conf.servers
434                    .iter()
435                    .filter(|f| filter(f))
436                    .map(Into::into)
437                    .collect()
438            },
439            start: Arc::new(AtomicUsize::new(0)),
440        }
441    }
442
443    pub fn is_empty(&self) -> bool {
444        self.servers.is_empty()
445    }
446
447    pub fn counter(&self, rotate: bool) -> ServerListCounter {
448        let res = ServerListCounter::new(self);
449        if rotate {
450            self.rotate()
451        }
452        res
453    }
454
455    pub fn iter(&self) -> ServerListIter<'_> {
456        ServerListIter::new(self)
457    }
458
459    pub fn rotate(&self) {
460        self.start.fetch_add(1, Ordering::SeqCst);
461    }
462}
463
464impl<'a> IntoIterator for &'a ServerList {
465    type Item = &'a ServerInfo;
466    type IntoIter = ServerListIter<'a>;
467
468    fn into_iter(self) -> Self::IntoIter {
469        self.iter()
470    }
471}
472
473impl Deref for ServerList {
474    type Target = [ServerInfo];
475
476    fn deref(&self) -> &Self::Target {
477        self.servers.as_ref()
478    }
479}
480
481#[derive(Clone, Debug)]
482struct ServerListCounter {
483    cur: usize,
484    end: usize,
485}
486
487impl ServerListCounter {
488    fn new(list: &ServerList) -> Self {
489        if list.servers.is_empty() {
490            return ServerListCounter { cur: 0, end: 0 };
491        }
492
493        let start = list.start.load(Ordering::Relaxed) % list.servers.len();
494        ServerListCounter {
495            cur: start,
496            end: start + list.servers.len(),
497        }
498    }
499
500    #[allow(clippy::should_implement_trait)]
501    pub fn next(&mut self) -> bool {
502        let next = self.cur + 1;
503        if next < self.end {
504            self.cur = next;
505            true
506        } else {
507            false
508        }
509    }
510
511    pub fn info<'a>(&self, list: &'a ServerList) -> &'a ServerInfo {
512        &list[self.cur % list.servers.len()]
513    }
514}
515
516#[derive(Clone, Debug)]
517struct ServerListIter<'a> {
518    servers: &'a ServerList,
519    counter: ServerListCounter,
520}
521
522impl<'a> ServerListIter<'a> {
523    fn new(list: &'a ServerList) -> Self {
524        ServerListIter {
525            servers: list,
526            counter: ServerListCounter::new(list),
527        }
528    }
529}
530
531impl<'a> Iterator for ServerListIter<'a> {
532    type Item = &'a ServerInfo;
533
534    fn next(&mut self) -> Option<Self::Item> {
535        if self.counter.next() {
536            Some(self.counter.info(self.servers))
537        } else {
538            None
539        }
540    }
541}
542
543impl Deref for Answer {
544    type Target = Message<Vec<u8>>;
545
546    fn deref(&self) -> &Self::Target {
547        &self.message
548    }
549}
550
551impl AsRef<Message<Vec<u8>>> for Answer {
552    fn as_ref(&self) -> &Message<Vec<u8>> {
553        &self.message
554    }
555}