use std::collections::HashSet;
use std::sync::Arc;
use bytes::Bytes;
use smoltcp::iface::SocketSet;
use smoltcp::socket::udp;
use smoltcp::storage::PacketMetadata;
use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
use tokio::sync::mpsc;
use crate::config::DnsConfig;
use crate::shared::SharedState;
const DNS_PORT: u16 = 53;
const DNS_MAX_SIZE: usize = 4096;
const DNS_SOCKET_PACKET_SLOTS: usize = 16;
const CHANNEL_CAPACITY: usize = 64;
pub struct DnsInterceptor {
socket_handle: smoltcp::iface::SocketHandle,
query_tx: mpsc::Sender<DnsQuery>,
response_rx: mpsc::Receiver<DnsResponse>,
}
struct NormalizedDnsConfig {
blocked_domains: HashSet<String>,
blocked_suffixes: Vec<String>,
blocked_suffixes_dotted: Vec<String>,
rebind_protection: bool,
}
struct DnsQuery {
data: Bytes,
source: IpEndpoint,
}
struct DnsResponse {
data: Bytes,
dest: IpEndpoint,
}
impl DnsInterceptor {
pub fn new(
sockets: &mut SocketSet<'_>,
dns_config: DnsConfig,
shared: Arc<SharedState>,
tokio_handle: &tokio::runtime::Handle,
) -> Self {
let rx_meta = vec![PacketMetadata::EMPTY; DNS_SOCKET_PACKET_SLOTS];
let rx_payload = vec![0u8; DNS_MAX_SIZE * DNS_SOCKET_PACKET_SLOTS];
let tx_meta = vec![PacketMetadata::EMPTY; DNS_SOCKET_PACKET_SLOTS];
let tx_payload = vec![0u8; DNS_MAX_SIZE * DNS_SOCKET_PACKET_SLOTS];
let mut socket = udp::Socket::new(
udp::PacketBuffer::new(rx_meta, rx_payload),
udp::PacketBuffer::new(tx_meta, tx_payload),
);
socket
.bind(IpListenEndpoint {
addr: None,
port: DNS_PORT,
})
.expect("failed to bind DNS socket to port 53");
let socket_handle = sockets.add(socket);
let (query_tx, query_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (response_tx, response_rx) = mpsc::channel(CHANNEL_CAPACITY);
let suffixes: Vec<String> = dns_config
.blocked_suffixes
.iter()
.map(|s| s.to_lowercase().trim_start_matches('.').to_string())
.collect();
let suffixes_dotted: Vec<String> = suffixes.iter().map(|s| format!(".{s}")).collect();
let normalized = Arc::new(NormalizedDnsConfig {
blocked_domains: dns_config
.blocked_domains
.iter()
.map(|d| d.to_lowercase())
.collect(),
blocked_suffixes: suffixes,
blocked_suffixes_dotted: suffixes_dotted,
rebind_protection: dns_config.rebind_protection,
});
tokio_handle.spawn(dns_resolver_task(query_rx, response_tx, normalized, shared));
Self {
socket_handle,
query_tx,
response_rx,
}
}
pub fn process(&mut self, sockets: &mut SocketSet<'_>) {
let socket = sockets.get_mut::<udp::Socket>(self.socket_handle);
let mut buf = [0u8; DNS_MAX_SIZE];
while socket.can_recv() {
match socket.recv_slice(&mut buf) {
Ok((n, meta)) => {
let query = DnsQuery {
data: Bytes::copy_from_slice(&buf[..n]),
source: meta.endpoint,
};
if self.query_tx.try_send(query).is_err() {
tracing::debug!("DNS query channel full, dropping query");
}
}
Err(_) => break,
}
}
while socket.can_send() {
match self.response_rx.try_recv() {
Ok(response) => {
let _ = socket.send_slice(&response.data, response.dest);
}
Err(_) => break,
}
}
}
}
async fn dns_resolver_task(
mut query_rx: mpsc::Receiver<DnsQuery>,
response_tx: mpsc::Sender<DnsResponse>,
dns_config: Arc<NormalizedDnsConfig>,
shared: Arc<SharedState>,
) {
let resolver = match hickory_resolver::Resolver::builder_tokio().map(|b| b.build()) {
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "failed to create DNS resolver");
return;
}
};
while let Some(query) = query_rx.recv().await {
let response_tx = response_tx.clone();
let dns_config = dns_config.clone();
let shared = shared.clone();
let resolver = resolver.clone();
tokio::spawn(async move {
let result = resolve_query(&query.data, &dns_config, &resolver).await;
match result {
Some(response_data) => {
let response = DnsResponse {
data: response_data,
dest: query.source,
};
if response_tx.send(response).await.is_ok() {
shared.proxy_wake.wake();
}
}
None => {
if let Some(servfail) = build_refused(&query.data) {
let response = DnsResponse {
data: servfail,
dest: query.source,
};
if response_tx.send(response).await.is_ok() {
shared.proxy_wake.wake();
}
}
}
}
});
}
}
async fn resolve_query(
raw_query: &[u8],
dns_config: &NormalizedDnsConfig,
resolver: &hickory_resolver::TokioResolver,
) -> Option<Bytes> {
use hickory_proto::op::Message;
use hickory_proto::rr::RData;
use hickory_proto::serialize::binary::BinDecodable;
let query_msg = Message::from_bytes(raw_query).ok()?;
let query_id = query_msg.id();
let question = query_msg.queries().first()?;
let domain = question.name().to_string();
let domain = domain.trim_end_matches('.');
if is_domain_blocked(domain, dns_config) {
tracing::debug!(domain = %domain, "DNS query blocked");
return None;
}
let record_type = question.query_type();
let lookup = resolver
.lookup(question.name().clone(), record_type)
.await
.ok()?;
if dns_config.rebind_protection {
for record in lookup.records() {
let is_private = match record.data() {
RData::A(a) => is_private_ipv4((*a).into()),
RData::AAAA(aaaa) => is_private_ipv6((*aaaa).into()),
_ => false,
};
if is_private {
tracing::debug!(
domain = %domain,
"DNS rebind protection: response contains private IP"
);
return None;
}
}
}
let mut response_msg = Message::new();
response_msg.set_id(query_id);
response_msg.set_message_type(hickory_proto::op::MessageType::Response);
response_msg.set_op_code(query_msg.op_code());
response_msg.set_response_code(hickory_proto::op::ResponseCode::NoError);
response_msg.set_recursion_desired(query_msg.recursion_desired());
response_msg.set_recursion_available(true);
response_msg.add_query(question.clone());
let answers: Vec<_> = lookup.records().to_vec();
response_msg.insert_answers(answers);
use hickory_proto::serialize::binary::BinEncodable;
let response_bytes = response_msg.to_bytes().ok()?;
Some(Bytes::from(response_bytes))
}
fn is_private_ipv4(addr: std::net::Ipv4Addr) -> bool {
let octets = addr.octets();
addr.is_loopback() || octets[0] == 10 || (octets[0] == 172 && (octets[1] & 0xf0) == 16) || (octets[0] == 192 && octets[1] == 168) || (octets[0] == 100 && (octets[1] & 0xc0) == 64) || (octets[0] == 169 && octets[1] == 254) || addr.is_unspecified() }
fn is_private_ipv6(addr: std::net::Ipv6Addr) -> bool {
let segments = addr.segments();
addr.is_loopback() || (segments[0] & 0xfe00) == 0xfc00 || (segments[0] & 0xffc0) == 0xfe80 || addr.is_unspecified() }
fn is_domain_blocked(domain: &str, config: &NormalizedDnsConfig) -> bool {
let domain_lower = domain.to_lowercase();
if config.blocked_domains.contains(&domain_lower) {
return true;
}
for (suffix, dotted) in config
.blocked_suffixes
.iter()
.zip(config.blocked_suffixes_dotted.iter())
{
if domain_lower == *suffix || domain_lower.ends_with(dotted.as_str()) {
return true;
}
}
false
}
fn build_refused(raw_query: &[u8]) -> Option<Bytes> {
use hickory_proto::op::Message;
use hickory_proto::serialize::binary::{BinDecodable, BinEncodable};
let query_msg = Message::from_bytes(raw_query).ok()?;
let mut response = Message::new();
response.set_id(query_msg.id());
for q in query_msg.queries() {
response.add_query(q.clone());
}
response.set_message_type(hickory_proto::op::MessageType::Response);
response.set_response_code(hickory_proto::op::ResponseCode::Refused);
response.set_recursion_available(true);
let bytes = response.to_bytes().ok()?;
Some(Bytes::from(bytes))
}
#[cfg(test)]
mod tests {
use super::*;
fn normalized(domains: Vec<&str>, suffixes: Vec<&str>) -> NormalizedDnsConfig {
let blocked_suffixes: Vec<String> = suffixes
.iter()
.map(|s| s.to_lowercase().trim_start_matches('.').to_string())
.collect();
let blocked_suffixes_dotted = blocked_suffixes.iter().map(|s| format!(".{s}")).collect();
NormalizedDnsConfig {
blocked_domains: domains
.iter()
.map(|d| d.to_lowercase())
.collect::<HashSet<_>>(),
blocked_suffixes,
blocked_suffixes_dotted,
rebind_protection: false,
}
}
#[test]
fn test_exact_domain_blocked() {
let config = normalized(vec!["evil.com"], vec![]);
assert!(is_domain_blocked("evil.com", &config));
assert!(is_domain_blocked("Evil.COM", &config));
assert!(!is_domain_blocked("not-evil.com", &config));
assert!(!is_domain_blocked("sub.evil.com", &config));
}
#[test]
fn test_suffix_domain_blocked() {
let config = normalized(vec![], vec![".evil.com"]);
assert!(is_domain_blocked("sub.evil.com", &config));
assert!(is_domain_blocked("deep.sub.evil.com", &config));
assert!(is_domain_blocked("evil.com", &config));
assert!(!is_domain_blocked("notevil.com", &config));
}
#[test]
fn test_no_blocks_nothing_blocked() {
let config = normalized(vec![], vec![]);
assert!(!is_domain_blocked("anything.com", &config));
}
}