use std::collections::HashMap;
use std::io;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::sync::Arc;
use std::time::Duration;
use std::time::SystemTime;
use futures::future;
use parking_lot::Mutex;
#[derive(Default)]
pub struct StdDnsResolver {
cache: DnsCache,
}
impl ureq::Resolver for StdDnsResolver {
fn resolve(&self, netloc: &str) -> io::Result<Vec<SocketAddr>> {
if let Some(v) = self.cache.get(netloc) {
return Ok(v);
}
ToSocketAddrs::to_socket_addrs(netloc).map(|iter| {
let res: Vec<_> = iter.collect();
self.cache.insert(netloc, res.clone());
res
})
}
}
pub struct AsyncStdDnsResolver {
cache: DnsCache,
runtime: Option<tokio::runtime::Runtime>,
}
impl Default for AsyncStdDnsResolver {
fn default() -> Self {
let runtime = {
let mut builder = tokio::runtime::Builder::new_current_thread();
builder.enable_all();
builder.build().expect("build dns runtime failed")
};
Self {
cache: DnsCache::default(),
runtime: Some(runtime),
}
}
}
impl Drop for AsyncStdDnsResolver {
fn drop(&mut self) {
let runtime = self.runtime.take().unwrap();
runtime.shutdown_background();
}
}
impl reqwest::dns::Resolve for AsyncStdDnsResolver {
fn resolve(&self, name: hyper::client::connect::dns::Name) -> reqwest::dns::Resolving {
if let Some(v) = self.cache.get(name.as_str()) {
return Box::pin(future::ok(Box::new(v.into_iter()) as reqwest::dns::Addrs));
}
debug_assert!(self.runtime.is_some(), "runtime must be valid");
let runtime = self.runtime.as_ref().unwrap().handle().clone();
let cache = self.cache.clone();
let fut = async move {
match runtime
.spawn_blocking(move || {
(name.as_str(), 0).to_socket_addrs().map(|iter| {
let res: Vec<_> = iter.collect();
cache.insert(name.as_str(), res.clone());
res
})
})
.await
{
Ok(v) => v.map(|v| Box::new(v.into_iter()) as reqwest::dns::Addrs),
Err(err) => Err(io::Error::new(
io::ErrorKind::Other,
format!("spawn dns resolving task failed: {err:?}"),
)),
}
.map_err(|err| Box::new(err) as Box<_>)
};
Box::pin(fut)
}
}
#[cfg(feature = "trust-dns")]
pub struct AsyncTrustDnsResolver {
inner: Arc<trust_dns_resolver::TokioAsyncResolver>,
}
#[cfg(feature = "trust-dns")]
impl AsyncTrustDnsResolver {
pub fn new() -> io::Result<Self> {
let resolver = trust_dns_resolver::TokioAsyncResolver::from_system_conf(
trust_dns_resolver::TokioHandle,
)?;
Ok(Self {
inner: Arc::new(resolver),
})
}
}
#[cfg(feature = "trust-dns")]
impl reqwest::dns::Resolve for AsyncTrustDnsResolver {
fn resolve(&self, name: hyper::client::connect::dns::Name) -> reqwest::dns::Resolving {
let resolver = self.inner.clone();
let fut = async move {
let lookup = resolver.lookup_ip(name.as_str()).await?;
Ok(Box::new(lookup.into_iter().map(|v| SocketAddr::new(v, 0))) as reqwest::dns::Addrs)
};
Box::pin(fut)
}
}
#[derive(Clone)]
struct DnsCacheEntry {
value: Vec<SocketAddr>,
expires_in: SystemTime,
}
#[derive(Clone)]
struct DnsCache {
inner: Arc<Mutex<HashMap<String, DnsCacheEntry>>>,
limits: usize,
default_expire: Duration,
}
impl Default for DnsCache {
fn default() -> Self {
DnsCache {
inner: Arc::default(),
limits: 32,
default_expire: Duration::from_secs(3600),
}
}
}
impl DnsCache {
fn get(&self, domain: &str) -> Option<Vec<SocketAddr>> {
let mut guard = self.inner.lock();
match guard.get(domain) {
None => None,
Some(entry) => {
let now = SystemTime::now();
if entry.expires_in >= now {
Some(entry.value.clone())
} else {
guard.remove(domain);
None
}
}
}
}
fn insert(&self, domain: &str, value: Vec<SocketAddr>) {
let mut guard = self.inner.lock();
if guard.len() >= self.limits {
guard.clear()
}
guard.insert(
domain.to_string(),
DnsCacheEntry {
value,
expires_in: SystemTime::now() + self.default_expire,
},
);
}
}