use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::sync::Mutex;
use std::time::{Duration, Instant};
use crate::error::{Error, ErrorKind, Result};
#[derive(Clone, Copy, Debug)]
pub struct DnsConfig {
pub ttl: Duration,
pub dual_stack: bool,
}
impl Default for DnsConfig {
fn default() -> Self {
Self {
ttl: Duration::from_secs(60),
dual_stack: false,
}
}
}
#[derive(Default)]
pub struct DnsCache {
entries: Mutex<HashMap<String, DnsEntry>>,
}
struct DnsEntry {
addrs: Vec<IpAddr>,
expires_at: Instant,
}
impl DnsCache {
pub fn new() -> Self {
Self::default()
}
pub fn resolve_host(&self, host: &str, config: DnsConfig) -> Result<Vec<IpAddr>> {
if let Ok(ip) = host.parse::<IpAddr>() {
return Ok(vec![ip]);
}
let now = Instant::now();
let cached = {
let entries = self.entries.lock().unwrap_or_else(|err| err.into_inner());
entries
.get(host)
.filter(|entry| entry.expires_at > now)
.map(|entry| entry.addrs.clone())
};
if let Some(addrs) = cached {
return Ok(addrs);
}
let addrs = (host, 0)
.to_socket_addrs()
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
format!("failed to resolve host {host}"),
err,
)
})?
.map(|addr| addr.ip())
.collect::<Vec<_>>();
let addrs = if config.dual_stack {
interleave_dual_stack(addrs)
} else {
addrs
};
if addrs.is_empty() {
return Err(Error::new(
ErrorKind::Transport,
format!("host resolved to no addresses: {host}"),
));
}
self.entries
.lock()
.unwrap_or_else(|err| err.into_inner())
.insert(
host.to_owned(),
DnsEntry {
addrs: addrs.clone(),
expires_at: now + config.ttl,
},
);
Ok(addrs)
}
pub fn resolve_socket_addrs(
&self,
host: &str,
port: u16,
config: DnsConfig,
) -> Result<Vec<SocketAddr>> {
let addrs = self.resolve_host(host, config)?;
Ok(addrs
.into_iter()
.map(|ip| SocketAddr::new(ip, port))
.collect())
}
pub fn prefetch<I, S>(&self, hosts: I, config: DnsConfig) -> Result<()>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
for host in hosts {
self.resolve_host(host.as_ref(), config)?;
}
Ok(())
}
}
fn interleave_dual_stack(addrs: Vec<IpAddr>) -> Vec<IpAddr> {
let mut ipv6 = addrs.iter().copied().filter(IpAddr::is_ipv6);
let mut ipv4 = addrs.iter().copied().filter(IpAddr::is_ipv4);
let mut ordered = Vec::with_capacity(addrs.len());
loop {
let mut progressed = false;
if let Some(addr) = ipv6.next() {
ordered.push(addr);
progressed = true;
}
if let Some(addr) = ipv4.next() {
ordered.push(addr);
progressed = true;
}
if !progressed {
break;
}
}
ordered
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolves_and_caches_localhost() {
let cache = DnsCache::new();
let first = cache
.resolve_socket_addrs("localhost", 8080, DnsConfig::default())
.unwrap();
let second = cache
.resolve_socket_addrs("localhost", 8080, DnsConfig::default())
.unwrap();
assert!(!first.is_empty());
assert_eq!(first, second);
}
#[test]
fn dual_stack_interleaves_ipv6_and_ipv4_addresses() {
let ordered = interleave_dual_stack(vec![
"::1".parse::<IpAddr>().unwrap(),
"::2".parse::<IpAddr>().unwrap(),
"127.0.0.1".parse::<IpAddr>().unwrap(),
"127.0.0.2".parse::<IpAddr>().unwrap(),
]);
assert_eq!(
ordered,
vec![
"::1".parse::<IpAddr>().unwrap(),
"127.0.0.1".parse::<IpAddr>().unwrap(),
"::2".parse::<IpAddr>().unwrap(),
"127.0.0.2".parse::<IpAddr>().unwrap(),
]
);
}
}