mod codec;
mod doh;
mod doq;
mod dot;
mod framing;
use crate::Client;
use async_lock::OnceCell;
pub(crate) use codec::Resolved;
use codec::{build_query, parse_response};
use dashmap::DashMap;
use doh::Doh;
use doq::Doq;
use dot::Dot;
use futures_lite::future;
use hickory_proto::rr::RecordType;
use std::{
future::Future,
io::{self, ErrorKind},
sync::Arc,
time::{Duration, Instant},
};
use trillium_http::Version;
use trillium_server_common::{Connector, url::Url};
const BOOTSTRAP_TTL: Duration = Duration::from_secs(300);
const MIN_TTL: Duration = Duration::from_secs(1);
const MAX_TTL: Duration = Duration::from_secs(3600);
const DEFAULT_DNS_TIMEOUT: Duration = Duration::from_secs(5);
fn dns_timeout(request_timeout: Option<Duration>) -> Duration {
request_timeout.map_or(DEFAULT_DNS_TIMEOUT, |timeout| timeout / 2)
}
#[derive(Debug, Clone, Default)]
pub(crate) struct DnsCache {
entries: Arc<DashMap<Box<str>, CacheEntry>>,
in_flight: Arc<DashMap<Box<str>, Arc<OnceCell<Resolved>>>>,
}
#[derive(Debug, Clone)]
struct CacheEntry {
resolved: Resolved,
expiry: Instant,
}
impl DnsCache {
pub(crate) fn get(&self, host: &str) -> Option<Resolved> {
let expired = {
let entry = self.entries.get(host)?;
if entry.expiry >= Instant::now() {
return Some(entry.resolved.clone());
}
true
};
if expired {
self.entries.remove(host);
}
None
}
pub(crate) fn insert(&self, host: &str, resolved: Resolved, ttl: Duration) {
let expiry = Instant::now() + ttl.clamp(MIN_TTL, MAX_TTL);
self.entries
.insert(host.into(), CacheEntry { resolved, expiry });
}
pub(crate) async fn resolve_coalesced(
&self,
host: &str,
query: impl Future<Output = io::Result<(Resolved, Duration)>>,
) -> io::Result<Resolved> {
if let Some(hit) = self.get(host) {
return Ok(hit);
}
let cell = self
.in_flight
.entry(host.into())
.or_insert_with(|| Arc::new(OnceCell::new()))
.clone();
let resolved = cell
.get_or_try_init(|| async {
let (resolved, ttl) = query.await?;
self.insert(host, resolved.clone(), ttl);
Ok::<_, io::Error>(resolved)
})
.await
.cloned();
self.in_flight.remove(host);
resolved
}
}
#[derive(Debug, Clone)]
pub(crate) struct Resolver {
cache: DnsCache,
transport: DnsTransport,
}
#[derive(Debug, Clone)]
enum DnsTransport {
Doh(Doh),
Dot(Dot),
Doq(Doq),
}
impl Resolver {
pub(crate) fn doh(resolver: Url) -> Self {
Self {
cache: DnsCache::default(),
transport: DnsTransport::Doh(Doh::new(resolver, None)),
}
}
pub(crate) fn doh3(resolver: Url) -> Self {
Self {
cache: DnsCache::default(),
transport: DnsTransport::Doh(Doh::new(resolver, Some(Version::Http3))),
}
}
pub(crate) fn dot(resolver: Url) -> Self {
Self {
cache: DnsCache::default(),
transport: DnsTransport::Dot(Dot::new(resolver)),
}
}
pub(crate) fn doq(resolver: Url) -> Self {
Self {
cache: DnsCache::default(),
transport: DnsTransport::Doq(Doq::new(resolver)),
}
}
pub(crate) async fn resolve(
&self,
client: &Client,
host: &str,
port: u16,
request_timeout: Option<Duration>,
) -> io::Result<Resolved> {
let kind = self.transport.kind();
let endpoint = self.transport.resolver_endpoint();
let timeout = dns_timeout(request_timeout);
log::debug!("resolving {host}:{port} via {kind} ({endpoint})");
let resolved = client
.connector()
.runtime()
.timeout(
timeout,
self.cache
.resolve_coalesced(host, Box::pin(self.query_host(client, host, port))),
)
.await
.unwrap_or_else(|| {
Err(io::Error::new(
ErrorKind::TimedOut,
format!(
"{kind} resolution of {host} via {endpoint} timed out after {timeout:?}; \
the resolver may be unreachable or may not speak {kind}"
),
))
});
match &resolved {
Ok(r) => log::debug!(
"resolved {host} to {} address(es), {} service binding(s)",
r.addrs.len(),
r.services.len()
),
Err(e) => log::debug!("resolution of {host} failed: {e}"),
}
resolved
}
async fn query_host(
&self,
client: &Client,
host: &str,
port: u16,
) -> io::Result<(Resolved, Duration)> {
if self.transport.resolver_host() == Some(host) {
let addrs = client
.connector()
.resolve(host, port)
.await?
.into_iter()
.map(|addr| addr.ip())
.collect();
return Ok((
Resolved {
addrs,
services: Vec::new(),
},
BOOTSTRAP_TTL,
));
}
let (a, (aaaa, https)) = future::try_zip(
self.query(client, build_query(host, port, RecordType::A)?),
future::try_zip(
self.query(client, build_query(host, port, RecordType::AAAA)?),
self.query(client, build_query(host, port, RecordType::HTTPS)?),
),
)
.await?;
let mut resolved = Resolved::default();
let mut min_ttl = MAX_TTL;
for (part, ttl) in [a, aaaa, https] {
resolved.merge(part);
min_ttl = min_ttl.min(ttl);
}
resolved.services.sort_by_key(|s| s.priority);
if !resolved.has_addrs() {
return Err(io::Error::new(
ErrorKind::NotFound,
format!("DNS resolver returned no addresses for {host}"),
));
}
Ok((resolved, min_ttl))
}
async fn query(&self, client: &Client, query: Vec<u8>) -> io::Result<(Resolved, Duration)> {
let bytes = self.transport.exchange(client, query).await?;
parse_response(&bytes)
}
}
impl DnsTransport {
fn kind(&self) -> &'static str {
match self {
DnsTransport::Doh(_) => "DoH",
DnsTransport::Dot(_) => "DoT",
DnsTransport::Doq(_) => "DoQ",
}
}
fn resolver_endpoint(&self) -> &Url {
match self {
DnsTransport::Doh(doh) => doh.resolver(),
DnsTransport::Dot(dot) => dot.resolver(),
DnsTransport::Doq(doq) => doq.resolver(),
}
}
fn resolver_host(&self) -> Option<&str> {
match self {
DnsTransport::Doh(doh) => doh.host(),
DnsTransport::Dot(dot) => dot.host(),
DnsTransport::Doq(doq) => doq.host(),
}
}
async fn exchange(&self, client: &Client, query: Vec<u8>) -> io::Result<Vec<u8>> {
match self {
DnsTransport::Doh(doh) => doh.exchange(client, query).await,
DnsTransport::Dot(dot) => dot.exchange(client, query).await,
DnsTransport::Doq(doq) => doq.exchange(client, query).await,
}
}
}
impl Client {
fn doh_resolver_url(resolver: &str) -> Url {
let mut url = if resolver.contains("://") {
Url::parse(resolver)
} else {
Url::parse(&format!("https://{resolver}"))
}
.expect("DoH resolver must be a valid URL or host");
if matches!(url.path(), "" | "/") {
url.set_path("/dns-query");
}
url
}
fn set_resolver(&mut self, resolver: Resolver) {
if self.resolver.is_some() {
log::warn!(
"replacing an already-configured DNS resolver; encrypted-DNS resolvers are \
mutually exclusive"
);
}
self.resolver = Some(resolver);
}
#[must_use]
pub fn with_doh(mut self, resolver: impl AsRef<str>) -> Self {
let url = Self::doh_resolver_url(resolver.as_ref());
self.set_resolver(Resolver::doh(url));
self
}
#[must_use]
pub fn with_doh3(mut self, resolver: impl AsRef<str>) -> Self {
assert!(
self.h3().is_some(),
"with_doh3 requires an HTTP/3-capable client; build it with Client::new_with_quic"
);
let url = Self::doh_resolver_url(resolver.as_ref());
self.set_resolver(Resolver::doh3(url));
self
}
#[must_use]
pub fn with_dot(mut self, resolver: impl AsRef<str>) -> Self {
let resolver = resolver.as_ref();
let url = if resolver.contains("://") {
Url::parse(resolver)
} else {
Url::parse(&format!("https://{resolver}:853"))
}
.expect("with_dot requires a valid resolver host or URL");
self.set_resolver(Resolver::dot(url));
self
}
#[must_use]
pub fn with_doq(mut self, resolver: impl AsRef<str>) -> Self {
assert!(
self.h3().is_some(),
"with_doq requires an HTTP/3-capable client; build it with Client::new_with_quic"
);
let resolver = resolver.as_ref();
let url = if resolver.contains("://") {
Url::parse(resolver)
} else {
Url::parse(&format!("https://{resolver}:853"))
}
.expect("with_doq requires a valid resolver host or URL");
self.set_resolver(Resolver::doq(url));
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
#[test]
fn cache_round_trips_and_expires() {
let cache = DnsCache::default();
let resolved = Resolved {
addrs: vec![
IpAddr::V4(Ipv4Addr::new(192, 0, 2, 9)),
IpAddr::V6(Ipv6Addr::LOCALHOST),
],
services: Vec::new(),
};
cache.insert("example.com", resolved.clone(), Duration::from_secs(300));
assert_eq!(cache.get("example.com").unwrap().addrs.len(), 2);
assert!(cache.get("absent.example").is_none());
cache.insert("floor.example", resolved, Duration::ZERO);
assert!(cache.get("floor.example").is_some());
}
}