use std::net::{SocketAddr, UdpSocket};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use cellos_core::{CloudEventV1, DnsQueryType};
use cellos_supervisor::dns_proxy::{run_one_shot, DnsProxyConfig, DnsQueryEmitter};
#[derive(Default)]
struct CollectingEmitter {
events: Mutex<Vec<CloudEventV1>>,
}
impl DnsQueryEmitter for CollectingEmitter {
fn emit(&self, event: CloudEventV1) {
self.events.lock().unwrap().push(event);
}
}
fn build_a_query(qname: &str) -> Vec<u8> {
let mut p = Vec::new();
p.extend_from_slice(&[
0x42, 0x42, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
]);
for label in qname.split('.') {
p.push(label.len() as u8);
p.extend_from_slice(label.as_bytes());
}
p.push(0);
p.extend_from_slice(&[0x00, 0x01, 0x00, 0x01]); p
}
fn synth_a_response(query: &[u8], answers: u16) -> Vec<u8> {
let mut r = query.to_vec();
r[2] = 0x81;
r[3] = 0x80;
r[6] = (answers >> 8) as u8;
r[7] = (answers & 0xff) as u8;
for _ in 0..answers {
r.extend_from_slice(&[0xc0, 0x0c]);
r.extend_from_slice(&[0x00, 0x01, 0x00, 0x01]);
r.extend_from_slice(&[0x00, 0x00, 0x01, 0x2c]);
r.extend_from_slice(&[0x00, 0x04]);
r.extend_from_slice(&[203, 0, 113, 1]);
}
r
}
fn spawn_synthetic_upstream() -> SocketAddr {
let sock = UdpSocket::bind("127.0.0.1:0").expect("bind upstream");
let addr = sock.local_addr().unwrap();
sock.set_read_timeout(Some(Duration::from_secs(3))).unwrap();
std::thread::spawn(move || {
let mut buf = [0u8; 1500];
while let Ok((n, peer)) = sock.recv_from(&mut buf) {
let resp = synth_a_response(&buf[..n], 1);
let _ = sock.send_to(&resp, peer);
}
});
addr
}
#[test]
fn end_to_end_allow_and_deny() {
let upstream = spawn_synthetic_upstream();
let listener = UdpSocket::bind("127.0.0.1:0").expect("bind listener");
listener
.set_read_timeout(Some(Duration::from_millis(150)))
.unwrap();
let listen_addr = listener.local_addr().unwrap();
let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
let cfg = DnsProxyConfig {
bind_addr: listen_addr,
upstream_addr: upstream,
hostname_allowlist: vec!["api.example.com".into(), "*.cdn.example.com".into()],
allowed_query_types: vec![DnsQueryType::A, DnsQueryType::AAAA],
cell_id: "it-cell-001".into(),
run_id: "it-run-001".into(),
policy_digest: Some(
"sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".into(),
),
keyset_id: Some("it-keyset".into()),
issuer_kid: Some("it-kid-0001".into()),
correlation_id: Some("it-corr-001".into()),
upstream_resolver_id: "resolver-it-001".into(),
upstream_timeout: Duration::from_millis(400),
tcp_idle_timeout: Duration::ZERO,
dnssec_validator: None,
transport: cellos_supervisor::dns_proxy::upstream::UpstreamTransport::Do53Udp,
upstream_extras: cellos_supervisor::dns_proxy::upstream::UpstreamExtras::default(),
};
let emitter = Arc::new(CollectingEmitter::default());
let shutdown = Arc::new(AtomicBool::new(false));
let proxy_handle = {
let emitter = emitter.clone();
let shutdown = shutdown.clone();
let cfg = cfg.clone();
std::thread::spawn(move || {
let _ = run_one_shot(&cfg, &listener, &upstream_sock, &*emitter, &shutdown);
})
};
let client = UdpSocket::bind("127.0.0.1:0").unwrap();
client
.set_read_timeout(Some(Duration::from_secs(2)))
.unwrap();
let q1 = build_a_query("api.example.com");
client.send_to(&q1, listen_addr).unwrap();
let mut buf = [0u8; 1500];
let (_, _) = client.recv_from(&mut buf).unwrap();
assert_eq!(buf[3] & 0x0f, 0, "allow path should return NOERROR");
let q2 = build_a_query("img.cdn.example.com");
client.send_to(&q2, listen_addr).unwrap();
let (_, _) = client.recv_from(&mut buf).unwrap();
assert_eq!(buf[3] & 0x0f, 0);
let q3 = build_a_query("cdn.example.com");
client.send_to(&q3, listen_addr).unwrap();
let (_, _) = client.recv_from(&mut buf).unwrap();
assert_eq!(buf[3] & 0x0f, 5, "wildcard should not match parent domain");
let q4 = build_a_query("evil.example.org");
client.send_to(&q4, listen_addr).unwrap();
let (_, _) = client.recv_from(&mut buf).unwrap();
assert_eq!(buf[3] & 0x0f, 5);
shutdown.store(true, Ordering::SeqCst);
proxy_handle.join().unwrap();
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 8, "two events per query observed");
let aggregates: Vec<&CloudEventV1> = evs
.iter()
.filter(|e| e.ty == "dev.cellos.events.cell.observability.v1.dns_query")
.collect();
assert_eq!(aggregates.len(), 4, "one aggregate per query");
let decisions: Vec<&str> = aggregates
.iter()
.map(|e| e.data.as_ref().unwrap()["decision"].as_str().unwrap())
.collect();
assert_eq!(decisions, vec!["allow", "allow", "deny", "deny"]);
let reasons: Vec<&str> = aggregates
.iter()
.map(|e| e.data.as_ref().unwrap()["reasonCode"].as_str().unwrap())
.collect();
assert_eq!(
reasons,
vec![
"allowed_by_allowlist",
"allowed_by_allowlist",
"denied_not_in_allowlist",
"denied_not_in_allowlist"
]
);
for ev in &aggregates {
let d = ev.data.as_ref().unwrap();
assert_eq!(d["cellId"], "it-cell-001");
assert_eq!(d["runId"], "it-run-001");
assert_eq!(d["keysetId"], "it-keyset");
assert_eq!(d["issuerKid"], "it-kid-0001");
assert_eq!(d["correlationId"], "it-corr-001");
assert_eq!(d["schemaVersion"], "1.0.0");
let decision = d["decision"].as_str().unwrap();
if decision == "allow" {
assert_eq!(d["upstreamResolverId"], "resolver-it-001");
} else {
assert!(d.get("upstreamResolverId").is_none());
}
}
let permitted: Vec<&CloudEventV1> = evs
.iter()
.filter(|e| e.ty == "dev.cellos.events.cell.dns.v1.query_permitted")
.collect();
assert_eq!(permitted.len(), 2, "two allow-path permits");
for ev in &permitted {
let d = ev.data.as_ref().unwrap();
assert_eq!(d["cellId"], "it-cell-001");
assert_eq!(d["resolver"], "resolver-it-001");
assert_eq!(d["schemaVersion"], "1.0.0");
assert_eq!(d["queryType"], "A");
}
let refused: Vec<&CloudEventV1> = evs
.iter()
.filter(|e| e.ty == "dev.cellos.events.cell.dns.v1.query_refused")
.collect();
assert_eq!(refused.len(), 2, "two deny-path refusals");
for ev in &refused {
let d = ev.data.as_ref().unwrap();
assert_eq!(d["cellId"], "it-cell-001");
assert_eq!(d["reason"], "denied_not_in_allowlist");
assert_eq!(d["schemaVersion"], "1.0.0");
}
}
#[test]
fn rejects_multi_question_packet_without_forwarding() {
use std::sync::atomic::AtomicU32;
let upstream_sock = UdpSocket::bind("127.0.0.1:0").expect("bind upstream");
let upstream_addr = upstream_sock.local_addr().unwrap();
upstream_sock
.set_read_timeout(Some(Duration::from_millis(300)))
.unwrap();
let upstream_seen = Arc::new(AtomicU32::new(0));
let upstream_thread = {
let seen = upstream_seen.clone();
std::thread::spawn(move || {
let mut buf = [0u8; 1500];
let deadline = std::time::Instant::now() + Duration::from_millis(800);
while std::time::Instant::now() < deadline {
match upstream_sock.recv_from(&mut buf) {
Ok((n, peer)) => {
seen.fetch_add(1, Ordering::SeqCst);
let resp = synth_a_response(&buf[..n], 1);
let _ = upstream_sock.send_to(&resp, peer);
}
Err(_) => break,
}
}
})
};
let listener = UdpSocket::bind("127.0.0.1:0").expect("bind listener");
listener
.set_read_timeout(Some(Duration::from_millis(150)))
.unwrap();
let listen_addr = listener.local_addr().unwrap();
let proxy_upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
let cfg = DnsProxyConfig {
bind_addr: listen_addr,
upstream_addr,
hostname_allowlist: vec!["allowed.example.com".into()],
allowed_query_types: vec![DnsQueryType::A, DnsQueryType::AAAA],
cell_id: "it-cell-d1".into(),
run_id: "it-run-d1".into(),
policy_digest: None,
keyset_id: None,
issuer_kid: None,
correlation_id: None,
upstream_resolver_id: "resolver-it-d1".into(),
upstream_timeout: Duration::from_millis(400),
tcp_idle_timeout: Duration::ZERO,
dnssec_validator: None,
transport: cellos_supervisor::dns_proxy::upstream::UpstreamTransport::Do53Udp,
upstream_extras: cellos_supervisor::dns_proxy::upstream::UpstreamExtras::default(),
};
let emitter = Arc::new(CollectingEmitter::default());
let shutdown = Arc::new(AtomicBool::new(false));
let proxy_handle = {
let emitter = emitter.clone();
let shutdown = shutdown.clone();
let cfg = cfg.clone();
std::thread::spawn(move || {
let _ = run_one_shot(&cfg, &listener, &proxy_upstream_sock, &*emitter, &shutdown);
})
};
let mut atk = Vec::new();
atk.extend_from_slice(&[
0x12, 0x34, 0x01, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
for label in "allowed.example.com".split('.') {
atk.push(label.len() as u8);
atk.extend_from_slice(label.as_bytes());
}
atk.push(0);
atk.extend_from_slice(&[0x00, 0x01, 0x00, 0x01]); for label in "attacker.tld".split('.') {
atk.push(label.len() as u8);
atk.extend_from_slice(label.as_bytes());
}
atk.push(0);
atk.extend_from_slice(&[0x00, 0x01, 0x00, 0x01]);
let client = UdpSocket::bind("127.0.0.1:0").unwrap();
client
.set_read_timeout(Some(Duration::from_millis(400)))
.unwrap();
client.send_to(&atk, listen_addr).unwrap();
let mut buf = [0u8; 1500];
match client.recv_from(&mut buf) {
Ok((n, _)) => {
panic!("proxy must NOT respond to multi-question packets, got {n} bytes back")
}
Err(e) => {
assert!(
matches!(
e.kind(),
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
),
"expected recv timeout, got {e:?}"
);
}
}
shutdown.store(true, Ordering::SeqCst);
proxy_handle.join().unwrap();
let _ = upstream_thread.join();
assert_eq!(
upstream_seen.load(Ordering::SeqCst),
0,
"HIGH-D1: multi-question packet must not be forwarded upstream"
);
let evs = emitter.events.lock().unwrap();
let aggregates: Vec<&CloudEventV1> = evs
.iter()
.filter(|e| e.ty == "dev.cellos.events.cell.observability.v1.dns_query")
.collect();
assert_eq!(aggregates.len(), 1, "one malformed_query event expected");
let d = aggregates[0].data.as_ref().unwrap();
assert_eq!(d["decision"], "deny");
assert_eq!(d["reasonCode"], "malformed_query");
let permitted = evs
.iter()
.filter(|e| e.ty == "dev.cellos.events.cell.dns.v1.query_permitted")
.count();
assert_eq!(
permitted, 0,
"HIGH-D1: no per-question permit event must be emitted for a rejected multi-question packet"
);
}
#[test]
fn qtype_mapping_helper_round_trip() {
use cellos_core::qtype_to_dns_query_type;
let cases: &[(u16, Option<DnsQueryType>)] = &[
(1, Some(DnsQueryType::A)),
(2, Some(DnsQueryType::NS)),
(5, Some(DnsQueryType::CNAME)),
(12, Some(DnsQueryType::PTR)),
(15, Some(DnsQueryType::MX)),
(16, Some(DnsQueryType::TXT)),
(28, Some(DnsQueryType::AAAA)),
(33, Some(DnsQueryType::SRV)),
(64, Some(DnsQueryType::SVCB)),
(65, Some(DnsQueryType::HTTPS)),
(0, None),
(3, None),
(255, None),
(10000, None),
];
for (qt, expected) in cases {
assert_eq!(
qtype_to_dns_query_type(*qt),
*expected,
"qtype mapping for {qt}"
);
}
}