use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use hickory_resolver::TokioResolver;
use hickory_resolver::config::LookupIpStrategy;
use hickory_resolver::net::NetError;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::connect::dns::Name;
use tower::Service;
use crate::configuration::shared::DnsResolutionStrategy;
#[derive(Debug, Clone)]
pub(crate) struct AsyncHyperResolver(Arc<TokioResolver>);
impl AsyncHyperResolver {
fn new_from_system_conf(
dns_resolution_strategy: DnsResolutionStrategy,
) -> Result<Self, io::Error> {
let mut builder = TokioResolver::builder_tokio().map_err(convert_net_error)?;
builder.options_mut().ip_strategy = dns_resolution_strategy.into();
Ok(Self(Arc::new(builder.build().map_err(convert_net_error)?)))
}
}
impl Service<Name> for AsyncHyperResolver {
type Response = std::vec::IntoIter<SocketAddr>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
type Error = io::Error;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, name: Name) -> Self::Future {
let resolver = self.0.clone();
Box::pin(async move {
Ok(resolver
.lookup_ip(name.as_str())
.await
.map_err(convert_net_error)?
.iter()
.map(|addr| (addr, 0_u16).to_socket_addrs())
.try_fold(Vec::new(), |mut acc, s_addr| {
acc.extend(s_addr?);
Ok::<_, io::Error>(acc)
})?
.into_iter())
})
}
}
impl From<DnsResolutionStrategy> for LookupIpStrategy {
fn from(value: DnsResolutionStrategy) -> LookupIpStrategy {
match value {
DnsResolutionStrategy::Ipv4Only => LookupIpStrategy::Ipv4Only,
DnsResolutionStrategy::Ipv6Only => LookupIpStrategy::Ipv6Only,
DnsResolutionStrategy::Ipv4AndIpv6 => LookupIpStrategy::Ipv4AndIpv6,
DnsResolutionStrategy::Ipv6ThenIpv4 => LookupIpStrategy::Ipv6thenIpv4,
DnsResolutionStrategy::Ipv4ThenIpv6 => LookupIpStrategy::Ipv4thenIpv6,
}
}
}
pub(crate) fn new_async_http_connector(
dns_resolution_strategy: DnsResolutionStrategy,
) -> Result<HttpConnector<AsyncHyperResolver>, io::Error> {
let resolver = AsyncHyperResolver::new_from_system_conf(dns_resolution_strategy)?;
Ok(HttpConnector::new_with_resolver(resolver))
}
fn convert_net_error(err: NetError) -> io::Error {
match err {
NetError::Busy => io::Error::new(io::ErrorKind::ResourceBusy, err),
NetError::Io(io_err) => io::Error::new(io_err.kind(), io_err),
NetError::Timeout => io::Error::new(io::ErrorKind::TimedOut, err),
_ => io::Error::other(err),
}
}
#[cfg(test)]
mod tests {
use std::io;
use std::sync::Arc;
use hickory_resolver::net::NetError;
use super::convert_net_error;
#[test]
fn busy_maps_to_resource_busy() {
let err = convert_net_error(NetError::Busy);
assert_eq!(err.kind(), io::ErrorKind::ResourceBusy);
}
#[test]
fn timeout_maps_to_timed_out() {
let err = convert_net_error(NetError::Timeout);
assert_eq!(err.kind(), io::ErrorKind::TimedOut);
}
#[test]
fn io_preserves_kind() {
let inner = io::Error::new(io::ErrorKind::ConnectionRefused, "refused");
let err = convert_net_error(NetError::Io(Arc::new(inner)));
assert_eq!(err.kind(), io::ErrorKind::ConnectionRefused);
}
#[test]
fn other_variants_map_to_other() {
let err = convert_net_error(NetError::Message("something went wrong"));
assert_eq!(err.kind(), io::ErrorKind::Other);
}
}