1use std::future::Future;
13use std::net::IpAddr;
14use std::pin::Pin;
15use std::slice;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use std::time::Instant;
19
20use futures_util::{
21 FutureExt,
22 future::{self, BoxFuture},
23};
24use tracing::debug;
25
26use crate::cache::MAX_TTL;
27use crate::caching_client::CachingClient;
28use crate::config::LookupIpStrategy;
29use crate::hosts::Hosts;
30use crate::lookup::Lookup;
31use crate::net::NetError;
32use crate::net::xfer::DnsHandle;
33use crate::proto::op::{DnsRequestOptions, Query};
34use crate::proto::rr::{Name, RData, Record, RecordType};
35
36#[derive(Debug, Clone)]
40pub struct LookupIp(Lookup);
41
42impl LookupIp {
43 pub fn iter(&self) -> LookupIpIter<'_> {
47 LookupIpIter(self.0.answers().iter())
48 }
49
50 pub fn query(&self) -> &Query {
52 self.0.query()
53 }
54
55 pub fn valid_until(&self) -> Instant {
57 self.0.valid_until()
58 }
59
60 pub fn as_lookup(&self) -> &Lookup {
64 &self.0
65 }
66}
67
68impl From<Lookup> for LookupIp {
69 fn from(lookup: Lookup) -> Self {
70 Self(lookup)
71 }
72}
73
74impl From<LookupIp> for Lookup {
75 fn from(lookup: LookupIp) -> Self {
76 lookup.0
77 }
78}
79
80pub struct LookupIpIter<'a>(slice::Iter<'a, Record>);
82
83impl Iterator for LookupIpIter<'_> {
84 type Item = IpAddr;
85
86 fn next(&mut self) -> Option<Self::Item> {
87 self.0.find_map(|record| match record.data {
88 RData::A(ip) => Some(IpAddr::from(*ip)),
89 RData::AAAA(ip) => Some(IpAddr::from(*ip)),
90 _ => None,
91 })
92 }
93}
94
95pub struct LookupIpFuture<C: DnsHandle + 'static> {
99 client_cache: CachingClient<C>,
100 names: Vec<Name>,
101 strategy: LookupIpStrategy,
102 options: DnsRequestOptions,
103 query: BoxFuture<'static, Result<Lookup, NetError>>,
104 hosts: Arc<Hosts>,
105 finally_ip_addr: Option<RData>,
106}
107
108impl<C: DnsHandle + 'static> LookupIpFuture<C> {
109 pub fn lookup(
117 names: Vec<Name>,
118 strategy: LookupIpStrategy,
119 client_cache: CachingClient<C>,
120 options: DnsRequestOptions,
121 hosts: Arc<Hosts>,
122 finally_ip_addr: Option<RData>,
123 ) -> Self {
124 Self {
125 names,
126 strategy,
127 client_cache,
128 query: future::err("can not lookup IPs for no names".into()).boxed(),
131 options,
132 hosts,
133 finally_ip_addr,
134 }
135 }
136}
137
138impl<C: DnsHandle + 'static> Future for LookupIpFuture<C> {
139 type Output = Result<LookupIp, NetError>;
140
141 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
142 loop {
143 let query = self.query.as_mut().poll(cx);
145
146 let should_retry = match &query {
148 Poll::Pending => return Poll::Pending,
150 Poll::Ready(Ok(lookup)) => lookup.answers().is_empty(),
154 Poll::Ready(Err(_)) => true,
156 };
157
158 if !should_retry {
159 return query.map(|f| f.map(LookupIp::from));
163 }
164
165 if let Some(name) = self.names.pop() {
166 self.query = LookupContext {
169 client: self.client_cache.clone(),
170 options: self.options,
171 hosts: self.hosts.clone(),
172 }
173 .strategic_lookup(name, self.strategy)
174 .boxed();
175 continue;
178 } else if let Some(ip_addr) = self.finally_ip_addr.take() {
179 let record = Record::from_rdata(Name::new(), MAX_TTL, ip_addr);
182 let lookup = Lookup::new_with_max_ttl(Query::new(), [record]);
183 return Poll::Ready(Ok(lookup.into()));
184 }
185
186 return query.map(|f| f.map(LookupIp::from));
191 }
192 }
193}
194
195#[derive(Clone)]
196struct LookupContext<C: DnsHandle> {
197 client: CachingClient<C>,
198 options: DnsRequestOptions,
199 hosts: Arc<Hosts>,
200}
201
202impl<C: DnsHandle> LookupContext<C> {
203 async fn strategic_lookup(
205 self,
206 name: Name,
207 strategy: LookupIpStrategy,
208 ) -> Result<Lookup, NetError> {
209 match strategy {
210 LookupIpStrategy::Ipv4Only => self.ipv4_only(name).await,
211 LookupIpStrategy::Ipv6Only => self.ipv6_only(name).await,
212 LookupIpStrategy::Ipv4AndIpv6 => self.ipv4_and_ipv6(name).await,
213 LookupIpStrategy::Ipv6AndIpv4 => self.ipv6_and_ipv4(name).await,
214 LookupIpStrategy::Ipv6thenIpv4 => self.ipv6_then_ipv4(name).await,
215 LookupIpStrategy::Ipv4thenIpv6 => self.ipv4_then_ipv6(name).await,
216 }
217 }
218
219 async fn ipv4_only(&self, name: Name) -> Result<Lookup, NetError> {
221 self.hosts_lookup(Query::query(name, RecordType::A)).await
222 }
223
224 async fn ipv6_only(&self, name: Name) -> Result<Lookup, NetError> {
226 self.hosts_lookup(Query::query(name, RecordType::AAAA))
227 .await
228 }
229
230 async fn ipv4_and_ipv6(&self, name: Name) -> Result<Lookup, NetError> {
233 self.multi_lookup(name, RecordType::A, RecordType::AAAA)
234 .await
235 }
236
237 async fn ipv6_and_ipv4(&self, name: Name) -> Result<Lookup, NetError> {
240 self.multi_lookup(name, RecordType::AAAA, RecordType::A)
241 .await
242 }
243
244 async fn multi_lookup(
246 &self,
247 name: Name,
248 first_type: RecordType,
249 second_type: RecordType,
250 ) -> Result<Lookup, NetError> {
251 let joined_res = future::join(
252 self.hosts_lookup(Query::query(name.clone(), first_type)),
253 self.hosts_lookup(Query::query(name, second_type)),
254 )
255 .await;
256
257 match joined_res {
258 (Ok(first), Ok(second)) => {
259 let ips = first.append(second);
261 Ok(ips)
262 }
263 (Ok(ips), Err(e)) | (Err(e), Ok(ips)) => {
264 debug!("one of ipv4 or ipv6 lookup failed: {e}");
265 Ok(ips)
266 }
267 (Err(e1), Err(e2)) => {
268 debug!("both of ipv4 or ipv6 lookup failed e1: {e1}, e2: {e2}");
269 Err(e1)
270 }
271 }
272 }
273
274 async fn ipv6_then_ipv4(&self, name: Name) -> Result<Lookup, NetError> {
276 self.rt_then_swap(name, RecordType::AAAA, RecordType::A)
277 .await
278 }
279
280 async fn ipv4_then_ipv6(&self, name: Name) -> Result<Lookup, NetError> {
282 self.rt_then_swap(name, RecordType::A, RecordType::AAAA)
283 .await
284 }
285
286 async fn rt_then_swap(
288 &self,
289 name: Name,
290 first_type: RecordType,
291 second_type: RecordType,
292 ) -> Result<Lookup, NetError> {
293 let res = self
294 .hosts_lookup(Query::query(name.clone(), first_type))
295 .await;
296
297 match res {
298 Ok(ips) if !ips.answers().is_empty() => Ok(ips),
299 _ => self.hosts_lookup(Query::query(name, second_type)).await,
301 }
302 }
303
304 async fn hosts_lookup(&self, query: Query) -> Result<Lookup, NetError> {
306 match self.hosts.lookup_static_host(&query) {
307 Some(lookup) => Ok(lookup),
308 None => self.client.lookup(query, self.options).await,
309 }
310 }
311}
312
313#[cfg(test)]
314pub(crate) mod tests {
315 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
316 use std::sync::{Arc, Mutex};
317
318 use futures_executor::block_on;
319 use futures_util::future;
320 use futures_util::stream::{Stream, once};
321 use test_support::subscribe;
322
323 use super::*;
324 use crate::net::runtime::TokioRuntimeProvider;
325 use crate::net::xfer::DnsHandle;
326 use crate::proto::op::{DnsRequest, DnsResponse, Message};
327 use crate::proto::rr::{Name, RData, Record};
328
329 #[derive(Clone)]
330 pub(crate) struct MockDnsHandle {
331 messages: Arc<Mutex<Vec<Result<DnsResponse, NetError>>>>,
332 }
333
334 impl DnsHandle for MockDnsHandle {
335 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send + Unpin>>;
336 type Runtime = TokioRuntimeProvider;
337
338 fn send(&self, _: DnsRequest) -> Self::Response {
339 Box::pin(once(future::ready(
340 self.messages.lock().unwrap().pop().unwrap_or_else(empty),
341 )))
342 }
343 }
344
345 pub(crate) fn v4_message() -> Result<DnsResponse, NetError> {
346 let mut message = Message::query();
347 message.add_query(Query::query(Name::root(), RecordType::A));
348 message.insert_answers(vec![Record::from_rdata(
349 Name::root(),
350 86400,
351 RData::A(Ipv4Addr::LOCALHOST.into()),
352 )]);
353
354 let resp = DnsResponse::from_message(message.into_response()).unwrap();
355 assert!(resp.contains_answer());
356 Ok(resp)
357 }
358
359 pub(crate) fn v6_message() -> Result<DnsResponse, NetError> {
360 let mut message = Message::query();
361 message.add_query(Query::query(Name::root(), RecordType::AAAA));
362 message.insert_answers(vec![Record::from_rdata(
363 Name::root(),
364 86400,
365 RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()),
366 )]);
367
368 let resp = DnsResponse::from_message(message.into_response()).unwrap();
369 assert!(resp.contains_answer());
370 Ok(resp)
371 }
372
373 pub(crate) fn empty() -> Result<DnsResponse, NetError> {
374 Ok(DnsResponse::from_message(Message::query().into_response()).unwrap())
375 }
376
377 pub(crate) fn error() -> Result<DnsResponse, NetError> {
378 Err(NetError::from("forced test failure"))
379 }
380
381 pub(crate) fn mock(messages: Vec<Result<DnsResponse, NetError>>) -> MockDnsHandle {
382 MockDnsHandle {
383 messages: Arc::new(Mutex::new(messages)),
384 }
385 }
386
387 #[test]
388 fn test_ipv4_only_strategy() {
389 subscribe();
390
391 let cx = LookupContext {
392 client: CachingClient::new(0, mock(vec![v4_message()]), false),
393 options: DnsRequestOptions::default(),
394 hosts: Arc::new(Hosts::default()),
395 };
396
397 assert_eq!(
398 block_on(cx.ipv4_only(Name::root()))
399 .unwrap()
400 .answers()
401 .iter()
402 .map(|r| r.data.ip_addr().unwrap())
403 .collect::<Vec<IpAddr>>(),
404 vec![Ipv4Addr::LOCALHOST]
405 );
406 }
407
408 #[test]
409 fn test_ipv6_only_strategy() {
410 subscribe();
411
412 let cx = LookupContext {
413 client: CachingClient::new(0, mock(vec![v6_message()]), false),
414 options: DnsRequestOptions::default(),
415 hosts: Arc::new(Hosts::default()),
416 };
417
418 assert_eq!(
419 block_on(cx.ipv6_only(Name::root()))
420 .unwrap()
421 .answers()
422 .iter()
423 .map(|r| r.data.ip_addr().unwrap())
424 .collect::<Vec<IpAddr>>(),
425 vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
426 );
427 }
428
429 #[test]
430 fn test_ipv4_and_ipv6_strategy() {
431 subscribe();
432
433 let mut cx = LookupContext {
434 client: CachingClient::new(0, mock(vec![v6_message(), v4_message()]), false),
435 options: DnsRequestOptions::default(),
436 hosts: Arc::new(Hosts::default()),
437 };
438
439 assert_eq!(
442 block_on(cx.ipv4_and_ipv6(Name::root()))
443 .unwrap()
444 .answers()
445 .iter()
446 .map(|r| r.data.ip_addr().unwrap())
447 .collect::<Vec<IpAddr>>(),
448 vec![
449 IpAddr::V4(Ipv4Addr::LOCALHOST),
450 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
451 ]
452 );
453
454 cx.client = CachingClient::new(0, mock(vec![empty(), v4_message()]), false);
456 assert_eq!(
457 block_on(cx.ipv4_and_ipv6(Name::root()))
458 .unwrap()
459 .answers()
460 .iter()
461 .map(|r| r.data.ip_addr().unwrap())
462 .collect::<Vec<IpAddr>>(),
463 vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]
464 );
465
466 cx.client = CachingClient::new(0, mock(vec![error(), v4_message()]), false);
468 assert_eq!(
469 block_on(cx.ipv4_and_ipv6(Name::root()))
470 .unwrap()
471 .answers()
472 .iter()
473 .map(|r| r.data.ip_addr().unwrap())
474 .collect::<Vec<IpAddr>>(),
475 vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]
476 );
477
478 cx.client = CachingClient::new(0, mock(vec![v6_message(), empty()]), false);
480 assert_eq!(
481 block_on(cx.ipv4_and_ipv6(Name::root()))
482 .unwrap()
483 .answers()
484 .iter()
485 .map(|r| r.data.ip_addr().unwrap())
486 .collect::<Vec<IpAddr>>(),
487 vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
488 );
489
490 cx.client = CachingClient::new(0, mock(vec![v6_message(), error()]), false);
492 assert_eq!(
493 block_on(cx.ipv4_and_ipv6(Name::root()))
494 .unwrap()
495 .answers()
496 .iter()
497 .map(|r| r.data.ip_addr().unwrap())
498 .collect::<Vec<IpAddr>>(),
499 vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
500 );
501 }
502
503 #[test]
504 fn test_ipv6_and_ipv4_strategy() {
505 subscribe();
506
507 let mut cx = LookupContext {
508 client: CachingClient::new(0, mock(vec![v4_message(), v6_message()]), false),
509 options: DnsRequestOptions::default(),
510 hosts: Arc::new(Hosts::default()),
511 };
512
513 assert_eq!(
516 block_on(cx.ipv6_and_ipv4(Name::root()))
517 .unwrap()
518 .answers()
519 .iter()
520 .map(|r| r.data.ip_addr().unwrap())
521 .collect::<Vec<IpAddr>>(),
522 vec![
523 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
524 IpAddr::V4(Ipv4Addr::LOCALHOST),
525 ]
526 );
527
528 cx.client = CachingClient::new(0, mock(vec![v4_message(), empty()]), false);
530 assert_eq!(
531 block_on(cx.ipv6_and_ipv4(Name::root()))
532 .unwrap()
533 .answers()
534 .iter()
535 .map(|r| r.data.ip_addr().unwrap())
536 .collect::<Vec<IpAddr>>(),
537 vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]
538 );
539
540 cx.client = CachingClient::new(0, mock(vec![v4_message(), error()]), false);
542 assert_eq!(
543 block_on(cx.ipv6_and_ipv4(Name::root()))
544 .unwrap()
545 .answers()
546 .iter()
547 .map(|r| r.data.ip_addr().unwrap())
548 .collect::<Vec<IpAddr>>(),
549 vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]
550 );
551
552 cx.client = CachingClient::new(0, mock(vec![empty(), v6_message()]), false);
554 assert_eq!(
555 block_on(cx.ipv6_and_ipv4(Name::root()))
556 .unwrap()
557 .answers()
558 .iter()
559 .map(|r| r.data.ip_addr().unwrap())
560 .collect::<Vec<IpAddr>>(),
561 vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
562 );
563
564 cx.client = CachingClient::new(0, mock(vec![error(), v6_message()]), false);
566 assert_eq!(
567 block_on(cx.ipv6_and_ipv4(Name::root()))
568 .unwrap()
569 .answers()
570 .iter()
571 .map(|r| r.data.ip_addr().unwrap())
572 .collect::<Vec<IpAddr>>(),
573 vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
574 );
575 }
576
577 #[test]
578 fn test_ipv6_then_ipv4_strategy() {
579 subscribe();
580
581 let mut cx = LookupContext {
582 client: CachingClient::new(0, mock(vec![v6_message()]), false),
583 options: DnsRequestOptions::default(),
584 hosts: Arc::new(Hosts::default()),
585 };
586
587 assert_eq!(
589 block_on(cx.ipv6_then_ipv4(Name::root()))
590 .unwrap()
591 .answers()
592 .iter()
593 .map(|r| r.data.ip_addr().unwrap())
594 .collect::<Vec<IpAddr>>(),
595 vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
596 );
597
598 cx.client = CachingClient::new(0, mock(vec![v4_message(), empty()]), false);
600 assert_eq!(
601 block_on(cx.ipv6_then_ipv4(Name::root()))
602 .unwrap()
603 .answers()
604 .iter()
605 .map(|r| r.data.ip_addr().unwrap())
606 .collect::<Vec<IpAddr>>(),
607 vec![Ipv4Addr::LOCALHOST]
608 );
609
610 cx.client = CachingClient::new(0, mock(vec![v4_message(), error()]), false);
612 assert_eq!(
613 block_on(cx.ipv6_then_ipv4(Name::root()))
614 .unwrap()
615 .answers()
616 .iter()
617 .map(|r| r.data.ip_addr().unwrap())
618 .collect::<Vec<IpAddr>>(),
619 vec![Ipv4Addr::LOCALHOST]
620 );
621 }
622
623 #[test]
624 fn test_ipv4_then_ipv6_strategy() {
625 subscribe();
626
627 let mut cx = LookupContext {
628 client: CachingClient::new(0, mock(vec![v4_message()]), false),
629 options: DnsRequestOptions::default(),
630 hosts: Arc::new(Hosts::default()),
631 };
632
633 assert_eq!(
635 block_on(cx.ipv4_then_ipv6(Name::root()))
636 .unwrap()
637 .answers()
638 .iter()
639 .map(|r| r.data.ip_addr().unwrap())
640 .collect::<Vec<IpAddr>>(),
641 vec![Ipv4Addr::LOCALHOST]
642 );
643
644 cx.client = CachingClient::new(0, mock(vec![v6_message(), empty()]), false);
646 assert_eq!(
647 block_on(cx.ipv4_then_ipv6(Name::root()))
648 .unwrap()
649 .answers()
650 .iter()
651 .map(|r| r.data.ip_addr().unwrap())
652 .collect::<Vec<IpAddr>>(),
653 vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
654 );
655
656 cx.client = CachingClient::new(0, mock(vec![v6_message(), error()]), false);
658 assert_eq!(
659 block_on(cx.ipv4_then_ipv6(Name::root()))
660 .unwrap()
661 .answers()
662 .iter()
663 .map(|r| r.data.ip_addr().unwrap())
664 .collect::<Vec<IpAddr>>(),
665 vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
666 );
667 }
668}