1use std::{io, ops};
13use std::future::Future;
14use std::net::{IpAddr, SocketAddr};
15use std::pin::Pin;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
18use std::vec::Vec;
19use bytes::Bytes;
20use futures::future::FutureExt;
21#[cfg(feature = "sync")] use tokio::runtime;
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::{TcpStream, UdpSocket};
24use tokio::time::timeout;
25use domain::base::iana::Rcode;
26use domain::base::message::Message;
27use domain::base::message_builder::{
28 AdditionalBuilder, MessageBuilder, StreamTarget
29};
30use domain::base::name::{ToDname, ToRelativeDname};
31use domain::base::octets::Octets512;
32use domain::base::question::Question;
33use crate::lookup::addr::{lookup_addr, FoundAddrs};
34use crate::lookup::host::{lookup_host, search_host, FoundHosts};
35use crate::lookup::srv::{lookup_srv, FoundSrvs, SrvError};
36use crate::resolver::{Resolver, SearchNames};
37use self::conf::{
38 ResolvConf, ResolvOptions, SearchSuffix, ServerConf, Transport
39};
40
41
42pub mod conf;
45
46
47const RETRY_RANDOM_PORT: usize = 10;
51
52
53#[derive(Clone, Debug)]
75pub struct StubResolver {
76 preferred: ServerList,
78
79 stream: ServerList,
81
82 options: ResolvOptions,
84}
85
86
87impl StubResolver {
88 pub fn new() -> Self {
90 Self::from_conf(ResolvConf::default())
91 }
92
93 pub fn from_conf(conf: ResolvConf) -> Self {
95 StubResolver {
96 preferred: ServerList::from_conf(&conf, |s| {
97 s.transport.is_preferred()
98 }),
99 stream: ServerList::from_conf(&conf, |s| {
100 s.transport.is_stream()
101 }),
102 options: conf.options
103 }
104 }
105
106 pub fn options(&self) -> &ResolvOptions {
107 &self.options
108 }
109
110 pub async fn query<N: ToDname, Q: Into<Question<N>>>(
111 &self, question: Q
112 ) -> Result<Answer, io::Error> {
113 Query::new(self)?.run(
114 Query::create_message(question.into())
115 ).await
116 }
117
118 async fn query_message(
119 &self, message: QueryMessage
120 ) -> Result<Answer, io::Error> {
121 Query::new(self)?.run(message).await
122 }
123}
124
125impl StubResolver {
126 pub async fn lookup_addr(
127 &self, addr: IpAddr
128 ) -> Result<FoundAddrs<&Self>, io::Error> {
129 lookup_addr(&self, addr).await
130 }
131
132 pub async fn lookup_host(
133 &self, qname: impl ToDname
134 ) -> Result<FoundHosts<&Self>, io::Error> {
135 lookup_host(&self, qname).await
136 }
137
138 pub async fn search_host(
139 &self, qname: impl ToRelativeDname
140 ) -> Result<FoundHosts<&Self>, io::Error> {
141 search_host(&self, qname).await
142 }
143
144 pub async fn lookup_srv(
145 &self,
146 service: impl ToRelativeDname,
147 name: impl ToDname,
148 fallback_port: u16
149 ) -> Result<Option<FoundSrvs>, SrvError> {
150 lookup_srv(&self, service, name, fallback_port).await
151 }
152}
153
154#[cfg(feature = "sync")]
155impl StubResolver {
156 pub fn run<R, F>(op: F) -> R::Output
167 where
168 R: Future + Send + 'static,
169 R::Output: Send + 'static,
170 F: FnOnce(StubResolver) -> R + Send + 'static,
171 {
172 Self::run_with_conf(ResolvConf::default(), op)
173 }
174
175 pub fn run_with_conf<R, F>(
182 conf: ResolvConf,
183 op: F
184 ) -> R::Output
185 where
186 R: Future + Send + 'static,
187 R::Output: Send + 'static,
188 F: FnOnce(StubResolver) -> R + Send + 'static,
189 {
190 let resolver = Self::from_conf(conf);
191 let mut runtime = runtime::Builder::new()
192 .basic_scheduler()
193 .enable_all()
194 .build().unwrap();
195 runtime.block_on(op(resolver))
196 }
197}
198
199impl Default for StubResolver {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205impl<'a> Resolver for &'a StubResolver {
206 type Octets = Bytes;
207 type Answer = Answer;
208 type Query = Pin<Box<dyn Future<Output = Result<Answer, io::Error>> + 'a>>;
209
210 fn query<N, Q>(&self, question: Q) -> Self::Query
211 where N: ToDname, Q: Into<Question<N>> {
212 let message = Query::create_message(question.into());
213 self.query_message(message).boxed()
214 }
215}
216
217impl<'a> SearchNames for &'a StubResolver {
218 type Name = SearchSuffix;
219 type Iter = SearchIter<'a>;
220
221 fn search_iter(&self) -> Self::Iter {
222 SearchIter {
223 resolver: self,
224 pos: 0
225 }
226 }
227}
228
229
230pub struct Query<'a> {
233 resolver: &'a StubResolver,
235
236 preferred: bool,
238
239 attempt: usize,
241
242 counter: ServerListCounter,
244
245 error: Result<Answer, io::Error>,
253}
254
255impl<'a> Query<'a> {
256 pub fn new(
257 resolver: &'a StubResolver,
258 ) -> Result<Self, io::Error> {
259 let (preferred, counter) = if
260 resolver.options().use_vc ||
261 resolver.preferred.is_empty()
262 {
263 if resolver.stream.is_empty() {
264 return Err(
265 io::Error::new(
266 io::ErrorKind::NotFound,
267 "no servers available"
268 )
269 )
270 }
271 (false, resolver.stream.counter(resolver.options().rotate))
272 }
273 else {
274 (true, resolver.preferred.counter(resolver.options().rotate))
275 };
276 Ok(Query {
277 resolver,
278 preferred,
279 attempt: 0,
280 counter,
281 error: Err(io::Error::new(
282 io::ErrorKind::TimedOut,
283 "all timed out"
284 ))
285 })
286 }
287
288 pub async fn run(
289 mut self,
290 mut message: QueryMessage,
291 ) -> Result<Answer, io::Error> {
292 loop {
293 match self.run_query(&mut message).await {
294 Ok(answer) => {
295 if answer.header().rcode() == Rcode::FormErr
296 && self.current_server().does_edns()
297 {
298 self.current_server().disable_edns();
300 continue
301 }
302 else if answer.header().rcode() == Rcode::ServFail {
303 self.update_error_servfail(answer);
305 }
306 else if answer.header().tc() && self.preferred
307 && !self.resolver.options().ign_tc
308 {
309 if self.switch_to_stream() {
313 continue
314 }
315 else {
316 return Ok(answer)
317 }
318 }
319 else {
320 return Ok(answer);
322 }
323 }
324 Err(err) => self.update_error(err),
325 }
326 if !self.next_server() {
327 return self.error
328 }
329 }
330 }
331
332 fn create_message(
333 question: Question<impl ToDname>
334 ) -> QueryMessage {
335 let mut message = MessageBuilder::from_target(
336 StreamTarget::new(Octets512::new()).unwrap()
337 ).unwrap();
338 message.header_mut().set_rd(true);
339 let mut message = message.question();
340 message.push(question).unwrap();
341 message.additional()
342 }
343
344 async fn run_query(
345 &mut self, message: &mut QueryMessage
346 ) -> Result<Answer, io::Error> {
347 let server = self.current_server();
348 server.prepare_message(message);
349 server.query(message).await
350 }
351
352 fn current_server(&self) -> &ServerInfo {
353 let list = if self.preferred { &self.resolver.preferred }
354 else { &self.resolver.stream };
355 self.counter.info(list)
356 }
357
358 fn update_error(&mut self, err: io::Error) {
359 if err.kind() != io::ErrorKind::TimedOut && self.error.is_err() {
363 self.error = Err(err)
364 }
365 }
366
367 fn update_error_servfail(&mut self, answer: Answer) {
368 self.error = Ok(answer)
369 }
370
371 fn switch_to_stream(&mut self) -> bool {
372 if !self.preferred {
373 return false
375 }
376 self.preferred = false;
377 self.attempt = 0;
378 self.counter = self.resolver.stream.counter(
379 self.resolver.options().rotate
380 );
381 true
382 }
383
384 fn next_server(&mut self) -> bool {
385 if self.counter.next() {
386 return true
387 }
388 self.attempt += 1;
389 if self.attempt >= self.resolver.options().attempts {
390 return false
391 }
392 self.counter = if self.preferred {
393 self.resolver.preferred.counter(self.resolver.options().rotate)
394 }
395 else {
396 self.resolver.stream.counter(self.resolver.options().rotate)
397 };
398 true
399 }
400}
401
402
403pub(super) type QueryMessage = AdditionalBuilder<StreamTarget<Octets512>>;
407
408
409#[derive(Clone)]
416pub struct Answer {
417 message: Message<Bytes>,
418}
419
420impl Answer {
421 pub fn is_final(&self) -> bool {
423 (self.message.header().rcode() == Rcode::NoError
424 || self.message.header().rcode() == Rcode::NXDomain)
425 && !self.message.header().tc()
426 }
427
428 pub fn is_truncated(&self) -> bool {
430 self.message.header().tc()
431 }
432
433 pub fn into_message(self) -> Message<Bytes> {
434 self.message
435 }
436}
437
438impl From<Message<Bytes>> for Answer {
439 fn from(message: Message<Bytes>) -> Self {
440 Answer { message }
441 }
442}
443
444
445#[derive(Clone, Debug)]
448struct ServerInfo {
449 conf: ServerConf,
451
452 edns: Arc<AtomicBool>,
456}
457
458impl ServerInfo {
459 pub fn does_edns(&self) -> bool {
460 self.edns.load(Ordering::Relaxed)
461 }
462
463 pub fn disable_edns(&self) {
464 self.edns.store(false, Ordering::Relaxed);
465 }
466
467 pub fn prepare_message(&self, query: &mut QueryMessage) {
468 query.rewind();
469 if self.does_edns() {
470 query.opt(|opt| {
471 opt.set_udp_payload_size(self.conf.udp_payload_size);
472 Ok(())
473 }).unwrap();
474 }
475 }
476
477 pub async fn query(
478 &self, query: &QueryMessage
479 ) -> Result<Answer, io::Error> {
480 let res = match self.conf.transport {
481 Transport::Udp => {
482 timeout(
483 self.conf.request_timeout,
484 Self::udp_query(query, self.conf.addr, self.conf.recv_size)
485 ).await
486 }
487 Transport::Tcp => {
488 timeout(
489 self.conf.request_timeout,
490 Self::tcp_query(query, self.conf.addr)
491 ).await
492 }
493 };
494 match res {
495 Ok(Ok(answer)) => Ok(answer),
496 Ok(Err(err)) => Err(err),
497 Err(_) => {
498 Err(io::Error::new(
499 io::ErrorKind::TimedOut,
500 "request timed out"
501 ))
502 }
503 }
504 }
505
506 pub async fn tcp_query(
507 query: &QueryMessage, addr: SocketAddr
508 ) -> Result<Answer, io::Error> {
509 let mut sock = TcpStream::connect(&addr).await?;
510 sock.write_all(query.as_target().as_stream_slice()).await?;
511
512 loop {
515 let mut buf = Vec::new();
516 let len = sock.read_u16().await? as u64;
517 AsyncReadExt::take(&mut sock, len).read_to_end(&mut buf).await?;
518 if let Ok(answer) = Message::from_octets(buf.into()) {
519 if answer.is_answer(&query.as_message()) {
520 return Ok(answer.into())
521 }
522 }
524 else {
525 return Err(io::Error::new(io::ErrorKind::Other, "short buf"))
526 }
527 }
528 }
529
530 pub async fn udp_query(
531 query: &QueryMessage, addr: SocketAddr, recv_size: usize
532 ) -> Result<Answer, io::Error> {
533 let mut sock = Self::udp_bind(addr.is_ipv4()).await?;
534 sock.connect(addr).await?;
535 let sent = sock.send(query.as_target().as_dgram_slice()).await?;
536 if sent != query.as_target().as_dgram_slice().len() {
537 return Err(io::Error::new(io::ErrorKind::Other, "short UDP send"))
538 }
539 loop {
540 let mut buf = vec![0; recv_size]; let len = sock.recv(&mut buf).await?;
542 buf.truncate(len);
543
544 let answer = match Message::from_octets(buf.into()) {
546 Ok(answer) => answer,
547 Err(_) => continue,
548 };
549 if !answer.is_answer(&query.as_message()) {
550 continue
551 }
552 return Ok(answer.into())
553 }
554 }
555
556 async fn udp_bind(v4: bool) -> Result<UdpSocket, io::Error> {
557 let mut i = 0;
558 loop {
559 let local: SocketAddr = if v4 { ([0u8; 4], 0).into() }
560 else { ([0u16; 8], 0).into() };
561 match UdpSocket::bind(&local).await {
562 Ok(sock) => return Ok(sock),
563 Err(err) => {
564 if i == RETRY_RANDOM_PORT {
565 return Err(err);
566 }
567 else {
568 i += 1
569 }
570 }
571 }
572 }
573 }
574}
575
576impl From<ServerConf> for ServerInfo {
577 fn from(conf: ServerConf) -> Self {
578 ServerInfo {
579 conf,
580 edns: Arc::new(AtomicBool::new(true))
581 }
582 }
583}
584
585impl<'a> From<&'a ServerConf> for ServerInfo {
586 fn from(conf: &'a ServerConf) -> Self {
587 conf.clone().into()
588 }
589}
590
591
592#[derive(Clone, Debug)]
595struct ServerList {
596 servers: Vec<ServerInfo>,
598
599 start: Arc<AtomicUsize>,
608}
609
610impl ServerList {
611 pub fn from_conf<F>(conf: &ResolvConf, filter: F) -> Self
612 where F: Fn(&ServerConf) -> bool {
613 ServerList {
614 servers: {
615 conf.servers.iter().filter(|f| filter(*f))
616 .map(Into::into).collect()
617 },
618 start: Arc::new(AtomicUsize::new(0)),
619 }
620 }
621
622 pub fn is_empty(&self) -> bool {
623 self.servers.is_empty()
624 }
625
626 pub fn counter(&self, rotate: bool) -> ServerListCounter {
627 let res = ServerListCounter::new(self);
628 if rotate {
629 self.rotate()
630 }
631 res
632 }
633
634 pub fn iter(&self) -> ServerListIter {
635 ServerListIter::new(self)
636 }
637
638 pub fn rotate(&self) {
639 self.start.fetch_add(1, Ordering::SeqCst);
640 }
641}
642
643impl<'a> IntoIterator for &'a ServerList {
644 type Item = &'a ServerInfo;
645 type IntoIter = ServerListIter<'a>;
646
647 fn into_iter(self) -> Self::IntoIter {
648 self.iter()
649 }
650}
651
652impl ops::Deref for ServerList {
653 type Target = [ServerInfo];
654
655 fn deref(&self) -> &Self::Target {
656 self.servers.as_ref()
657 }
658}
659
660
661#[derive(Clone, Debug)]
664struct ServerListCounter {
665 cur: usize,
666 end: usize,
667}
668
669impl ServerListCounter {
670 fn new(list: &ServerList) -> Self {
671 if list.servers.is_empty() {
672 return ServerListCounter { cur: 0, end: 0 };
673 }
674
675 let start = list.start.load(Ordering::Relaxed) % list.servers.len();
678 ServerListCounter {
679 cur: start,
680 end: start + list.servers.len(),
681 }
682 }
683
684 #[allow(clippy::should_implement_trait)]
685 pub fn next(&mut self) -> bool {
686 let next = self.cur + 1;
687 if next < self.end {
688 self.cur = next;
689 true
690 }
691 else {
692 false
693 }
694 }
695
696 pub fn info<'a>(&self, list: &'a ServerList) -> &'a ServerInfo {
697 &list[self.cur % list.servers.len()]
698 }
699}
700
701
702
703#[derive(Clone, Debug)]
706struct ServerListIter<'a> {
707 servers: &'a ServerList,
708 counter: ServerListCounter,
709}
710
711impl<'a> ServerListIter<'a> {
712 fn new(list: &'a ServerList) -> Self {
713 ServerListIter {
714 servers: list,
715 counter: ServerListCounter::new(list)
716 }
717 }
718}
719
720impl<'a> Iterator for ServerListIter<'a> {
721 type Item = &'a ServerInfo;
722
723 fn next(&mut self) -> Option<Self::Item> {
724 if self.counter.next() {
725 Some(self.counter.info(self.servers))
726 }
727 else {
728 None
729 }
730 }
731}
732
733
734impl ops::Deref for Answer {
735 type Target = Message<Bytes>;
736
737 fn deref(&self) -> &Self::Target {
738 &self.message
739 }
740}
741
742impl AsRef<Message<Bytes>> for Answer {
743 fn as_ref(&self) -> &Message<Bytes> {
744 &self.message
745 }
746}
747
748
749#[derive(Clone, Debug)]
752pub struct SearchIter<'a> {
753 resolver: &'a StubResolver,
754 pos: usize,
755}
756
757impl<'a> Iterator for SearchIter<'a> {
758 type Item = SearchSuffix;
759
760 fn next(&mut self) -> Option<Self::Item> {
761 if let Some(res) = self.resolver.options().search.get(self.pos) {
762 self.pos += 1;
763 Some(res.clone())
764 }
765 else {
766 None
767 }
768 }
769}
770
771