use std::{
future::Future,
net::{IpAddr, SocketAddr},
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Instant,
};
use tower::{Layer, Service, ServiceExt as _};
use crate::{
resolver::{
pipeline::{
BoxError, DnsRequest, PipelineResponse,
cache_layer::CacheService,
forward::ForwardService,
layers::DecisionStack,
middleware::{ClassifyRejection, ProtectiveConfig, build_protective_service},
},
state::ResolverState,
upstream::{
DEFAULT_FAILOVER_BUDGET, DEFAULT_QUERY_TIMEOUT, LatencyWeightedSelector,
RandomSelector, SharedUpstreamPool, UpstreamConfig, UpstreamPool, UpstreamSelector,
UpstreamTransport,
},
},
storage::{
settings::SelectionStrategy,
upstreams::{Transport, Upstream},
},
telemetry::{QueryEvent, TelemetrySink},
time::Clock,
};
use tokio_util::task::TaskTracker;
pub struct TelemetryLayer {
sink: Arc<TelemetrySink>,
}
impl TelemetryLayer {
pub fn new(sink: Arc<TelemetrySink>) -> Self {
Self { sink }
}
}
impl<S> Layer<S> for TelemetryLayer {
type Service = TelemetryService<S>;
fn layer(&self, inner: S) -> Self::Service {
TelemetryService {
sink: self.sink.clone(),
inner,
}
}
}
#[derive(Clone)]
pub struct TelemetryService<S> {
sink: Arc<TelemetrySink>,
inner: S,
}
impl<S> Service<DnsRequest> for TelemetryService<S>
where
S: Service<DnsRequest, Response = PipelineResponse, Error = BoxError> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = PipelineResponse;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<PipelineResponse, BoxError>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: DnsRequest) -> Self::Future {
let start = Instant::now();
let ts = Clock::now_millis();
let client = req.client();
let qname = req.question().name.clone();
let qtype = req.question().qtype;
let sink = self.sink.clone();
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
Box::pin(async move {
let result = inner.call(req).await;
let latency = start.elapsed();
let (outcome, rcode, upstream) = match &result {
Ok(resp) => (resp.outcome, None, resp.upstream),
Err(e) => {
let (o, rc) = e.rejection_policy();
(o, Some(rc), None)
}
};
let mut ev = QueryEvent::new(client, qname, qtype, outcome)
.with_ts(ts)
.with_latency(latency);
if let Some(rc) = rcode {
ev = ev.with_rcode(rc);
}
if let Some(up) = upstream {
ev = ev.with_upstream(up);
}
sink.record(ev);
result
})
}
}
pub fn build_engine(
state: Arc<ResolverState>,
pool: Arc<SharedUpstreamPool>,
telemetry: Arc<TelemetrySink>,
config: &ProtectiveConfig,
) -> tower::util::BoxCloneService<DnsRequest, PipelineResponse, BoxError> {
let forward = ForwardService::new(pool, state.clone());
let cached = CacheService::new(state.clone(), forward);
let decision = DecisionStack::new(state, cached);
let protected = build_protective_service(config, decision);
TelemetryLayer::new(telemetry)
.layer(protected)
.boxed_clone()
}
pub fn build_internal_service(
state: Arc<ResolverState>,
pool: Arc<SharedUpstreamPool>,
) -> tower::util::BoxCloneService<DnsRequest, PipelineResponse, BoxError> {
let forward = ForwardService::new(pool, state.clone());
let cached = CacheService::new(state.clone(), forward);
DecisionStack::new(state, cached).boxed_clone()
}
pub async fn build_upstream_pool(
configs: &[UpstreamConfig],
tracker: &TaskTracker,
strategy: SelectionStrategy,
parallel_fanout: u32,
) -> UpstreamPool {
let selector: Arc<dyn UpstreamSelector> = match strategy {
SelectionStrategy::LatencyWeighted => Arc::new(LatencyWeightedSelector),
SelectionStrategy::Random | SelectionStrategy::Parallel => Arc::new(RandomSelector),
};
let pool = UpstreamPool::connect(
configs,
tracker,
selector,
DEFAULT_FAILOVER_BUDGET,
DEFAULT_QUERY_TIMEOUT,
)
.await;
match strategy {
SelectionStrategy::Parallel => pool.with_parallel_fanout(parallel_fanout as usize),
SelectionStrategy::Random | SelectionStrategy::LatencyWeighted => pool,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct UnmappableUpstream;
impl std::fmt::Display for UnmappableUpstream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("upstream address is not an IP or IP:port")
}
}
impl std::error::Error for UnmappableUpstream {}
impl TryFrom<&Upstream> for UpstreamConfig {
type Error = UnmappableUpstream;
fn try_from(row: &Upstream) -> Result<Self, Self::Error> {
let transport = match row.transport {
Transport::Udp => UpstreamTransport::Udp,
Transport::Tcp => UpstreamTransport::Tcp,
Transport::Dot => UpstreamTransport::Dot,
Transport::Doh => UpstreamTransport::Doh,
};
let default_port = match transport {
UpstreamTransport::Udp | UpstreamTransport::Tcp => 53u16,
UpstreamTransport::Dot => 853,
UpstreamTransport::Doh => 443,
};
let addr: SocketAddr = if let Ok(sa) = row.address.parse::<SocketAddr>() {
sa
} else if let Ok(ip) = row.address.parse::<IpAddr>() {
SocketAddr::new(ip, default_port)
} else {
return Err(UnmappableUpstream);
};
Ok(UpstreamConfig {
addr,
transport,
tls_server_name: row.tls_server_name.clone(),
http_endpoint: None, })
}
}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use super::*;
use crate::storage::upstreams::{Transport, Upstream};
fn make_row(address: &str, transport: Transport, tls_server_name: Option<&str>) -> Upstream {
Upstream {
id: 1,
address: address.to_owned(),
transport,
tls_server_name: tls_server_name.map(|s| s.to_owned()),
enabled: true,
sort_order: 0,
}
}
#[test]
fn udp_bare_ip_gets_default_port_53() {
let row = make_row("1.1.1.1", Transport::Udp, None);
let cfg = UpstreamConfig::try_from(&row).expect("must map");
assert_eq!(cfg.addr, "1.1.1.1:53".parse::<SocketAddr>().unwrap());
assert_eq!(cfg.transport, UpstreamTransport::Udp);
assert!(cfg.tls_server_name.is_none());
}
#[test]
fn udp_explicit_port_preserved() {
let row = make_row("9.9.9.9:8053", Transport::Udp, None);
let cfg = UpstreamConfig::try_from(&row).expect("must map");
assert_eq!(cfg.addr, "9.9.9.9:8053".parse::<SocketAddr>().unwrap());
}
#[test]
fn dot_bare_ip_gets_default_port_853_with_sni() {
let row = make_row("1.1.1.1", Transport::Dot, Some("cloudflare-dns.com"));
let cfg = UpstreamConfig::try_from(&row).expect("must map");
assert_eq!(
cfg.addr,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 853)
);
assert_eq!(cfg.transport, UpstreamTransport::Dot);
assert_eq!(cfg.tls_server_name.as_deref(), Some("cloudflare-dns.com"));
}
#[test]
fn doh_url_hostname_is_unmappable() {
let row = make_row(
"https://cloudflare-dns.com/dns-query",
Transport::Doh,
Some("cloudflare-dns.com"),
);
assert!(
matches!(UpstreamConfig::try_from(&row), Err(UnmappableUpstream)),
"DoH URL / hostname must not map to an UpstreamConfig in v0.1"
);
}
#[test]
fn plain_hostname_is_unmappable() {
let row = make_row("dns.quad9.net", Transport::Udp, None);
assert!(matches!(
UpstreamConfig::try_from(&row),
Err(UnmappableUpstream)
));
}
}