use std::future::Future;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::pin::Pin;
use std::slice;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use futures_util::{
FutureExt,
future::{self, BoxFuture},
};
use tracing::debug;
use crate::cache::MAX_TTL;
use crate::caching_client::CachingClient;
use crate::config::LookupIpStrategy;
use crate::hosts::Hosts;
use crate::lookup::Lookup;
use crate::net::NetError;
use crate::net::xfer::DnsHandle;
use crate::proto::op::{DnsRequestOptions, Query};
use crate::proto::rr::{Name, RData, Record, RecordType};
#[derive(Debug, Clone)]
pub struct LookupIp(Lookup);
impl LookupIp {
pub fn iter(&self) -> LookupIpIter<'_> {
LookupIpIter(self.0.answers().iter())
}
pub fn query(&self) -> &Query {
self.0.query()
}
pub fn valid_until(&self) -> Instant {
self.0.valid_until()
}
pub fn as_lookup(&self) -> &Lookup {
&self.0
}
}
impl From<Lookup> for LookupIp {
fn from(lookup: Lookup) -> Self {
Self(lookup)
}
}
impl From<LookupIp> for Lookup {
fn from(lookup: LookupIp) -> Self {
lookup.0
}
}
pub struct LookupIpIter<'a>(slice::Iter<'a, Record>);
impl Iterator for LookupIpIter<'_> {
type Item = IpAddr;
fn next(&mut self) -> Option<Self::Item> {
self.0.find_map(|record| match record.data() {
RData::A(ip) => Some(IpAddr::from(Ipv4Addr::from(*ip))),
RData::AAAA(ip) => Some(IpAddr::from(Ipv6Addr::from(*ip))),
_ => None,
})
}
}
pub struct LookupIpFuture<C: DnsHandle + 'static> {
client_cache: CachingClient<C>,
names: Vec<Name>,
strategy: LookupIpStrategy,
options: DnsRequestOptions,
query: BoxFuture<'static, Result<Lookup, NetError>>,
hosts: Arc<Hosts>,
finally_ip_addr: Option<RData>,
}
impl<C: DnsHandle + 'static> LookupIpFuture<C> {
pub fn lookup(
names: Vec<Name>,
strategy: LookupIpStrategy,
client_cache: CachingClient<C>,
options: DnsRequestOptions,
hosts: Arc<Hosts>,
finally_ip_addr: Option<RData>,
) -> Self {
Self {
names,
strategy,
client_cache,
query: future::err("can not lookup IPs for no names".into()).boxed(),
options,
hosts,
finally_ip_addr,
}
}
}
impl<C: DnsHandle + 'static> Future for LookupIpFuture<C> {
type Output = Result<LookupIp, NetError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let query = self.query.as_mut().poll(cx);
let should_retry = match &query {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(lookup)) => lookup.answers().is_empty(),
Poll::Ready(Err(_)) => true,
};
if !should_retry {
return query.map(|f| f.map(LookupIp::from));
}
if let Some(name) = self.names.pop() {
self.query = LookupContext {
client: self.client_cache.clone(),
options: self.options,
hosts: self.hosts.clone(),
}
.strategic_lookup(name, self.strategy)
.boxed();
continue;
} else if let Some(ip_addr) = self.finally_ip_addr.take() {
let record = Record::from_rdata(Name::new(), MAX_TTL, ip_addr);
let lookup = Lookup::new_with_max_ttl(Query::new(), [record]);
return Poll::Ready(Ok(lookup.into()));
}
return query.map(|f| f.map(LookupIp::from));
}
}
}
#[derive(Clone)]
struct LookupContext<C: DnsHandle> {
client: CachingClient<C>,
options: DnsRequestOptions,
hosts: Arc<Hosts>,
}
impl<C: DnsHandle> LookupContext<C> {
async fn strategic_lookup(
self,
name: Name,
strategy: LookupIpStrategy,
) -> Result<Lookup, NetError> {
match strategy {
LookupIpStrategy::Ipv4Only => self.ipv4_only(name).await,
LookupIpStrategy::Ipv6Only => self.ipv6_only(name).await,
LookupIpStrategy::Ipv4AndIpv6 => self.ipv4_and_ipv6(name).await,
LookupIpStrategy::Ipv6AndIpv4 => self.ipv6_and_ipv4(name).await,
LookupIpStrategy::Ipv6thenIpv4 => self.ipv6_then_ipv4(name).await,
LookupIpStrategy::Ipv4thenIpv6 => self.ipv4_then_ipv6(name).await,
}
}
async fn ipv4_only(&self, name: Name) -> Result<Lookup, NetError> {
self.hosts_lookup(Query::query(name, RecordType::A)).await
}
async fn ipv6_only(&self, name: Name) -> Result<Lookup, NetError> {
self.hosts_lookup(Query::query(name, RecordType::AAAA))
.await
}
async fn ipv4_and_ipv6(&self, name: Name) -> Result<Lookup, NetError> {
self.multi_lookup(name, RecordType::A, RecordType::AAAA)
.await
}
async fn ipv6_and_ipv4(&self, name: Name) -> Result<Lookup, NetError> {
self.multi_lookup(name, RecordType::AAAA, RecordType::A)
.await
}
async fn multi_lookup(
&self,
name: Name,
first_type: RecordType,
second_type: RecordType,
) -> Result<Lookup, NetError> {
let joined_res = future::join(
self.hosts_lookup(Query::query(name.clone(), first_type)),
self.hosts_lookup(Query::query(name, second_type)),
)
.await;
match joined_res {
(Ok(first), Ok(second)) => {
let ips = first.append(second);
Ok(ips)
}
(Ok(ips), Err(e)) | (Err(e), Ok(ips)) => {
debug!("one of ipv4 or ipv6 lookup failed: {e}");
Ok(ips)
}
(Err(e1), Err(e2)) => {
debug!("both of ipv4 or ipv6 lookup failed e1: {e1}, e2: {e2}");
Err(e1)
}
}
}
async fn ipv6_then_ipv4(&self, name: Name) -> Result<Lookup, NetError> {
self.rt_then_swap(name, RecordType::AAAA, RecordType::A)
.await
}
async fn ipv4_then_ipv6(&self, name: Name) -> Result<Lookup, NetError> {
self.rt_then_swap(name, RecordType::A, RecordType::AAAA)
.await
}
async fn rt_then_swap(
&self,
name: Name,
first_type: RecordType,
second_type: RecordType,
) -> Result<Lookup, NetError> {
let res = self
.hosts_lookup(Query::query(name.clone(), first_type))
.await;
match res {
Ok(ips) if !ips.answers().is_empty() => Ok(ips),
_ => self.hosts_lookup(Query::query(name, second_type)).await,
}
}
async fn hosts_lookup(&self, query: Query) -> Result<Lookup, NetError> {
match self.hosts.lookup_static_host(&query) {
Some(lookup) => Ok(lookup),
None => self.client.lookup(query, self.options).await,
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::{Arc, Mutex};
use futures_executor::block_on;
use futures_util::future;
use futures_util::stream::{Stream, once};
use test_support::subscribe;
use super::*;
use crate::net::runtime::TokioRuntimeProvider;
use crate::net::xfer::DnsHandle;
use crate::proto::op::{DnsRequest, DnsResponse, Message};
use crate::proto::rr::{Name, RData, Record};
#[derive(Clone)]
pub(crate) struct MockDnsHandle {
messages: Arc<Mutex<Vec<Result<DnsResponse, NetError>>>>,
}
impl DnsHandle for MockDnsHandle {
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, NetError>> + Send + Unpin>>;
type Runtime = TokioRuntimeProvider;
fn send(&self, _: DnsRequest) -> Self::Response {
Box::pin(once(future::ready(
self.messages.lock().unwrap().pop().unwrap_or_else(empty),
)))
}
}
pub(crate) fn v4_message() -> Result<DnsResponse, NetError> {
let mut message = Message::query();
message.add_query(Query::query(Name::root(), RecordType::A));
message.insert_answers(vec![Record::from_rdata(
Name::root(),
86400,
RData::A(Ipv4Addr::LOCALHOST.into()),
)]);
let resp = DnsResponse::from_message(message.into_response()).unwrap();
assert!(resp.contains_answer());
Ok(resp)
}
pub(crate) fn v6_message() -> Result<DnsResponse, NetError> {
let mut message = Message::query();
message.add_query(Query::query(Name::root(), RecordType::AAAA));
message.insert_answers(vec![Record::from_rdata(
Name::root(),
86400,
RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()),
)]);
let resp = DnsResponse::from_message(message.into_response()).unwrap();
assert!(resp.contains_answer());
Ok(resp)
}
pub(crate) fn empty() -> Result<DnsResponse, NetError> {
Ok(DnsResponse::from_message(Message::query().into_response()).unwrap())
}
pub(crate) fn error() -> Result<DnsResponse, NetError> {
Err(NetError::from("forced test failure"))
}
pub(crate) fn mock(messages: Vec<Result<DnsResponse, NetError>>) -> MockDnsHandle {
MockDnsHandle {
messages: Arc::new(Mutex::new(messages)),
}
}
#[test]
fn test_ipv4_only_strategy() {
subscribe();
let cx = LookupContext {
client: CachingClient::new(0, mock(vec![v4_message()]), false),
options: DnsRequestOptions::default(),
hosts: Arc::new(Hosts::default()),
};
assert_eq!(
block_on(cx.ipv4_only(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv4Addr::LOCALHOST]
);
}
#[test]
fn test_ipv6_only_strategy() {
subscribe();
let cx = LookupContext {
client: CachingClient::new(0, mock(vec![v6_message()]), false),
options: DnsRequestOptions::default(),
hosts: Arc::new(Hosts::default()),
};
assert_eq!(
block_on(cx.ipv6_only(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
);
}
#[test]
fn test_ipv4_and_ipv6_strategy() {
subscribe();
let mut cx = LookupContext {
client: CachingClient::new(0, mock(vec![v6_message(), v4_message()]), false),
options: DnsRequestOptions::default(),
hosts: Arc::new(Hosts::default()),
};
assert_eq!(
block_on(cx.ipv4_and_ipv6(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![
IpAddr::V4(Ipv4Addr::LOCALHOST),
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
]
);
cx.client = CachingClient::new(0, mock(vec![empty(), v4_message()]), false);
assert_eq!(
block_on(cx.ipv4_and_ipv6(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]
);
cx.client = CachingClient::new(0, mock(vec![error(), v4_message()]), false);
assert_eq!(
block_on(cx.ipv4_and_ipv6(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]
);
cx.client = CachingClient::new(0, mock(vec![v6_message(), empty()]), false);
assert_eq!(
block_on(cx.ipv4_and_ipv6(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
);
cx.client = CachingClient::new(0, mock(vec![v6_message(), error()]), false);
assert_eq!(
block_on(cx.ipv4_and_ipv6(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
);
}
#[test]
fn test_ipv6_and_ipv4_strategy() {
subscribe();
let mut cx = LookupContext {
client: CachingClient::new(0, mock(vec![v4_message(), v6_message()]), false),
options: DnsRequestOptions::default(),
hosts: Arc::new(Hosts::default()),
};
assert_eq!(
block_on(cx.ipv6_and_ipv4(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::LOCALHOST),
]
);
cx.client = CachingClient::new(0, mock(vec![v4_message(), empty()]), false);
assert_eq!(
block_on(cx.ipv6_and_ipv4(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]
);
cx.client = CachingClient::new(0, mock(vec![v4_message(), error()]), false);
assert_eq!(
block_on(cx.ipv6_and_ipv4(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]
);
cx.client = CachingClient::new(0, mock(vec![empty(), v6_message()]), false);
assert_eq!(
block_on(cx.ipv6_and_ipv4(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
);
cx.client = CachingClient::new(0, mock(vec![error(), v6_message()]), false);
assert_eq!(
block_on(cx.ipv6_and_ipv4(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
);
}
#[test]
fn test_ipv6_then_ipv4_strategy() {
subscribe();
let mut cx = LookupContext {
client: CachingClient::new(0, mock(vec![v6_message()]), false),
options: DnsRequestOptions::default(),
hosts: Arc::new(Hosts::default()),
};
assert_eq!(
block_on(cx.ipv6_then_ipv4(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
);
cx.client = CachingClient::new(0, mock(vec![v4_message(), empty()]), false);
assert_eq!(
block_on(cx.ipv6_then_ipv4(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv4Addr::LOCALHOST]
);
cx.client = CachingClient::new(0, mock(vec![v4_message(), error()]), false);
assert_eq!(
block_on(cx.ipv6_then_ipv4(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv4Addr::LOCALHOST]
);
}
#[test]
fn test_ipv4_then_ipv6_strategy() {
subscribe();
let mut cx = LookupContext {
client: CachingClient::new(0, mock(vec![v4_message()]), false),
options: DnsRequestOptions::default(),
hosts: Arc::new(Hosts::default()),
};
assert_eq!(
block_on(cx.ipv4_then_ipv6(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv4Addr::LOCALHOST]
);
cx.client = CachingClient::new(0, mock(vec![v6_message(), empty()]), false);
assert_eq!(
block_on(cx.ipv4_then_ipv6(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
);
cx.client = CachingClient::new(0, mock(vec![v6_message(), error()]), false);
assert_eq!(
block_on(cx.ipv4_then_ipv6(Name::root()))
.unwrap()
.answers()
.iter()
.map(|r| r.data().ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
);
}
}