use std::{
fmt::Write,
net::{IpAddr, Ipv6Addr},
time::Duration,
};
use anyhow::Result;
use futures_lite::{Future, StreamExt};
use hickory_resolver::{AsyncResolver, IntoName, TokioAsyncResolver};
use iroh_base::{key::NodeId, node_addr::NodeAddr};
use once_cell::sync::Lazy;
pub mod node_info;
pub type DnsResolver = TokioAsyncResolver;
static DNS_RESOLVER: Lazy<TokioAsyncResolver> =
Lazy::new(|| create_default_resolver().expect("unable to create DNS resolver"));
pub fn default_resolver() -> &'static DnsResolver {
&DNS_RESOLVER
}
pub fn resolver() -> &'static TokioAsyncResolver {
Lazy::force(&DNS_RESOLVER)
}
const WINDOWS_BAD_SITE_LOCAL_DNS_SERVERS: [IpAddr; 3] = [
IpAddr::V6(Ipv6Addr::new(0xfec0, 0, 0, 0xffff, 0, 0, 0, 1)),
IpAddr::V6(Ipv6Addr::new(0xfec0, 0, 0, 0xffff, 0, 0, 0, 2)),
IpAddr::V6(Ipv6Addr::new(0xfec0, 0, 0, 0xffff, 0, 0, 0, 3)),
];
fn create_default_resolver() -> Result<TokioAsyncResolver> {
let (system_config, mut options) =
hickory_resolver::system_conf::read_system_conf().unwrap_or_default();
let mut config = hickory_resolver::config::ResolverConfig::new();
if let Some(name) = system_config.domain() {
config.set_domain(name.clone());
}
for name in system_config.search() {
config.add_search(name.clone());
}
for nameserver_cfg in system_config.name_servers() {
if !WINDOWS_BAD_SITE_LOCAL_DNS_SERVERS.contains(&nameserver_cfg.socket_addr.ip()) {
config.add_name_server(nameserver_cfg.clone());
}
}
options.ip_strategy = hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6;
let resolver = AsyncResolver::tokio(config, options);
Ok(resolver)
}
pub trait ResolverExt {
fn lookup_ipv4<N: IntoName>(
&self,
host: N,
timeout: Duration,
) -> impl Future<Output = Result<impl Iterator<Item = IpAddr>>>;
fn lookup_ipv6<N: IntoName>(
&self,
host: N,
timeout: Duration,
) -> impl Future<Output = Result<impl Iterator<Item = IpAddr>>>;
fn lookup_ipv4_ipv6<N: IntoName + Clone>(
&self,
host: N,
timeout: Duration,
) -> impl Future<Output = Result<impl Iterator<Item = IpAddr>>>;
fn lookup_by_name(&self, name: &str) -> impl Future<Output = Result<NodeAddr>>;
fn lookup_by_id(
&self,
node_id: &NodeId,
origin: &str,
) -> impl Future<Output = Result<NodeAddr>>;
fn lookup_ipv4_staggered<N: IntoName + Clone>(
&self,
host: N,
timeout: Duration,
delays_ms: &[u64],
) -> impl Future<Output = Result<impl Iterator<Item = IpAddr>>>;
fn lookup_ipv6_staggered<N: IntoName + Clone>(
&self,
host: N,
timeout: Duration,
delays_ms: &[u64],
) -> impl Future<Output = Result<impl Iterator<Item = IpAddr>>>;
fn lookup_ipv4_ipv6_staggered<N: IntoName + Clone>(
&self,
host: N,
timeout: Duration,
delays_ms: &[u64],
) -> impl Future<Output = Result<impl Iterator<Item = IpAddr>>>;
fn lookup_by_name_staggered(
&self,
name: &str,
delays_ms: &[u64],
) -> impl Future<Output = Result<NodeAddr>>;
fn lookup_by_id_staggered(
&self,
node_id: &NodeId,
origin: &str,
delays_ms: &[u64],
) -> impl Future<Output = Result<NodeAddr>>;
}
impl ResolverExt for DnsResolver {
async fn lookup_ipv4<N: IntoName>(
&self,
host: N,
timeout: Duration,
) -> Result<impl Iterator<Item = IpAddr>> {
let addrs = tokio::time::timeout(timeout, self.ipv4_lookup(host)).await??;
Ok(addrs.into_iter().map(|ip| IpAddr::V4(ip.0)))
}
async fn lookup_ipv6<N: IntoName>(
&self,
host: N,
timeout: Duration,
) -> Result<impl Iterator<Item = IpAddr>> {
let addrs = tokio::time::timeout(timeout, self.ipv6_lookup(host)).await??;
Ok(addrs.into_iter().map(|ip| IpAddr::V6(ip.0)))
}
async fn lookup_ipv4_ipv6<N: IntoName + Clone>(
&self,
host: N,
timeout: Duration,
) -> Result<impl Iterator<Item = IpAddr>> {
let res = tokio::join!(
self.lookup_ipv4(host.clone(), timeout),
self.lookup_ipv6(host, timeout)
);
match res {
(Ok(ipv4), Ok(ipv6)) => Ok(LookupIter::Both(ipv4.chain(ipv6))),
(Ok(ipv4), Err(_)) => Ok(LookupIter::Ipv4(ipv4)),
(Err(_), Ok(ipv6)) => Ok(LookupIter::Ipv6(ipv6)),
(Err(ipv4_err), Err(ipv6_err)) => {
anyhow::bail!("Ipv4: {:?}, Ipv6: {:?}", ipv4_err, ipv6_err)
}
}
}
async fn lookup_by_name(&self, name: &str) -> Result<NodeAddr> {
let attrs = node_info::TxtAttrs::<node_info::IrohAttr>::lookup_by_name(self, name).await?;
let info: node_info::NodeInfo = attrs.into();
Ok(info.into())
}
async fn lookup_by_id(&self, node_id: &NodeId, origin: &str) -> Result<NodeAddr> {
let attrs =
node_info::TxtAttrs::<node_info::IrohAttr>::lookup_by_id(self, node_id, origin).await?;
let info: node_info::NodeInfo = attrs.into();
Ok(info.into())
}
async fn lookup_ipv4_staggered<N: IntoName + Clone>(
&self,
host: N,
timeout: Duration,
delays_ms: &[u64],
) -> Result<impl Iterator<Item = IpAddr>> {
let f = || self.lookup_ipv4(host.clone(), timeout);
stagger_call(f, delays_ms).await
}
async fn lookup_ipv6_staggered<N: IntoName + Clone>(
&self,
host: N,
timeout: Duration,
delays_ms: &[u64],
) -> Result<impl Iterator<Item = IpAddr>> {
let f = || self.lookup_ipv6(host.clone(), timeout);
stagger_call(f, delays_ms).await
}
async fn lookup_ipv4_ipv6_staggered<N: IntoName + Clone>(
&self,
host: N,
timeout: Duration,
delays_ms: &[u64],
) -> Result<impl Iterator<Item = IpAddr>> {
let f = || self.lookup_ipv4_ipv6(host.clone(), timeout);
stagger_call(f, delays_ms).await
}
async fn lookup_by_name_staggered(&self, name: &str, delays_ms: &[u64]) -> Result<NodeAddr> {
let f = || self.lookup_by_name(name);
stagger_call(f, delays_ms).await
}
async fn lookup_by_id_staggered(
&self,
node_id: &NodeId,
origin: &str,
delays_ms: &[u64],
) -> Result<NodeAddr> {
let f = || self.lookup_by_id(node_id, origin);
stagger_call(f, delays_ms).await
}
}
enum LookupIter<A, B> {
Ipv4(A),
Ipv6(B),
Both(std::iter::Chain<A, B>),
}
impl<A: Iterator<Item = IpAddr>, B: Iterator<Item = IpAddr>> Iterator for LookupIter<A, B> {
type Item = IpAddr;
fn next(&mut self) -> Option<Self::Item> {
match self {
LookupIter::Ipv4(iter) => iter.next(),
LookupIter::Ipv6(iter) => iter.next(),
LookupIter::Both(iter) => iter.next(),
}
}
}
async fn stagger_call<T, F: Fn() -> Fut, Fut: Future<Output = Result<T>>>(
f: F,
delays_ms: &[u64],
) -> Result<T> {
let mut calls = futures_buffered::FuturesUnorderedBounded::new(delays_ms.len() + 1);
for delay in std::iter::once(&0u64).chain(delays_ms) {
let delay = std::time::Duration::from_millis(*delay);
let fut = f();
let staggered_fut = async move {
tokio::time::sleep(delay).await;
fut.await
};
calls.push(staggered_fut)
}
let mut errors = vec![];
while let Some(call_result) = calls.next().await {
match call_result {
Ok(t) => return Ok(t),
Err(e) => errors.push(e),
}
}
anyhow::bail!(
"no calls succeed: [ {}]",
errors.into_iter().fold(String::new(), |mut summary, e| {
write!(summary, "{e} ").expect("infallible");
summary
})
)
}
#[cfg(test)]
pub(crate) mod tests {
use std::sync::atomic::AtomicUsize;
use super::*;
use crate::defaults::staging::NA_RELAY_HOSTNAME;
const TIMEOUT: Duration = Duration::from_secs(5);
const STAGGERING_DELAYS: &[u64] = &[200, 300];
#[tokio::test]
async fn test_dns_lookup_ipv4_ipv6() {
let _logging = iroh_test::logging::setup();
let resolver = default_resolver();
let res: Vec<_> = resolver
.lookup_ipv4_ipv6_staggered(NA_RELAY_HOSTNAME, TIMEOUT, STAGGERING_DELAYS)
.await
.unwrap()
.collect();
assert!(!res.is_empty());
dbg!(res);
}
#[tokio::test]
async fn stagger_basic() {
let _logging = iroh_test::logging::setup();
const CALL_RESULTS: &[Result<u8, u8>] = &[Err(2), Ok(3), Ok(5), Ok(7)];
static DONE_CALL: AtomicUsize = AtomicUsize::new(0);
let f = || {
let r_pos = DONE_CALL.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
async move {
tracing::info!(r_pos, "call");
CALL_RESULTS[r_pos].map_err(|e| anyhow::anyhow!("{e}"))
}
};
let delays = [1000, 15];
let result = stagger_call(f, &delays).await.unwrap();
assert_eq!(result, 5)
}
}