use dashmap::DashMap;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::sync::Arc;
use std::time::{Duration, Instant};
const MAX_ENTRIES: usize = 5_000;
pub(crate) struct DnsEntry {
sockaddrs: Arc<[SocketAddr]>,
addrs: Vec<IpAddr>,
expires: Instant,
}
pub struct DnsCache {
pub(crate) cache: DashMap<String, DnsEntry>,
pub(crate) ttl: Duration,
}
impl DnsCache {
pub fn new(ttl: Duration) -> Self {
Self {
cache: DashMap::with_capacity(128),
ttl,
}
}
pub async fn resolve(&self, host: &str) -> Option<Vec<IpAddr>> {
if let Some(entry) = self.cache.get(host) {
if entry.expires > Instant::now() {
return Some(entry.addrs.clone());
}
}
let host_owned = format!("{}:0", host);
let result = tokio::task::spawn_blocking(move || {
host_owned
.to_socket_addrs()
.ok()
.map(|addrs| addrs.collect::<Vec<SocketAddr>>())
})
.await
.ok()
.flatten();
match result {
Some(resolved) if !resolved.is_empty() => {
let ips: Vec<IpAddr> = resolved.iter().map(|s| s.ip()).collect();
if self.cache.len() >= MAX_ENTRIES {
self.evict_expired();
if self.cache.len() >= MAX_ENTRIES {
let to_remove = MAX_ENTRIES / 10;
let keys_to_remove: Vec<String> = self
.cache
.iter()
.take(to_remove)
.map(|entry| entry.key().clone())
.collect();
for k in keys_to_remove {
self.cache.remove(&k);
}
}
}
self.cache.insert(
host.to_string(),
DnsEntry {
sockaddrs: resolved.into(),
addrs: ips.clone(),
expires: Instant::now() + self.ttl,
},
);
Some(ips)
}
_ => None,
}
}
pub async fn pre_resolve(&self, hosts: &[&str]) {
let mut set = tokio::task::JoinSet::new();
for &host in hosts {
let host_string = host.to_string();
let ttl = self.ttl;
set.spawn(async move {
let addr_str = format!("{}:0", host_string);
let result = tokio::task::spawn_blocking(move || {
addr_str
.to_socket_addrs()
.ok()
.map(|addrs| addrs.collect::<Vec<SocketAddr>>())
})
.await
.ok()
.flatten();
(host_string, result, ttl)
});
}
while let Some(Ok((host, result, ttl))) = set.join_next().await {
if let Some(resolved) = result {
if !resolved.is_empty() {
let ips: Vec<IpAddr> = resolved.iter().map(|s| s.ip()).collect();
self.cache.insert(
host,
DnsEntry {
sockaddrs: resolved.into(),
addrs: ips,
expires: Instant::now() + ttl,
},
);
}
}
}
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn evict_expired(&self) {
let now = Instant::now();
self.cache.retain(|_, v| v.expires > now);
}
}
pub struct DnsCacheResolver(pub Arc<DnsCache>);
impl crate::client::dns::Resolve for DnsCacheResolver {
fn resolve(&self, name: crate::client::dns::Name) -> crate::client::dns::Resolving {
let host = name.as_str().to_string();
let cache = self.0.clone();
Box::pin(async move {
let now = Instant::now();
if let Some(entry) = cache.cache.get(&host) {
if entry.expires > now {
let addrs = entry.sockaddrs.clone();
let iter: crate::client::dns::Addrs = Box::new(ArcSocketAddrIter {
inner: addrs,
pos: 0,
});
return Ok(iter);
}
}
let host_for_resolve = format!("{}:0", host);
let result = tokio::task::spawn_blocking(move || {
host_for_resolve
.to_socket_addrs()
.ok()
.map(|addrs| addrs.collect::<Vec<SocketAddr>>())
})
.await
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?
.ok_or_else(|| -> Box<dyn std::error::Error + Send + Sync> {
"dns resolution failed".into()
})?;
let ips: Vec<IpAddr> = result.iter().map(|s| s.ip()).collect();
let sockaddrs: Arc<[SocketAddr]> = result.into();
if !ips.is_empty() {
cache.cache.insert(
host,
DnsEntry {
sockaddrs: sockaddrs.clone(),
addrs: ips,
expires: Instant::now() + cache.ttl,
},
);
}
let iter: crate::client::dns::Addrs = Box::new(ArcSocketAddrIter {
inner: sockaddrs,
pos: 0,
});
Ok(iter)
})
}
}
struct ArcSocketAddrIter {
inner: Arc<[SocketAddr]>,
pos: usize,
}
impl Iterator for ArcSocketAddrIter {
type Item = SocketAddr;
fn next(&mut self) -> Option<SocketAddr> {
if self.pos < self.inner.len() {
let addr = self.inner[self.pos];
self.pos += 1;
Some(addr)
} else {
None
}
}
}
pub fn shared_dns_cache() -> Arc<DnsCacheResolver> {
use std::sync::OnceLock;
static CACHE: OnceLock<Arc<DnsCacheResolver>> = OnceLock::new();
CACHE
.get_or_init(|| {
Arc::new(DnsCacheResolver(Arc::new(DnsCache::new(
Duration::from_secs(300),
))))
})
.clone()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn resolve_localhost_returns_result() {
let cache = DnsCache::new(Duration::from_secs(60));
let result = cache.resolve("localhost").await;
assert!(result.is_some());
assert!(!result.unwrap().is_empty());
}
#[tokio::test]
async fn cache_hit_returns_same_result() {
let cache = DnsCache::new(Duration::from_secs(60));
let first = cache.resolve("localhost").await;
assert_eq!(cache.len(), 1);
let second = cache.resolve("localhost").await;
assert_eq!(first, second);
assert_eq!(cache.len(), 1);
}
#[tokio::test]
async fn expired_entry_triggers_re_resolve() {
let cache = DnsCache::new(Duration::from_millis(1));
let _ = cache.resolve("localhost").await;
assert_eq!(cache.len(), 1);
tokio::time::sleep(Duration::from_millis(10)).await;
let result = cache.resolve("localhost").await;
assert!(result.is_some());
assert_eq!(cache.len(), 1);
}
#[tokio::test]
async fn unknown_host_returns_none() {
let cache = DnsCache::new(Duration::from_secs(60));
let result = cache
.resolve("this.host.definitely.does.not.exist.example")
.await;
assert!(result.is_none());
assert!(cache.is_empty());
}
#[tokio::test]
async fn pre_resolve_populates_cache() {
let cache = DnsCache::new(Duration::from_secs(60));
cache.pre_resolve(&["localhost"]).await;
assert!(cache.len() >= 1);
let result = cache.resolve("localhost").await;
assert!(result.is_some());
}
#[tokio::test]
async fn evict_expired_removes_stale_entries() {
let cache = DnsCache::new(Duration::from_millis(1));
let _ = cache.resolve("localhost").await;
assert_eq!(cache.len(), 1);
tokio::time::sleep(Duration::from_millis(10)).await;
cache.evict_expired();
assert!(cache.is_empty());
}
#[test]
fn new_cache_is_empty() {
let cache = DnsCache::new(Duration::from_secs(60));
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
}
#[tokio::test]
async fn resolver_cache_hit_returns_socket_addrs() {
let dns = Arc::new(DnsCache::new(Duration::from_secs(60)));
let _ = dns.resolve("localhost").await;
assert_eq!(dns.len(), 1);
let resolver = DnsCacheResolver(dns);
let name = "localhost".parse().expect("valid name");
let addrs: Vec<SocketAddr> = crate::client::dns::Resolve::resolve(&resolver, name)
.await
.expect("should resolve")
.collect();
assert!(!addrs.is_empty());
}
}