use dashmap::DashMap;
use hickory_resolver::TokioResolver;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::{Duration, Instant};
use crate::{Error, Result};
#[derive(Clone)]
pub struct DnsCache {
inner: Arc<DashMap<String, Entry>>,
fallback_ttl: Duration,
resolver: Arc<OnceLock<TokioResolver>>,
}
impl Default for DnsCache {
fn default() -> Self {
Self::new(Duration::from_secs(300))
}
}
#[derive(Clone)]
struct Entry {
addrs: Vec<SocketAddr>,
inserted: Instant,
ttl: Duration,
}
impl DnsCache {
pub fn new(fallback_ttl: Duration) -> Self {
Self {
inner: Arc::new(DashMap::new()),
fallback_ttl,
resolver: Arc::new(OnceLock::new()),
}
}
fn resolver(&self) -> Result<&TokioResolver> {
if let Some(r) = self.resolver.get() {
return Ok(r);
}
let r = TokioResolver::builder_tokio()
.map_err(|e| Error::DnsResolution {
host: String::new(),
reason: format!("builder: {e}"),
})?
.build()
.map_err(|e| Error::DnsResolution {
host: String::new(),
reason: format!("build: {e}"),
})?;
Ok(self.resolver.get_or_init(|| r))
}
pub async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return Ok(vec![SocketAddr::new(ip, port)]);
}
let key = format!("{host}:{port}");
if let Some(e) = self.inner.get(&key) {
if e.inserted.elapsed() < e.ttl {
return Ok(e.addrs.clone());
}
}
let resolver = self.resolver()?;
let lookup = resolver
.lookup_ip(host)
.await
.map_err(|e| Error::DnsResolution {
host: host.to_string(),
reason: e.to_string(),
})?;
let ttl = lookup
.as_lookup()
.answers()
.iter()
.map(|r| Duration::from_secs(u64::from(r.ttl)))
.min()
.unwrap_or(self.fallback_ttl)
.max(Duration::from_secs(1));
let addrs: Vec<SocketAddr> = lookup.iter().map(|ip| SocketAddr::new(ip, port)).collect();
if addrs.is_empty() {
return Err(Error::DnsResolution {
host: host.to_string(),
reason: "no A/AAAA records".into(),
});
}
self.inner.insert(
key,
Entry {
addrs: addrs.clone(),
inserted: Instant::now(),
ttl,
},
);
Ok(addrs)
}
}