1pub mod dnssec;
55pub mod parser;
56pub mod spawn;
57pub mod upstream;
58
59use std::io;
60use std::io::{Read, Write};
61use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket};
62use std::sync::atomic::{AtomicBool, Ordering};
63use std::sync::{Arc, Mutex};
64use std::time::{Duration, Instant};
65
66use cellos_core::{
67 cloud_event_v1_dns_authority_dnssec_failed, cloud_event_v1_dns_query_permitted,
68 cloud_event_v1_dns_query_refused, qtype_to_dns_query_type, CloudEventV1,
69 DnsAuthorityDnssecFailed, DnsAuthorityDnssecFailureReason, DnsQueryDecision, DnsQueryEvent,
70 DnsQueryReasonCode, DnsQueryType,
71};
72
73use dnssec::{DataplaneDnssecOutcome, DataplaneDnssecValidator};
74use parser::{parse_query, DnsParseError, DnsQueryView, DNS_HEADER_LEN};
75use upstream::{UpstreamExtras, UpstreamTransport};
76
77const DEFAULT_QUERY_TYPES: &[DnsQueryType] = &[
79 DnsQueryType::A,
80 DnsQueryType::AAAA,
81 DnsQueryType::CNAME,
82 DnsQueryType::HTTPS,
83];
84
85const MAX_UDP_PAYLOAD: usize = 1500;
89
90const DEFAULT_TCP_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
99
100#[derive(Debug, Clone)]
106pub struct DnsProxyConfig {
107 pub bind_addr: SocketAddr,
110 pub upstream_addr: SocketAddr,
113 pub hostname_allowlist: Vec<String>,
117 pub allowed_query_types: Vec<DnsQueryType>,
119 pub cell_id: String,
121 pub run_id: String,
123 pub policy_digest: Option<String>,
125 pub keyset_id: Option<String>,
127 pub issuer_kid: Option<String>,
129 pub correlation_id: Option<String>,
131 pub upstream_resolver_id: String,
135 pub upstream_timeout: Duration,
138 pub tcp_idle_timeout: Duration,
146 pub dnssec_validator: Option<Arc<DataplaneDnssecValidator>>,
154 pub transport: UpstreamTransport,
159 pub upstream_extras: UpstreamExtras,
162}
163
164pub trait DnsQueryEmitter: Send + Sync {
168 fn emit(&self, event: CloudEventV1);
172}
173
174#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
178pub struct DnsProxyStats {
179 pub queries_total: u64,
180 pub queries_allowed: u64,
181 pub queries_denied: u64,
182 pub queries_malformed: u64,
183 pub upstream_failures: u64,
184}
185
186pub fn run_one_shot(
196 cfg: &DnsProxyConfig,
197 socket: &UdpSocket,
198 upstream: &UdpSocket,
199 emitter: &dyn DnsQueryEmitter,
200 shutdown: &AtomicBool,
201) -> io::Result<DnsProxyStats> {
202 let mut stats = DnsProxyStats::default();
203 let mut recv_buf = [0u8; MAX_UDP_PAYLOAD];
204 let mut up_buf = [0u8; MAX_UDP_PAYLOAD];
205
206 while !shutdown.load(Ordering::SeqCst) {
207 let (n, peer) = match socket.recv_from(&mut recv_buf) {
208 Ok(t) => t,
209 Err(e) if is_timeout(&e) => continue,
210 Err(e) if matches!(e.kind(), io::ErrorKind::Interrupted) => continue,
211 Err(e) => return Err(e),
212 };
213 stats.queries_total = stats.queries_total.saturating_add(1);
214
215 let pkt = &recv_buf[..n];
216 match parse_query(pkt) {
217 Err(parse_err) => {
218 stats.queries_malformed = stats.queries_malformed.saturating_add(1);
219 let event = build_event(
220 cfg,
221 EventInputs {
222 view: None,
223 decision: DnsQueryDecision::Deny,
224 reason_code: malformed_reason(parse_err),
225 response_rcode: None,
226 upstream_resolver_id: None,
227 upstream_latency_ms: None,
228 response_target_count: None,
229 },
230 );
231 emit_event(emitter, event);
232 continue;
234 }
235 Ok(view) => {
236 let qtype_known = qtype_to_dns_query_type(view.qtype);
237 let allowed_types = if cfg.allowed_query_types.is_empty() {
238 DEFAULT_QUERY_TYPES
239 } else {
240 cfg.allowed_query_types.as_slice()
241 };
242 let qtype_in_set = qtype_known.is_some_and(|t| allowed_types.contains(&t));
243 if !qtype_in_set {
244 stats.queries_denied = stats.queries_denied.saturating_add(1);
245 let resp = build_refused_response(pkt, &view);
246 let _ = socket.send_to(&resp, peer);
247 let event = build_event(
248 cfg,
249 EventInputs {
250 view: Some(&view),
251 decision: DnsQueryDecision::Deny,
252 reason_code: DnsQueryReasonCode::DeniedQueryType,
253 response_rcode: Some(5),
254 upstream_resolver_id: None,
255 upstream_latency_ms: None,
256 response_target_count: Some(0),
257 },
258 );
259 emit_event(emitter, event);
260 emit_query_refused(cfg, emitter, &view, "denied_query_type");
262 continue;
263 }
264
265 if !hostname_in_allowlist(&view.qname, &cfg.hostname_allowlist) {
266 stats.queries_denied = stats.queries_denied.saturating_add(1);
267 let resp = build_refused_response(pkt, &view);
268 let _ = socket.send_to(&resp, peer);
269 let event = build_event(
270 cfg,
271 EventInputs {
272 view: Some(&view),
273 decision: DnsQueryDecision::Deny,
274 reason_code: DnsQueryReasonCode::DeniedNotInAllowlist,
275 response_rcode: Some(5),
276 upstream_resolver_id: None,
277 upstream_latency_ms: None,
278 response_target_count: Some(0),
279 },
280 );
281 emit_event(emitter, event);
282 emit_query_refused(cfg, emitter, &view, "denied_not_in_allowlist");
284 continue;
285 }
286
287 emit_query_permitted(cfg, emitter, &view);
292
293 let started = Instant::now();
295 let upstream_result = upstream::forward(
296 cfg.transport,
297 upstream,
298 cfg.upstream_addr,
299 pkt,
300 &mut up_buf,
301 cfg.upstream_timeout,
302 &cfg.upstream_extras,
303 );
304 let elapsed_ms = started.elapsed().as_millis() as u64;
305 match upstream_result {
306 Ok(resp_len) => {
307 if let Some(validator) = cfg.dnssec_validator.as_ref() {
314 let outcome = validator.validate(pkt, &up_buf[..resp_len]);
315 let action =
316 decide_dnssec_action(validator.is_require_mode(), &outcome);
317 match action {
318 DnssecAction::Forward => {
319 }
324 DnssecAction::Servfail { reason } => {
325 let resp = build_servfail_response(pkt, &view);
326 let _ = socket.send_to(&resp, peer);
327 stats.queries_denied = stats.queries_denied.saturating_add(1);
328 let q_event = build_event(
334 cfg,
335 EventInputs {
336 view: Some(&view),
337 decision: DnsQueryDecision::Deny,
338 reason_code: DnsQueryReasonCode::DeniedDnssec,
339 response_rcode: Some(2),
340 upstream_resolver_id: Some(
341 cfg.upstream_resolver_id.clone(),
342 ),
343 upstream_latency_ms: Some(elapsed_ms),
344 response_target_count: Some(0),
345 },
346 );
347 emit_event(emitter, q_event);
348 let dnssec_event = build_dataplane_dnssec_failed_event(
353 cfg, &view, validator, reason,
354 );
355 emit_event(emitter, dnssec_event);
356 continue;
357 }
358 DnssecAction::ForwardUnsignedBestEffort => {
359 }
364 }
365 }
366 let resp = &up_buf[..resp_len];
367 let _ = socket.send_to(resp, peer);
368 stats.queries_allowed = stats.queries_allowed.saturating_add(1);
369 let answer_count = parse_response_target_count(resp, view.qtype);
370 let event = build_event(
371 cfg,
372 EventInputs {
373 view: Some(&view),
374 decision: DnsQueryDecision::Allow,
375 reason_code: DnsQueryReasonCode::AllowedByAllowlist,
376 response_rcode: Some(parse_response_rcode(resp)),
377 upstream_resolver_id: Some(cfg.upstream_resolver_id.clone()),
378 upstream_latency_ms: Some(elapsed_ms),
379 response_target_count: Some(answer_count),
380 },
381 );
382 emit_event(emitter, event);
383 }
384 Err(_e) => {
385 stats.upstream_failures = stats.upstream_failures.saturating_add(1);
386 let resp = build_servfail_response(pkt, &view);
387 let _ = socket.send_to(&resp, peer);
388 let event = build_event(
389 cfg,
390 EventInputs {
391 view: Some(&view),
392 decision: DnsQueryDecision::Deny,
393 reason_code: DnsQueryReasonCode::UpstreamFailure,
394 response_rcode: Some(2),
395 upstream_resolver_id: Some(cfg.upstream_resolver_id.clone()),
396 upstream_latency_ms: Some(elapsed_ms),
397 response_target_count: Some(0),
398 },
399 );
400 emit_event(emitter, event);
401 }
402 }
403 }
404 }
405 }
406
407 Ok(stats)
408}
409
410pub fn run_tcp_one_shot(
440 cfg: &DnsProxyConfig,
441 listener: &TcpListener,
442 upstream: Arc<UdpSocket>,
443 emitter: Arc<dyn DnsQueryEmitter>,
444 shutdown: &AtomicBool,
445) -> io::Result<DnsProxyStats> {
446 listener.set_nonblocking(true)?;
448
449 let stats = Arc::new(Mutex::new(DnsProxyStats::default()));
450 let mut workers: Vec<std::thread::JoinHandle<()>> = Vec::new();
451
452 let tcp_idle_timeout = if cfg.tcp_idle_timeout.is_zero() {
457 DEFAULT_TCP_IDLE_TIMEOUT
458 } else {
459 cfg.tcp_idle_timeout
460 };
461
462 while !shutdown.load(Ordering::SeqCst) {
463 match listener.accept() {
464 Ok((stream, _peer)) => {
465 if let Err(_e) = stream.set_nonblocking(false) {
469 continue;
470 }
471 let _ = stream.set_read_timeout(Some(tcp_idle_timeout));
477 let _ = stream.set_write_timeout(Some(tcp_idle_timeout));
478
479 let cfg_owned = cfg.clone();
480 let upstream = upstream.clone();
481 let emitter = emitter.clone();
482 let stats = stats.clone();
483 let handle = std::thread::spawn(move || {
484 handle_tcp_connection(&cfg_owned, stream, &upstream, &*emitter, &stats);
485 });
486 workers.push(handle);
487 }
488 Err(e) if matches!(e.kind(), io::ErrorKind::WouldBlock) => {
489 std::thread::sleep(Duration::from_millis(50));
490 continue;
491 }
492 Err(e) if matches!(e.kind(), io::ErrorKind::Interrupted) => continue,
493 Err(e) => {
494 let _ = listener.set_nonblocking(false);
495 return Err(e);
496 }
497 }
498 }
499
500 for h in workers {
502 let _ = h.join();
503 }
504
505 let _ = listener.set_nonblocking(false);
506 let final_stats = *stats.lock().expect("dns proxy stats mutex poisoned");
507 Ok(final_stats)
508}
509
510fn handle_tcp_connection(
515 cfg: &DnsProxyConfig,
516 mut stream: TcpStream,
517 upstream: &UdpSocket,
518 emitter: &dyn DnsQueryEmitter,
519 stats: &Mutex<DnsProxyStats>,
520) {
521 let mut up_buf = [0u8; MAX_UDP_PAYLOAD];
522 loop {
523 let mut len_buf = [0u8; 2];
525 if stream.read_exact(&mut len_buf).is_err() {
526 return;
528 }
529 let msg_len = u16::from_be_bytes(len_buf) as usize;
530 if msg_len == 0 {
531 bump_malformed(stats);
533 let event = build_event(
534 cfg,
535 EventInputs {
536 view: None,
537 decision: DnsQueryDecision::Deny,
538 reason_code: DnsQueryReasonCode::MalformedQuery,
539 response_rcode: None,
540 upstream_resolver_id: None,
541 upstream_latency_ms: None,
542 response_target_count: None,
543 },
544 );
545 emit_event(emitter, event);
546 return;
547 }
548 let mut pkt = vec![0u8; msg_len];
549 if stream.read_exact(&mut pkt).is_err() {
550 return;
551 }
552
553 bump_total(stats);
554
555 match parse_query(&pkt) {
556 Err(parse_err) => {
557 bump_malformed(stats);
558 let event = build_event(
559 cfg,
560 EventInputs {
561 view: None,
562 decision: DnsQueryDecision::Deny,
563 reason_code: malformed_reason(parse_err),
564 response_rcode: None,
565 upstream_resolver_id: None,
566 upstream_latency_ms: None,
567 response_target_count: None,
568 },
569 );
570 emit_event(emitter, event);
571 return;
573 }
574 Ok(view) => {
575 let qtype_known = qtype_to_dns_query_type(view.qtype);
576 let allowed_types = if cfg.allowed_query_types.is_empty() {
577 DEFAULT_QUERY_TYPES
578 } else {
579 cfg.allowed_query_types.as_slice()
580 };
581 let qtype_in_set = qtype_known.is_some_and(|t| allowed_types.contains(&t));
582 if !qtype_in_set {
583 bump_denied(stats);
584 let resp = build_refused_response(&pkt, &view);
585 if write_framed(&mut stream, &resp).is_err() {
586 return;
587 }
588 let event = build_event(
589 cfg,
590 EventInputs {
591 view: Some(&view),
592 decision: DnsQueryDecision::Deny,
593 reason_code: DnsQueryReasonCode::DeniedQueryType,
594 response_rcode: Some(5),
595 upstream_resolver_id: None,
596 upstream_latency_ms: None,
597 response_target_count: Some(0),
598 },
599 );
600 emit_event(emitter, event);
601 emit_query_refused(cfg, emitter, &view, "denied_query_type");
603 continue;
604 }
605
606 if !hostname_in_allowlist(&view.qname, &cfg.hostname_allowlist) {
607 bump_denied(stats);
608 let resp = build_refused_response(&pkt, &view);
609 if write_framed(&mut stream, &resp).is_err() {
610 return;
611 }
612 let event = build_event(
613 cfg,
614 EventInputs {
615 view: Some(&view),
616 decision: DnsQueryDecision::Deny,
617 reason_code: DnsQueryReasonCode::DeniedNotInAllowlist,
618 response_rcode: Some(5),
619 upstream_resolver_id: None,
620 upstream_latency_ms: None,
621 response_target_count: Some(0),
622 },
623 );
624 emit_event(emitter, event);
625 emit_query_refused(cfg, emitter, &view, "denied_not_in_allowlist");
627 continue;
628 }
629
630 emit_query_permitted(cfg, emitter, &view);
633
634 let started = Instant::now();
642 let upstream_result = upstream::forward(
643 cfg.transport,
644 upstream,
645 cfg.upstream_addr,
646 &pkt,
647 &mut up_buf,
648 cfg.upstream_timeout,
649 &cfg.upstream_extras,
650 );
651 let elapsed_ms = started.elapsed().as_millis() as u64;
652 match upstream_result {
653 Ok(resp_len) => {
654 if let Some(validator) = cfg.dnssec_validator.as_ref() {
658 let outcome = validator.validate(&pkt, &up_buf[..resp_len]);
659 let action =
660 decide_dnssec_action(validator.is_require_mode(), &outcome);
661 match action {
662 DnssecAction::Forward | DnssecAction::ForwardUnsignedBestEffort => {
663 }
667 DnssecAction::Servfail { reason } => {
668 let resp = build_servfail_response(&pkt, &view);
669 if write_framed(&mut stream, &resp).is_err() {
670 return;
671 }
672 bump_denied(stats);
673 let q_event = build_event(
674 cfg,
675 EventInputs {
676 view: Some(&view),
677 decision: DnsQueryDecision::Deny,
678 reason_code: DnsQueryReasonCode::DeniedDnssec,
679 response_rcode: Some(2),
680 upstream_resolver_id: Some(
681 cfg.upstream_resolver_id.clone(),
682 ),
683 upstream_latency_ms: Some(elapsed_ms),
684 response_target_count: Some(0),
685 },
686 );
687 emit_event(emitter, q_event);
688 let dnssec_event = build_dataplane_dnssec_failed_event(
689 cfg, &view, validator, reason,
690 );
691 emit_event(emitter, dnssec_event);
692 continue;
693 }
694 }
695 }
696 let resp = &up_buf[..resp_len];
697 if write_framed(&mut stream, resp).is_err() {
698 return;
699 }
700 bump_allowed(stats);
701 let answer_count = parse_response_target_count(resp, view.qtype);
702 let event = build_event(
703 cfg,
704 EventInputs {
705 view: Some(&view),
706 decision: DnsQueryDecision::Allow,
707 reason_code: DnsQueryReasonCode::AllowedByAllowlist,
708 response_rcode: Some(parse_response_rcode(resp)),
709 upstream_resolver_id: Some(cfg.upstream_resolver_id.clone()),
710 upstream_latency_ms: Some(elapsed_ms),
711 response_target_count: Some(answer_count),
712 },
713 );
714 emit_event(emitter, event);
715 }
716 Err(_e) => {
717 bump_upstream_failure(stats);
718 let resp = build_servfail_response(&pkt, &view);
719 if write_framed(&mut stream, &resp).is_err() {
720 return;
721 }
722 let event = build_event(
723 cfg,
724 EventInputs {
725 view: Some(&view),
726 decision: DnsQueryDecision::Deny,
727 reason_code: DnsQueryReasonCode::UpstreamFailure,
728 response_rcode: Some(2),
729 upstream_resolver_id: Some(cfg.upstream_resolver_id.clone()),
730 upstream_latency_ms: Some(elapsed_ms),
731 response_target_count: Some(0),
732 },
733 );
734 emit_event(emitter, event);
735 }
736 }
737 }
738 }
739 }
740}
741
742fn write_framed(stream: &mut TcpStream, msg: &[u8]) -> io::Result<()> {
745 let len = u16::try_from(msg.len()).map_err(|_| {
746 io::Error::new(
747 io::ErrorKind::InvalidData,
748 "DNS message exceeds 65535-byte TCP frame limit",
749 )
750 })?;
751 stream.write_all(&len.to_be_bytes())?;
752 stream.write_all(msg)?;
753 stream.flush()?;
754 Ok(())
755}
756
757fn bump_total(stats: &Mutex<DnsProxyStats>) {
758 if let Ok(mut s) = stats.lock() {
759 s.queries_total = s.queries_total.saturating_add(1);
760 }
761}
762
763fn bump_allowed(stats: &Mutex<DnsProxyStats>) {
764 if let Ok(mut s) = stats.lock() {
765 s.queries_allowed = s.queries_allowed.saturating_add(1);
766 }
767}
768
769fn bump_denied(stats: &Mutex<DnsProxyStats>) {
770 if let Ok(mut s) = stats.lock() {
771 s.queries_denied = s.queries_denied.saturating_add(1);
772 }
773}
774
775fn bump_malformed(stats: &Mutex<DnsProxyStats>) {
776 if let Ok(mut s) = stats.lock() {
777 s.queries_malformed = s.queries_malformed.saturating_add(1);
778 }
779}
780
781fn bump_upstream_failure(stats: &Mutex<DnsProxyStats>) {
782 if let Ok(mut s) = stats.lock() {
783 s.upstream_failures = s.upstream_failures.saturating_add(1);
784 }
785}
786
787fn decide_dnssec_action(require_mode: bool, outcome: &DataplaneDnssecOutcome) -> DnssecAction {
803 match (outcome, require_mode) {
804 (DataplaneDnssecOutcome::Validated, _) => DnssecAction::Forward,
805 (DataplaneDnssecOutcome::Failed { reason }, _) => DnssecAction::Servfail { reason },
806 (DataplaneDnssecOutcome::Unsigned, true) => DnssecAction::Servfail {
807 reason: "unsigned_in_require_mode",
808 },
809 (DataplaneDnssecOutcome::Unsigned, false) => DnssecAction::ForwardUnsignedBestEffort,
810 (DataplaneDnssecOutcome::Skip, true) => DnssecAction::Servfail {
811 reason: "unsupported_query_type_in_require_mode",
812 },
813 (DataplaneDnssecOutcome::Skip, false) => DnssecAction::Forward,
814 }
815}
816
817#[derive(Debug, Clone, Copy, PartialEq, Eq)]
819enum DnssecAction {
820 Forward,
823 Servfail { reason: &'static str },
827 ForwardUnsignedBestEffort,
829}
830
831fn build_dataplane_dnssec_failed_event(
839 cfg: &DnsProxyConfig,
840 view: &DnsQueryView,
841 validator: &DataplaneDnssecValidator,
842 reason: &'static str,
843) -> CloudEventV1 {
844 let observed_at = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true);
845 let reason_str = match reason {
853 "unsigned_in_require_mode" => "unsigned_in_require_mode",
854 "unsupported_query_type_in_require_mode" => "unsupported_query_type_in_require_mode",
855 _ => DnsAuthorityDnssecFailureReason::ValidationFailed.as_str(),
856 };
857 let payload = DnsAuthorityDnssecFailed {
858 schema_version: "1.0.0".into(),
859 cell_id: cfg.cell_id.clone(),
860 run_id: cfg.run_id.clone(),
861 resolver_id: cfg.upstream_resolver_id.clone(),
862 hostname: view.qname.clone(),
863 reason: reason_str.into(),
864 fail_closed: validator.is_require_mode(),
870 trust_anchor_source: validator.trust_anchor_source().to_string(),
871 policy_digest: cfg.policy_digest.clone(),
872 keyset_id: cfg.keyset_id.clone(),
873 issuer_kid: cfg.issuer_kid.clone(),
874 correlation_id: cfg.correlation_id.clone(),
875 source: Some("dataplane".into()),
879 observed_at: observed_at.clone(),
880 };
881 cloud_event_v1_dns_authority_dnssec_failed("cellos-dns-proxy", &observed_at, &payload)
882 .expect("DnsAuthorityDnssecFailed serializes to JSON")
883}
884
885struct EventInputs<'a> {
886 view: Option<&'a DnsQueryView>,
887 decision: DnsQueryDecision,
888 reason_code: DnsQueryReasonCode,
889 response_rcode: Option<u8>,
890 upstream_resolver_id: Option<String>,
891 upstream_latency_ms: Option<u64>,
892 response_target_count: Option<u32>,
893}
894
895fn build_event(cfg: &DnsProxyConfig, inputs: EventInputs<'_>) -> CloudEventV1 {
896 let observed_at = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true);
897 let (qname, qtype) = match inputs.view {
898 Some(v) => (
899 v.qname.clone(),
900 qtype_to_dns_query_type(v.qtype).unwrap_or(DnsQueryType::A),
901 ),
902 None => (String::new(), DnsQueryType::A),
903 };
904 let payload = DnsQueryEvent {
905 schema_version: "1.0.0".into(),
906 cell_id: cfg.cell_id.clone(),
907 run_id: cfg.run_id.clone(),
908 query_id: uuid::Uuid::new_v4().to_string(),
909 query_name: if qname.is_empty() {
913 "(unknown)".into()
914 } else {
915 qname
916 },
917 query_type: qtype,
918 decision: inputs.decision,
919 reason_code: inputs.reason_code,
920 response_rcode: inputs.response_rcode,
921 upstream_resolver_id: inputs.upstream_resolver_id,
922 upstream_latency_ms: inputs.upstream_latency_ms,
923 response_target_count: inputs.response_target_count,
924 keyset_id: cfg.keyset_id.clone(),
925 issuer_kid: cfg.issuer_kid.clone(),
926 policy_digest: cfg.policy_digest.clone(),
927 correlation_id: cfg.correlation_id.clone(),
928 observed_at: observed_at.clone(),
929 };
930 cellos_core::cloud_event_v1_dns_query("cellos-dns-proxy", &observed_at, &payload)
931 .expect("DnsQueryEvent serializes to JSON")
932}
933
934fn emit_event(emitter: &dyn DnsQueryEmitter, event: CloudEventV1) {
935 emitter.emit(event);
936}
937
938fn dns_query_type_str(t: DnsQueryType) -> &'static str {
941 match t {
942 DnsQueryType::A => "A",
943 DnsQueryType::AAAA => "AAAA",
944 DnsQueryType::CNAME => "CNAME",
945 DnsQueryType::TXT => "TXT",
946 DnsQueryType::MX => "MX",
947 DnsQueryType::SRV => "SRV",
948 DnsQueryType::NS => "NS",
949 DnsQueryType::PTR => "PTR",
950 DnsQueryType::HTTPS => "HTTPS",
951 DnsQueryType::SVCB => "SVCB",
952 }
953}
954
955fn emit_query_permitted(cfg: &DnsProxyConfig, emitter: &dyn DnsQueryEmitter, view: &DnsQueryView) {
961 let qtype = qtype_to_dns_query_type(view.qtype).unwrap_or(DnsQueryType::A);
962 let observed_at = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true);
963 let event = cloud_event_v1_dns_query_permitted(
964 "cellos-dns-proxy",
965 &observed_at,
966 &view.qname,
967 dns_query_type_str(qtype),
968 &cfg.cell_id,
969 &cfg.upstream_resolver_id,
970 );
971 emit_event(emitter, event);
972}
973
974fn emit_query_refused(
980 cfg: &DnsProxyConfig,
981 emitter: &dyn DnsQueryEmitter,
982 view: &DnsQueryView,
983 reason: &str,
984) {
985 let qtype = qtype_to_dns_query_type(view.qtype).unwrap_or(DnsQueryType::A);
986 let observed_at = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true);
987 let event = cloud_event_v1_dns_query_refused(
988 "cellos-dns-proxy",
989 &observed_at,
990 &view.qname,
991 dns_query_type_str(qtype),
992 &cfg.cell_id,
993 reason,
994 );
995 emit_event(emitter, event);
996}
997
998fn malformed_reason(_e: DnsParseError) -> DnsQueryReasonCode {
999 DnsQueryReasonCode::MalformedQuery
1003}
1004
1005fn hostname_in_allowlist(qname: &str, allowlist: &[String]) -> bool {
1013 cellos_core::hostname_allowlist::matches_allowlist(qname, allowlist)
1014}
1015
1016fn is_timeout(e: &io::Error) -> bool {
1017 matches!(
1018 e.kind(),
1019 io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
1020 )
1021}
1022
1023fn build_refused_response(query: &[u8], view: &DnsQueryView) -> Vec<u8> {
1028 build_error_response(query, view, 5)
1029}
1030
1031fn build_servfail_response(query: &[u8], view: &DnsQueryView) -> Vec<u8> {
1032 build_error_response(query, view, 2)
1033}
1034
1035fn build_error_response(query: &[u8], view: &DnsQueryView, rcode: u8) -> Vec<u8> {
1036 let mut question_end = DNS_HEADER_LEN;
1040 let mut idx = DNS_HEADER_LEN;
1041 while idx < query.len() {
1042 let b = query[idx];
1043 if b == 0 {
1044 idx += 1;
1045 break;
1046 }
1047 idx += 1 + b as usize;
1049 }
1050 if idx + 4 <= query.len() {
1051 question_end = idx + 4;
1052 }
1053 let mut resp = Vec::with_capacity(question_end);
1054 resp.extend_from_slice(&view.txn_id.to_be_bytes());
1055 let mut flags = view.flags;
1057 flags |= 0x8000; flags &= !0x0080; flags = (flags & 0xfff0) | u16::from(rcode & 0x0f);
1060 resp.extend_from_slice(&flags.to_be_bytes());
1061 resp.extend_from_slice(&[0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
1063 resp.extend_from_slice(&query[DNS_HEADER_LEN..question_end]);
1065 resp
1066}
1067
1068fn parse_response_rcode(resp: &[u8]) -> u8 {
1070 if resp.len() < 4 {
1071 return 2; }
1073 resp[3] & 0x0f
1074}
1075
1076fn parse_response_target_count(resp: &[u8], qtype: u16) -> u32 {
1087 if !matches!(qtype, 1 | 28) {
1088 return 0;
1089 }
1090 if resp.len() < DNS_HEADER_LEN {
1091 return 0;
1092 }
1093 let qdcount = u16::from_be_bytes([resp[4], resp[5]]) as usize;
1094 let ancount = u16::from_be_bytes([resp[6], resp[7]]) as usize;
1095 let mut idx = DNS_HEADER_LEN;
1096 for _ in 0..qdcount {
1098 idx = match skip_name(resp, idx) {
1099 Some(n) => n,
1100 None => return 0,
1101 };
1102 idx += 4; if idx > resp.len() {
1104 return 0;
1105 }
1106 }
1107 let mut count: u32 = 0;
1108 for _ in 0..ancount {
1109 idx = match skip_name(resp, idx) {
1110 Some(n) => n,
1111 None => return count,
1112 };
1113 if idx + 10 > resp.len() {
1114 return count;
1115 }
1116 let rtype = u16::from_be_bytes([resp[idx], resp[idx + 1]]);
1117 let rdlen = u16::from_be_bytes([resp[idx + 8], resp[idx + 9]]) as usize;
1118 idx += 10;
1119 if rtype == qtype {
1120 count = count.saturating_add(1);
1121 }
1122 idx += rdlen;
1123 if idx > resp.len() {
1124 return count;
1125 }
1126 }
1127 count
1128}
1129
1130fn skip_name(buf: &[u8], mut idx: usize) -> Option<usize> {
1133 loop {
1134 if idx >= buf.len() {
1135 return None;
1136 }
1137 let b = buf[idx];
1138 if b == 0 {
1139 return Some(idx + 1);
1140 }
1141 if b & 0xc0 == 0xc0 {
1142 if idx + 1 >= buf.len() {
1144 return None;
1145 }
1146 return Some(idx + 2);
1147 }
1148 idx += 1 + b as usize;
1149 }
1150}
1151
1152#[cfg(test)]
1153mod tests {
1154 use super::*;
1155 use std::net::UdpSocket;
1156 use std::sync::Mutex;
1157 use std::time::Duration;
1158
1159 #[derive(Default)]
1161 struct MemEmitter {
1162 events: Mutex<Vec<CloudEventV1>>,
1163 }
1164 impl DnsQueryEmitter for MemEmitter {
1165 fn emit(&self, event: CloudEventV1) {
1166 self.events.lock().unwrap().push(event);
1167 }
1168 }
1169
1170 fn build_query_packet(qname: &str, qtype: u16) -> Vec<u8> {
1171 let mut p = Vec::new();
1172 p.extend_from_slice(&[
1173 0xab, 0xcd, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1174 ]);
1175 for label in qname.split('.') {
1176 p.push(label.len() as u8);
1177 p.extend_from_slice(label.as_bytes());
1178 }
1179 p.push(0);
1180 p.extend_from_slice(&qtype.to_be_bytes());
1181 p.extend_from_slice(&[0x00, 0x01]);
1182 p
1183 }
1184
1185 fn build_a_response(query: &[u8], ancount: u16) -> Vec<u8> {
1187 let mut resp = query.to_vec();
1189 resp[2] = 0x81;
1190 resp[3] = 0x80;
1191 resp[6] = (ancount >> 8) as u8;
1192 resp[7] = (ancount & 0xff) as u8;
1193 for _ in 0..ancount {
1194 resp.extend_from_slice(&[0xc0, 0x0c]); resp.extend_from_slice(&[0x00, 0x01]); resp.extend_from_slice(&[0x00, 0x01]); resp.extend_from_slice(&[0x00, 0x00, 0x01, 0x2c]); resp.extend_from_slice(&[0x00, 0x04]); resp.extend_from_slice(&[203, 0, 113, 1]);
1200 }
1201 resp
1202 }
1203
1204 fn spawn_upstream(swallow: bool, ancount: u16) -> (SocketAddr, std::thread::JoinHandle<()>) {
1207 let sock = UdpSocket::bind("127.0.0.1:0").unwrap();
1208 let addr = sock.local_addr().unwrap();
1209 sock.set_read_timeout(Some(Duration::from_millis(2000)))
1210 .unwrap();
1211 let h = std::thread::spawn(move || {
1212 let mut buf = [0u8; 1500];
1213 while let Ok((n, peer)) = sock.recv_from(&mut buf) {
1214 if swallow {
1215 continue;
1217 }
1218 let resp = build_a_response(&buf[..n], ancount);
1219 let _ = sock.send_to(&resp, peer);
1220 }
1221 });
1222 (addr, h)
1223 }
1224
1225 fn proxy_cfg(allowlist: Vec<&str>, upstream: SocketAddr) -> DnsProxyConfig {
1226 DnsProxyConfig {
1227 bind_addr: "127.0.0.1:0".parse().unwrap(),
1228 upstream_addr: upstream,
1229 hostname_allowlist: allowlist.into_iter().map(String::from).collect(),
1230 allowed_query_types: vec![],
1231 cell_id: "test-cell".into(),
1232 run_id: "test-run".into(),
1233 policy_digest: None,
1234 keyset_id: Some("test-keyset".into()),
1235 issuer_kid: Some("test-kid-001".into()),
1236 correlation_id: None,
1237 upstream_resolver_id: "resolver-test-001".into(),
1238 upstream_timeout: Duration::from_millis(300),
1239 tcp_idle_timeout: Duration::ZERO,
1243 dnssec_validator: None,
1251 transport: UpstreamTransport::Do53Udp,
1255 upstream_extras: UpstreamExtras::default(),
1256 }
1257 }
1258
1259 #[test]
1260 fn proxy_allows_query_in_allowlist() {
1261 let (upstream_addr, _h) = spawn_upstream(false, 2);
1262 let listener = UdpSocket::bind("127.0.0.1:0").unwrap();
1263 listener
1264 .set_read_timeout(Some(Duration::from_millis(150)))
1265 .unwrap();
1266 let listen_addr = listener.local_addr().unwrap();
1267 let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
1268 let cfg = proxy_cfg(vec!["api.example.com"], upstream_addr);
1269 let emitter = std::sync::Arc::new(MemEmitter::default());
1270 let shutdown = std::sync::Arc::new(AtomicBool::new(false));
1271
1272 let proxy_handle = {
1273 let emitter = emitter.clone();
1274 let shutdown = shutdown.clone();
1275 let cfg = cfg.clone();
1276 std::thread::spawn(move || {
1277 let _ = run_one_shot(&cfg, &listener, &upstream_sock, &*emitter, &shutdown);
1278 })
1279 };
1280
1281 let client = UdpSocket::bind("127.0.0.1:0").unwrap();
1282 client
1283 .set_read_timeout(Some(Duration::from_secs(1)))
1284 .unwrap();
1285 let q = build_query_packet("api.example.com", 1);
1286 client.send_to(&q, listen_addr).unwrap();
1287 let mut rb = [0u8; 1500];
1288 let (n, _) = client.recv_from(&mut rb).unwrap();
1289 assert!(n > DNS_HEADER_LEN);
1290 let rcode = rb[3] & 0x0f;
1291 assert_eq!(rcode, 0, "expected NOERROR on allow path");
1292
1293 shutdown.store(true, Ordering::SeqCst);
1294 proxy_handle.join().unwrap();
1295 let evs = emitter.events.lock().unwrap();
1296 assert_eq!(evs.len(), 2);
1300 assert_eq!(evs[0].ty, "dev.cellos.events.cell.dns.v1.query_permitted");
1301 let permitted_data = evs[0].data.as_ref().unwrap();
1302 assert_eq!(permitted_data["queryName"], "api.example.com");
1303 assert_eq!(permitted_data["queryType"], "A");
1304 assert_eq!(permitted_data["resolver"], "resolver-test-001");
1305 let data = evs[1].data.as_ref().unwrap();
1306 assert_eq!(data["decision"], "allow");
1307 assert_eq!(data["reasonCode"], "allowed_by_allowlist");
1308 assert_eq!(data["responseRcode"], 0);
1309 assert_eq!(data["upstreamResolverId"], "resolver-test-001");
1310 assert_eq!(data["responseTargetCount"], 2);
1311 }
1312
1313 #[test]
1314 fn proxy_denies_query_not_in_allowlist() {
1315 let (upstream_addr, _h) = spawn_upstream(false, 0);
1316 let listener = UdpSocket::bind("127.0.0.1:0").unwrap();
1317 listener
1318 .set_read_timeout(Some(Duration::from_millis(150)))
1319 .unwrap();
1320 let listen_addr = listener.local_addr().unwrap();
1321 let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
1322 let cfg = proxy_cfg(vec!["api.example.com"], upstream_addr);
1323 let emitter = std::sync::Arc::new(MemEmitter::default());
1324 let shutdown = std::sync::Arc::new(AtomicBool::new(false));
1325
1326 let proxy_handle = {
1327 let emitter = emitter.clone();
1328 let shutdown = shutdown.clone();
1329 let cfg = cfg.clone();
1330 std::thread::spawn(move || {
1331 let _ = run_one_shot(&cfg, &listener, &upstream_sock, &*emitter, &shutdown);
1332 })
1333 };
1334
1335 let client = UdpSocket::bind("127.0.0.1:0").unwrap();
1336 client
1337 .set_read_timeout(Some(Duration::from_secs(1)))
1338 .unwrap();
1339 let q = build_query_packet("blocked.example.com", 1);
1340 client.send_to(&q, listen_addr).unwrap();
1341 let mut rb = [0u8; 1500];
1342 let (n, _) = client.recv_from(&mut rb).unwrap();
1343 let rcode = rb[3] & 0x0f;
1344 assert_eq!(
1345 rcode, 5,
1346 "expected REFUSED on deny path, got rcode={rcode} n={n}"
1347 );
1348
1349 shutdown.store(true, Ordering::SeqCst);
1350 proxy_handle.join().unwrap();
1351 let evs = emitter.events.lock().unwrap();
1352 assert_eq!(evs.len(), 2, "aggregate + short-form refusal event");
1356 let data = evs[0].data.as_ref().unwrap();
1357 assert_eq!(data["decision"], "deny");
1358 assert_eq!(data["reasonCode"], "denied_not_in_allowlist");
1359 assert_eq!(data["responseRcode"], 5);
1360 assert_eq!(evs[1].ty, "dev.cellos.events.cell.dns.v1.query_refused");
1361 let refused = evs[1].data.as_ref().unwrap();
1362 assert_eq!(refused["reason"], "denied_not_in_allowlist");
1363 assert_eq!(refused["queryName"], "blocked.example.com");
1364 }
1365
1366 #[test]
1367 fn proxy_wildcard_matches_subdomain_only() {
1368 assert!(hostname_in_allowlist(
1370 "foo.cdn.example.com",
1371 &["*.cdn.example.com".into()]
1372 ));
1373 assert!(hostname_in_allowlist(
1374 "deep.foo.cdn.example.com",
1375 &["*.cdn.example.com".into()]
1376 ));
1377 assert!(!hostname_in_allowlist(
1378 "cdn.example.com",
1379 &["*.cdn.example.com".into()]
1380 ));
1381 assert!(!hostname_in_allowlist(
1382 "evil-cdn.example.com",
1383 &["*.cdn.example.com".into()]
1384 ));
1385 assert!(hostname_in_allowlist(
1387 "api.example.com",
1388 &["api.example.com".into()]
1389 ));
1390 assert!(!hostname_in_allowlist(
1391 "x.api.example.com",
1392 &["api.example.com".into()]
1393 ));
1394 }
1395
1396 #[test]
1397 fn proxy_denies_disallowed_query_type() {
1398 let (upstream_addr, _h) = spawn_upstream(false, 0);
1399 let listener = UdpSocket::bind("127.0.0.1:0").unwrap();
1400 listener
1401 .set_read_timeout(Some(Duration::from_millis(150)))
1402 .unwrap();
1403 let listen_addr = listener.local_addr().unwrap();
1404 let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
1405 let mut cfg = proxy_cfg(vec!["api.example.com"], upstream_addr);
1406 cfg.allowed_query_types = vec![DnsQueryType::A, DnsQueryType::AAAA];
1408 let emitter = std::sync::Arc::new(MemEmitter::default());
1409 let shutdown = std::sync::Arc::new(AtomicBool::new(false));
1410
1411 let proxy_handle = {
1412 let emitter = emitter.clone();
1413 let shutdown = shutdown.clone();
1414 let cfg = cfg.clone();
1415 std::thread::spawn(move || {
1416 let _ = run_one_shot(&cfg, &listener, &upstream_sock, &*emitter, &shutdown);
1417 })
1418 };
1419
1420 let client = UdpSocket::bind("127.0.0.1:0").unwrap();
1421 client
1422 .set_read_timeout(Some(Duration::from_secs(1)))
1423 .unwrap();
1424 let q = build_query_packet("api.example.com", 16); client.send_to(&q, listen_addr).unwrap();
1426 let mut rb = [0u8; 1500];
1427 let (_n, _) = client.recv_from(&mut rb).unwrap();
1428 assert_eq!(rb[3] & 0x0f, 5);
1429
1430 shutdown.store(true, Ordering::SeqCst);
1431 proxy_handle.join().unwrap();
1432 let evs = emitter.events.lock().unwrap();
1433 assert_eq!(evs.len(), 2);
1435 let data = evs[0].data.as_ref().unwrap();
1436 assert_eq!(data["reasonCode"], "denied_query_type");
1437 assert_eq!(evs[1].ty, "dev.cellos.events.cell.dns.v1.query_refused");
1438 assert_eq!(evs[1].data.as_ref().unwrap()["reason"], "denied_query_type");
1439 }
1440
1441 #[test]
1442 fn proxy_emits_event_per_query() {
1443 let (upstream_addr, _h) = spawn_upstream(false, 1);
1444 let listener = UdpSocket::bind("127.0.0.1:0").unwrap();
1445 listener
1446 .set_read_timeout(Some(Duration::from_millis(150)))
1447 .unwrap();
1448 let listen_addr = listener.local_addr().unwrap();
1449 let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
1450 let cfg = proxy_cfg(vec!["api.example.com", "*.cdn.example.com"], upstream_addr);
1451 let emitter = std::sync::Arc::new(MemEmitter::default());
1452 let shutdown = std::sync::Arc::new(AtomicBool::new(false));
1453
1454 let proxy_handle = {
1455 let emitter = emitter.clone();
1456 let shutdown = shutdown.clone();
1457 let cfg = cfg.clone();
1458 std::thread::spawn(move || {
1459 let _ = run_one_shot(&cfg, &listener, &upstream_sock, &*emitter, &shutdown);
1460 })
1461 };
1462
1463 let client = UdpSocket::bind("127.0.0.1:0").unwrap();
1464 client
1465 .set_read_timeout(Some(Duration::from_secs(1)))
1466 .unwrap();
1467 for name in [
1468 "api.example.com",
1469 "img.cdn.example.com",
1470 "blocked.example.com",
1471 ] {
1472 let q = build_query_packet(name, 1);
1473 client.send_to(&q, listen_addr).unwrap();
1474 let mut rb = [0u8; 1500];
1475 let _ = client.recv_from(&mut rb).unwrap();
1476 }
1477
1478 shutdown.store(true, Ordering::SeqCst);
1479 proxy_handle.join().unwrap();
1480 let evs = emitter.events.lock().unwrap();
1481 assert_eq!(evs.len(), 6);
1485 assert_eq!(evs[0].ty, "dev.cellos.events.cell.dns.v1.query_permitted");
1487 assert_eq!(
1488 evs[0].data.as_ref().unwrap()["queryName"],
1489 "api.example.com"
1490 );
1491 let data1_agg = evs[1].data.as_ref().unwrap();
1492 assert_eq!(data1_agg["decision"], "allow");
1493 assert_eq!(data1_agg["queryName"], "api.example.com");
1494 assert_eq!(data1_agg["upstreamResolverId"], "resolver-test-001");
1495 assert_eq!(evs[2].ty, "dev.cellos.events.cell.dns.v1.query_permitted");
1497 let data3_agg = evs[3].data.as_ref().unwrap();
1498 assert_eq!(data3_agg["decision"], "allow");
1499 assert_eq!(data3_agg["queryName"], "img.cdn.example.com");
1500 let data4_agg = evs[4].data.as_ref().unwrap();
1502 assert_eq!(data4_agg["decision"], "deny");
1503 assert_eq!(data4_agg["queryName"], "blocked.example.com");
1504 assert_eq!(evs[5].ty, "dev.cellos.events.cell.dns.v1.query_refused");
1505 assert_eq!(
1506 evs[5].data.as_ref().unwrap()["reason"],
1507 "denied_not_in_allowlist"
1508 );
1509 }
1510
1511 #[test]
1512 fn proxy_returns_servfail_on_upstream_timeout() {
1513 let (upstream_addr, _h) = spawn_upstream(true, 0); let listener = UdpSocket::bind("127.0.0.1:0").unwrap();
1515 listener
1516 .set_read_timeout(Some(Duration::from_millis(150)))
1517 .unwrap();
1518 let listen_addr = listener.local_addr().unwrap();
1519 let upstream_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
1520 let mut cfg = proxy_cfg(vec!["api.example.com"], upstream_addr);
1521 cfg.upstream_timeout = Duration::from_millis(120);
1522 let emitter = std::sync::Arc::new(MemEmitter::default());
1523 let shutdown = std::sync::Arc::new(AtomicBool::new(false));
1524
1525 let proxy_handle = {
1526 let emitter = emitter.clone();
1527 let shutdown = shutdown.clone();
1528 let cfg = cfg.clone();
1529 std::thread::spawn(move || {
1530 let _ = run_one_shot(&cfg, &listener, &upstream_sock, &*emitter, &shutdown);
1531 })
1532 };
1533
1534 let client = UdpSocket::bind("127.0.0.1:0").unwrap();
1535 client
1536 .set_read_timeout(Some(Duration::from_secs(2)))
1537 .unwrap();
1538 let q = build_query_packet("api.example.com", 1);
1539 client.send_to(&q, listen_addr).unwrap();
1540 let mut rb = [0u8; 1500];
1541 let (_n, _) = client.recv_from(&mut rb).unwrap();
1542 assert_eq!(rb[3] & 0x0f, 2, "expected SERVFAIL on upstream timeout");
1543
1544 shutdown.store(true, Ordering::SeqCst);
1545 proxy_handle.join().unwrap();
1546 let evs = emitter.events.lock().unwrap();
1547 assert_eq!(evs.len(), 2);
1553 assert_eq!(evs[0].ty, "dev.cellos.events.cell.dns.v1.query_permitted");
1554 let data = evs[1].data.as_ref().unwrap();
1555 assert_eq!(data["reasonCode"], "upstream_failure");
1556 assert_eq!(data["responseRcode"], 2);
1557 }
1558
1559 #[test]
1560 fn parse_response_target_count_counts_a_records() {
1561 let q = build_query_packet("api.example.com", 1);
1562 let r = build_a_response(&q, 3);
1563 assert_eq!(parse_response_target_count(&r, 1), 3);
1564 assert_eq!(parse_response_target_count(&r, 16), 0);
1566 }
1567}