use std::net::UdpSocket as StdUdpSocket;
use std::time::Duration;
use hickory_proto::op::{Header, Message, MessageType, OpCode, Query, ResponseCode};
use hickory_proto::rr::{Name, RecordType};
use hickory_proto::serialize::binary::{BinDecodable, BinEncodable};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UdpSocket;
use dnsink::bloom::BloomFilter;
use dnsink::config::{
Config, FeedsConfig, ListenConfig, LoggingConfig, RateLimitConfig, UpstreamConfig,
};
use dnsink::proxy::DnsProxy;
use dnsink::trie::DomainTrie;
fn free_port() -> u16 {
StdUdpSocket::bind("127.0.0.1:0")
.unwrap()
.local_addr()
.unwrap()
.port()
}
fn make_query(domain: &str) -> Vec<u8> {
let mut msg = Message::new();
let mut header = Header::new();
header.set_id(1234);
header.set_op_code(OpCode::Query);
header.set_recursion_desired(true);
msg.set_header(header);
msg.add_query(Query::query(
Name::from_ascii(domain).unwrap(),
RecordType::A,
));
msg.to_bytes().unwrap()
}
fn test_config(port: u16) -> Config {
Config {
listen: ListenConfig {
address: "127.0.0.1".to_string(),
port,
tcp_address: None,
},
upstream: UpstreamConfig {
address: "8.8.8.8".to_string(),
port: 53,
timeout_ms: 5000,
protocol: Default::default(),
doh_url: None,
},
blocklist: None,
feeds: FeedsConfig {
urlhaus: false,
openphish: false,
phishtank_api_key: None,
oisd: false,
refresh_secs: 0,
},
logging: LoggingConfig::default(),
tunneling_detection: Default::default(),
metrics: Default::default(),
ratelimit: Default::default(),
}
}
fn test_blocklist() -> (Option<BloomFilter>, DomainTrie) {
let domains = ["evil.com", "malware.org"];
let mut bloom = BloomFilter::new(domains.len(), 0.01);
let mut trie = DomainTrie::new();
for d in &domains {
bloom.insert(&d.to_string());
trie.insert(d);
}
(Some(bloom), trie)
}
async fn spawn_proxy(port: u16) -> std::sync::Arc<dnsink::proxy::QueryMetrics> {
let config = test_config(port);
let (bloom, trie) = test_blocklist();
let proxy = DnsProxy::new(config, bloom, trie).unwrap();
let metrics = proxy.metrics();
tokio::spawn(async move {
proxy.run().await.unwrap();
});
tokio::time::sleep(Duration::from_millis(100)).await;
metrics
}
#[tokio::test]
async fn udp_blocked_domain_returns_nxdomain() {
let port = free_port();
spawn_proxy(port).await;
let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let query = make_query("evil.com");
client
.send_to(&query, format!("127.0.0.1:{port}"))
.await
.unwrap();
let mut buf = vec![0u8; 4096];
let len = tokio::time::timeout(Duration::from_secs(2), client.recv(&mut buf))
.await
.expect("timeout waiting for response")
.unwrap();
let response = Message::from_bytes(&buf[..len]).unwrap();
assert_eq!(response.response_code(), ResponseCode::NXDomain);
assert_eq!(response.message_type(), MessageType::Response);
assert_eq!(response.id(), 1234);
}
#[tokio::test]
async fn udp_subdomain_of_blocked_domain_returns_nxdomain() {
let port = free_port();
spawn_proxy(port).await;
let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let query = make_query("sub.evil.com");
client
.send_to(&query, format!("127.0.0.1:{port}"))
.await
.unwrap();
let mut buf = vec![0u8; 4096];
let len = tokio::time::timeout(Duration::from_secs(2), client.recv(&mut buf))
.await
.expect("timeout waiting for response")
.unwrap();
let response = Message::from_bytes(&buf[..len]).unwrap();
assert_eq!(response.response_code(), ResponseCode::NXDomain);
}
#[tokio::test]
#[ignore] async fn udp_clean_domain_gets_forwarded() {
let port = free_port();
spawn_proxy(port).await;
let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let query = make_query("google.com");
client
.send_to(&query, format!("127.0.0.1:{port}"))
.await
.unwrap();
let mut buf = vec![0u8; 4096];
let len = tokio::time::timeout(Duration::from_secs(5), client.recv(&mut buf))
.await
.expect("timeout — is network available?")
.unwrap();
let response = Message::from_bytes(&buf[..len]).unwrap();
assert_eq!(response.response_code(), ResponseCode::NoError);
assert_eq!(response.message_type(), MessageType::Response);
assert!(!response.answers().is_empty(), "should have DNS answers");
}
#[tokio::test]
async fn tcp_blocked_domain_returns_nxdomain() {
let port = free_port();
spawn_proxy(port).await;
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{port}"))
.await
.unwrap();
let query = make_query("evil.com");
let len_bytes = (query.len() as u16).to_be_bytes();
stream.write_all(&len_bytes).await.unwrap();
stream.write_all(&query).await.unwrap();
let resp_len = stream.read_u16().await.unwrap() as usize;
let mut resp_buf = vec![0u8; resp_len];
stream.read_exact(&mut resp_buf).await.unwrap();
let response = Message::from_bytes(&resp_buf).unwrap();
assert_eq!(response.response_code(), ResponseCode::NXDomain);
assert_eq!(response.message_type(), MessageType::Response);
assert_eq!(response.id(), 1234);
}
#[tokio::test]
async fn metrics_update_after_queries() {
let port = free_port();
let metrics = spawn_proxy(port).await;
let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let addr = format!("127.0.0.1:{port}");
let mut buf = vec![0u8; 4096];
client
.send_to(&make_query("evil.com"), &addr)
.await
.unwrap();
tokio::time::timeout(Duration::from_secs(2), client.recv(&mut buf))
.await
.unwrap()
.unwrap();
client
.send_to(&make_query("malware.org"), &addr)
.await
.unwrap();
tokio::time::timeout(Duration::from_secs(2), client.recv(&mut buf))
.await
.unwrap()
.unwrap();
let snap = metrics.snapshot();
assert_eq!(snap.total, 2);
assert_eq!(snap.blocked, 2);
assert_eq!(snap.allowed, 0);
let top = metrics.top_blocked(10);
assert_eq!(top.len(), 2);
}
#[tokio::test]
async fn udp_rate_limit_does_not_drop_when_under_capacity() {
use std::sync::atomic::Ordering;
let port = free_port();
let mut config = test_config(port);
config.ratelimit = RateLimitConfig {
enabled: true,
requests_per_minute: 60,
burst: 5,
};
let (bloom, trie) = test_blocklist();
let proxy = DnsProxy::new(config, bloom, trie).unwrap();
let metrics = proxy.metrics();
tokio::spawn(async move { proxy.run().await.unwrap() });
tokio::time::sleep(Duration::from_millis(100)).await;
let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
client
.send_to(&make_query("evil.com"), format!("127.0.0.1:{port}"))
.await
.unwrap();
let mut buf = vec![0u8; 4096];
let len = tokio::time::timeout(Duration::from_secs(2), client.recv(&mut buf))
.await
.expect("timeout — query under burst should always come back")
.unwrap();
let response = Message::from_bytes(&buf[..len]).unwrap();
assert_eq!(response.response_code(), ResponseCode::NXDomain);
assert_eq!(
metrics.ratelimited.load(Ordering::Relaxed),
0,
"single query within burst must not increment the limiter counter",
);
assert_eq!(metrics.blocked.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn tcp_rate_limit_silently_closes_excess_connections() {
use std::sync::atomic::Ordering;
let port = free_port();
let mut config = test_config(port);
config.ratelimit = RateLimitConfig {
enabled: true,
requests_per_minute: 60,
burst: 2,
};
let (bloom, trie) = test_blocklist();
let proxy = DnsProxy::new(config, bloom, trie).unwrap();
let metrics = proxy.metrics();
tokio::spawn(async move { proxy.run().await.unwrap() });
tokio::time::sleep(Duration::from_millis(100)).await;
let target = format!("127.0.0.1:{port}");
let query = make_query("evil.com");
let len_bytes = (query.len() as u16).to_be_bytes();
let mut completed = 0;
let mut closed = 0;
for _ in 0..5 {
let mut stream = tokio::net::TcpStream::connect(&target).await.unwrap();
let _ = stream.write_all(&len_bytes).await;
let _ = stream.write_all(&query).await;
match tokio::time::timeout(Duration::from_millis(500), stream.read_u16()).await {
Ok(Ok(_)) => completed += 1,
Ok(Err(_)) => closed += 1,
Err(_) => closed += 1,
}
}
assert_eq!(
completed, 2,
"expected exactly burst=2 connections to receive responses, got {completed}",
);
assert_eq!(
closed, 3,
"expected 3 connections to be silently closed by the limiter, got {closed}",
);
assert_eq!(
metrics.ratelimited.load(Ordering::Relaxed),
3,
"5 connects − burst 2 = 3 silently dropped",
);
}
#[tokio::test]
async fn udp_rate_limit_silently_drops_excess_and_increments_counter() {
use std::sync::atomic::Ordering;
let port = free_port();
let mut config = test_config(port);
config.ratelimit = RateLimitConfig {
enabled: true,
requests_per_minute: 60,
burst: 2,
};
let (bloom, trie) = test_blocklist();
let proxy = DnsProxy::new(config, bloom, trie).unwrap();
let metrics = proxy.metrics();
tokio::spawn(async move { proxy.run().await.unwrap() });
tokio::time::sleep(Duration::from_millis(100)).await;
let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let target = format!("127.0.0.1:{port}");
let query = make_query("evil.com");
for _ in 0..5 {
client.send_to(&query, &target).await.unwrap();
}
let mut got = 0;
let mut buf = vec![0u8; 4096];
while got < 5 {
match tokio::time::timeout(Duration::from_millis(300), client.recv(&mut buf)).await {
Ok(Ok(_)) => got += 1,
_ => break,
}
}
assert_eq!(got, 2, "expected exactly burst=2 responses, got {got}");
assert_eq!(
metrics.ratelimited.load(Ordering::Relaxed),
3,
"5 sends − burst 2 = 3 silently dropped",
);
assert_eq!(
metrics.blocked.load(Ordering::Relaxed),
2,
"the two allowed-through queries hit the blocklist (evil.com)",
);
}