use std::collections::HashSet;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use futures::StreamExt;
use hickory_client::client::Client;
use hickory_client::proto::op::{Message, MessageType, ResponseCode};
use hickory_client::proto::rr::rdata::{A, AAAA};
use hickory_client::proto::rr::{RData, Record, RecordType};
use hickory_client::proto::serialize::binary::{BinDecodable, BinEncodable};
use hickory_client::proto::xfer::{DnsHandle, DnsRequest};
use tokio::sync::{OnceCell, watch};
use super::client::{build_direct_client, build_tcp_client, build_udp_client};
use super::common::config::NormalizedDnsConfig;
use super::common::filter::{is_domain_blocked, is_private_ipv4, is_private_ipv6};
use super::common::transport::Transport;
use super::nameserver::{read_host_dns_servers, resolve_nameservers};
use crate::policy::NetworkPolicy;
use crate::shared::{ResolvedHostnameFamily, SharedState};
use crate::stack::GatewayIps;
const RESOLVED_HOSTNAME_MIN_TTL_SECS: u32 = 1;
const HOST_ALIAS_TTL_SECS: u32 = 60;
pub(crate) type DnsForwarderHandle = watch::Receiver<Option<Arc<DnsForwarder>>>;
pub(crate) struct DnsForwarder {
configured_udp: Client,
configured_tcp: OnceCell<Client>,
configured_upstream: SocketAddr,
gateway_ips: Arc<HashSet<IpAddr>>,
network_policy: Arc<NetworkPolicy>,
shared: Arc<SharedState>,
gateway: GatewayIps,
config: Arc<NormalizedDnsConfig>,
}
enum UpstreamChoice {
Client(Client),
Refused,
ServFail,
}
#[derive(Debug, PartialEq, Eq)]
enum UpstreamDecision {
Configured,
Direct(SocketAddr),
Refused,
}
impl DnsForwarder {
pub(crate) async fn forward(
&self,
raw_query: &[u8],
original_dst: Option<IpAddr>,
transport: Transport,
sni: Option<&str>,
) -> Option<Bytes> {
let query_msg = Message::from_bytes(raw_query).ok()?;
let guest_id = query_msg.id();
let question = query_msg.queries().first()?;
let query_type = question.query_type();
let domain = question.name().to_string();
let domain = domain.trim_end_matches('.').to_owned();
if is_domain_blocked(&domain, &self.config) {
tracing::debug!(domain = %domain, "DNS query blocked by domain policy");
return build_status_response(&query_msg, ResponseCode::Refused);
}
if is_host_alias_query(&domain)
&& let Some(response) =
synthesize_host_alias_response(&query_msg, self.gateway, query_type)
{
return Some(response);
}
let client = match self.select_upstream(original_dst, transport, sni).await {
UpstreamChoice::Client(c) => c,
UpstreamChoice::Refused => {
tracing::debug!(
domain = %domain,
?original_dst,
"DNS resolver denied by network policy"
);
return build_status_response(&query_msg, ResponseCode::Refused);
}
UpstreamChoice::ServFail => {
return build_status_response(&query_msg, ResponseCode::ServFail);
}
};
let mut send = client.send(DnsRequest::from(query_msg.clone()));
let response = match send.next().await {
Some(Ok(resp)) => resp,
Some(Err(e)) => {
tracing::warn!(domain = %domain, error = %e, "upstream DNS send failed");
return build_status_response(&query_msg, ResponseCode::ServFail);
}
None => {
tracing::warn!(domain = %domain, "upstream DNS closed stream without a response");
return build_status_response(&query_msg, ResponseCode::ServFail);
}
};
let mut response_msg: Message = response.into();
if self.config.rebind_protection {
for record in response_msg.answers() {
let is_private = match record.data() {
RData::A(a) => is_private_ipv4((*a).into()),
RData::AAAA(aaaa) => is_private_ipv6((*aaaa).into()),
_ => false,
};
if is_private {
tracing::debug!(
domain = %domain,
"DNS rebind protection: response contains private IP"
);
return build_status_response(&query_msg, ResponseCode::Refused);
}
}
}
if let Some(family) = family_for_query_type(query_type) {
if let Some((addrs, ttl)) = extract_addrs_and_ttl(&response_msg, family) {
self.shared
.cache_resolved_hostname(&domain, family, addrs, ttl);
} else {
self.shared.clear_resolved_hostname(&domain, family);
}
}
response_msg.set_id(guest_id);
let response_bytes = response_msg.to_bytes().ok()?;
if transport == Transport::Udp {
let max_size = query_msg.max_payload() as usize;
if response_bytes.len() > max_size {
tracing::debug!(
domain = %domain,
response_size = response_bytes.len(),
advertised = max_size,
"DNS response exceeds guest UDP buffer; setting TC=1"
);
return build_truncated_response(&query_msg).map(Bytes::from);
}
}
Some(Bytes::from(response_bytes))
}
async fn select_upstream(
&self,
original_dst: Option<IpAddr>,
transport: Transport,
sni: Option<&str>,
) -> UpstreamChoice {
match decide_upstream(
&self.gateway_ips,
&self.network_policy,
&self.shared,
original_dst,
transport,
) {
UpstreamDecision::Configured => self.configured_client(transport).await,
UpstreamDecision::Refused => UpstreamChoice::Refused,
UpstreamDecision::Direct(addr) => {
match build_direct_client(addr, transport, sni, self.config.query_timeout).await {
Some(client) => UpstreamChoice::Client(client),
None => UpstreamChoice::ServFail,
}
}
}
}
async fn configured_client(&self, transport: Transport) -> UpstreamChoice {
match transport {
Transport::Udp => UpstreamChoice::Client(self.configured_udp.clone()),
Transport::Tcp | Transport::Dot => {
let timeout = self.config.query_timeout;
let upstream = self.configured_upstream;
let result = self
.configured_tcp
.get_or_try_init(|| async move {
build_tcp_client(upstream, timeout).await.ok_or(())
})
.await;
match result {
Ok(c) => UpstreamChoice::Client(c.clone()),
Err(()) => UpstreamChoice::ServFail,
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub(super) fn spawn(
handle: &tokio::runtime::Handle,
config: Arc<NormalizedDnsConfig>,
gateway_ips: Arc<HashSet<IpAddr>>,
network_policy: Arc<NetworkPolicy>,
shared: Arc<SharedState>,
gateway: GatewayIps,
) -> DnsForwarderHandle {
let (forwarder_tx, forwarder_rx) = watch::channel(None);
handle.spawn(async move {
let Some(forwarder) =
Self::build(config, gateway_ips, network_policy, shared, gateway).await
else {
return;
};
let _ = forwarder_tx.send(Some(forwarder));
});
forwarder_rx
}
async fn build(
config: Arc<NormalizedDnsConfig>,
gateway_ips: Arc<HashSet<IpAddr>>,
network_policy: Arc<NetworkPolicy>,
shared: Arc<SharedState>,
gateway: GatewayIps,
) -> Option<Arc<Self>> {
let upstreams = if !config.nameservers.is_empty() {
match resolve_nameservers(&config.nameservers).await {
Ok(s) if !s.is_empty() => s,
Ok(_) => {
tracing::error!("no configured nameservers resolved to an address");
return None;
}
Err(e) => {
tracing::error!(error = %e, "failed to resolve configured nameservers");
return None;
}
}
} else {
match read_host_dns_servers().await {
Ok(s) if !s.is_empty() => s,
Ok(_) => {
tracing::error!("no upstream DNS servers discovered from host");
return None;
}
Err(e) => {
tracing::error!(error = %e, "failed to read host DNS configuration");
return None;
}
}
};
let upstream = upstreams[0];
let configured_udp = build_udp_client(upstream, config.query_timeout).await?;
Some(Arc::new(Self {
configured_udp,
configured_tcp: OnceCell::new(),
configured_upstream: upstream,
gateway_ips,
network_policy,
shared,
gateway,
config,
}))
}
pub(crate) async fn wait(mut handle: DnsForwarderHandle) -> Option<Arc<Self>> {
if let Some(f) = handle.borrow().clone() {
return Some(f);
}
handle.changed().await.ok()?;
handle.borrow().clone()
}
}
fn decide_upstream(
gateway_ips: &HashSet<IpAddr>,
policy: &NetworkPolicy,
shared: &SharedState,
original_dst: Option<IpAddr>,
transport: Transport,
) -> UpstreamDecision {
let Some(dst) = original_dst else {
return UpstreamDecision::Configured;
};
if gateway_ips.contains(&dst) {
return UpstreamDecision::Configured;
}
let policy_dst = SocketAddr::new(dst, transport.upstream_port());
if policy
.evaluate_egress(policy_dst, transport.policy_protocol(), shared)
.is_deny()
{
return UpstreamDecision::Refused;
}
UpstreamDecision::Direct(policy_dst)
}
fn build_status_response(query: &Message, rcode: ResponseCode) -> Option<Bytes> {
let mut response = Message::new();
response.set_id(query.id());
response.set_op_code(query.op_code());
response.set_recursion_desired(query.recursion_desired());
response.set_message_type(MessageType::Response);
response.set_response_code(rcode);
response.set_recursion_available(true);
if let Some(q) = query.queries().first() {
response.add_query(q.clone());
}
response.to_bytes().ok().map(Bytes::from)
}
fn family_for_query_type(query_type: RecordType) -> Option<ResolvedHostnameFamily> {
match query_type {
RecordType::A => Some(ResolvedHostnameFamily::Ipv4),
RecordType::AAAA => Some(ResolvedHostnameFamily::Ipv6),
_ => None,
}
}
fn extract_addrs_and_ttl(
response: &Message,
family: ResolvedHostnameFamily,
) -> Option<(Vec<IpAddr>, Duration)> {
let mut addrs = Vec::new();
let mut ttl: Option<Duration> = None;
for record in response.answers() {
let addr = match (family, record.data()) {
(ResolvedHostnameFamily::Ipv4, RData::A(a)) => IpAddr::V4((*a).into()),
(ResolvedHostnameFamily::Ipv6, RData::AAAA(aaaa)) => IpAddr::V6((*aaaa).into()),
_ => continue,
};
addrs.push(addr);
let record_ttl =
Duration::from_secs(u64::from(record.ttl().max(RESOLVED_HOSTNAME_MIN_TTL_SECS)));
ttl = Some(ttl.map_or(record_ttl, |current| current.min(record_ttl)));
}
ttl.map(|ttl| (addrs, ttl))
}
fn is_host_alias_query(query_name: &str) -> bool {
query_name
.trim_end_matches('.')
.eq_ignore_ascii_case(crate::HOST_ALIAS)
}
fn synthesize_host_alias_response(
query: &Message,
gateway: GatewayIps,
qtype: RecordType,
) -> Option<Bytes> {
let question = query.queries().first()?;
let name = question.name().clone();
let rdata = match qtype {
RecordType::A => RData::A(A::from(gateway.ipv4)),
RecordType::AAAA => RData::AAAA(AAAA::from(gateway.ipv6)),
_ => return None,
};
let mut response = Message::new();
response.set_id(query.id());
response.set_op_code(query.op_code());
response.set_recursion_desired(query.recursion_desired());
response.set_message_type(MessageType::Response);
response.set_response_code(ResponseCode::NoError);
response.set_recursion_available(true);
response.set_authoritative(true);
response.add_query(question.clone());
response.add_answer(Record::from_rdata(name, HOST_ALIAS_TTL_SECS, rdata));
response.to_bytes().ok().map(Bytes::from)
}
fn build_truncated_response(query: &Message) -> Option<Vec<u8>> {
let mut response = Message::new();
response.set_id(query.id());
response.set_op_code(query.op_code());
response.set_recursion_desired(query.recursion_desired());
response.set_message_type(MessageType::Response);
response.set_response_code(ResponseCode::NoError);
response.set_recursion_available(true);
response.set_truncated(true);
if let Some(q) = query.queries().first() {
response.add_query(q.clone());
}
response.to_bytes().ok()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::policy::Protocol;
use hickory_client::proto::op::{Edns, MessageType, OpCode, Query};
use hickory_client::proto::rr::{DNSClass, Name, RecordType};
fn make_query(name: &str, qtype: RecordType) -> Message {
let mut msg = Message::new();
msg.set_id(0x4242);
msg.set_message_type(MessageType::Query);
msg.set_op_code(OpCode::Query);
msg.set_recursion_desired(true);
let parsed = Name::from_ascii(name).expect("valid dns name");
let mut q = Query::new();
q.set_name(parsed);
q.set_query_type(qtype);
q.set_query_class(DNSClass::IN);
msg.add_query(q);
msg
}
#[test]
fn build_status_response_preserves_header_and_question() {
let query = make_query("slack.com.", RecordType::AAAA);
let bytes = build_status_response(&query, ResponseCode::Refused).expect("built");
let msg = Message::from_bytes(&bytes).expect("parse response");
assert_eq!(msg.id(), 0x4242);
assert_eq!(msg.response_code(), ResponseCode::Refused);
assert_eq!(msg.message_type(), MessageType::Response);
assert_eq!(msg.op_code(), OpCode::Query);
assert!(msg.recursion_desired());
assert!(msg.recursion_available());
assert_eq!(msg.queries().len(), 1);
assert_eq!(msg.queries()[0].query_type(), RecordType::AAAA);
assert_eq!(msg.answers().len(), 0);
}
#[test]
fn build_status_response_servfail_variant() {
let query = make_query("example.com.", RecordType::A);
let bytes = build_status_response(&query, ResponseCode::ServFail).expect("built");
let msg = Message::from_bytes(&bytes).expect("parse response");
assert_eq!(msg.response_code(), ResponseCode::ServFail);
assert_eq!(msg.answers().len(), 0);
}
#[test]
fn build_truncated_response_sets_tc_and_keeps_question() {
let query = make_query("example.com.", RecordType::TXT);
let bytes = build_truncated_response(&query).expect("built");
let msg = Message::from_bytes(&bytes).expect("parse response");
assert_eq!(msg.id(), 0x4242);
assert_eq!(msg.message_type(), MessageType::Response);
assert_eq!(msg.response_code(), ResponseCode::NoError);
assert!(msg.truncated(), "TC bit should be set");
assert_eq!(msg.queries().len(), 1);
assert_eq!(msg.queries()[0].query_type(), RecordType::TXT);
assert!(msg.answers().is_empty());
}
#[test]
fn edns_opt_round_trips_through_wire() {
let mut query = make_query("example.com.", RecordType::A);
let mut edns = Edns::new();
edns.set_max_payload(4096);
edns.set_dnssec_ok(true);
edns.set_version(0);
*query.extensions_mut() = Some(edns);
let bytes = query.to_bytes().expect("serialize");
let parsed = Message::from_bytes(&bytes).expect("parse");
let opt = parsed.extensions().as_ref().expect("OPT preserved");
assert_eq!(opt.max_payload(), 4096);
assert!(opt.flags().dnssec_ok, "DO bit preserved");
assert_eq!(parsed.max_payload(), 4096);
}
#[test]
fn max_payload_defaults_to_512_without_opt() {
let query = make_query("example.com.", RecordType::A);
assert!(query.extensions().is_none());
assert_eq!(query.max_payload(), 512);
}
fn gateway_set() -> HashSet<IpAddr> {
HashSet::from([
IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
IpAddr::V6(std::net::Ipv6Addr::LOCALHOST),
])
}
#[test]
fn decide_upstream_configured_when_dst_is_gateway_v4() {
let gw = gateway_set();
let shared = SharedState::new(4);
let policy = NetworkPolicy::allow_all();
let dst = Some(IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)));
assert_eq!(
decide_upstream(&gw, &policy, &shared, dst, Transport::Udp),
UpstreamDecision::Configured
);
}
#[test]
fn decide_upstream_configured_when_dst_is_gateway_v6() {
let gw = gateway_set();
let shared = SharedState::new(4);
let policy = NetworkPolicy::allow_all();
let dst = Some(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST));
assert_eq!(
decide_upstream(&gw, &policy, &shared, dst, Transport::Tcp),
UpstreamDecision::Configured
);
}
#[test]
fn decide_upstream_configured_when_dst_unknown() {
let gw = gateway_set();
let shared = SharedState::new(4);
let policy = NetworkPolicy::allow_all();
assert_eq!(
decide_upstream(&gw, &policy, &shared, None, Transport::Udp),
UpstreamDecision::Configured
);
}
#[test]
fn decide_upstream_direct_when_dst_external_and_policy_allows() {
let gw = gateway_set();
let shared = SharedState::new(4);
let policy = NetworkPolicy::allow_all();
let dst = Some(IpAddr::V4(std::net::Ipv4Addr::new(1, 1, 1, 1)));
assert_eq!(
decide_upstream(&gw, &policy, &shared, dst, Transport::Udp),
UpstreamDecision::Direct(SocketAddr::from(([1, 1, 1, 1], 53)))
);
}
#[test]
fn decide_upstream_refused_when_policy_denies_resolver() {
let gw = gateway_set();
let shared = SharedState::new(4);
let policy = NetworkPolicy::public_only();
let dst = Some(IpAddr::V4(std::net::Ipv4Addr::new(192, 168, 1, 53)));
assert_eq!(
decide_upstream(&gw, &policy, &shared, dst, Transport::Udp),
UpstreamDecision::Refused
);
}
#[test]
fn decide_upstream_refused_when_policy_denies_all() {
let gw = gateway_set();
let shared = SharedState::new(4);
let policy = NetworkPolicy::none();
let dst = Some(IpAddr::V4(std::net::Ipv4Addr::new(1, 1, 1, 1)));
assert_eq!(
decide_upstream(&gw, &policy, &shared, dst, Transport::Tcp),
UpstreamDecision::Refused
);
let gw_dst = Some(IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)));
assert_eq!(
decide_upstream(&gw, &policy, &shared, gw_dst, Transport::Tcp),
UpstreamDecision::Configured
);
}
#[test]
fn decide_upstream_uses_correct_transport_protocol() {
use crate::policy::{Action, Destination, Direction, Rule};
let gw = gateway_set();
let shared = SharedState::new(4);
let dst_ip = std::net::Ipv4Addr::new(8, 8, 8, 8);
let policy = NetworkPolicy {
default_egress: Action::Allow,
default_ingress: Action::Allow,
rules: vec![Rule {
direction: Direction::Egress,
destination: Destination::Cidr("8.8.8.8/32".parse().unwrap()),
protocols: vec![Protocol::Tcp],
ports: vec![],
action: Action::Deny,
}],
};
let dst = Some(IpAddr::V4(dst_ip));
assert_eq!(
decide_upstream(&gw, &policy, &shared, dst, Transport::Udp),
UpstreamDecision::Direct(SocketAddr::from(([8, 8, 8, 8], 53)))
);
assert_eq!(
decide_upstream(&gw, &policy, &shared, dst, Transport::Tcp),
UpstreamDecision::Refused
);
}
#[test]
fn decide_upstream_dot_configured_when_dst_is_gateway() {
let gw = gateway_set();
let shared = SharedState::new(4);
let policy = NetworkPolicy::allow_all();
let dst = Some(IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)));
assert_eq!(
decide_upstream(&gw, &policy, &shared, dst, Transport::Dot),
UpstreamDecision::Configured
);
}
#[test]
fn decide_upstream_dot_direct_targets_port_853() {
let gw = gateway_set();
let shared = SharedState::new(4);
let policy = NetworkPolicy::allow_all();
let dst = Some(IpAddr::V4(std::net::Ipv4Addr::new(1, 1, 1, 1)));
assert_eq!(
decide_upstream(&gw, &policy, &shared, dst, Transport::Dot),
UpstreamDecision::Direct(SocketAddr::from(([1, 1, 1, 1], 853))),
);
}
#[test]
fn decide_upstream_dot_refused_when_policy_denies_853() {
use crate::policy::{Action, Destination, Direction, Rule};
let gw = gateway_set();
let shared = SharedState::new(4);
let policy = NetworkPolicy {
default_egress: Action::Allow,
default_ingress: Action::Allow,
rules: vec![Rule {
direction: Direction::Egress,
destination: Destination::Cidr("1.1.1.1/32".parse().unwrap()),
protocols: vec![Protocol::Tcp],
ports: vec![],
action: Action::Deny,
}],
};
let dst = Some(IpAddr::V4(std::net::Ipv4Addr::new(1, 1, 1, 1)));
assert_eq!(
decide_upstream(&gw, &policy, &shared, dst, Transport::Dot),
UpstreamDecision::Refused
);
}
}