use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use lazy_static::lazy_static;
use crate::error::{CdpError, Result};
const TTL: Duration = Duration::from_secs(30);
const REAP_AT: usize = 256;
const HARD_CAP: usize = 1024;
#[derive(Clone)]
struct Entry {
addrs: Arc<[SocketAddr]>,
at: Instant,
}
lazy_static! {
static ref CACHE: DashMap<String, Entry> = DashMap::new();
}
#[inline]
fn key(host: &str, port: u16) -> String {
format!("{host}:{port}")
}
#[inline]
pub fn enabled() -> bool {
use std::sync::OnceLock;
static ON: OnceLock<bool> = OnceLock::new();
*ON.get_or_init(|| {
!matches!(
std::env::var("CHROMEY_WS_DNS_CACHE")
.ok()
.as_deref()
.map(str::trim),
Some("0") | Some("false") | Some("no") | Some("off")
)
})
}
#[inline]
pub fn is_ip_literal(host: &str) -> bool {
host.parse::<IpAddr>().is_ok()
}
pub async fn resolve(host: &str, port: u16) -> Result<Arc<[SocketAddr]>> {
let k = key(host, port);
if let Some(addrs) = fresh(&k) {
return Ok(addrs);
}
let addrs = lookup(host, port).await?;
store(k, addrs.clone());
Ok(addrs)
}
#[inline]
pub fn invalidate(host: &str, port: u16) {
CACHE.remove(&key(host, port));
}
fn fresh(k: &str) -> Option<Arc<[SocketAddr]>> {
let e = CACHE.get(k)?; if e.at.elapsed() <= TTL {
Some(e.addrs.clone())
} else {
None
}
}
fn store(k: String, addrs: Arc<[SocketAddr]>) {
if CACHE.len() >= REAP_AT {
CACHE.retain(|_, e| e.at.elapsed() <= TTL);
}
if CACHE.len() < HARD_CAP {
CACHE.insert(k, Entry { addrs, at: Instant::now() });
}
}
async fn lookup(host: &str, port: u16) -> Result<Arc<[SocketAddr]>> {
let addrs: Vec<SocketAddr> = tokio::net::lookup_host((host, port))
.await
.map_err(CdpError::Io)?
.collect();
if addrs.is_empty() {
return Err(CdpError::msg(format!("dns: no addresses for {host}:{port}")));
}
Ok(Arc::from(addrs.into_boxed_slice()))
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
fn sa(p: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), p)
}
#[test]
fn ip_literals_are_recognised() {
assert!(is_ip_literal("127.0.0.1"));
assert!(is_ip_literal("::1"));
assert!(!is_ip_literal("localhost"));
assert!(!is_ip_literal("chrome.internal"));
}
#[test]
fn fresh_returns_cached_then_invalidate_clears() {
let k = key("unit-host", 9222);
store(k.clone(), Arc::from(vec![sa(9222)].into_boxed_slice()));
assert!(fresh(&k).is_some(), "fresh entry should be served");
invalidate("unit-host", 9222);
assert!(fresh(&k).is_none(), "invalidated entry must be gone");
}
#[test]
fn key_includes_port_so_same_host_different_port_are_distinct() {
assert_ne!(key("h2", 9222), key("h2", 9223));
store(key("h2", 9222), Arc::from(vec![sa(9222)].into_boxed_slice()));
assert!(fresh(&key("h2", 9222)).is_some());
assert!(fresh(&key("h2", 9223)).is_none());
invalidate("h2", 9222);
}
}