use core::{net::SocketAddr, time::Duration};
use super::{super::AddressResolver, CachedSocketAddr};
use crate::address::{Domain, HostAddr};
use crossbeam_skiplist::SkipMap;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct HostAddrResolverOptions {
#[cfg_attr(
feature = "serde",
serde(with = "humantime_serde", default = "default_record_ttl")
)]
record_ttl: Duration,
}
impl Default for HostAddrResolverOptions {
fn default() -> Self {
Self::new()
}
}
const fn default_record_ttl() -> Duration {
Duration::from_secs(60)
}
impl HostAddrResolverOptions {
#[inline]
pub const fn new() -> Self {
Self {
record_ttl: default_record_ttl(),
}
}
#[inline]
pub const fn with_record_ttl(mut self, val: Duration) -> Self {
self.record_ttl = val;
self
}
#[inline]
pub fn set_record_ttl(&mut self, val: Duration) -> &mut Self {
self.record_ttl = val;
self
}
#[inline]
pub const fn record_ttl(&self) -> Duration {
self.record_ttl
}
}
pub use resolver::HostAddrResolver;
#[cfg(feature = "agnostic")]
mod resolver {
use super::*;
use agnostic::{RuntimeLite, net::ToSocketAddrs};
use hostaddr::Host;
pub struct HostAddrResolver<R> {
cache: SkipMap<Domain, CachedSocketAddr>,
record_ttl: Duration,
_marker: std::marker::PhantomData<R>,
}
impl<R> Default for HostAddrResolver<R> {
fn default() -> Self {
Self::new(Default::default())
}
}
impl<R: RuntimeLite> AddressResolver for HostAddrResolver<R> {
type Address = HostAddr;
type ResolvedAddress = SocketAddr;
type Error = std::io::Error;
type Runtime = R;
type Options = HostAddrResolverOptions;
#[inline]
async fn new(opts: Self::Options) -> Result<Self, Self::Error> {
Ok(Self {
record_ttl: opts.record_ttl,
cache: Default::default(),
_marker: Default::default(),
})
}
async fn resolve(&self, address: &Self::Address) -> Result<SocketAddr, Self::Error> {
let Some(port) = address.port() else {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"address missing port",
));
};
let address: hostaddr::HostAddr<&Domain> = address.into();
let host = address.host();
match host {
Host::Ip(ip) => Ok(SocketAddr::new(*ip, port)),
Host::Domain(name) => {
if let Some(ent) = self.cache.get(name.as_inner()) {
let val = ent.value();
if !val.is_expired() {
return Ok(val.val);
} else {
ent.remove();
}
}
let res =
ToSocketAddrs::<Self::Runtime>::to_socket_addrs(&(name.as_inner().as_str(), port))
.await?;
if let Some(addr) = res.into_iter().next() {
self.cache.insert(
(*name).clone(),
CachedSocketAddr::new(addr, self.record_ttl),
);
return Ok(addr);
}
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("failed to resolve {}", name.as_inner().as_str()),
))
}
}
}
}
impl<R> HostAddrResolver<R> {
pub fn new(opts: HostAddrResolverOptions) -> Self {
Self {
record_ttl: opts.record_ttl,
cache: Default::default(),
_marker: Default::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_dns_resolver() {
use agnostic::tokio::TokioRuntime;
let resolver = HostAddrResolver::<TokioRuntime>::default();
let google_addr = HostAddr::try_from("google.com:8080").unwrap();
let ip = resolver.resolve(&google_addr).await.unwrap();
#[cfg(feature = "std")]
println!("google.com:8080 resolved to: {}", ip);
}
#[tokio::test]
async fn test_dns_resolver_with_record_ttl() {
use agnostic::tokio::TokioRuntime;
let resolver = HostAddrResolver::<TokioRuntime>::new(
HostAddrResolverOptions::new().with_record_ttl(Duration::from_millis(100)),
);
let google_addr = HostAddr::try_from("google.com:8080").unwrap();
resolver.resolve(&google_addr).await.unwrap();
resolver.resolve(&google_addr).await.unwrap();
let ip_addr = HostAddr::try_from(("127.0.0.1", 8080)).unwrap();
resolver.resolve(&ip_addr).await.unwrap();
let dns_name = Domain::try_from("google.com").unwrap();
assert!(!resolver.cache.get(&dns_name).unwrap().value().is_expired());
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(resolver.cache.get(&dns_name).unwrap().value().is_expired());
resolver.resolve(&google_addr).await.unwrap();
let bad_addr = HostAddr::try_from("adasdjkljasidjaosdjaisudnaisudibasd.com:8080").unwrap();
assert!(resolver.resolve(&bad_addr).await.is_err());
}
}
}
#[cfg(not(feature = "agnostic"))]
mod resolver {
use super::*;
pub struct HostAddrResolver {
cache: SkipMap<Domain, CachedSocketAddr>,
record_ttl: Duration,
}
impl AddressResolver for HostAddrResolver {
type Address = HostAddr;
type ResolvedAddress = SocketAddr;
type Error = std::io::Error;
type Options = HostAddrResolverOptions;
#[inline]
async fn new(opts: Self::Options) -> Result<Self, Self::Error> {
Ok(Self {
record_ttl: opts.record_ttl,
cache: Default::default(),
})
}
async fn resolve(&self, address: &Self::Address) -> Result<SocketAddr, Self::Error> {
match address.as_inner() {
Either::Left(addr) => Ok(addr),
Either::Right((port, name)) => {
if let Some(ent) = self.cache.get(name) {
let val = ent.value();
if !val.is_expired() {
return Ok(val.val);
} else {
ent.remove();
}
}
let res = ToSocketAddrs::to_socket_addrs(&(name.as_str(), port))?;
if let Some(addr) = res.into_iter().next() {
self
.cache
.insert(name.clone(), CachedSocketAddr::new(addr, self.record_ttl));
return Ok(addr);
}
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("failed to resolve {}", name),
))
}
}
}
}
impl Default for HostAddrResolver {
fn default() -> Self {
Self::new(Default::default())
}
}
impl HostAddrResolver {
pub fn new(opts: HostAddrResolverOptions) -> Self {
Self {
record_ttl: opts.record_ttl,
cache: Default::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_dns_resolver() {
let resolver = HostAddrResolver::default();
let google_addr = HostAddr::try_from("google.com:8080").unwrap();
let ip = resolver.resolve(&google_addr).await.unwrap();
#[cfg(feature = "std")]
println!("google.com:8080 resolved to: {}", ip);
}
#[tokio::test]
async fn test_dns_resolver_with_record_ttl() {
let resolver = HostAddrResolver::new(
HostAddrResolverOptions::new().with_record_ttl(Duration::from_millis(100)),
);
let google_addr = HostAddr::try_from("google.com:8080").unwrap();
resolver.resolve(&google_addr).await.unwrap();
let dns_name = Domain::try_from("google.com").unwrap();
assert!(!resolver.cache.get(&dns_name).unwrap().value().is_expired());
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(resolver.cache.get(&dns_name).unwrap().value().is_expired());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_opts() {
let opts = HostAddrResolverOptions::default();
assert_eq!(opts.record_ttl(), default_record_ttl());
let mut opts = opts.with_record_ttl(Duration::from_secs(10));
assert_eq!(opts.record_ttl(), Duration::from_secs(10));
opts.set_record_ttl(Duration::from_secs(11));
assert_eq!(opts.record_ttl(), Duration::from_secs(11));
}
}