1use hickory_client::client::{Client, SyncClient};
4use hickory_client::udp::UdpClientConnection;
5use hickory_server::authority::{Catalog, ZoneType};
6use hickory_server::proto::rr::rdata::{A, AAAA};
7use hickory_server::proto::rr::{DNSClass, LowerName, Name, RData, Record, RecordType};
8use hickory_server::resolver::config::NameServerConfigGroup;
9use hickory_server::server::ServerFuture;
10use hickory_server::store::in_memory::InMemoryAuthority;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
14use std::str::FromStr;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::net::{TcpListener, UdpSocket};
18use tokio::sync::RwLock;
19
20pub const DEFAULT_DNS_PORT: u16 = 15353;
22
23const STANDARD_DNS_PORT: u16 = 53;
26
27const PUBLIC_FALLBACK_UPSTREAMS: [IpAddr; 2] = [
36 IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
37 IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
38];
39
40pub(crate) const RESOLV_CONF_PATH: &str = "/etc/resolv.conf";
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct DnsConfig {
46 pub zone: String,
48 pub port: u16,
50 pub bind_addr: IpAddr,
52 #[serde(default)]
69 pub upstreams: Option<Vec<SocketAddr>>,
70}
71
72impl DnsConfig {
73 #[must_use]
75 pub fn new(zone: &str, bind_addr: IpAddr) -> Self {
76 Self {
77 zone: zone.to_string(),
78 port: DEFAULT_DNS_PORT,
79 bind_addr,
80 upstreams: None,
81 }
82 }
83
84 #[must_use]
86 pub fn with_port(mut self, port: u16) -> Self {
87 self.port = port;
88 self
89 }
90
91 #[must_use]
97 pub fn with_upstreams(mut self, upstreams: Vec<SocketAddr>) -> Self {
98 self.upstreams = Some(upstreams);
99 self
100 }
101}
102
103fn is_unusable_upstream(ip: IpAddr) -> bool {
116 match ip {
117 IpAddr::V4(v4) => v4.is_loopback() || v4.is_unspecified(),
118 IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified(),
119 }
120}
121
122fn parse_resolv_conf(contents: &str) -> Vec<SocketAddr> {
133 let mut out: Vec<SocketAddr> = Vec::new();
134 for line in contents.lines() {
135 let line = line.trim();
136 if line.is_empty() || line.starts_with('#') || line.starts_with(';') {
137 continue;
138 }
139 let mut parts = line.split_whitespace();
140 if parts.next() != Some("nameserver") {
141 continue;
142 }
143 let Some(addr_str) = parts.next() else {
144 continue;
145 };
146 let addr_str = addr_str.split('%').next().unwrap_or(addr_str);
149 let Ok(ip) = IpAddr::from_str(addr_str) else {
150 continue;
151 };
152 if is_unusable_upstream(ip) {
153 continue;
154 }
155 let sock = SocketAddr::new(ip, STANDARD_DNS_PORT);
156 if !out.contains(&sock) {
157 out.push(sock);
158 }
159 }
160 out
161}
162
163pub(crate) fn resolve_upstreams(config: &DnsConfig, resolv_conf_path: &str) -> Vec<SocketAddr> {
185 if let Some(explicit) = &config.upstreams {
186 if !explicit.is_empty() {
187 tracing::debug!(
188 count = explicit.len(),
189 "using explicit overlay DNS upstreams from config (host detection skipped)",
190 );
191 return explicit.clone();
192 }
193 }
194
195 let detected = match std::fs::read_to_string(resolv_conf_path) {
196 Ok(contents) => parse_resolv_conf(&contents),
197 Err(e) => {
198 tracing::warn!(
199 path = resolv_conf_path,
200 error = %e,
201 "could not read host resolv.conf for overlay DNS upstream detection",
202 );
203 Vec::new()
204 }
205 };
206
207 if detected.is_empty() {
208 let fallback: Vec<SocketAddr> = PUBLIC_FALLBACK_UPSTREAMS
209 .iter()
210 .map(|ip| SocketAddr::new(*ip, STANDARD_DNS_PORT))
211 .collect();
212 tracing::warn!(
213 fallback = ?fallback,
214 "no usable host DNS upstreams found (resolv.conf empty, missing, or stub-only); \
215 falling back to public resolvers for overlay forwarding",
216 );
217 fallback
218 } else {
219 tracing::info!(
220 upstreams = ?detected,
221 "overlay DNS forwarding to host upstreams (loopback/stub filtered out)",
222 );
223 detected
224 }
225}
226
227pub(crate) fn build_forward_resolver(
249 upstreams: &[SocketAddr],
250) -> Result<hickory_server::resolver::TokioAsyncResolver, DnsError> {
251 use hickory_server::resolver::config::{ResolverConfig, ResolverOpts};
252
253 if upstreams.is_empty() {
254 return Err(DnsError::Server("no upstreams for forward resolver".into()));
255 }
256
257 let mut group = NameServerConfigGroup::new();
258 let mut by_port: std::collections::BTreeMap<u16, Vec<IpAddr>> =
259 std::collections::BTreeMap::new();
260 for addr in upstreams {
261 by_port.entry(addr.port()).or_default().push(addr.ip());
262 }
263 for (port, ips) in by_port {
264 group.merge(NameServerConfigGroup::from_ips_clear(&ips, port, true));
267 }
268
269 let mut options = ResolverOpts::default();
273 options.timeout = Duration::from_secs(2);
274 options.attempts = 2;
275 options.preserve_intermediates = true;
277
278 let config = ResolverConfig::from_parts(None, vec![], group);
279 Ok(hickory_server::resolver::TokioAsyncResolver::tokio(
280 config, options,
281 ))
282}
283
284struct ForwardingCatalog {
306 catalog: Catalog,
307 zone_origin: LowerName,
308 static_origin: LowerName,
313 resolver: Option<Arc<hickory_server::resolver::TokioAsyncResolver>>,
314}
315
316impl ForwardingCatalog {
317 fn forward_answer_response<'a>(
319 request: &'a hickory_server::server::Request,
320 answers: &'a [Record],
321 ) -> hickory_server::authority::MessageResponse<
322 'a,
323 'a,
324 std::slice::Iter<'a, Record>,
325 std::iter::Empty<&'a Record>,
326 std::iter::Empty<&'a Record>,
327 std::iter::Empty<&'a Record>,
328 > {
329 use hickory_server::authority::MessageResponseBuilder;
330 use hickory_server::proto::op::ResponseCode;
331
332 let mut header = hickory_server::proto::op::Header::response_from_request(request.header());
333 header.set_recursion_available(true);
334 header.set_response_code(ResponseCode::NoError);
335 header.set_authoritative(false);
337
338 MessageResponseBuilder::from_message_request(request).build(
339 header,
340 answers.iter(),
341 std::iter::empty(),
342 std::iter::empty(),
343 std::iter::empty(),
344 )
345 }
346
347 fn forward_code_response(
350 request: &hickory_server::server::Request,
351 code: hickory_server::proto::op::ResponseCode,
352 ) -> hickory_server::authority::MessageResponse<
353 '_,
354 '_,
355 impl Iterator<Item = &Record> + Send,
356 impl Iterator<Item = &Record> + Send,
357 impl Iterator<Item = &Record> + Send,
358 impl Iterator<Item = &Record> + Send,
359 > {
360 use hickory_server::authority::MessageResponseBuilder;
361 MessageResponseBuilder::from_message_request(request).error_msg(request.header(), code)
362 }
363
364 async fn forward<R: hickory_server::server::ResponseHandler>(
367 &self,
368 resolver: &hickory_server::resolver::TokioAsyncResolver,
369 request: &hickory_server::server::Request,
370 mut response_handle: R,
371 ) -> hickory_server::server::ResponseInfo {
372 use hickory_server::proto::op::ResponseCode;
373 use hickory_server::resolver::error::ResolveErrorKind;
374
375 let query = request.request_info().query;
376 let name = Name::from(query.name());
377 let rtype = query.query_type();
378
379 match resolver.lookup(name, rtype).await {
380 Ok(lookup) => {
381 let records: Vec<Record> = lookup.records().to_vec();
382 let response = Self::forward_answer_response(request, &records);
383 Self::send_or_servfail(&mut response_handle, response).await
384 }
385 Err(e) => {
386 let code = match e.kind() {
387 ResolveErrorKind::NoRecordsFound { response_code, .. }
389 if *response_code == ResponseCode::NXDomain =>
390 {
391 ResponseCode::NXDomain
392 }
393 ResolveErrorKind::NoRecordsFound { response_code, .. }
395 if *response_code == ResponseCode::NoError =>
396 {
397 ResponseCode::NoError
398 }
399 _ => {
403 tracing::debug!(error = %e, "overlay DNS upstream forward failed; SERVFAIL");
404 ResponseCode::ServFail
405 }
406 };
407 let response = Self::forward_code_response(request, code);
408 Self::send_or_servfail(&mut response_handle, response).await
409 }
410 }
411 }
412
413 async fn send_or_servfail<'a, R, A, N, S, D>(
416 response_handle: &mut R,
417 response: hickory_server::authority::MessageResponse<'_, 'a, A, N, S, D>,
418 ) -> hickory_server::server::ResponseInfo
419 where
420 R: hickory_server::server::ResponseHandler,
421 A: Iterator<Item = &'a Record> + Send + 'a,
422 N: Iterator<Item = &'a Record> + Send + 'a,
423 S: Iterator<Item = &'a Record> + Send + 'a,
424 D: Iterator<Item = &'a Record> + Send + 'a,
425 {
426 match response_handle.send_response(response).await {
427 Ok(info) => info,
428 Err(e) => {
429 tracing::error!(error = %e, "failed to send overlay DNS forward response");
430 let mut header = hickory_server::proto::op::Header::new();
431 header.set_response_code(hickory_server::proto::op::ResponseCode::ServFail);
432 header.into()
433 }
434 }
435 }
436}
437
438#[async_trait::async_trait]
439impl hickory_server::server::RequestHandler for ForwardingCatalog {
440 async fn handle_request<R: hickory_server::server::ResponseHandler>(
441 &self,
442 request: &hickory_server::server::Request,
443 response_handle: R,
444 ) -> hickory_server::server::ResponseInfo {
445 let query_name = request.request_info().query.name().clone();
449 let is_overlay =
450 self.zone_origin.zone_of(&query_name) || self.static_origin.zone_of(&query_name);
451
452 match (&self.resolver, is_overlay) {
453 (Some(resolver), false) => self.forward(resolver, request, response_handle).await,
454 _ => self.catalog.handle_request(request, response_handle).await,
455 }
456 }
457}
458
459#[must_use]
464pub fn peer_hostname(ip: IpAddr) -> String {
465 match ip {
466 IpAddr::V4(v4) => {
467 let octets = v4.octets();
468 format!("node-{}-{}", octets[2], octets[3])
469 }
470 IpAddr::V6(v6) => {
471 let segments = v6.segments();
472 let last_segment = segments[7];
473 format!("node-{last_segment:04x}")
474 }
475 }
476}
477
478#[derive(Debug, thiserror::Error)]
480pub enum DnsError {
481 #[error("Invalid domain name: {0}")]
482 InvalidName(String),
483
484 #[error("DNS server error: {0}")]
485 Server(String),
486
487 #[error("DNS client error: {0}")]
488 Client(String),
489
490 #[error("IO error: {0}")]
491 Io(#[from] std::io::Error),
492
493 #[error("Record not found: {0}")]
494 NotFound(String),
495}
496
497#[derive(Clone)]
501pub struct DnsHandle {
502 authority: Arc<InMemoryAuthority>,
503 zone_origin: Name,
504 serial: Arc<RwLock<u32>>,
505}
506
507impl DnsHandle {
508 pub async fn add_record(&self, hostname: &str, ip: IpAddr) -> Result<(), DnsError> {
516 let fqdn = if hostname.ends_with('.') {
518 Name::from_str(hostname)
519 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?
520 } else {
521 let name = Name::from_str(hostname)
523 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
524 name.append_domain(&self.zone_origin)
525 .map_err(|e| DnsError::InvalidName(format!("Failed to append zone: {e}")))?
526 };
527
528 let rdata = match ip {
530 IpAddr::V4(v4) => RData::A(A::from(v4)),
531 IpAddr::V6(v6) => RData::AAAA(AAAA::from(v6)),
532 };
533 let record = Record::from_rdata(fqdn, 300, rdata); let serial = {
537 let mut s = self.serial.write().await;
538 let current = *s;
539 *s = s.wrapping_add(1);
540 current
541 };
542
543 self.authority.upsert(record, serial).await;
545
546 Ok(())
547 }
548
549 pub async fn remove_record(&self, hostname: &str) -> Result<bool, DnsError> {
557 let fqdn = if hostname.ends_with('.') {
558 Name::from_str(hostname)
559 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?
560 } else {
561 let name = Name::from_str(hostname)
562 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
563 name.append_domain(&self.zone_origin)
564 .map_err(|e| DnsError::InvalidName(format!("Failed to append zone: {e}")))?
565 };
566
567 let serial = {
568 let mut s = self.serial.write().await;
569 let current = *s;
570 *s = s.wrapping_add(1);
571 current
572 };
573
574 let a_record = Record::with(fqdn.clone(), RecordType::A, 0);
578 self.authority.upsert(a_record, serial).await;
579
580 let aaaa_record = Record::with(fqdn.clone(), RecordType::AAAA, 0);
581 self.authority.upsert(aaaa_record, serial).await;
582
583 Ok(true)
584 }
585
586 #[must_use]
588 pub fn zone_origin(&self) -> &Name {
589 &self.zone_origin
590 }
591
592 pub async fn lookup_a(&self, fqdn: &str) -> Option<IpAddr> {
604 use hickory_server::authority::{Authority, LookupOptions};
605
606 let name = Name::from_str(fqdn).ok()?;
607 let lower = LowerName::from(name);
608 let lookup = self
609 .authority
610 .lookup(&lower, RecordType::A, LookupOptions::default())
611 .await
612 .ok()?;
613 lookup.iter().find_map(|record| match record.data() {
614 Some(RData::A(a)) => Some(IpAddr::V4((*a).into())),
615 _ => None,
616 })
617 }
618}
619
620pub const STATIC_ZONE_ORIGIN: &str = "zlayer.local.";
625
626pub struct DnsServer {
628 listen_addr: SocketAddr,
629 authority: Arc<InMemoryAuthority>,
630 zone_origin: Name,
631 static_authority: Arc<InMemoryAuthority>,
633 static_origin: Name,
635 serial: Arc<RwLock<u32>>,
636 upstreams: Vec<SocketAddr>,
648}
649
650impl DnsServer {
651 pub fn new(listen_addr: SocketAddr, zone: &str) -> Result<Self, DnsError> {
661 let upstreams =
662 resolve_upstreams(&DnsConfig::new(zone, listen_addr.ip()), RESOLV_CONF_PATH);
663 Self::new_with_upstreams(listen_addr, zone, upstreams)
664 }
665
666 pub fn new_with_upstreams(
676 listen_addr: SocketAddr,
677 zone: &str,
678 upstreams: Vec<SocketAddr>,
679 ) -> Result<Self, DnsError> {
680 let zone_origin =
681 Name::from_str(zone).map_err(|e| DnsError::InvalidName(format!("{zone}: {e}")))?;
682
683 let authority = Arc::new(InMemoryAuthority::empty(
686 zone_origin.clone(),
687 ZoneType::Primary,
688 false,
689 ));
690
691 let static_origin = Name::from_str(STATIC_ZONE_ORIGIN)
703 .map_err(|e| DnsError::InvalidName(format!("{STATIC_ZONE_ORIGIN}: {e}")))?;
704 let static_authority = if static_origin == zone_origin {
705 Arc::clone(&authority)
706 } else {
707 Arc::new(InMemoryAuthority::empty(
708 static_origin.clone(),
709 ZoneType::Primary,
710 false,
711 ))
712 };
713
714 Ok(Self {
715 listen_addr,
716 authority,
717 zone_origin,
718 static_authority,
719 static_origin,
720 serial: Arc::new(RwLock::new(1)),
721 upstreams,
722 })
723 }
724
725 pub fn from_config(config: &DnsConfig) -> Result<Self, DnsError> {
734 let listen_addr = SocketAddr::new(config.bind_addr, config.port);
735 let upstreams = resolve_upstreams(config, RESOLV_CONF_PATH);
736 Self::new_with_upstreams(listen_addr, &config.zone, upstreams)
737 }
738
739 #[must_use]
741 pub fn upstreams(&self) -> &[SocketAddr] {
742 &self.upstreams
743 }
744
745 fn build_catalog(
756 zone_origin: Name,
757 authority: Arc<InMemoryAuthority>,
758 static_origin: Name,
759 static_authority: Arc<InMemoryAuthority>,
760 upstreams: &[SocketAddr],
761 ) -> ForwardingCatalog {
762 let lower_origin = LowerName::from(zone_origin.clone());
763 let lower_static_origin = LowerName::from(static_origin.clone());
764
765 let mut catalog = Catalog::new();
766 catalog.upsert(zone_origin.into(), Box::new(authority));
768 if lower_static_origin != lower_origin {
773 catalog.upsert(static_origin.into(), Box::new(static_authority));
774 }
775
776 let resolver = if upstreams.is_empty() {
777 None
778 } else {
779 match build_forward_resolver(upstreams) {
780 Ok(r) => {
781 tracing::debug!(
782 upstreams = ?upstreams,
783 "overlay DNS forwarder ready for non-overlay queries",
784 );
785 Some(Arc::new(r))
786 }
787 Err(e) => {
788 tracing::error!(
789 error = %e,
790 "failed to build overlay DNS forwarder; non-overlay queries \
791 will be refused (overlay zone still served)",
792 );
793 None
794 }
795 }
796 };
797
798 ForwardingCatalog {
799 catalog,
800 zone_origin: lower_origin,
801 static_origin: lower_static_origin,
802 resolver,
803 }
804 }
805
806 #[must_use]
811 pub fn handle(&self) -> DnsHandle {
812 DnsHandle {
813 authority: Arc::clone(&self.authority),
814 zone_origin: self.zone_origin.clone(),
815 serial: Arc::clone(&self.serial),
816 }
817 }
818
819 pub async fn add_record(&self, hostname: &str, ip: IpAddr) -> Result<(), DnsError> {
827 self.handle().add_record(hostname, ip).await
828 }
829
830 pub async fn add_static_record(&self, hostname: &str, ip: IpAddr) -> Result<(), DnsError> {
843 let fqdn = if hostname.ends_with('.') {
844 Name::from_str(hostname)
845 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?
846 } else {
847 let name = Name::from_str(hostname)
848 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
849 name.append_domain(&self.static_origin)
850 .map_err(|e| DnsError::InvalidName(format!("Failed to append static zone: {e}")))?
851 };
852
853 let rdata = match ip {
854 IpAddr::V4(v4) => RData::A(A::from(v4)),
855 IpAddr::V6(v6) => RData::AAAA(AAAA::from(v6)),
856 };
857 let record = Record::from_rdata(fqdn, 300, rdata);
858
859 let serial = {
860 let mut s = self.serial.write().await;
861 let current = *s;
862 *s = s.wrapping_add(1);
863 current
864 };
865 self.static_authority.upsert(record, serial).await;
866 Ok(())
867 }
868
869 #[cfg(test)]
873 pub(crate) async fn lookup_static_a(&self, fqdn: &str) -> Option<IpAddr> {
874 use hickory_server::authority::{Authority, LookupOptions};
875
876 let name = Name::from_str(fqdn).ok()?;
877 let lower = LowerName::from(name);
878 let lookup = self
879 .static_authority
880 .lookup(&lower, RecordType::A, LookupOptions::default())
881 .await
882 .ok()?;
883 lookup.iter().find_map(|record| match record.data() {
884 Some(RData::A(a)) => Some(IpAddr::V4((*a).into())),
885 _ => None,
886 })
887 }
888
889 pub async fn remove_record(&self, hostname: &str) -> Result<bool, DnsError> {
895 self.handle().remove_record(hostname).await
896 }
897
898 #[allow(clippy::unused_async)]
907 pub async fn start(self) -> Result<DnsHandle, DnsError> {
908 let handle = self.handle();
909 let listen_addr = self.listen_addr;
910 let zone_origin = self.zone_origin.clone();
911 let authority = Arc::clone(&self.authority);
912 let static_origin = self.static_origin.clone();
913 let static_authority = Arc::clone(&self.static_authority);
914 let upstreams = self.upstreams.clone();
915
916 tokio::spawn(async move {
918 if let Err(e) = Self::run_server(
919 listen_addr,
920 zone_origin,
921 authority,
922 static_origin,
923 static_authority,
924 upstreams,
925 )
926 .await
927 {
928 tracing::error!("DNS server error: {}", e);
929 }
930 });
931
932 Ok(handle)
933 }
934
935 #[allow(clippy::unused_async)]
945 pub async fn start_background(&self) -> Result<DnsHandle, DnsError> {
946 let handle = self.handle();
947 let listen_addr = self.listen_addr;
948 let zone_origin = self.zone_origin.clone();
949 let authority = Arc::clone(&self.authority);
950 let static_origin = self.static_origin.clone();
951 let static_authority = Arc::clone(&self.static_authority);
952 let upstreams = self.upstreams.clone();
953
954 tokio::spawn(async move {
955 if let Err(e) = Self::run_server(
956 listen_addr,
957 zone_origin,
958 authority,
959 static_origin,
960 static_authority,
961 upstreams,
962 )
963 .await
964 {
965 tracing::error!("DNS server error: {}", e);
966 }
967 });
968
969 Ok(handle)
970 }
971
972 #[allow(clippy::unused_async)]
1001 pub async fn bind_windows_fallback(&self, bind_ip: IpAddr) -> Result<DnsHandle, DnsError> {
1002 self.bind_secondary(SocketAddr::new(bind_ip, 53)).await
1003 }
1004
1005 #[allow(clippy::unused_async)]
1020 pub async fn bind_secondary(&self, listen_addr: SocketAddr) -> Result<DnsHandle, DnsError> {
1021 let handle = self.handle();
1022 let zone_origin = self.zone_origin.clone();
1023 let authority = Arc::clone(&self.authority);
1024 let static_origin = self.static_origin.clone();
1025 let static_authority = Arc::clone(&self.static_authority);
1026 let upstreams = self.upstreams.clone();
1027
1028 let udp_socket = UdpSocket::bind(listen_addr).await?;
1032 let tcp_listener = TcpListener::bind(listen_addr).await?;
1033
1034 tokio::spawn(async move {
1035 let catalog = Self::build_catalog(
1036 zone_origin,
1037 authority,
1038 static_origin,
1039 static_authority,
1040 &upstreams,
1041 );
1042 let mut server = ServerFuture::new(catalog);
1043 server.register_socket(udp_socket);
1044 server.register_listener(tcp_listener, Duration::from_secs(30));
1045 tracing::info!(
1046 addr = %listen_addr,
1047 "secondary DNS listener started",
1048 );
1049 if let Err(e) = server.block_until_done().await {
1050 tracing::error!("secondary DNS listener error: {}", e);
1051 }
1052 });
1053
1054 Ok(handle)
1055 }
1056
1057 async fn run_server(
1059 listen_addr: SocketAddr,
1060 zone_origin: Name,
1061 authority: Arc<InMemoryAuthority>,
1062 static_origin: Name,
1063 static_authority: Arc<InMemoryAuthority>,
1064 upstreams: Vec<SocketAddr>,
1065 ) -> Result<(), DnsError> {
1066 let catalog = Self::build_catalog(
1069 zone_origin,
1070 authority,
1071 static_origin,
1072 static_authority,
1073 &upstreams,
1074 );
1075
1076 let mut server = ServerFuture::new(catalog);
1078
1079 let udp_socket = UdpSocket::bind(listen_addr).await?;
1081 server.register_socket(udp_socket);
1082
1083 let tcp_listener = TcpListener::bind(listen_addr).await?;
1085 server.register_listener(tcp_listener, Duration::from_secs(30));
1086
1087 tracing::info!(addr = %listen_addr, "DNS server listening");
1088
1089 server
1091 .block_until_done()
1092 .await
1093 .map_err(|e| DnsError::Server(e.to_string()))?;
1094
1095 Ok(())
1096 }
1097
1098 #[must_use]
1100 pub fn listen_addr(&self) -> SocketAddr {
1101 self.listen_addr
1102 }
1103
1104 #[must_use]
1106 pub fn zone_origin(&self) -> &Name {
1107 &self.zone_origin
1108 }
1109}
1110
1111pub struct DnsClient {
1113 server_addr: SocketAddr,
1114}
1115
1116impl DnsClient {
1117 #[must_use]
1119 pub fn new(server_addr: SocketAddr) -> Self {
1120 Self { server_addr }
1121 }
1122
1123 pub fn query_a(&self, hostname: &str) -> Result<Option<Ipv4Addr>, DnsError> {
1129 let name = Name::from_str(hostname)
1130 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
1131
1132 let conn = UdpClientConnection::new(self.server_addr)
1133 .map_err(|e| DnsError::Client(e.to_string()))?;
1134
1135 let client = SyncClient::new(conn);
1136
1137 let response = client
1138 .query(&name, DNSClass::IN, RecordType::A)
1139 .map_err(|e| DnsError::Client(e.to_string()))?;
1140
1141 for answer in response.answers() {
1143 if let Some(RData::A(a_record)) = answer.data() {
1144 return Ok(Some((*a_record).into()));
1145 }
1146 }
1147
1148 Ok(None)
1149 }
1150
1151 pub fn query_aaaa(&self, hostname: &str) -> Result<Option<Ipv6Addr>, DnsError> {
1157 let name = Name::from_str(hostname)
1158 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
1159
1160 let conn = UdpClientConnection::new(self.server_addr)
1161 .map_err(|e| DnsError::Client(e.to_string()))?;
1162
1163 let client = SyncClient::new(conn);
1164
1165 let response = client
1166 .query(&name, DNSClass::IN, RecordType::AAAA)
1167 .map_err(|e| DnsError::Client(e.to_string()))?;
1168
1169 for answer in response.answers() {
1171 if let Some(RData::AAAA(aaaa_record)) = answer.data() {
1172 return Ok(Some((*aaaa_record).into()));
1173 }
1174 }
1175
1176 Ok(None)
1177 }
1178
1179 pub fn query_addr(&self, hostname: &str) -> Result<Option<IpAddr>, DnsError> {
1187 if let Ok(Some(v4)) = self.query_a(hostname) {
1189 return Ok(Some(IpAddr::V4(v4)));
1190 }
1191
1192 if let Ok(Some(v6)) = self.query_aaaa(hostname) {
1194 return Ok(Some(IpAddr::V6(v6)));
1195 }
1196
1197 Ok(None)
1198 }
1199}
1200
1201pub struct ServiceDiscovery {
1203 dns_server: SocketAddr,
1204 records: RwLock<HashMap<String, IpAddr>>,
1205}
1206
1207impl ServiceDiscovery {
1208 #[must_use]
1210 pub fn new(dns_server_addr: SocketAddr) -> Self {
1211 Self {
1212 dns_server: dns_server_addr,
1213 records: RwLock::new(HashMap::new()),
1214 }
1215 }
1216
1217 pub async fn register(&self, name: &str, ip: IpAddr) {
1219 let mut records = self.records.write().await;
1220 records.insert(name.to_string(), ip);
1221 }
1222
1223 pub async fn resolve(&self, name: &str) -> Option<IpAddr> {
1228 {
1230 let records = self.records.read().await;
1231 if let Some(ip) = records.get(name) {
1232 return Some(*ip);
1233 }
1234 }
1235
1236 let client = DnsClient::new(self.dns_server);
1238 if let Ok(Some(addr)) = client.query_addr(name) {
1239 return Some(addr);
1240 }
1241
1242 None
1243 }
1244
1245 pub async fn unregister(&self, name: &str) {
1247 let mut records = self.records.write().await;
1248 records.remove(name);
1249 }
1250
1251 pub async fn list_services(&self) -> Vec<String> {
1253 let records = self.records.read().await;
1254 records.keys().cloned().collect()
1255 }
1256
1257 pub fn dns_server(&self) -> SocketAddr {
1259 self.dns_server
1260 }
1261}
1262
1263#[cfg(any(target_os = "macos", test))]
1274fn scoped_resolver_body(node_ip: std::net::IpAddr, port: Option<u16>) -> String {
1275 use std::fmt::Write as _;
1276 let mut body = format!("nameserver {node_ip}\n");
1277 if let Some(p) = port {
1278 let _ = writeln!(body, "port {p}");
1279 }
1280 body
1281}
1282
1283#[cfg(target_os = "macos")]
1296#[allow(unsafe_code)]
1299pub fn write_scoped_resolver(
1300 zone: &str,
1301 node_ip: std::net::IpAddr,
1302 port: Option<u16>,
1303) -> std::io::Result<()> {
1304 if unsafe { libc::geteuid() } != 0 {
1307 return Err(std::io::Error::new(
1308 std::io::ErrorKind::PermissionDenied,
1309 "writing /etc/resolver requires root",
1310 ));
1311 }
1312
1313 std::fs::create_dir_all("/etc/resolver")?;
1314
1315 let body = scoped_resolver_body(node_ip, port);
1316 let path = format!("/etc/resolver/{zone}");
1317
1318 if let Ok(existing) = std::fs::read_to_string(&path) {
1321 if existing == body {
1322 return Ok(());
1323 }
1324 }
1325
1326 std::fs::write(&path, body)
1327}
1328
1329#[cfg(target_os = "macos")]
1337#[allow(unsafe_code)]
1340pub fn remove_scoped_resolver(zone: &str) -> std::io::Result<()> {
1341 if unsafe { libc::geteuid() } != 0 {
1342 return Err(std::io::Error::new(
1343 std::io::ErrorKind::PermissionDenied,
1344 "removing /etc/resolver entries requires root",
1345 ));
1346 }
1347
1348 let path = format!("/etc/resolver/{zone}");
1349 match std::fs::remove_file(&path) {
1350 Ok(()) => Ok(()),
1351 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
1352 Err(e) => Err(e),
1353 }
1354}
1355
1356#[cfg(test)]
1357mod tests {
1358 use super::*;
1359
1360 #[test]
1361 fn test_scoped_resolver_body_with_port() {
1362 let body = scoped_resolver_body(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)), Some(15353));
1363 assert_eq!(body, "nameserver 10.200.0.1\nport 15353\n");
1364 }
1365
1366 #[test]
1367 fn test_scoped_resolver_body_without_port() {
1368 let body = scoped_resolver_body(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)), None);
1370 assert_eq!(body, "nameserver 10.200.0.1\n");
1371 }
1372
1373 #[test]
1374 fn test_scoped_resolver_body_v6() {
1375 let body = scoped_resolver_body(IpAddr::V6("fd00::1".parse().unwrap()), Some(53));
1376 assert_eq!(body, "nameserver fd00::1\nport 53\n");
1377 }
1378
1379 #[test]
1380 fn test_peer_hostname_v4() {
1381 assert_eq!(
1383 peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1))),
1384 "node-0-1"
1385 );
1386 assert_eq!(
1387 peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 5))),
1388 "node-0-5"
1389 );
1390 assert_eq!(
1391 peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 1, 100))),
1392 "node-1-100"
1393 );
1394 assert_eq!(
1395 peer_hostname(IpAddr::V4(Ipv4Addr::new(192, 168, 255, 254))),
1396 "node-255-254"
1397 );
1398 }
1399
1400 #[test]
1401 fn test_peer_hostname_v6() {
1402 assert_eq!(
1404 peer_hostname(IpAddr::V6("fd00::1".parse().unwrap())),
1405 "node-0001"
1406 );
1407 assert_eq!(
1408 peer_hostname(IpAddr::V6("fd00::abcd".parse().unwrap())),
1409 "node-abcd"
1410 );
1411 assert_eq!(
1412 peer_hostname(IpAddr::V6("fd00:200::ffff".parse().unwrap())),
1413 "node-ffff"
1414 );
1415 assert_eq!(
1417 peer_hostname(IpAddr::V6("fd00::1:0".parse().unwrap())),
1418 "node-0000"
1419 );
1420 }
1421
1422 #[test]
1423 fn test_dns_config() {
1424 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)));
1425 assert_eq!(config.zone, "overlay.local.");
1426 assert_eq!(config.port, DEFAULT_DNS_PORT);
1427 assert_eq!(config.bind_addr, IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)));
1428
1429 let config = config.with_port(5353);
1431 assert_eq!(config.port, 5353);
1432 }
1433
1434 #[test]
1435 fn test_dns_config_serialization() {
1436 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)))
1437 .with_port(15353);
1438
1439 let json = serde_json::to_string(&config).unwrap();
1440 let deserialized: DnsConfig = serde_json::from_str(&json).unwrap();
1441
1442 assert_eq!(deserialized.zone, config.zone);
1443 assert_eq!(deserialized.port, config.port);
1444 assert_eq!(deserialized.bind_addr, config.bind_addr);
1445 }
1446
1447 #[tokio::test]
1448 async fn test_service_discovery_local_cache() {
1449 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1451 let discovery = ServiceDiscovery::new(addr);
1452
1453 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
1454 discovery.register("test-service", ip).await;
1455
1456 let resolved = discovery.resolve("test-service").await;
1457 assert_eq!(resolved, Some(ip));
1458
1459 discovery.unregister("test-service").await;
1461 let services = discovery.list_services().await;
1462 assert!(services.is_empty());
1463 }
1464
1465 #[test]
1466 fn test_dns_server_creation() {
1467 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1468 let server = DnsServer::new(addr, "overlay.local.");
1469
1470 assert!(server.is_ok());
1471 let server = server.unwrap();
1472 assert_eq!(server.listen_addr(), addr);
1473 assert_eq!(server.zone_origin().to_string(), "overlay.local.");
1474 }
1475
1476 #[test]
1477 fn test_dns_server_from_config() {
1478 let config =
1479 DnsConfig::new("test.local.", IpAddr::V4(Ipv4Addr::LOCALHOST)).with_port(15353);
1480 let server = DnsServer::from_config(&config);
1481
1482 assert!(server.is_ok());
1483 let server = server.unwrap();
1484 assert_eq!(server.listen_addr().port(), 15353);
1485 assert_eq!(server.zone_origin().to_string(), "test.local.");
1486 }
1487
1488 #[test]
1489 fn test_dns_server_invalid_zone() {
1490 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1491 let server = DnsServer::new(addr, "overlay.local.");
1493 assert!(server.is_ok());
1494 }
1495
1496 #[tokio::test]
1497 async fn test_dns_server_add_record() {
1498 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1499 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1500
1501 let result = server
1502 .add_record("myservice", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))
1503 .await;
1504 assert!(result.is_ok());
1505 }
1506
1507 #[tokio::test]
1508 async fn test_add_static_record_resolves_in_static_zone() {
1509 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1510 let server = DnsServer::new(addr, "mydeploy.local.").unwrap();
1513 let node_ip = IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1));
1514
1515 server.add_static_record("host", node_ip).await.unwrap();
1517 assert_eq!(
1518 server.lookup_static_a("host.zlayer.local.").await,
1519 Some(node_ip)
1520 );
1521
1522 server
1524 .add_static_record("daemon.mydeploy.zlayer.local.", node_ip)
1525 .await
1526 .unwrap();
1527 assert_eq!(
1528 server
1529 .lookup_static_a("daemon.mydeploy.zlayer.local.")
1530 .await,
1531 Some(node_ip)
1532 );
1533 }
1534
1535 #[tokio::test]
1536 async fn test_dns_handle_add_record() {
1537 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1538 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1539
1540 let handle = server.handle();
1542
1543 let result = handle
1544 .add_record("service1", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
1545 .await;
1546 assert!(result.is_ok());
1547
1548 let result = handle
1549 .add_record("service2", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)))
1550 .await;
1551 assert!(result.is_ok());
1552
1553 assert_eq!(handle.zone_origin().to_string(), "overlay.local.");
1555 }
1556
1557 #[test]
1558 fn test_dns_client_creation() {
1559 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53);
1560 let client = DnsClient::new(addr);
1561 assert_eq!(client.server_addr, addr);
1562 }
1563
1564 #[tokio::test]
1565 async fn test_dns_handle_add_aaaa_record() {
1566 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1567 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1568 let handle = server.handle();
1569
1570 let ipv6: IpAddr = "fd00::1".parse().unwrap();
1572 let result = handle.add_record("service-v6", ipv6).await;
1573 assert!(result.is_ok());
1574
1575 let ipv6_2: IpAddr = "fd00::abcd".parse().unwrap();
1577 let result = handle.add_record("service-v6-2", ipv6_2).await;
1578 assert!(result.is_ok());
1579 }
1580
1581 #[tokio::test]
1582 async fn test_dns_server_add_aaaa_record() {
1583 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1584 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1585
1586 let ipv6: IpAddr = "fd00::42".parse().unwrap();
1588 let result = server.add_record("myservice-v6", ipv6).await;
1589 assert!(result.is_ok());
1590 }
1591
1592 #[tokio::test]
1593 async fn test_dns_handle_remove_record_covers_both_types() {
1594 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1595 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1596 let handle = server.handle();
1597
1598 let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
1600 handle.add_record("dual-service", ipv4).await.unwrap();
1601
1602 let removed = handle.remove_record("dual-service").await.unwrap();
1604 assert!(removed);
1605
1606 let ipv6: IpAddr = "fd00::1".parse().unwrap();
1608 handle.add_record("v6-service", ipv6).await.unwrap();
1609
1610 let removed = handle.remove_record("v6-service").await.unwrap();
1612 assert!(removed);
1613 }
1614
1615 #[tokio::test]
1616 async fn test_service_discovery_local_cache_ipv6() {
1617 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1618 let discovery = ServiceDiscovery::new(addr);
1619
1620 let ipv6: IpAddr = "fd00::beef".parse().unwrap();
1622 discovery.register("v6-service", ipv6).await;
1623
1624 let resolved = discovery.resolve("v6-service").await;
1626 assert_eq!(resolved, Some(ipv6));
1627
1628 discovery.unregister("v6-service").await;
1630 let services = discovery.list_services().await;
1631 assert!(services.is_empty());
1632 }
1633
1634 #[tokio::test]
1635 async fn test_service_discovery_mixed_v4_v6_cache() {
1636 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1637 let discovery = ServiceDiscovery::new(addr);
1638
1639 let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
1640 let ipv6: IpAddr = "fd00::1".parse().unwrap();
1641
1642 discovery.register("svc-v4", ipv4).await;
1643 discovery.register("svc-v6", ipv6).await;
1644
1645 assert_eq!(discovery.resolve("svc-v4").await, Some(ipv4));
1646 assert_eq!(discovery.resolve("svc-v6").await, Some(ipv6));
1647
1648 let mut services = discovery.list_services().await;
1649 services.sort();
1650 assert_eq!(services, vec!["svc-v4", "svc-v6"]);
1651 }
1652
1653 #[test]
1654 fn test_dns_config_with_ipv6_bind_addr() {
1655 let ipv6_bind: IpAddr = "fd00::1".parse().unwrap();
1656 let config = DnsConfig::new("overlay.local.", ipv6_bind);
1657 assert_eq!(config.bind_addr, ipv6_bind);
1658 assert_eq!(config.port, DEFAULT_DNS_PORT);
1659
1660 let json = serde_json::to_string(&config).unwrap();
1662 let deserialized: DnsConfig = serde_json::from_str(&json).unwrap();
1663 assert_eq!(deserialized.bind_addr, ipv6_bind);
1664 }
1665
1666 #[test]
1667 fn test_dns_server_creation_ipv6_bind() {
1668 let ipv6_addr: IpAddr = "::1".parse().unwrap();
1669 let addr = SocketAddr::new(ipv6_addr, 15353);
1670 let server = DnsServer::new(addr, "overlay.local.");
1671
1672 assert!(server.is_ok());
1673 let server = server.unwrap();
1674 assert_eq!(server.listen_addr(), addr);
1675 }
1676
1677 #[tokio::test]
1684 async fn test_bind_windows_fallback_errors_or_shares_authority() {
1685 let primary = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
1686 let server = DnsServer::new(primary, "overlay.local.").unwrap();
1687 let bind_ip: IpAddr = "127.0.0.2".parse().unwrap();
1688
1689 match server.bind_windows_fallback(bind_ip).await {
1690 Ok(handle) => {
1691 assert_eq!(handle.zone_origin().to_string(), "overlay.local.");
1695 handle
1696 .add_record("dual", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)))
1697 .await
1698 .expect("add_record via fallback handle");
1699 }
1700 Err(DnsError::Io(_)) => {
1701 }
1705 Err(other) => panic!("unexpected error from bind_windows_fallback: {other}"),
1706 }
1707 }
1708
1709 #[test]
1710 fn test_peer_hostname_uniqueness() {
1711 let v4_a = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
1713 let v4_b = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)));
1714 assert_ne!(v4_a, v4_b);
1715
1716 let v6_a = peer_hostname(IpAddr::V6("fd00::1".parse().unwrap()));
1717 let v6_b = peer_hostname(IpAddr::V6("fd00::2".parse().unwrap()));
1718 assert_ne!(v6_a, v6_b);
1719
1720 let v4 = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
1722 let v6 = peer_hostname(IpAddr::V6("fd00::1".parse().unwrap()));
1723 assert_ne!(v4, v6);
1724 }
1725
1726 #[test]
1729 fn test_parse_resolv_conf_filters_stub_and_loopback() {
1730 let contents = "\
1733 # generated by netbird\n\
1734 nameserver 127.0.0.53\n\
1735 nameserver 127.0.0.1\n\
1736 nameserver 192.168.1.1\n\
1737 search example.com\n\
1738 options edns0\n";
1739 let parsed = parse_resolv_conf(contents);
1740 assert_eq!(
1741 parsed,
1742 vec![SocketAddr::new(
1743 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
1744 53
1745 )],
1746 "127.0.0.53 stub and 127.0.0.1 loopback must be filtered out",
1747 );
1748 }
1749
1750 #[test]
1751 fn test_parse_resolv_conf_dedup_and_comments() {
1752 let contents = "\
1753 ; a comment\n\
1754 nameserver 8.8.8.8\n\
1755 nameserver 8.8.8.8\n\
1756 nameserver fe80::1%eth0\n\
1757 nameserver 0.0.0.0\n";
1758 let parsed = parse_resolv_conf(contents);
1759 assert_eq!(parsed.len(), 2);
1762 assert_eq!(
1763 parsed[0],
1764 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)
1765 );
1766 assert_eq!(parsed[1].ip(), "fe80::1".parse::<IpAddr>().unwrap());
1767 }
1768
1769 #[test]
1770 fn test_resolve_upstreams_config_override_wins() {
1771 let explicit = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 9, 9, 9)), 5300);
1774 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::LOCALHOST))
1775 .with_upstreams(vec![explicit]);
1776 let resolved = resolve_upstreams(&config, "/nonexistent/resolv.conf");
1777 assert_eq!(resolved, vec![explicit]);
1778 }
1779
1780 #[test]
1781 fn test_resolve_upstreams_falls_back_to_public_when_missing() {
1782 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::LOCALHOST));
1784 let resolved = resolve_upstreams(&config, "/definitely/not/a/real/resolv.conf");
1785 assert_eq!(
1786 resolved,
1787 vec![
1788 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53),
1789 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
1790 ],
1791 );
1792 }
1793
1794 async fn spawn_stub_upstream(answer_ip: Ipv4Addr) -> SocketAddr {
1802 use hickory_server::proto::op::{Message, MessageType, ResponseCode};
1803
1804 let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1805 .await
1806 .expect("bind stub upstream");
1807 let addr = sock.local_addr().expect("stub local_addr");
1808
1809 tokio::spawn(async move {
1810 let mut buf = vec![0u8; 1500];
1811 loop {
1812 let Ok((len, from)) = sock.recv_from(&mut buf).await else {
1813 break;
1814 };
1815 let Ok(request) = Message::from_vec(&buf[..len]) else {
1816 continue;
1817 };
1818 let mut resp = Message::new();
1819 resp.set_id(request.id());
1820 resp.set_message_type(MessageType::Response);
1821 resp.set_recursion_available(true);
1822 resp.set_response_code(ResponseCode::NoError);
1823 for q in request.queries() {
1824 resp.add_query(q.clone());
1825 if q.query_type() == RecordType::A {
1826 let rec =
1827 Record::from_rdata(q.name().clone(), 60, RData::A(A::from(answer_ip)));
1828 resp.add_answer(rec);
1829 }
1830 }
1831 if let Ok(bytes) = resp.to_vec() {
1832 let _ = sock.send_to(&bytes, from).await;
1833 }
1834 }
1835 });
1836
1837 addr
1838 }
1839
1840 async fn raw_query_a(
1844 server: SocketAddr,
1845 name: &str,
1846 ) -> Result<Option<Ipv4Addr>, hickory_server::proto::op::ResponseCode> {
1847 use hickory_server::proto::op::{Message, MessageType, Query, ResponseCode};
1848
1849 let client = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1850 .await
1851 .expect("bind client");
1852
1853 let qname = Name::from_str(name).expect("query name");
1854 let mut msg = Message::new();
1855 msg.set_id(0x1234);
1856 msg.set_message_type(MessageType::Query);
1857 msg.set_recursion_desired(true);
1858 msg.add_query(Query::query(qname, RecordType::A));
1859 let bytes = msg.to_vec().expect("encode query");
1860
1861 client.send_to(&bytes, server).await.expect("send query");
1862
1863 let mut buf = vec![0u8; 1500];
1864 let len = tokio::time::timeout(Duration::from_secs(12), client.recv(&mut buf))
1869 .await
1870 .expect("query timed out")
1871 .expect("recv response");
1872 let resp = Message::from_vec(&buf[..len]).expect("decode response");
1873
1874 if resp.response_code() != ResponseCode::NoError {
1875 return Err(resp.response_code());
1876 }
1877 for ans in resp.answers() {
1878 if let Some(RData::A(a)) = ans.data() {
1879 return Ok(Some((*a).into()));
1880 }
1881 }
1882 Ok(None)
1883 }
1884
1885 #[tokio::test]
1886 async fn test_forwarding_overlay_answered_and_nonoverlay_forwarded() {
1887 let upstream_answer = Ipv4Addr::new(203, 0, 113, 7);
1889 let upstream = spawn_stub_upstream(upstream_answer).await;
1890
1891 let bound = {
1895 let probe = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1896 .await
1897 .unwrap();
1898 let a = probe.local_addr().unwrap();
1899 drop(probe);
1900 a
1901 };
1902
1903 let overlay_ip = Ipv4Addr::new(10, 200, 0, 5);
1906 let server =
1907 DnsServer::new_with_upstreams(bound, "overlay.local.", vec![upstream]).unwrap();
1908 let handle = server.handle();
1909 handle
1910 .add_record("svc", IpAddr::V4(overlay_ip))
1911 .await
1912 .unwrap();
1913 let _running = server.start().await.unwrap();
1914
1915 tokio::time::sleep(Duration::from_millis(150)).await;
1917
1918 let overlay = raw_query_a(bound, "svc.overlay.local.")
1920 .await
1921 .expect("overlay query should not SERVFAIL");
1922 assert_eq!(
1923 overlay,
1924 Some(overlay_ip),
1925 "overlay name must be answered from InMemoryAuthority",
1926 );
1927
1928 let forwarded = raw_query_a(bound, "example.com.")
1930 .await
1931 .expect("forwarded query should not SERVFAIL");
1932 assert_eq!(
1933 forwarded,
1934 Some(upstream_answer),
1935 "non-overlay name must be forwarded to the upstream stub",
1936 );
1937 }
1938
1939 #[tokio::test]
1940 async fn test_forwarding_total_upstream_failure_is_servfail_not_panic() {
1941 use hickory_server::proto::op::ResponseCode;
1942
1943 let dead_upstream = {
1947 let s = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1949 .await
1950 .unwrap();
1951 let a = s.local_addr().unwrap();
1952 drop(s);
1953 a
1954 };
1955
1956 let bound = {
1957 let s = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1958 .await
1959 .unwrap();
1960 let a = s.local_addr().unwrap();
1961 drop(s);
1962 a
1963 };
1964
1965 let server =
1966 DnsServer::new_with_upstreams(bound, "overlay.local.", vec![dead_upstream]).unwrap();
1967 let handle = server.handle();
1968 handle
1969 .add_record("svc", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 9)))
1970 .await
1971 .unwrap();
1972 let _running = server.start().await.unwrap();
1973 tokio::time::sleep(Duration::from_millis(150)).await;
1974
1975 let overlay = raw_query_a(bound, "svc.overlay.local.")
1977 .await
1978 .expect("overlay query should still succeed");
1979 assert_eq!(overlay, Some(Ipv4Addr::new(10, 200, 0, 9)));
1980
1981 match raw_query_a(bound, "example.com.").await {
1984 Err(ResponseCode::ServFail) => {} Err(other) => panic!("expected SERVFAIL, got {other:?}"),
1986 Ok(answer) => panic!("expected SERVFAIL, got answer {answer:?}"),
1987 }
1988 }
1989
1990 #[tokio::test]
2002 async fn test_colliding_static_zone_does_not_evict_service_records() {
2003 let bound = {
2005 let probe = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
2006 .await
2007 .unwrap();
2008 let a = probe.local_addr().unwrap();
2009 drop(probe);
2010 a
2011 };
2012
2013 let server = DnsServer::new_with_upstreams(bound, "zlayer.local.", vec![]).unwrap();
2015
2016 let svc_ip = Ipv4Addr::new(10, 200, 0, 42);
2017 let static_ip = Ipv4Addr::new(10, 200, 0, 1);
2018
2019 server
2022 .add_record("forgejodb.service.forgejo-stack", IpAddr::V4(svc_ip))
2023 .await
2024 .unwrap();
2025 server
2026 .add_static_record("host", IpAddr::V4(static_ip))
2027 .await
2028 .unwrap();
2029
2030 let _running = server.start().await.unwrap();
2031 tokio::time::sleep(Duration::from_millis(150)).await;
2032
2033 let svc = raw_query_a(bound, "forgejodb.service.forgejo-stack.zlayer.local.")
2035 .await
2036 .expect(
2037 "service-discovery query must not NXDOMAIN/SERVFAIL (catalog eviction regression)",
2038 );
2039 assert_eq!(
2040 svc,
2041 Some(svc_ip),
2042 "service-discovery record must survive a colliding static zone",
2043 );
2044
2045 let stat = raw_query_a(bound, "host.zlayer.local.")
2047 .await
2048 .expect("static node-stable query must not NXDOMAIN/SERVFAIL");
2049 assert_eq!(
2050 stat,
2051 Some(static_ip),
2052 "static record must resolve when its zone is shared with the primary",
2053 );
2054 }
2055}