nodecraft/resolver/impls/
dns.rsuse core::time::Duration;
use std::{io, net::SocketAddr};
use agnostic::Runtime;
pub use agnostic::{
dns::{AsyncConnectionProvider, Dns, ResolverConfig, ResolverOpts},
net::Net,
};
use crossbeam_skiplist::SkipMap;
use super::{super::AddressResolver, CachedSocketAddr};
use crate::{DnsName, Kind, NodeAddress};
#[derive(Debug, thiserror::Error)]
enum ResolveErrorKind {
#[error("cannot resolve an ip address for {0}")]
NotFound(DnsName),
#[error(transparent)]
Resolve(#[from] hickory_resolver::error::ResolveError),
}
#[derive(Debug)]
#[repr(transparent)]
pub struct ResolveError(ResolveErrorKind);
impl core::fmt::Display for ResolveError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl core::error::Error for ResolveError {}
impl From<ResolveErrorKind> for ResolveError {
fn from(value: ResolveErrorKind) -> Self {
Self(value)
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
IO(#[from] io::Error),
#[error(transparent)]
Resolve(#[from] ResolveError),
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DnsOptions {
resolver_opts: ResolverOpts,
resolver_config: ResolverConfig,
}
const fn default_record_ttl() -> Duration {
Duration::from_secs(60)
}
impl DnsOptions {
pub fn new() -> Self {
Self {
resolver_opts: ResolverOpts::default(),
resolver_config: ResolverConfig::default(),
}
}
pub fn with_resolver_config(mut self, c: ResolverConfig) -> Self {
self.resolver_config = c;
self
}
pub fn set_resolver_config(&mut self, c: ResolverConfig) -> &mut Self {
self.resolver_config = c;
self
}
pub fn resolver_config(&self) -> &ResolverConfig {
&self.resolver_config
}
pub fn with_resolver_opts(mut self, o: ResolverOpts) -> Self {
self.resolver_opts = o;
self
}
pub fn set_resolver_opts(&mut self, o: ResolverOpts) -> &mut Self {
self.resolver_opts = o;
self
}
pub fn resolver_opts(&self) -> &ResolverOpts {
&self.resolver_opts
}
}
impl Default for DnsOptions {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DnsResolverOptions {
#[cfg_attr(feature = "serde", serde(default = "default_record_ttl"))]
record_ttl: Duration,
dns: Option<DnsOptions>,
}
impl Default for DnsResolverOptions {
fn default() -> Self {
Self::new()
}
}
impl DnsResolverOptions {
#[inline]
pub fn new() -> Self {
Self {
record_ttl: default_record_ttl(),
dns: Some(DnsOptions::default()),
}
}
#[inline]
pub const fn with_record_ttl(mut self, ttl: Duration) -> Self {
self.record_ttl = ttl;
self
}
#[inline]
pub fn set_record_ttl(&mut self, ttl: Duration) -> &mut Self {
self.record_ttl = ttl;
self
}
#[inline]
pub const fn record_ttl(&self) -> Duration {
self.record_ttl
}
#[inline]
pub fn with_dns(mut self, dns: Option<DnsOptions>) -> Self {
self.dns = dns;
self
}
#[inline]
pub fn set_dns(&mut self, dns: Option<DnsOptions>) -> &mut Self {
self.dns = dns;
self
}
#[inline]
pub const fn dns(&self) -> Option<&DnsOptions> {
self.dns.as_ref()
}
}
pub struct DnsResolver<R: Runtime> {
dns: Option<Dns<R::Net>>,
record_ttl: Duration,
cache: SkipMap<DnsName, CachedSocketAddr>,
}
impl<R: Runtime> AddressResolver for DnsResolver<R> {
type Address = NodeAddress;
type Error = Error;
type ResolvedAddress = SocketAddr;
type Runtime = R;
type Options = DnsResolverOptions;
async fn new(opts: Self::Options) -> Result<Self, Self::Error>
where
Self: Sized,
{
let dns = if let Some(opts) = opts.dns {
Some(Dns::new(
opts.resolver_config,
opts.resolver_opts,
AsyncConnectionProvider::new(),
))
} else {
None
};
Ok(Self {
dns,
record_ttl: opts.record_ttl,
cache: Default::default(),
})
}
async fn resolve(&self, address: &Self::Address) -> Result<Self::ResolvedAddress, Self::Error> {
match &address.kind {
Kind::Ip(ip) => Ok(SocketAddr::new(*ip, address.port)),
Kind::Dns(name) => {
if let Some(ent) = self.cache.get(name.as_str()) {
let val = ent.value();
if !val.is_expired() {
return Ok(val.val);
} else {
ent.remove();
}
}
if let Some(ref dns) = self.dns {
if let Some(ip) = dns
.lookup_ip(name.terminate_str())
.await
.map_err(|e| ResolveError::from(ResolveErrorKind::from(e)))?
.into_iter()
.next()
{
let addr = SocketAddr::new(ip, address.port);
self
.cache
.insert(name.clone(), CachedSocketAddr::new(addr, self.record_ttl));
return Ok(addr);
}
}
let port = address.port;
let tsafe = name.clone();
let res =
agnostic::net::ToSocketAddrs::<R>::to_socket_addrs(&(tsafe.as_str(), port)).await?;
if let Some(addr) = res.into_iter().next() {
self
.cache
.insert(name.clone(), CachedSocketAddr::new(addr, self.record_ttl));
return Ok(addr);
}
Err(Error::Resolve(ResolveError(ResolveErrorKind::NotFound(
name.clone(),
))))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_dns_resolver() {
use agnostic::tokio::TokioRuntime;
let resolver = DnsResolver::<TokioRuntime>::new(Default::default())
.await
.unwrap();
let google_addr = NodeAddress::try_from("google.com:8080").unwrap();
let ip = resolver.resolve(&google_addr).await.unwrap();
println!("google.com:8080 resolved to: {}", ip);
}
#[tokio::test]
async fn test_dns_resolver_with_record_ttl() {
use agnostic::tokio::TokioRuntime;
let resolver = DnsResolver::<TokioRuntime>::new(
DnsResolverOptions::default().with_record_ttl(Duration::from_millis(100)),
)
.await
.unwrap();
let google_addr = NodeAddress::try_from("google.com:8080").unwrap();
resolver.resolve(&google_addr).await.unwrap();
let dns_name = DnsName::try_from("google.com").unwrap();
assert!(!resolver
.cache
.get(dns_name.as_str())
.unwrap()
.value()
.is_expired());
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(resolver
.cache
.get(dns_name.as_str())
.unwrap()
.value()
.is_expired());
}
#[tokio::test]
async fn test_dns_resolver_without_dns() {
use agnostic::tokio::TokioRuntime;
let resolver = DnsResolver::<TokioRuntime>::new(
DnsResolverOptions::default()
.with_dns(None)
.with_record_ttl(Duration::from_millis(100)),
)
.await
.unwrap();
let google_addr = NodeAddress::try_from("google.com:8080").unwrap();
resolver.resolve(&google_addr).await.unwrap();
resolver.resolve(&google_addr).await.unwrap();
let ip_addr = NodeAddress::try_from(("127.0.0.1", 8080)).unwrap();
resolver.resolve(&ip_addr).await.unwrap();
let dns_name = DnsName::try_from("google.com").unwrap();
assert!(!resolver
.cache
.get(dns_name.as_str())
.unwrap()
.value()
.is_expired());
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(resolver
.cache
.get(dns_name.as_str())
.unwrap()
.value()
.is_expired());
resolver.resolve(&google_addr).await.unwrap();
let err = ResolveError::from(ResolveErrorKind::NotFound(dns_name.clone()));
println!("{err}");
println!("{err:?}");
let bad_addr = NodeAddress::try_from("adasdjkljasidjaosdjaisudnaisudibasd.com:8080").unwrap();
assert!(resolver.resolve(&bad_addr).await.is_err());
}
#[test]
fn test_opts() {
let opts = DnsOptions::new();
let opts = opts.with_resolver_config(Default::default());
opts.resolver_config();
let mut opts = opts.with_resolver_opts(Default::default());
opts.resolver_opts();
opts.set_resolver_config(Default::default());
opts.set_resolver_opts(Default::default());
let mut opts = DnsResolverOptions::new().with_dns(Some(opts));
opts.dns();
opts.set_dns(Some(Default::default()));
opts.set_record_ttl(Duration::from_secs(100));
opts.record_ttl();
}
}