domain_resolv/stub/
mod.rs

1/// A stub resolver.
2///
3/// The most simple resolver possible simply relays all messages to one of a
4/// set of pre-configured resolvers that will do the actual work. This is
5/// equivalent to what the resolver part of the C library does. This module
6/// provides such a stub resolver that emulates this C resolver as closely
7/// as possible, in particular in the way it is being configured.
8///
9/// The main type is [`StubResolver`] that implements the [`Resolver`] trait
10/// and thus can be used with the various lookup functions.
11
12use std::{io, ops};
13use std::future::Future;
14use std::net::{IpAddr, SocketAddr};
15use std::pin::Pin;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
18use std::vec::Vec;
19use bytes::Bytes;
20use futures::future::FutureExt;
21#[cfg(feature = "sync")] use tokio::runtime;
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::{TcpStream, UdpSocket};
24use tokio::time::timeout;
25use domain::base::iana::Rcode;
26use domain::base::message::Message;
27use domain::base::message_builder::{
28    AdditionalBuilder, MessageBuilder, StreamTarget
29};
30use domain::base::name::{ToDname, ToRelativeDname};
31use domain::base::octets::Octets512;
32use domain::base::question::Question;
33use crate::lookup::addr::{lookup_addr, FoundAddrs};
34use crate::lookup::host::{lookup_host, search_host, FoundHosts};
35use crate::lookup::srv::{lookup_srv, FoundSrvs, SrvError};
36use crate::resolver::{Resolver, SearchNames};
37use self::conf::{
38    ResolvConf, ResolvOptions, SearchSuffix, ServerConf, Transport
39};
40
41
42//------------ Sub-modules ---------------------------------------------------
43
44pub mod conf;
45
46
47//------------ Module Configuration ------------------------------------------
48
49/// How many times do we try a new random port if we get ‘address in use.’
50const RETRY_RANDOM_PORT: usize = 10;
51
52
53//------------ StubResolver --------------------------------------------------
54
55/// A DNS stub resolver.
56///
57/// This type collects all information making it possible to start DNS
58/// queries. You can create a new resoler using the system’s configuration
59/// using the [`new()`] associate function or using your own configuration
60/// with [`from_conf()`].
61///
62/// Stub resolver values can be cloned relatively cheaply as they keep all
63/// information behind an arc.
64///
65/// If you want to run a single query or lookup on a resolver synchronously,
66/// you can do so simply by using the [`run()`] or [`run_with_conf()`]
67/// associated functions.
68///
69/// [`new()`]: #method.new
70/// [`from_conf()`]: #method.from_conf
71/// [`query()`]: #method.query
72/// [`run()`]: #method.run
73/// [`run_with_conf()`]: #method.run_with_conf
74#[derive(Clone, Debug)]
75pub struct StubResolver {
76    /// Preferred servers.
77    preferred: ServerList,
78
79    /// Streaming servers.
80    stream: ServerList,
81
82    /// Resolver options.
83    options: ResolvOptions,
84}
85
86
87impl StubResolver {
88    /// Creates a new resolver using the system’s default configuration.
89    pub fn new() -> Self {
90        Self::from_conf(ResolvConf::default())
91    }
92
93    /// Creates a new resolver using the given configuraiton.
94    pub fn from_conf(conf: ResolvConf) -> Self {
95        StubResolver {
96            preferred: ServerList::from_conf(&conf, |s| {
97                s.transport.is_preferred()
98            }),
99            stream: ServerList::from_conf(&conf, |s| {
100                s.transport.is_stream()
101            }),
102            options: conf.options
103        }
104    }
105
106    pub fn options(&self) -> &ResolvOptions {
107        &self.options
108    }
109
110    pub async fn query<N: ToDname, Q: Into<Question<N>>>(
111        &self, question: Q
112    ) -> Result<Answer, io::Error> {
113        Query::new(self)?.run(
114            Query::create_message(question.into())
115        ).await
116    }
117
118    async fn query_message(
119        &self, message: QueryMessage
120    ) -> Result<Answer, io::Error> {
121        Query::new(self)?.run(message).await
122    }
123}
124
125impl StubResolver {
126    pub async fn lookup_addr(
127        &self, addr: IpAddr
128    ) -> Result<FoundAddrs<&Self>, io::Error> {
129        lookup_addr(&self, addr).await
130    }
131
132    pub async fn lookup_host(
133        &self, qname: impl ToDname
134    ) -> Result<FoundHosts<&Self>, io::Error> {
135        lookup_host(&self, qname).await
136    }
137
138    pub async fn search_host(
139        &self, qname: impl ToRelativeDname
140    ) -> Result<FoundHosts<&Self>, io::Error> {
141        search_host(&self, qname).await
142    }
143
144    pub async fn lookup_srv(
145        &self,
146        service: impl ToRelativeDname,
147        name: impl ToDname,
148        fallback_port: u16
149    ) -> Result<Option<FoundSrvs>, SrvError> {
150        lookup_srv(&self, service, name, fallback_port).await
151    }
152}
153
154#[cfg(feature = "sync")]
155impl StubResolver {
156    /// Synchronously perform a DNS operation atop a standard resolver.
157    ///
158    /// This associated functions removes almost all boiler plate for the
159    /// case that you want to perform some DNS operation, either a query or
160    /// lookup, on a resolver using the system’s configuration and wait for
161    /// the result.
162    ///
163    /// The only argument is a closure taking a reference to a `StubResolver`
164    /// and returning a future. Whatever that future resolves to will be
165    /// returned.
166    pub fn run<R, F>(op: F) -> R::Output
167    where
168        R: Future + Send + 'static,
169        R::Output: Send + 'static,
170        F: FnOnce(StubResolver) -> R + Send + 'static,
171    {
172        Self::run_with_conf(ResolvConf::default(), op)
173    }
174
175    /// Synchronously perform a DNS operation atop a configured resolver.
176    ///
177    /// This is like [`run()`] but also takes a resolver configuration for
178    /// tailor-making your own resolver.
179    ///
180    /// [`run()`]: #method.run
181    pub fn run_with_conf<R, F>(
182        conf: ResolvConf,
183        op: F
184    ) -> R::Output
185    where
186        R: Future + Send + 'static,
187        R::Output: Send + 'static,
188        F: FnOnce(StubResolver) -> R + Send + 'static,
189    {
190        let resolver = Self::from_conf(conf);
191        let mut runtime = runtime::Builder::new()
192            .basic_scheduler()
193            .enable_all()
194            .build().unwrap();
195        runtime.block_on(op(resolver))
196    }
197}
198
199impl Default for StubResolver {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205impl<'a> Resolver for &'a StubResolver {
206    type Octets = Bytes;
207    type Answer = Answer;
208    type Query = Pin<Box<dyn Future<Output = Result<Answer, io::Error>> + 'a>>;
209
210    fn query<N, Q>(&self, question: Q) -> Self::Query
211    where N: ToDname, Q: Into<Question<N>> {
212        let message = Query::create_message(question.into());
213        self.query_message(message).boxed()
214    }
215}
216
217impl<'a> SearchNames for &'a StubResolver {
218    type Name = SearchSuffix;
219    type Iter = SearchIter<'a>;
220
221    fn search_iter(&self) -> Self::Iter {
222        SearchIter {
223            resolver: self,
224            pos: 0
225        }
226    }
227}
228
229
230//------------ Query ---------------------------------------------------------
231
232pub struct Query<'a> {
233    /// The resolver whose configuration we are using.
234    resolver: &'a StubResolver,
235
236    /// Are we still in the preferred server list or have gone streaming?
237    preferred: bool,
238
239    /// The number of attempts, starting with zero.
240    attempt: usize,
241
242    /// The index in the server list we currently trying.
243    counter: ServerListCounter,
244
245    /// The preferred error to return.
246    ///
247    /// Every time we finish a single query, we see if we can update this with
248    /// a better one. If we finally have to fail, we return this result. This
249    /// is a result so we can return a servfail answer if that is the only
250    /// answer we get. (Remember, SERVFAIL is returned for a bogus answer, so
251    /// you might want to know.)
252    error: Result<Answer, io::Error>,
253}
254
255impl<'a> Query<'a> {
256    pub fn new(
257        resolver: &'a StubResolver,
258    ) -> Result<Self, io::Error> {
259        let (preferred, counter) = if
260            resolver.options().use_vc ||
261            resolver.preferred.is_empty()
262        {
263            if resolver.stream.is_empty() {
264                return Err(
265                    io::Error::new(
266                        io::ErrorKind::NotFound,
267                        "no servers available"
268                    )
269                )
270            }
271            (false, resolver.stream.counter(resolver.options().rotate))
272        }
273        else {
274            (true, resolver.preferred.counter(resolver.options().rotate))
275        };
276        Ok(Query {
277            resolver,
278            preferred,
279            attempt: 0,
280            counter,
281            error: Err(io::Error::new(
282                io::ErrorKind::TimedOut,
283                "all timed out"
284            ))
285        })
286    }
287
288    pub async fn run(
289        mut self,
290        mut message: QueryMessage,
291    ) -> Result<Answer, io::Error> {
292        loop {
293            match self.run_query(&mut message).await {
294                Ok(answer) => {
295                    if answer.header().rcode() == Rcode::FormErr
296                        && self.current_server().does_edns()
297                    {
298                        // FORMERR with EDNS: turn off EDNS and try again.
299                        self.current_server().disable_edns();
300                        continue
301                    }
302                    else if answer.header().rcode() == Rcode::ServFail {
303                        // SERVFAIL: go to next server.
304                        self.update_error_servfail(answer);
305                    }
306                    else if answer.header().tc() && self.preferred
307                        && !self.resolver.options().ign_tc
308                    {
309                        // Truncated. If we can, switch to stream transports
310                        // and try again. Otherwise return the truncated
311                        // answer.
312                        if self.switch_to_stream() {
313                            continue
314                        }
315                        else {
316                            return Ok(answer)
317                        }
318                    }
319                    else {
320                        // I guess we have an answer ...
321                        return Ok(answer);
322                    }
323                }
324                Err(err) => self.update_error(err),
325            }
326            if !self.next_server() {
327                return self.error
328            }
329        }
330    }
331
332    fn create_message(
333        question: Question<impl ToDname>
334    ) -> QueryMessage {
335        let mut message = MessageBuilder::from_target(
336            StreamTarget::new(Octets512::new()).unwrap()
337        ).unwrap();
338        message.header_mut().set_rd(true);
339        let mut message = message.question();
340        message.push(question).unwrap();
341        message.additional()
342    }
343
344    async fn run_query(
345        &mut self, message: &mut QueryMessage
346    ) -> Result<Answer, io::Error> {
347        let server = self.current_server();
348        server.prepare_message(message);
349        server.query(message).await
350    }
351
352    fn current_server(&self) -> &ServerInfo {
353        let list = if self.preferred { &self.resolver.preferred }
354                   else { &self.resolver.stream };
355        self.counter.info(list)
356    }
357
358    fn update_error(&mut self, err: io::Error) {
359        // We keep the last error except for timeouts or if we have a servfail
360        // answer already. Since we start with a timeout, we still get a that
361        // if everything times out.
362        if err.kind() != io::ErrorKind::TimedOut && self.error.is_err() {
363            self.error = Err(err)
364        }
365    }
366
367    fn update_error_servfail(&mut self, answer: Answer) {
368        self.error = Ok(answer)
369    }
370
371    fn switch_to_stream(&mut self) -> bool {
372        if !self.preferred {
373            // We already did this.
374            return false
375        }
376        self.preferred = false;
377        self.attempt = 0;
378        self.counter = self.resolver.stream.counter(
379            self.resolver.options().rotate
380        );
381        true
382    }
383
384    fn next_server(&mut self) -> bool {
385        if self.counter.next() {
386            return true
387        }
388        self.attempt += 1;
389        if self.attempt >= self.resolver.options().attempts {
390            return false
391        }
392        self.counter = if self.preferred {
393            self.resolver.preferred.counter(self.resolver.options().rotate)
394        }
395        else {
396            self.resolver.stream.counter(self.resolver.options().rotate)
397        };
398        true
399    }
400}
401
402
403//------------ QueryMessage --------------------------------------------------
404
405// XXX This needs to be re-evaluated if we start adding OPTtions to the query.
406pub(super) type QueryMessage = AdditionalBuilder<StreamTarget<Octets512>>;
407
408
409//------------ Answer --------------------------------------------------------
410
411/// The answer to a question.
412///
413/// This type is a wrapper around the DNS [`Message`] containing the answer
414/// that provides some additional information.
415#[derive(Clone)]
416pub struct Answer {
417    message: Message<Bytes>,
418}
419
420impl Answer {
421    /// Returns whether the answer is a final answer to be returned.
422    pub fn is_final(&self) -> bool {
423        (self.message.header().rcode() == Rcode::NoError
424            || self.message.header().rcode() == Rcode::NXDomain)
425        && !self.message.header().tc()
426    }
427
428    /// Returns whether the answer is truncated.
429    pub fn is_truncated(&self) -> bool {
430        self.message.header().tc()
431    }
432
433    pub fn into_message(self) -> Message<Bytes> {
434        self.message
435    }
436}
437
438impl From<Message<Bytes>> for Answer {
439    fn from(message: Message<Bytes>) -> Self {
440        Answer { message }
441    }
442}
443
444
445//------------ ServerInfo ----------------------------------------------------
446
447#[derive(Clone, Debug)]
448struct ServerInfo {
449    /// The basic server configuration.
450    conf: ServerConf,
451
452    /// Whether this server supports EDNS.
453    ///
454    /// We start out with assuming it does and unset it if we get a FORMERR.
455    edns: Arc<AtomicBool>,
456}
457
458impl ServerInfo {
459    pub fn does_edns(&self) -> bool {
460        self.edns.load(Ordering::Relaxed)
461    }
462        
463    pub fn disable_edns(&self) {
464        self.edns.store(false, Ordering::Relaxed);
465    }
466
467    pub fn prepare_message(&self, query: &mut QueryMessage) {
468        query.rewind();
469        if self.does_edns() {
470            query.opt(|opt| {
471                opt.set_udp_payload_size(self.conf.udp_payload_size);
472                Ok(())
473            }).unwrap();
474        }
475    }
476
477    pub async fn query(
478        &self, query: &QueryMessage
479    ) -> Result<Answer, io::Error> {
480        let res = match self.conf.transport {
481            Transport::Udp => {
482               timeout(
483                   self.conf.request_timeout,
484                   Self::udp_query(query, self.conf.addr, self.conf.recv_size)
485                ).await
486            }
487            Transport::Tcp => {
488               timeout(
489                   self.conf.request_timeout,
490                   Self::tcp_query(query, self.conf.addr)
491                ).await
492            }
493        };
494        match res {
495            Ok(Ok(answer)) => Ok(answer),
496            Ok(Err(err)) => Err(err),
497            Err(_) => {
498                Err(io::Error::new(
499                    io::ErrorKind::TimedOut,
500                    "request timed out"
501                ))
502            }
503        }
504    }
505
506    pub async fn tcp_query(
507        query: &QueryMessage, addr: SocketAddr
508    ) -> Result<Answer, io::Error> {
509        let mut sock = TcpStream::connect(&addr).await?;
510        sock.write_all(query.as_target().as_stream_slice()).await?;
511
512        // This loop can be infinite because we have a timeout on this whole
513        // thing, anyway.
514        loop {
515            let mut buf = Vec::new();
516            let len = sock.read_u16().await? as u64;
517            AsyncReadExt::take(&mut sock, len).read_to_end(&mut buf).await?;
518            if let Ok(answer) = Message::from_octets(buf.into()) {
519                if answer.is_answer(&query.as_message()) {
520                    return Ok(answer.into())
521                }
522                // else try with the next message.
523            }
524            else {
525                return Err(io::Error::new(io::ErrorKind::Other, "short buf"))
526            }
527        }
528    }
529
530    pub async fn udp_query(
531        query: &QueryMessage, addr: SocketAddr, recv_size: usize
532    ) -> Result<Answer, io::Error> {
533        let mut sock = Self::udp_bind(addr.is_ipv4()).await?;
534        sock.connect(addr).await?;
535        let sent = sock.send(query.as_target().as_dgram_slice()).await?;
536        if sent != query.as_target().as_dgram_slice().len() {
537            return Err(io::Error::new(io::ErrorKind::Other, "short UDP send"))
538        }
539        loop {
540            let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here.
541            let len = sock.recv(&mut buf).await?;
542            buf.truncate(len);
543            
544            // We ignore garbage since there is a timer on this whole thing.
545            let answer = match Message::from_octets(buf.into()) {
546                Ok(answer) => answer,
547                Err(_) => continue,
548            };
549            if !answer.is_answer(&query.as_message()) {
550                continue
551            }
552            return Ok(answer.into())
553        }
554    }
555
556    async fn udp_bind(v4: bool) -> Result<UdpSocket, io::Error> {
557        let mut i = 0;
558        loop {
559            let local: SocketAddr = if v4 { ([0u8; 4], 0).into() }
560                        else { ([0u16; 8], 0).into() };
561            match UdpSocket::bind(&local).await {
562                Ok(sock) => return Ok(sock),
563                Err(err) => {
564                    if i == RETRY_RANDOM_PORT {
565                        return Err(err);
566                    }
567                    else {
568                        i += 1
569                    }
570                }
571            }
572        }
573    }
574}
575
576impl From<ServerConf> for ServerInfo {
577    fn from(conf: ServerConf) -> Self {
578        ServerInfo {
579            conf,
580            edns: Arc::new(AtomicBool::new(true))
581        }
582    }
583}
584
585impl<'a> From<&'a ServerConf> for ServerInfo {
586    fn from(conf: &'a ServerConf) -> Self {
587        conf.clone().into()
588    }
589}
590
591
592//------------ ServerList ----------------------------------------------------
593
594#[derive(Clone, Debug)]
595struct ServerList {
596    /// The actual list of servers.
597    servers: Vec<ServerInfo>,
598
599    /// Where to start accessing the list.
600    ///
601    /// In rotate mode, this value will always keep growing and will have to
602    /// be used modulo `servers`’s length.
603    ///
604    /// When it eventually wraps around the end of usize’s range, there will
605    /// be a jump in rotation. Since that will happen only oh-so-often, we
606    /// accept that in favour of simpler code.
607    start: Arc<AtomicUsize>,
608}
609
610impl ServerList {
611    pub fn from_conf<F>(conf: &ResolvConf, filter: F) -> Self
612    where F: Fn(&ServerConf) -> bool {
613        ServerList {
614            servers: {
615                conf.servers.iter().filter(|f| filter(*f))
616                    .map(Into::into).collect()
617            },
618            start: Arc::new(AtomicUsize::new(0)),
619        }
620    }
621
622    pub fn is_empty(&self) -> bool {
623        self.servers.is_empty()
624    }
625
626    pub fn counter(&self, rotate: bool) -> ServerListCounter {
627        let res = ServerListCounter::new(self);
628        if rotate {
629            self.rotate()
630        }
631        res
632    }
633
634    pub fn iter(&self) -> ServerListIter {
635        ServerListIter::new(self)
636    }
637
638    pub fn rotate(&self) {
639        self.start.fetch_add(1, Ordering::SeqCst);
640    }
641}
642
643impl<'a> IntoIterator for &'a ServerList {
644    type Item = &'a ServerInfo;
645    type IntoIter = ServerListIter<'a>;
646
647    fn into_iter(self) -> Self::IntoIter {
648        self.iter()
649    }
650}
651
652impl ops::Deref for ServerList {
653    type Target = [ServerInfo];
654
655    fn deref(&self) -> &Self::Target {
656        self.servers.as_ref()
657    }
658}
659
660
661//------------ ServerListCounter ---------------------------------------------
662
663#[derive(Clone, Debug)]
664struct ServerListCounter {
665    cur: usize,
666    end: usize,
667}
668
669impl ServerListCounter {
670    fn new(list: &ServerList) -> Self {
671        if list.servers.is_empty() {
672            return ServerListCounter { cur: 0, end: 0 };
673        }
674
675        // We modulo the start value here to prevent hick-ups towards the
676        // end of usize’s range.
677        let start = list.start.load(Ordering::Relaxed) % list.servers.len();
678        ServerListCounter {
679            cur: start,
680            end: start + list.servers.len(),
681        }
682    }
683
684    #[allow(clippy::should_implement_trait)]
685    pub fn next(&mut self) -> bool {
686        let next = self.cur + 1;
687        if next < self.end {
688            self.cur = next;
689            true
690        }
691        else {
692            false
693        }
694    }
695
696    pub fn info<'a>(&self, list: &'a ServerList) -> &'a ServerInfo {
697        &list[self.cur % list.servers.len()]
698    }
699}
700
701
702
703//------------ ServerListIter ------------------------------------------------
704
705#[derive(Clone, Debug)]
706struct ServerListIter<'a> {
707    servers: &'a ServerList,
708    counter: ServerListCounter,
709}
710
711impl<'a> ServerListIter<'a> {
712    fn new(list: &'a ServerList) -> Self {
713        ServerListIter {
714            servers: list,
715            counter: ServerListCounter::new(list)
716        }
717    }
718}
719
720impl<'a> Iterator for ServerListIter<'a> {
721    type Item = &'a ServerInfo;
722
723    fn next(&mut self) -> Option<Self::Item> {
724        if self.counter.next() {
725            Some(self.counter.info(self.servers))
726        }
727        else {
728            None
729        }
730    }
731}
732
733
734impl ops::Deref for Answer {
735    type Target = Message<Bytes>;
736
737    fn deref(&self) -> &Self::Target {
738        &self.message
739    }
740}
741
742impl AsRef<Message<Bytes>> for Answer {
743    fn as_ref(&self) -> &Message<Bytes> {
744        &self.message
745    }
746}
747
748
749//------------ SearchIter ----------------------------------------------------
750
751#[derive(Clone, Debug)]
752pub struct SearchIter<'a> {
753    resolver: &'a StubResolver,
754    pos: usize,
755}
756
757impl<'a> Iterator for SearchIter<'a> {
758    type Item = SearchSuffix;
759
760    fn next(&mut self) -> Option<Self::Item> {
761        if let Some(res) = self.resolver.options().search.get(self.pos) {
762            self.pos += 1;
763            Some(res.clone())
764        }
765        else {
766            None
767        }
768    }
769}
770
771