use dashmap::DashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
const MAX_ENTRIES: usize = 5_000;
const PROXY_REFRESH_MIN_SECS: u64 = 30;
const PROXY_REFRESH_MAX_SECS: u64 = 240;
pub(crate) struct DnsEntry {
sockaddrs: Arc<[SocketAddr]>,
addrs: Vec<IpAddr>,
expires: Instant,
}
fn async_resolver() -> &'static hickory_resolver::TokioResolver {
use hickory_resolver::name_server::TokioConnectionProvider;
use std::sync::OnceLock;
static RESOLVER: OnceLock<hickory_resolver::TokioResolver> = OnceLock::new();
RESOLVER.get_or_init(|| {
hickory_resolver::Resolver::builder_with_config(
hickory_resolver::config::ResolverConfig::default(),
TokioConnectionProvider::default(),
)
.with_options(hickory_resolver::config::ResolverOpts::default())
.build()
})
}
pub struct DnsCache {
pub(crate) cache: DashMap<String, DnsEntry>,
pub(crate) ttl: Duration,
}
impl DnsCache {
pub fn new(ttl: Duration) -> Self {
Self {
cache: DashMap::with_capacity(128),
ttl,
}
}
pub async fn resolve(&self, host: &str) -> Option<Vec<IpAddr>> {
if let Some(entry) = self.cache.get(host) {
if entry.expires > Instant::now() {
return Some(entry.addrs.clone());
}
}
let lookup = async_resolver().lookup_ip(host).await.ok()?;
let ips: Vec<IpAddr> = lookup.iter().collect();
if ips.is_empty() {
return None;
}
let sockaddrs: Vec<SocketAddr> = ips.iter().map(|ip| SocketAddr::new(*ip, 0)).collect();
if self.cache.len() >= MAX_ENTRIES {
self.evict_expired();
if self.cache.len() >= MAX_ENTRIES {
let to_remove = MAX_ENTRIES / 10;
let keys_to_remove: Vec<String> = self
.cache
.iter()
.take(to_remove)
.map(|entry| entry.key().clone())
.collect();
for k in keys_to_remove {
self.cache.remove(&k);
}
}
}
self.cache.insert(
host.to_string(),
DnsEntry {
sockaddrs: sockaddrs.into(),
addrs: ips.clone(),
expires: Instant::now() + self.ttl,
},
);
Some(ips)
}
pub async fn pre_resolve(&self, hosts: &[&str]) {
let mut set = tokio::task::JoinSet::new();
for &host in hosts {
let host = host.to_string();
let ttl = self.ttl;
set.spawn(async move {
let lookup = async_resolver().lookup_ip(&host).await.ok()?;
let ips: Vec<IpAddr> = lookup.iter().collect();
if ips.is_empty() {
return None;
}
let sockaddrs: Vec<SocketAddr> =
ips.iter().map(|ip| SocketAddr::new(*ip, 0)).collect();
Some((host, sockaddrs, ips, ttl))
});
}
while let Some(Ok(Some((host, sockaddrs, ips, ttl)))) = set.join_next().await {
self.cache.insert(
host,
DnsEntry {
sockaddrs: sockaddrs.into(),
addrs: ips,
expires: Instant::now() + ttl,
},
);
}
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn evict_expired(&self) {
let now = Instant::now();
self.cache.retain(|_, v| v.expires > now);
}
pub fn invalidate(&self, host: &str) {
self.cache.remove(host);
}
pub async fn resolve_or_refresh(&self, host: &str) -> Option<Vec<IpAddr>> {
if let Some(ips) = self.resolve(host).await {
return Some(ips);
}
None
}
pub(crate) async fn resolve_hash(&self, host: &str) -> Option<u64> {
let ips = self.resolve(host).await?;
let mut hash: u64 = 0xcbf29ce484222325;
for ip in ips {
let bits = match ip {
IpAddr::V4(v4) => u32::from(v4) as u64,
IpAddr::V6(v6) => {
let b = u128::from(v6);
(b as u64) ^ ((b >> 64) as u64)
}
};
hash ^= bits.wrapping_mul(0x100000001b3);
}
Some(hash)
}
}
pub struct DnsCacheResolver(pub Arc<DnsCache>);
impl crate::client::dns::Resolve for DnsCacheResolver {
fn resolve(&self, name: crate::client::dns::Name) -> crate::client::dns::Resolving {
let host = name.as_str().to_string();
let cache = self.0.clone();
Box::pin(async move {
let now = Instant::now();
if let Some(entry) = cache.cache.get(&host) {
if entry.expires > now {
let addrs = entry.sockaddrs.clone();
let iter: crate::client::dns::Addrs = Box::new(ArcSocketAddrIter {
inner: addrs,
pos: 0,
});
return Ok(iter);
}
}
let lookup = async_resolver()
.lookup_ip(&host)
.await
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?;
let ips: Vec<IpAddr> = lookup.iter().collect();
if ips.is_empty() {
return Err("dns resolution returned no addresses".into());
}
let sockaddrs: Arc<[SocketAddr]> = ips
.iter()
.map(|ip| SocketAddr::new(*ip, 0))
.collect::<Vec<_>>()
.into();
cache.cache.insert(
host,
DnsEntry {
sockaddrs: sockaddrs.clone(),
addrs: ips,
expires: Instant::now() + cache.ttl,
},
);
let iter: crate::client::dns::Addrs = Box::new(ArcSocketAddrIter {
inner: sockaddrs,
pos: 0,
});
Ok(iter)
})
}
}
struct ArcSocketAddrIter {
inner: Arc<[SocketAddr]>,
pos: usize,
}
impl Iterator for ArcSocketAddrIter {
type Item = SocketAddr;
fn next(&mut self) -> Option<SocketAddr> {
if self.pos < self.inner.len() {
let addr = self.inner[self.pos];
self.pos += 1;
Some(addr)
} else {
None
}
}
}
impl DnsCacheResolver {
fn parse_proxy_host(proxy_url: &str) -> Option<String> {
url::Url::parse(proxy_url)
.ok()
.and_then(|u| u.host_str().map(|h| h.to_string()))
}
pub async fn prefetch_proxy_hosts(&self, proxy_urls: &[String]) {
let hosts: Vec<&str> = proxy_urls
.iter()
.filter_map(|u| {
url::Url::parse(u)
.ok()
.and_then(|parsed| parsed.host_str().map(|_| u.as_str()))
})
.collect::<Vec<_>>();
let mut unique = Vec::with_capacity(hosts.len());
let mut seen = std::collections::HashSet::new();
for url_str in proxy_urls {
if let Some(host) = Self::parse_proxy_host(url_str) {
if seen.insert(host.clone()) {
unique.push(host);
}
}
}
if !unique.is_empty() {
let refs: Vec<&str> = unique.iter().map(|s| s.as_str()).collect();
self.0.pre_resolve(&refs).await;
}
}
pub fn invalidate_proxy(&self, proxy_url: &str) {
if let Some(host) = Self::parse_proxy_host(proxy_url) {
self.0.invalidate(&host);
}
}
pub fn spawn_proxy_dns_refresh(&self, proxy_urls: &[String]) -> tokio::task::JoinHandle<()> {
let mut unique_hosts = Vec::new();
let mut seen = std::collections::HashSet::new();
for url_str in proxy_urls {
if let Some(host) = Self::parse_proxy_host(url_str) {
if seen.insert(host.clone()) {
unique_hosts.push(host);
}
}
}
if unique_hosts.is_empty() {
return tokio::spawn(async {});
}
let cache = self.0.clone();
tokio::spawn(async move {
let mut last_hashes: Vec<Option<u64>> = Vec::with_capacity(unique_hosts.len());
for host in &unique_hosts {
last_hashes.push(cache.resolve_hash(host).await);
}
let mut interval_secs = PROXY_REFRESH_MIN_SECS;
let mut jitter_counter: u64 = 0;
loop {
jitter_counter = jitter_counter.wrapping_add(0x9e3779b97f4a7c15);
let factor = 0.8 + (((jitter_counter >> 33) as f64) / (u32::MAX as f64)) * 0.4;
let sleep_ms = (interval_secs as f64 * factor * 1000.0) as u64;
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
let mut any_changed = false;
for (i, host) in unique_hosts.iter().enumerate() {
let current = cache.resolve_hash(host).await;
if current != last_hashes[i] {
log::info!("proxy DNS changed for {host}, resetting refresh interval");
last_hashes[i] = current;
any_changed = true;
}
}
if any_changed {
interval_secs = PROXY_REFRESH_MIN_SECS;
} else {
interval_secs = (interval_secs * 2).min(PROXY_REFRESH_MAX_SECS);
}
}
})
}
}
pub fn shared_dns_cache() -> Arc<DnsCacheResolver> {
use std::sync::OnceLock;
static CACHE: OnceLock<Arc<DnsCacheResolver>> = OnceLock::new();
CACHE
.get_or_init(|| {
Arc::new(DnsCacheResolver(Arc::new(DnsCache::new(
Duration::from_secs(300),
))))
})
.clone()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn resolve_localhost_returns_result() {
let cache = DnsCache::new(Duration::from_secs(60));
let result = cache.resolve("localhost").await;
assert!(result.is_some());
assert!(!result.unwrap().is_empty());
}
#[tokio::test]
async fn cache_hit_returns_same_result() {
let cache = DnsCache::new(Duration::from_secs(60));
let first = cache.resolve("localhost").await;
assert_eq!(cache.len(), 1);
let second = cache.resolve("localhost").await;
assert_eq!(first, second);
assert_eq!(cache.len(), 1);
}
#[tokio::test]
async fn expired_entry_triggers_re_resolve() {
let cache = DnsCache::new(Duration::from_millis(1));
let _ = cache.resolve("localhost").await;
assert_eq!(cache.len(), 1);
tokio::time::sleep(Duration::from_millis(10)).await;
let result = cache.resolve("localhost").await;
assert!(result.is_some());
assert_eq!(cache.len(), 1);
}
#[tokio::test]
async fn unknown_host_returns_none() {
let cache = DnsCache::new(Duration::from_secs(60));
let result = cache
.resolve("this.host.definitely.does.not.exist.example")
.await;
assert!(result.is_none());
assert!(cache.is_empty());
}
#[tokio::test]
async fn pre_resolve_populates_cache() {
let cache = DnsCache::new(Duration::from_secs(60));
cache.pre_resolve(&["localhost"]).await;
assert!(cache.len() >= 1);
let result = cache.resolve("localhost").await;
assert!(result.is_some());
}
#[tokio::test]
async fn evict_expired_removes_stale_entries() {
let cache = DnsCache::new(Duration::from_millis(1));
let _ = cache.resolve("localhost").await;
assert_eq!(cache.len(), 1);
tokio::time::sleep(Duration::from_millis(10)).await;
cache.evict_expired();
assert!(cache.is_empty());
}
#[test]
fn new_cache_is_empty() {
let cache = DnsCache::new(Duration::from_secs(60));
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
}
#[tokio::test]
async fn resolver_cache_hit_returns_socket_addrs() {
let dns = Arc::new(DnsCache::new(Duration::from_secs(60)));
let _ = dns.resolve("localhost").await;
assert_eq!(dns.len(), 1);
let resolver = DnsCacheResolver(dns);
let name = "localhost".parse().expect("valid name");
let addrs: Vec<SocketAddr> = crate::client::dns::Resolve::resolve(&resolver, name)
.await
.expect("should resolve")
.collect();
assert!(!addrs.is_empty());
}
}