pub mod dnssec;
pub mod parser;
pub mod spawn;
pub mod upstream;
use std::io;
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use cellos_core::{
cloud_event_v1_dns_authority_dnssec_failed, cloud_event_v1_dns_query_permitted,
cloud_event_v1_dns_query_refused, qtype_to_dns_query_type, CloudEventV1,
DnsAuthorityDnssecFailed, DnsAuthorityDnssecFailureReason, DnsQueryDecision, DnsQueryEvent,
DnsQueryReasonCode, DnsQueryType,
};
use dnssec::{DataplaneDnssecOutcome, DataplaneDnssecValidator};
use parser::{parse_query, DnsParseError, DnsQueryView, DNS_HEADER_LEN};
use upstream::{UpstreamExtras, UpstreamTransport};
const DEFAULT_QUERY_TYPES: &[DnsQueryType] = &[
DnsQueryType::A,
DnsQueryType::AAAA,
DnsQueryType::CNAME,
DnsQueryType::HTTPS,
];
const MAX_UDP_PAYLOAD: usize = 1500;
const DEFAULT_TCP_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone)]
pub struct DnsProxyConfig {
pub bind_addr: SocketAddr,
pub upstream_addr: SocketAddr,
pub hostname_allowlist: Vec<String>,
pub allowed_query_types: Vec<DnsQueryType>,
pub cell_id: String,
pub run_id: String,
pub policy_digest: Option<String>,
pub keyset_id: Option<String>,
pub issuer_kid: Option<String>,
pub correlation_id: Option<String>,
pub upstream_resolver_id: String,
pub upstream_timeout: Duration,
pub tcp_idle_timeout: Duration,
pub dnssec_validator: Option<Arc<DataplaneDnssecValidator>>,
pub transport: UpstreamTransport,
pub upstream_extras: UpstreamExtras,
}
pub trait DnsQueryEmitter: Send + Sync {
fn emit(&self, event: CloudEventV1);
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub struct DnsProxyStats {
pub queries_total: u64,
pub queries_allowed: u64,
pub queries_denied: u64,
pub queries_malformed: u64,
pub upstream_failures: u64,
}
pub fn run_one_shot(
cfg: &DnsProxyConfig,
socket: &UdpSocket,
upstream: &UdpSocket,
emitter: &dyn DnsQueryEmitter,
shutdown: &AtomicBool,
) -> io::Result<DnsProxyStats> {
let mut stats = DnsProxyStats::default();
let mut recv_buf = [0u8; MAX_UDP_PAYLOAD];
let mut up_buf = [0u8; MAX_UDP_PAYLOAD];
while !shutdown.load(Ordering::SeqCst) {
let (n, peer) = match socket.recv_from(&mut recv_buf) {
Ok(t) => t,
Err(e) if is_timeout(&e) => continue,
Err(e) if matches!(e.kind(), io::ErrorKind::Interrupted) => continue,
Err(e) => return Err(e),
};
stats.queries_total = stats.queries_total.saturating_add(1);
let pkt = &recv_buf[..n];
match parse_query(pkt) {
Err(parse_err) => {
stats.queries_malformed = stats.queries_malformed.saturating_add(1);
let event = build_event(
cfg,
EventInputs {
view: None,
decision: DnsQueryDecision::Deny,
reason_code: malformed_reason(parse_err),
response_rcode: None,
upstream_resolver_id: None,
upstream_latency_ms: None,
response_target_count: None,
},
);
emit_event(emitter, event);
continue;
}
Ok(view) => {
let qtype_known = qtype_to_dns_query_type(view.qtype);
let allowed_types = if cfg.allowed_query_types.is_empty() {
DEFAULT_QUERY_TYPES
} else {
cfg.allowed_query_types.as_slice()
};
let qtype_in_set = qtype_known.is_some_and(|t| allowed_types.contains(&t));
if !qtype_in_set {
stats.queries_denied = stats.queries_denied.saturating_add(1);
let resp = build_refused_response(pkt, &view);
let _ = socket.send_to(&resp, peer);
let event = build_event(
cfg,
EventInputs {
view: Some(&view),
decision: DnsQueryDecision::Deny,
reason_code: DnsQueryReasonCode::DeniedQueryType,
response_rcode: Some(5),
upstream_resolver_id: None,
upstream_latency_ms: None,
response_target_count: Some(0),
},
);
emit_event(emitter, event);
emit_query_refused(cfg, emitter, &view, "denied_query_type");
continue;
}
if !hostname_in_allowlist(&view.qname, &cfg.hostname_allowlist) {
stats.queries_denied = stats.queries_denied.saturating_add(1);
let resp = build_refused_response(pkt, &view);
let _ = socket.send_to(&resp, peer);
let event = build_event(
cfg,
EventInputs {
view: Some(&view),
decision: DnsQueryDecision::Deny,
reason_code: DnsQueryReasonCode::DeniedNotInAllowlist,
response_rcode: Some(5),
upstream_resolver_id: None,
upstream_latency_ms: None,
response_target_count: Some(0),
},
);
emit_event(emitter, event);
emit_query_refused(cfg, emitter, &view, "denied_not_in_allowlist");
continue;
}
emit_query_permitted(cfg, emitter, &view);
let started = Instant::now();
let upstream_result = upstream::forward(
cfg.transport,
upstream,
cfg.upstream_addr,
pkt,
&mut up_buf,
cfg.upstream_timeout,
&cfg.upstream_extras,
);
let elapsed_ms = started.elapsed().as_millis() as u64;
match upstream_result {
Ok(resp_len) => {
if let Some(validator) = cfg.dnssec_validator.as_ref() {
let outcome = validator.validate(pkt, &up_buf[..resp_len]);
let action =
decide_dnssec_action(validator.is_require_mode(), &outcome);
match action {
DnssecAction::Forward => {
}
DnssecAction::Servfail { reason } => {
let resp = build_servfail_response(pkt, &view);
let _ = socket.send_to(&resp, peer);
stats.queries_denied = stats.queries_denied.saturating_add(1);
let q_event = build_event(
cfg,
EventInputs {
view: Some(&view),
decision: DnsQueryDecision::Deny,
reason_code: DnsQueryReasonCode::DeniedDnssec,
response_rcode: Some(2),
upstream_resolver_id: Some(
cfg.upstream_resolver_id.clone(),
),
upstream_latency_ms: Some(elapsed_ms),
response_target_count: Some(0),
},
);
emit_event(emitter, q_event);
let dnssec_event = build_dataplane_dnssec_failed_event(
cfg, &view, validator, reason,
);
emit_event(emitter, dnssec_event);
continue;
}
DnssecAction::ForwardUnsignedBestEffort => {
}
}
}
let resp = &up_buf[..resp_len];
let _ = socket.send_to(resp, peer);
stats.queries_allowed = stats.queries_allowed.saturating_add(1);
let answer_count = parse_response_target_count(resp, view.qtype);
let event = build_event(
cfg,
EventInputs {
view: Some(&view),
decision: DnsQueryDecision::Allow,
reason_code: DnsQueryReasonCode::AllowedByAllowlist,
response_rcode: Some(parse_response_rcode(resp)),
upstream_resolver_id: Some(cfg.upstream_resolver_id.clone()),
upstream_latency_ms: Some(elapsed_ms),
response_target_count: Some(answer_count),
},
);
emit_event(emitter, event);
}
Err(_e) => {
stats.upstream_failures = stats.upstream_failures.saturating_add(1);
let resp = build_servfail_response(pkt, &view);
let _ = socket.send_to(&resp, peer);
let event = build_event(
cfg,
EventInputs {
view: Some(&view),
decision: DnsQueryDecision::Deny,
reason_code: DnsQueryReasonCode::UpstreamFailure,
response_rcode: Some(2),
upstream_resolver_id: Some(cfg.upstream_resolver_id.clone()),
upstream_latency_ms: Some(elapsed_ms),
response_target_count: Some(0),
},
);
emit_event(emitter, event);
}
}
}
}
}
Ok(stats)
}
pub fn run_tcp_one_shot(
cfg: &DnsProxyConfig,
listener: &TcpListener,
upstream: Arc<UdpSocket>,
emitter: Arc<dyn DnsQueryEmitter>,
shutdown: &AtomicBool,
) -> io::Result<DnsProxyStats> {
listener.set_nonblocking(true)?;
let stats = Arc::new(Mutex::new(DnsProxyStats::default()));
let mut workers: Vec<std::thread::JoinHandle<()>> = Vec::new();
let tcp_idle_timeout = if cfg.tcp_idle_timeout.is_zero() {
DEFAULT_TCP_IDLE_TIMEOUT
} else {
cfg.tcp_idle_timeout
};
while !shutdown.load(Ordering::SeqCst) {
match listener.accept() {
Ok((stream, _peer)) => {
if let Err(_e) = stream.set_nonblocking(false) {
continue;
}
let _ = stream.set_read_timeout(Some(tcp_idle_timeout));
let _ = stream.set_write_timeout(Some(tcp_idle_timeout));
let cfg_owned = cfg.clone();
let upstream = upstream.clone();
let emitter = emitter.clone();
let stats = stats.clone();
let handle = std::thread::spawn(move || {
handle_tcp_connection(&cfg_owned, stream, &upstream, &*emitter, &stats);
});
workers.push(handle);
}
Err(e) if matches!(e.kind(), io::ErrorKind::WouldBlock) => {
std::thread::sleep(Duration::from_millis(50));
continue;
}
Err(e) if matches!(e.kind(), io::ErrorKind::Interrupted) => continue,
Err(e) => {
let _ = listener.set_nonblocking(false);
return Err(e);
}
}
}
for h in workers {
let _ = h.join();
}
let _ = listener.set_nonblocking(false);
let final_stats = *stats.lock().expect("dns proxy stats mutex poisoned");
Ok(final_stats)
}
fn handle_tcp_connection(
cfg: &DnsProxyConfig,
mut stream: TcpStream,
upstream: &UdpSocket,
emitter: &dyn DnsQueryEmitter,
stats: &Mutex<DnsProxyStats>,
) {
let mut up_buf = [0u8; MAX_UDP_PAYLOAD];
loop {
let mut len_buf = [0u8; 2];
if stream.read_exact(&mut len_buf).is_err() {
return;
}
let msg_len = u16::from_be_bytes(len_buf) as usize;
if msg_len == 0 {
bump_malformed(stats);
let event = build_event(
cfg,
EventInputs {
view: None,
decision: DnsQueryDecision::Deny,
reason_code: DnsQueryReasonCode::MalformedQuery,
response_rcode: None,
upstream_resolver_id: None,
upstream_latency_ms: None,
response_target_count: None,
},
);
emit_event(emitter, event);
return;
}
let mut pkt = vec![0u8; msg_len];
if stream.read_exact(&mut pkt).is_err() {
return;
}
bump_total(stats);
match parse_query(&pkt) {
Err(parse_err) => {
bump_malformed(stats);
let event = build_event(
cfg,
EventInputs {
view: None,
decision: DnsQueryDecision::Deny,
reason_code: malformed_reason(parse_err),
response_rcode: None,
upstream_resolver_id: None,
upstream_latency_ms: None,
response_target_count: None,
},
);
emit_event(emitter, event);
return;
}
Ok(view) => {
let qtype_known = qtype_to_dns_query_type(view.qtype);
let allowed_types = if cfg.allowed_query_types.is_empty() {
DEFAULT_QUERY_TYPES
} else {
cfg.allowed_query_types.as_slice()
};
let qtype_in_set = qtype_known.is_some_and(|t| allowed_types.contains(&t));
if !qtype_in_set {
bump_denied(stats);
let resp = build_refused_response(&pkt, &view);
if write_framed(&mut stream, &resp).is_err() {
return;
}
let event = build_event(
cfg,
EventInputs {
view: Some(&view),
decision: DnsQueryDecision::Deny,
reason_code: DnsQueryReasonCode::DeniedQueryType,
response_rcode: Some(5),
upstream_resolver_id: None,
upstream_latency_ms: None,
response_target_count: Some(0),
},
);
emit_event(emitter, event);
emit_query_refused(cfg, emitter, &view, "denied_query_type");
continue;
}
if !hostname_in_allowlist(&view.qname, &cfg.hostname_allowlist) {
bump_denied(stats);
let resp = build_refused_response(&pkt, &view);
if write_framed(&mut stream, &resp).is_err() {
return;
}
let event = build_event(
cfg,
EventInputs {
view: Some(&view),
decision: DnsQueryDecision::Deny,
reason_code: DnsQueryReasonCode::DeniedNotInAllowlist,
response_rcode: Some(5),
upstream_resolver_id: None,
upstream_latency_ms: None,
response_target_count: Some(0),
},
);
emit_event(emitter, event);
emit_query_refused(cfg, emitter, &view, "denied_not_in_allowlist");
continue;
}
emit_query_permitted(cfg, emitter, &view);
let started = Instant::now();
let upstream_result = upstream::forward(
cfg.transport,
upstream,
cfg.upstream_addr,
&pkt,
&mut up_buf,
cfg.upstream_timeout,
&cfg.upstream_extras,
);
let elapsed_ms = started.elapsed().as_millis() as u64;
match upstream_result {
Ok(resp_len) => {
if let Some(validator) = cfg.dnssec_validator.as_ref() {
let outcome = validator.validate(&pkt, &up_buf[..resp_len]);
let action =
decide_dnssec_action(validator.is_require_mode(), &outcome);
match action {
DnssecAction::Forward | DnssecAction::ForwardUnsignedBestEffort => {
}
DnssecAction::Servfail { reason } => {
let resp = build_servfail_response(&pkt, &view);
if write_framed(&mut stream, &resp).is_err() {
return;
}
bump_denied(stats);
let q_event = build_event(
cfg,
EventInputs {
view: Some(&view),
decision: DnsQueryDecision::Deny,
reason_code: DnsQueryReasonCode::DeniedDnssec,
response_rcode: Some(2),
upstream_resolver_id: Some(
cfg.upstream_resolver_id.clone(),
),
upstream_latency_ms: Some(elapsed_ms),
response_target_count: Some(0),
},
);
emit_event(emitter, q_event);
let dnssec_event = build_dataplane_dnssec_failed_event(
cfg, &view, validator, reason,
);
emit_event(emitter, dnssec_event);
continue;
}
}
}
let resp = &up_buf[..resp_len];
if write_framed(&mut stream, resp).is_err() {
return;
}
bump_allowed(stats);
let answer_count = parse_response_target_count(resp, view.qtype);
let event = build_event(
cfg,
EventInputs {
view: Some(&view),
decision: DnsQueryDecision::Allow,
reason_code: DnsQueryReasonCode::AllowedByAllowlist,
response_rcode: Some(parse_response_rcode(resp)),
upstream_resolver_id: Some(cfg.upstream_resolver_id.clone()),
upstream_latency_ms: Some(elapsed_ms),
response_target_count: Some(answer_count),
},
);
emit_event(emitter, event);
}
Err(_e) => {
bump_upstream_failure(stats);
let resp = build_servfail_response(&pkt, &view);
if write_framed(&mut stream, &resp).is_err() {
return;
}
let event = build_event(
cfg,
EventInputs {
view: Some(&view),
decision: DnsQueryDecision::Deny,
reason_code: DnsQueryReasonCode::UpstreamFailure,
response_rcode: Some(2),
upstream_resolver_id: Some(cfg.upstream_resolver_id.clone()),
upstream_latency_ms: Some(elapsed_ms),
response_target_count: Some(0),
},
);
emit_event(emitter, event);
}
}
}
}
}
}
fn write_framed(stream: &mut TcpStream, msg: &[u8]) -> io::Result<()> {
let len = u16::try_from(msg.len()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"DNS message exceeds 65535-byte TCP frame limit",
)
})?;
stream.write_all(&len.to_be_bytes())?;
stream.write_all(msg)?;
stream.flush()?;
Ok(())
}
fn bump_total(stats: &Mutex<DnsProxyStats>) {
if let Ok(mut s) = stats.lock() {
s.queries_total = s.queries_total.saturating_add(1);
}
}
fn bump_allowed(stats: &Mutex<DnsProxyStats>) {
if let Ok(mut s) = stats.lock() {
s.queries_allowed = s.queries_allowed.saturating_add(1);
}
}
fn bump_denied(stats: &Mutex<DnsProxyStats>) {
if let Ok(mut s) = stats.lock() {
s.queries_denied = s.queries_denied.saturating_add(1);
}
}
fn bump_malformed(stats: &Mutex<DnsProxyStats>) {
if let Ok(mut s) = stats.lock() {
s.queries_malformed = s.queries_malformed.saturating_add(1);
}
}
fn bump_upstream_failure(stats: &Mutex<DnsProxyStats>) {
if let Ok(mut s) = stats.lock() {
s.upstream_failures = s.upstream_failures.saturating_add(1);
}
}
fn decide_dnssec_action(require_mode: bool, outcome: &DataplaneDnssecOutcome) -> DnssecAction {
match (outcome, require_mode) {
(DataplaneDnssecOutcome::Validated, _) => DnssecAction::Forward,
(DataplaneDnssecOutcome::Failed { reason }, _) => DnssecAction::Servfail { reason },
(DataplaneDnssecOutcome::Unsigned, true) => DnssecAction::Servfail {
reason: "unsigned_in_require_mode",
},
(DataplaneDnssecOutcome::Unsigned, false) => DnssecAction::ForwardUnsignedBestEffort,
(DataplaneDnssecOutcome::Skip, true) => DnssecAction::Servfail {
reason: "unsupported_query_type_in_require_mode",
},
(DataplaneDnssecOutcome::Skip, false) => DnssecAction::Forward,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DnssecAction {
Forward,
Servfail { reason: &'static str },
ForwardUnsignedBestEffort,
}
fn build_dataplane_dnssec_failed_event(
cfg: &DnsProxyConfig,
view: &DnsQueryView,
validator: &DataplaneDnssecValidator,
reason: &'static str,
) -> CloudEventV1 {
let observed_at = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true);
let reason_str = match reason {
"unsigned_in_require_mode" => "unsigned_in_require_mode",
"unsupported_query_type_in_require_mode" => "unsupported_query_type_in_require_mode",
_ => DnsAuthorityDnssecFailureReason::ValidationFailed.as_str(),
};
let payload = DnsAuthorityDnssecFailed {
schema_version: "1.0.0".into(),
cell_id: cfg.cell_id.clone(),
run_id: cfg.run_id.clone(),
resolver_id: cfg.upstream_resolver_id.clone(),
hostname: view.qname.clone(),
reason: reason_str.into(),
fail_closed: validator.is_require_mode(),
trust_anchor_source: validator.trust_anchor_source().to_string(),
policy_digest: cfg.policy_digest.clone(),
keyset_id: cfg.keyset_id.clone(),
issuer_kid: cfg.issuer_kid.clone(),
correlation_id: cfg.correlation_id.clone(),
source: Some("dataplane".into()),
observed_at: observed_at.clone(),
};
cloud_event_v1_dns_authority_dnssec_failed("cellos-dns-proxy", &observed_at, &payload)
.expect("DnsAuthorityDnssecFailed serializes to JSON")
}
struct EventInputs<'a> {
view: Option<&'a DnsQueryView>,
decision: DnsQueryDecision,
reason_code: DnsQueryReasonCode,
response_rcode: Option<u8>,
upstream_resolver_id: Option<String>,
upstream_latency_ms: Option<u64>,
response_target_count: Option<u32>,
}
fn build_event(cfg: &DnsProxyConfig, inputs: EventInputs<'_>) -> CloudEventV1 {
let observed_at = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true);
let (qname, qtype) = match inputs.view {
Some(v) => (
v.qname.clone(),
qtype_to_dns_query_type(v.qtype).unwrap_or(DnsQueryType::A),
),
None => (String::new(), DnsQueryType::A),
};
let payload = DnsQueryEvent {
schema_version: "1.0.0".into(),
cell_id: cfg.cell_id.clone(),
run_id: cfg.run_id.clone(),
query_id: uuid::Uuid::new_v4().to_string(),
query_name: if qname.is_empty() {
"(unknown)".into()
} else {
qname
},
query_type: qtype,
decision: inputs.decision,
reason_code: inputs.reason_code,
response_rcode: inputs.response_rcode,
upstream_resolver_id: inputs.upstream_resolver_id,
upstream_latency_ms: inputs.upstream_latency_ms,
response_target_count: inputs.response_target_count,
keyset_id: cfg.keyset_id.clone(),
issuer_kid: cfg.issuer_kid.clone(),
policy_digest: cfg.policy_digest.clone(),
correlation_id: cfg.correlation_id.clone(),
observed_at: observed_at.clone(),
};
cellos_core::cloud_event_v1_dns_query("cellos-dns-proxy", &observed_at, &payload)
.expect("DnsQueryEvent serializes to JSON")
}
fn emit_event(emitter: &dyn DnsQueryEmitter, event: CloudEventV1) {
emitter.emit(event);
}
fn dns_query_type_str(t: DnsQueryType) -> &'static str {
match t {
DnsQueryType::A => "A",
DnsQueryType::AAAA => "AAAA",
DnsQueryType::CNAME => "CNAME",
DnsQueryType::TXT => "TXT",
DnsQueryType::MX => "MX",
DnsQueryType::SRV => "SRV",
DnsQueryType::NS => "NS",
DnsQueryType::PTR => "PTR",
DnsQueryType::HTTPS => "HTTPS",
DnsQueryType::SVCB => "SVCB",
}
}
fn emit_query_permitted(cfg: &DnsProxyConfig, emitter: &dyn DnsQueryEmitter, view: &DnsQueryView) {
let qtype = qtype_to_dns_query_type(view.qtype).unwrap_or(DnsQueryType::A);
let observed_at = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true);
let event = cloud_event_v1_dns_query_permitted(
"cellos-dns-proxy",
&observed_at,
&view.qname,
dns_query_type_str(qtype),
&cfg.cell_id,
&cfg.upstream_resolver_id,
);
emit_event(emitter, event);
}
fn emit_query_refused(
cfg: &DnsProxyConfig,
emitter: &dyn DnsQueryEmitter,
view: &DnsQueryView,
reason: &str,
) {
let qtype = qtype_to_dns_query_type(view.qtype).unwrap_or(DnsQueryType::A);
let observed_at = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true);
let event = cloud_event_v1_dns_query_refused(
"cellos-dns-proxy",
&observed_at,
&view.qname,
dns_query_type_str(qtype),
&cfg.cell_id,
reason,
);
emit_event(emitter, event);
}
fn malformed_reason(_e: DnsParseError) -> DnsQueryReasonCode {
DnsQueryReasonCode::MalformedQuery
}
fn hostname_in_allowlist(qname: &str, allowlist: &[String]) -> bool {
cellos_core::hostname_allowlist::matches_allowlist(qname, allowlist)
}
fn is_timeout(e: &io::Error) -> bool {
matches!(
e.kind(),
io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
)
}
fn build_refused_response(query: &[u8], view: &DnsQueryView) -> Vec<u8> {
build_error_response(query, view, 5)
}
fn build_servfail_response(query: &[u8], view: &DnsQueryView) -> Vec<u8> {
build_error_response(query, view, 2)
}
fn build_error_response(query: &[u8], view: &DnsQueryView, rcode: u8) -> Vec<u8> {
let mut question_end = DNS_HEADER_LEN;
let mut idx = DNS_HEADER_LEN;
while idx < query.len() {
let b = query[idx];
if b == 0 {
idx += 1;
break;
}
idx += 1 + b as usize;
}
if idx + 4 <= query.len() {
question_end = idx + 4;
}
let mut resp = Vec::with_capacity(question_end);
resp.extend_from_slice(&view.txn_id.to_be_bytes());
let mut flags = view.flags;
flags |= 0x8000; flags &= !0x0080; flags = (flags & 0xfff0) | u16::from(rcode & 0x0f);
resp.extend_from_slice(&flags.to_be_bytes());
resp.extend_from_slice(&[0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
resp.extend_from_slice(&query[DNS_HEADER_LEN..question_end]);
resp
}
fn parse_response_rcode(resp: &[u8]) -> u8 {
if resp.len() < 4 {
return 2; }
resp[3] & 0x0f
}
fn parse_response_target_count(resp: &[u8], qtype: u16) -> u32 {
if !matches!(qtype, 1 | 28) {
return 0;
}
if resp.len() < DNS_HEADER_LEN {
return 0;
}
let qdcount = u16::from_be_bytes([resp[4], resp[5]]) as usize;
let ancount = u16::from_be_bytes([resp[6], resp[7]]) as usize;
let mut idx = DNS_HEADER_LEN;
for _ in 0..qdcount {
idx = match skip_name(resp, idx) {
Some(n) => n,
None => return 0,
};
idx += 4; if idx > resp.len() {
return 0;
}
}
let mut count: u32 = 0;
for _ in 0..ancount {
idx = match skip_name(resp, idx) {
Some(n) => n,
None => return count,
};
if idx + 10 > resp.len() {
return count;
}
let rtype = u16::from_be_bytes([resp[idx], resp[idx + 1]]);
let rdlen = u16::from_be_bytes([resp[idx + 8], resp[idx + 9]]) as usize;
idx += 10;
if rtype == qtype {
count = count.saturating_add(1);
}
idx += rdlen;
if idx > resp.len() {
return count;
}
}
count
}
fn skip_name(buf: &[u8], mut idx: usize) -> Option<usize> {
loop {
if idx >= buf.len() {
return None;
}
let b = buf[idx];
if b == 0 {
return Some(idx + 1);
}
if b & 0xc0 == 0xc0 {
if idx + 1 >= buf.len() {
return None;
}
return Some(idx + 2);
}
idx += 1 + b as usize;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::UdpSocket;
use std::sync::Mutex;
use std::time::Duration;
#[derive(Default)]
struct MemEmitter {
events: Mutex<Vec<CloudEventV1>>,
}
impl DnsQueryEmitter for MemEmitter {
fn emit(&self, event: CloudEventV1) {
self.events.lock().unwrap().push(event);
}
}
fn build_query_packet(qname: &str, qtype: u16) -> Vec<u8> {
let mut p = Vec::new();
p.extend_from_slice(&[
0xab, 0xcd, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
]);
for label in qname.split('.') {
p.push(label.len() as u8);
p.extend_from_slice(label.as_bytes());
}
p.push(0);
p.extend_from_slice(&qtype.to_be_bytes());
p.extend_from_slice(&[0x00, 0x01]);
p
}
fn build_a_response(query: &[u8], ancount: u16) -> Vec<u8> {
let mut resp = query.to_vec();
resp[2] = 0x81;
resp[3] = 0x80;
resp[6] = (ancount >> 8) as u8;
resp[7] = (ancount & 0xff) as u8;
for _ in 0..ancount {
resp.extend_from_slice(&[0xc0, 0x0c]); resp.extend_from_slice(&[0x00, 0x01]); resp.extend_from_slice(&[0x00, 0x01]); resp.extend_from_slice(&[0x00, 0x00, 0x01, 0x2c]); resp.extend_from_slice(&[0x00, 0x04]); resp.extend_from_slice(&[203, 0, 113, 1]);
}
resp
}
fn spawn_upstream(swallow: bool, ancount: u16) -> (SocketAddr, std::thread::JoinHandle<()>) {
let sock = UdpSocket::bind("127.0.0.1:0").unwrap();
let addr = sock.local_addr().unwrap();
sock.set_read_timeout(Some(Duration::from_millis(2000)))
.unwrap();
let h = std::thread::spawn(move || {
let mut buf = [0u8; 1500];
while let Ok((n, peer)) = sock.recv_from(&mut buf) {
if swallow {
continue;
}
let resp = build_a_response(&buf[..n], ancount);
let _ = sock.send_to(&resp, peer);
}
});
(addr, h)
}
fn proxy_cfg(allowlist: Vec<&str>, upstream: SocketAddr) -> DnsProxyConfig {
DnsProxyConfig {
bind_addr: "127.0.0.1:0".parse().unwrap(),
upstream_addr: upstream,
hostname_allowlist: allowlist.into_iter().map(String::from).collect(),
allowed_query_types: vec![],
cell_id: "test-cell".into(),
run_id: "test-run".into(),
policy_digest: None,
keyset_id: Some("test-keyset".into()),
issuer_kid: Some("test-kid-001".into()),
correlation_id: None,
upstream_resolver_id: "resolver-test-001".into(),
upstream_timeout: Duration::from_millis(300),
tcp_idle_timeout: Duration::ZERO,
dnssec_validator: None,
transport: UpstreamTransport::Do53Udp,
upstream_extras: UpstreamExtras::default(),
}
}
#[test]
fn proxy_allows_query_in_allowlist() {
let (upstream_addr, _h) = spawn_upstream(false, 2);
let listener = UdpSocket::bind("127.0.0.1:0").unwrap();
listener
.set_read_timeout(Some(Duration::from_millis(150)))
.unwrap();
let listen_addr = listener.local_addr().unwrap();
let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
let cfg = proxy_cfg(vec!["api.example.com"], upstream_addr);
let emitter = std::sync::Arc::new(MemEmitter::default());
let shutdown = std::sync::Arc::new(AtomicBool::new(false));
let proxy_handle = {
let emitter = emitter.clone();
let shutdown = shutdown.clone();
let cfg = cfg.clone();
std::thread::spawn(move || {
let _ = run_one_shot(&cfg, &listener, &upstream_sock, &*emitter, &shutdown);
})
};
let client = UdpSocket::bind("127.0.0.1:0").unwrap();
client
.set_read_timeout(Some(Duration::from_secs(1)))
.unwrap();
let q = build_query_packet("api.example.com", 1);
client.send_to(&q, listen_addr).unwrap();
let mut rb = [0u8; 1500];
let (n, _) = client.recv_from(&mut rb).unwrap();
assert!(n > DNS_HEADER_LEN);
let rcode = rb[3] & 0x0f;
assert_eq!(rcode, 0, "expected NOERROR on allow path");
shutdown.store(true, Ordering::SeqCst);
proxy_handle.join().unwrap();
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 2);
assert_eq!(evs[0].ty, "dev.cellos.events.cell.dns.v1.query_permitted");
let permitted_data = evs[0].data.as_ref().unwrap();
assert_eq!(permitted_data["queryName"], "api.example.com");
assert_eq!(permitted_data["queryType"], "A");
assert_eq!(permitted_data["resolver"], "resolver-test-001");
let data = evs[1].data.as_ref().unwrap();
assert_eq!(data["decision"], "allow");
assert_eq!(data["reasonCode"], "allowed_by_allowlist");
assert_eq!(data["responseRcode"], 0);
assert_eq!(data["upstreamResolverId"], "resolver-test-001");
assert_eq!(data["responseTargetCount"], 2);
}
#[test]
fn proxy_denies_query_not_in_allowlist() {
let (upstream_addr, _h) = spawn_upstream(false, 0);
let listener = UdpSocket::bind("127.0.0.1:0").unwrap();
listener
.set_read_timeout(Some(Duration::from_millis(150)))
.unwrap();
let listen_addr = listener.local_addr().unwrap();
let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
let cfg = proxy_cfg(vec!["api.example.com"], upstream_addr);
let emitter = std::sync::Arc::new(MemEmitter::default());
let shutdown = std::sync::Arc::new(AtomicBool::new(false));
let proxy_handle = {
let emitter = emitter.clone();
let shutdown = shutdown.clone();
let cfg = cfg.clone();
std::thread::spawn(move || {
let _ = run_one_shot(&cfg, &listener, &upstream_sock, &*emitter, &shutdown);
})
};
let client = UdpSocket::bind("127.0.0.1:0").unwrap();
client
.set_read_timeout(Some(Duration::from_secs(1)))
.unwrap();
let q = build_query_packet("blocked.example.com", 1);
client.send_to(&q, listen_addr).unwrap();
let mut rb = [0u8; 1500];
let (n, _) = client.recv_from(&mut rb).unwrap();
let rcode = rb[3] & 0x0f;
assert_eq!(
rcode, 5,
"expected REFUSED on deny path, got rcode={rcode} n={n}"
);
shutdown.store(true, Ordering::SeqCst);
proxy_handle.join().unwrap();
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 2, "aggregate + short-form refusal event");
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["decision"], "deny");
assert_eq!(data["reasonCode"], "denied_not_in_allowlist");
assert_eq!(data["responseRcode"], 5);
assert_eq!(evs[1].ty, "dev.cellos.events.cell.dns.v1.query_refused");
let refused = evs[1].data.as_ref().unwrap();
assert_eq!(refused["reason"], "denied_not_in_allowlist");
assert_eq!(refused["queryName"], "blocked.example.com");
}
#[test]
fn proxy_wildcard_matches_subdomain_only() {
assert!(hostname_in_allowlist(
"foo.cdn.example.com",
&["*.cdn.example.com".into()]
));
assert!(hostname_in_allowlist(
"deep.foo.cdn.example.com",
&["*.cdn.example.com".into()]
));
assert!(!hostname_in_allowlist(
"cdn.example.com",
&["*.cdn.example.com".into()]
));
assert!(!hostname_in_allowlist(
"evil-cdn.example.com",
&["*.cdn.example.com".into()]
));
assert!(hostname_in_allowlist(
"api.example.com",
&["api.example.com".into()]
));
assert!(!hostname_in_allowlist(
"x.api.example.com",
&["api.example.com".into()]
));
}
#[test]
fn proxy_denies_disallowed_query_type() {
let (upstream_addr, _h) = spawn_upstream(false, 0);
let listener = UdpSocket::bind("127.0.0.1:0").unwrap();
listener
.set_read_timeout(Some(Duration::from_millis(150)))
.unwrap();
let listen_addr = listener.local_addr().unwrap();
let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
let mut cfg = proxy_cfg(vec!["api.example.com"], upstream_addr);
cfg.allowed_query_types = vec![DnsQueryType::A, DnsQueryType::AAAA];
let emitter = std::sync::Arc::new(MemEmitter::default());
let shutdown = std::sync::Arc::new(AtomicBool::new(false));
let proxy_handle = {
let emitter = emitter.clone();
let shutdown = shutdown.clone();
let cfg = cfg.clone();
std::thread::spawn(move || {
let _ = run_one_shot(&cfg, &listener, &upstream_sock, &*emitter, &shutdown);
})
};
let client = UdpSocket::bind("127.0.0.1:0").unwrap();
client
.set_read_timeout(Some(Duration::from_secs(1)))
.unwrap();
let q = build_query_packet("api.example.com", 16); client.send_to(&q, listen_addr).unwrap();
let mut rb = [0u8; 1500];
let (_n, _) = client.recv_from(&mut rb).unwrap();
assert_eq!(rb[3] & 0x0f, 5);
shutdown.store(true, Ordering::SeqCst);
proxy_handle.join().unwrap();
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 2);
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["reasonCode"], "denied_query_type");
assert_eq!(evs[1].ty, "dev.cellos.events.cell.dns.v1.query_refused");
assert_eq!(evs[1].data.as_ref().unwrap()["reason"], "denied_query_type");
}
#[test]
fn proxy_emits_event_per_query() {
let (upstream_addr, _h) = spawn_upstream(false, 1);
let listener = UdpSocket::bind("127.0.0.1:0").unwrap();
listener
.set_read_timeout(Some(Duration::from_millis(150)))
.unwrap();
let listen_addr = listener.local_addr().unwrap();
let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
let cfg = proxy_cfg(vec!["api.example.com", "*.cdn.example.com"], upstream_addr);
let emitter = std::sync::Arc::new(MemEmitter::default());
let shutdown = std::sync::Arc::new(AtomicBool::new(false));
let proxy_handle = {
let emitter = emitter.clone();
let shutdown = shutdown.clone();
let cfg = cfg.clone();
std::thread::spawn(move || {
let _ = run_one_shot(&cfg, &listener, &upstream_sock, &*emitter, &shutdown);
})
};
let client = UdpSocket::bind("127.0.0.1:0").unwrap();
client
.set_read_timeout(Some(Duration::from_secs(1)))
.unwrap();
for name in [
"api.example.com",
"img.cdn.example.com",
"blocked.example.com",
] {
let q = build_query_packet(name, 1);
client.send_to(&q, listen_addr).unwrap();
let mut rb = [0u8; 1500];
let _ = client.recv_from(&mut rb).unwrap();
}
shutdown.store(true, Ordering::SeqCst);
proxy_handle.join().unwrap();
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 6);
assert_eq!(evs[0].ty, "dev.cellos.events.cell.dns.v1.query_permitted");
assert_eq!(
evs[0].data.as_ref().unwrap()["queryName"],
"api.example.com"
);
let data1_agg = evs[1].data.as_ref().unwrap();
assert_eq!(data1_agg["decision"], "allow");
assert_eq!(data1_agg["queryName"], "api.example.com");
assert_eq!(data1_agg["upstreamResolverId"], "resolver-test-001");
assert_eq!(evs[2].ty, "dev.cellos.events.cell.dns.v1.query_permitted");
let data3_agg = evs[3].data.as_ref().unwrap();
assert_eq!(data3_agg["decision"], "allow");
assert_eq!(data3_agg["queryName"], "img.cdn.example.com");
let data4_agg = evs[4].data.as_ref().unwrap();
assert_eq!(data4_agg["decision"], "deny");
assert_eq!(data4_agg["queryName"], "blocked.example.com");
assert_eq!(evs[5].ty, "dev.cellos.events.cell.dns.v1.query_refused");
assert_eq!(
evs[5].data.as_ref().unwrap()["reason"],
"denied_not_in_allowlist"
);
}
#[test]
fn proxy_returns_servfail_on_upstream_timeout() {
let (upstream_addr, _h) = spawn_upstream(true, 0); let listener = UdpSocket::bind("127.0.0.1:0").unwrap();
listener
.set_read_timeout(Some(Duration::from_millis(150)))
.unwrap();
let listen_addr = listener.local_addr().unwrap();
let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
let mut cfg = proxy_cfg(vec!["api.example.com"], upstream_addr);
cfg.upstream_timeout = Duration::from_millis(120);
let emitter = std::sync::Arc::new(MemEmitter::default());
let shutdown = std::sync::Arc::new(AtomicBool::new(false));
let proxy_handle = {
let emitter = emitter.clone();
let shutdown = shutdown.clone();
let cfg = cfg.clone();
std::thread::spawn(move || {
let _ = run_one_shot(&cfg, &listener, &upstream_sock, &*emitter, &shutdown);
})
};
let client = UdpSocket::bind("127.0.0.1:0").unwrap();
client
.set_read_timeout(Some(Duration::from_secs(2)))
.unwrap();
let q = build_query_packet("api.example.com", 1);
client.send_to(&q, listen_addr).unwrap();
let mut rb = [0u8; 1500];
let (_n, _) = client.recv_from(&mut rb).unwrap();
assert_eq!(rb[3] & 0x0f, 2, "expected SERVFAIL on upstream timeout");
shutdown.store(true, Ordering::SeqCst);
proxy_handle.join().unwrap();
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 2);
assert_eq!(evs[0].ty, "dev.cellos.events.cell.dns.v1.query_permitted");
let data = evs[1].data.as_ref().unwrap();
assert_eq!(data["reasonCode"], "upstream_failure");
assert_eq!(data["responseRcode"], 2);
}
#[test]
fn parse_response_target_count_counts_a_records() {
let q = build_query_packet("api.example.com", 1);
let r = build_a_response(&q, 3);
assert_eq!(parse_response_target_count(&r, 1), 3);
assert_eq!(parse_response_target_count(&r, 16), 0);
}
}