1use hickory_client::client::{Client, SyncClient};
4use hickory_client::udp::UdpClientConnection;
5use hickory_server::authority::{Catalog, ZoneType};
6use hickory_server::proto::rr::rdata::{A, AAAA};
7use hickory_server::proto::rr::{DNSClass, LowerName, Name, RData, Record, RecordType};
8use hickory_server::resolver::config::NameServerConfigGroup;
9use hickory_server::server::ServerFuture;
10use hickory_server::store::in_memory::InMemoryAuthority;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
14use std::str::FromStr;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::net::{TcpListener, UdpSocket};
18use tokio::sync::RwLock;
19
20pub const DEFAULT_DNS_PORT: u16 = 15353;
22
23const STANDARD_DNS_PORT: u16 = 53;
26
27const PUBLIC_FALLBACK_UPSTREAMS: [IpAddr; 2] = [
36 IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
37 IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
38];
39
40pub(crate) const RESOLV_CONF_PATH: &str = "/etc/resolv.conf";
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct DnsConfig {
46 pub zone: String,
48 pub port: u16,
50 pub bind_addr: IpAddr,
52 #[serde(default)]
69 pub upstreams: Option<Vec<SocketAddr>>,
70}
71
72impl DnsConfig {
73 #[must_use]
75 pub fn new(zone: &str, bind_addr: IpAddr) -> Self {
76 Self {
77 zone: zone.to_string(),
78 port: DEFAULT_DNS_PORT,
79 bind_addr,
80 upstreams: None,
81 }
82 }
83
84 #[must_use]
86 pub fn with_port(mut self, port: u16) -> Self {
87 self.port = port;
88 self
89 }
90
91 #[must_use]
97 pub fn with_upstreams(mut self, upstreams: Vec<SocketAddr>) -> Self {
98 self.upstreams = Some(upstreams);
99 self
100 }
101}
102
103fn is_unusable_upstream(ip: IpAddr) -> bool {
116 match ip {
117 IpAddr::V4(v4) => v4.is_loopback() || v4.is_unspecified(),
118 IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified(),
119 }
120}
121
122fn parse_resolv_conf(contents: &str) -> Vec<SocketAddr> {
133 let mut out: Vec<SocketAddr> = Vec::new();
134 for line in contents.lines() {
135 let line = line.trim();
136 if line.is_empty() || line.starts_with('#') || line.starts_with(';') {
137 continue;
138 }
139 let mut parts = line.split_whitespace();
140 if parts.next() != Some("nameserver") {
141 continue;
142 }
143 let Some(addr_str) = parts.next() else {
144 continue;
145 };
146 let addr_str = addr_str.split('%').next().unwrap_or(addr_str);
149 let Ok(ip) = IpAddr::from_str(addr_str) else {
150 continue;
151 };
152 if is_unusable_upstream(ip) {
153 continue;
154 }
155 let sock = SocketAddr::new(ip, STANDARD_DNS_PORT);
156 if !out.contains(&sock) {
157 out.push(sock);
158 }
159 }
160 out
161}
162
163pub(crate) fn resolve_upstreams(config: &DnsConfig, resolv_conf_path: &str) -> Vec<SocketAddr> {
185 if let Some(explicit) = &config.upstreams {
186 if !explicit.is_empty() {
187 tracing::debug!(
188 count = explicit.len(),
189 "using explicit overlay DNS upstreams from config (host detection skipped)",
190 );
191 return explicit.clone();
192 }
193 }
194
195 let detected = match std::fs::read_to_string(resolv_conf_path) {
196 Ok(contents) => parse_resolv_conf(&contents),
197 Err(e) => {
198 tracing::warn!(
199 path = resolv_conf_path,
200 error = %e,
201 "could not read host resolv.conf for overlay DNS upstream detection",
202 );
203 Vec::new()
204 }
205 };
206
207 if detected.is_empty() {
208 let fallback: Vec<SocketAddr> = PUBLIC_FALLBACK_UPSTREAMS
209 .iter()
210 .map(|ip| SocketAddr::new(*ip, STANDARD_DNS_PORT))
211 .collect();
212 tracing::warn!(
213 fallback = ?fallback,
214 "no usable host DNS upstreams found (resolv.conf empty, missing, or stub-only); \
215 falling back to public resolvers for overlay forwarding",
216 );
217 fallback
218 } else {
219 tracing::info!(
220 upstreams = ?detected,
221 "overlay DNS forwarding to host upstreams (loopback/stub filtered out)",
222 );
223 detected
224 }
225}
226
227pub(crate) fn build_forward_resolver(
249 upstreams: &[SocketAddr],
250) -> Result<hickory_server::resolver::TokioAsyncResolver, DnsError> {
251 use hickory_server::resolver::config::{ResolverConfig, ResolverOpts};
252
253 if upstreams.is_empty() {
254 return Err(DnsError::Server("no upstreams for forward resolver".into()));
255 }
256
257 let mut group = NameServerConfigGroup::new();
258 let mut by_port: std::collections::BTreeMap<u16, Vec<IpAddr>> =
259 std::collections::BTreeMap::new();
260 for addr in upstreams {
261 by_port.entry(addr.port()).or_default().push(addr.ip());
262 }
263 for (port, ips) in by_port {
264 group.merge(NameServerConfigGroup::from_ips_clear(&ips, port, true));
267 }
268
269 let mut options = ResolverOpts::default();
273 options.timeout = Duration::from_secs(2);
274 options.attempts = 2;
275 options.preserve_intermediates = true;
277
278 let config = ResolverConfig::from_parts(None, vec![], group);
279 Ok(hickory_server::resolver::TokioAsyncResolver::tokio(
280 config, options,
281 ))
282}
283
284struct ForwardingCatalog {
306 catalog: Catalog,
307 zone_origin: LowerName,
308 resolver: Option<Arc<hickory_server::resolver::TokioAsyncResolver>>,
309}
310
311impl ForwardingCatalog {
312 fn forward_answer_response<'a>(
314 request: &'a hickory_server::server::Request,
315 answers: &'a [Record],
316 ) -> hickory_server::authority::MessageResponse<
317 'a,
318 'a,
319 std::slice::Iter<'a, Record>,
320 std::iter::Empty<&'a Record>,
321 std::iter::Empty<&'a Record>,
322 std::iter::Empty<&'a Record>,
323 > {
324 use hickory_server::authority::MessageResponseBuilder;
325 use hickory_server::proto::op::ResponseCode;
326
327 let mut header = hickory_server::proto::op::Header::response_from_request(request.header());
328 header.set_recursion_available(true);
329 header.set_response_code(ResponseCode::NoError);
330 header.set_authoritative(false);
332
333 MessageResponseBuilder::from_message_request(request).build(
334 header,
335 answers.iter(),
336 std::iter::empty(),
337 std::iter::empty(),
338 std::iter::empty(),
339 )
340 }
341
342 fn forward_code_response(
345 request: &hickory_server::server::Request,
346 code: hickory_server::proto::op::ResponseCode,
347 ) -> hickory_server::authority::MessageResponse<
348 '_,
349 '_,
350 impl Iterator<Item = &Record> + Send,
351 impl Iterator<Item = &Record> + Send,
352 impl Iterator<Item = &Record> + Send,
353 impl Iterator<Item = &Record> + Send,
354 > {
355 use hickory_server::authority::MessageResponseBuilder;
356 MessageResponseBuilder::from_message_request(request).error_msg(request.header(), code)
357 }
358
359 async fn forward<R: hickory_server::server::ResponseHandler>(
362 &self,
363 resolver: &hickory_server::resolver::TokioAsyncResolver,
364 request: &hickory_server::server::Request,
365 mut response_handle: R,
366 ) -> hickory_server::server::ResponseInfo {
367 use hickory_server::proto::op::ResponseCode;
368 use hickory_server::resolver::error::ResolveErrorKind;
369
370 let query = request.request_info().query;
371 let name = Name::from(query.name());
372 let rtype = query.query_type();
373
374 match resolver.lookup(name, rtype).await {
375 Ok(lookup) => {
376 let records: Vec<Record> = lookup.records().to_vec();
377 let response = Self::forward_answer_response(request, &records);
378 Self::send_or_servfail(&mut response_handle, response).await
379 }
380 Err(e) => {
381 let code = match e.kind() {
382 ResolveErrorKind::NoRecordsFound { response_code, .. }
384 if *response_code == ResponseCode::NXDomain =>
385 {
386 ResponseCode::NXDomain
387 }
388 ResolveErrorKind::NoRecordsFound { response_code, .. }
390 if *response_code == ResponseCode::NoError =>
391 {
392 ResponseCode::NoError
393 }
394 _ => {
398 tracing::debug!(error = %e, "overlay DNS upstream forward failed; SERVFAIL");
399 ResponseCode::ServFail
400 }
401 };
402 let response = Self::forward_code_response(request, code);
403 Self::send_or_servfail(&mut response_handle, response).await
404 }
405 }
406 }
407
408 async fn send_or_servfail<'a, R, A, N, S, D>(
411 response_handle: &mut R,
412 response: hickory_server::authority::MessageResponse<'_, 'a, A, N, S, D>,
413 ) -> hickory_server::server::ResponseInfo
414 where
415 R: hickory_server::server::ResponseHandler,
416 A: Iterator<Item = &'a Record> + Send + 'a,
417 N: Iterator<Item = &'a Record> + Send + 'a,
418 S: Iterator<Item = &'a Record> + Send + 'a,
419 D: Iterator<Item = &'a Record> + Send + 'a,
420 {
421 match response_handle.send_response(response).await {
422 Ok(info) => info,
423 Err(e) => {
424 tracing::error!(error = %e, "failed to send overlay DNS forward response");
425 let mut header = hickory_server::proto::op::Header::new();
426 header.set_response_code(hickory_server::proto::op::ResponseCode::ServFail);
427 header.into()
428 }
429 }
430 }
431}
432
433#[async_trait::async_trait]
434impl hickory_server::server::RequestHandler for ForwardingCatalog {
435 async fn handle_request<R: hickory_server::server::ResponseHandler>(
436 &self,
437 request: &hickory_server::server::Request,
438 response_handle: R,
439 ) -> hickory_server::server::ResponseInfo {
440 let query_name = request.request_info().query.name().clone();
444 let is_overlay = self.zone_origin.zone_of(&query_name);
445
446 match (&self.resolver, is_overlay) {
447 (Some(resolver), false) => self.forward(resolver, request, response_handle).await,
448 _ => self.catalog.handle_request(request, response_handle).await,
449 }
450 }
451}
452
453#[must_use]
458pub fn peer_hostname(ip: IpAddr) -> String {
459 match ip {
460 IpAddr::V4(v4) => {
461 let octets = v4.octets();
462 format!("node-{}-{}", octets[2], octets[3])
463 }
464 IpAddr::V6(v6) => {
465 let segments = v6.segments();
466 let last_segment = segments[7];
467 format!("node-{last_segment:04x}")
468 }
469 }
470}
471
472#[derive(Debug, thiserror::Error)]
474pub enum DnsError {
475 #[error("Invalid domain name: {0}")]
476 InvalidName(String),
477
478 #[error("DNS server error: {0}")]
479 Server(String),
480
481 #[error("DNS client error: {0}")]
482 Client(String),
483
484 #[error("IO error: {0}")]
485 Io(#[from] std::io::Error),
486
487 #[error("Record not found: {0}")]
488 NotFound(String),
489}
490
491#[derive(Clone)]
495pub struct DnsHandle {
496 authority: Arc<InMemoryAuthority>,
497 zone_origin: Name,
498 serial: Arc<RwLock<u32>>,
499}
500
501impl DnsHandle {
502 pub async fn add_record(&self, hostname: &str, ip: IpAddr) -> Result<(), DnsError> {
510 let fqdn = if hostname.ends_with('.') {
512 Name::from_str(hostname)
513 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?
514 } else {
515 let name = Name::from_str(hostname)
517 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
518 name.append_domain(&self.zone_origin)
519 .map_err(|e| DnsError::InvalidName(format!("Failed to append zone: {e}")))?
520 };
521
522 let rdata = match ip {
524 IpAddr::V4(v4) => RData::A(A::from(v4)),
525 IpAddr::V6(v6) => RData::AAAA(AAAA::from(v6)),
526 };
527 let record = Record::from_rdata(fqdn, 300, rdata); let serial = {
531 let mut s = self.serial.write().await;
532 let current = *s;
533 *s = s.wrapping_add(1);
534 current
535 };
536
537 self.authority.upsert(record, serial).await;
539
540 Ok(())
541 }
542
543 pub async fn remove_record(&self, hostname: &str) -> Result<bool, DnsError> {
551 let fqdn = if hostname.ends_with('.') {
552 Name::from_str(hostname)
553 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?
554 } else {
555 let name = Name::from_str(hostname)
556 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
557 name.append_domain(&self.zone_origin)
558 .map_err(|e| DnsError::InvalidName(format!("Failed to append zone: {e}")))?
559 };
560
561 let serial = {
562 let mut s = self.serial.write().await;
563 let current = *s;
564 *s = s.wrapping_add(1);
565 current
566 };
567
568 let a_record = Record::with(fqdn.clone(), RecordType::A, 0);
572 self.authority.upsert(a_record, serial).await;
573
574 let aaaa_record = Record::with(fqdn.clone(), RecordType::AAAA, 0);
575 self.authority.upsert(aaaa_record, serial).await;
576
577 Ok(true)
578 }
579
580 #[must_use]
582 pub fn zone_origin(&self) -> &Name {
583 &self.zone_origin
584 }
585}
586
587pub struct DnsServer {
589 listen_addr: SocketAddr,
590 authority: Arc<InMemoryAuthority>,
591 zone_origin: Name,
592 serial: Arc<RwLock<u32>>,
593 upstreams: Vec<SocketAddr>,
605}
606
607impl DnsServer {
608 pub fn new(listen_addr: SocketAddr, zone: &str) -> Result<Self, DnsError> {
618 let upstreams =
619 resolve_upstreams(&DnsConfig::new(zone, listen_addr.ip()), RESOLV_CONF_PATH);
620 Self::new_with_upstreams(listen_addr, zone, upstreams)
621 }
622
623 pub fn new_with_upstreams(
633 listen_addr: SocketAddr,
634 zone: &str,
635 upstreams: Vec<SocketAddr>,
636 ) -> Result<Self, DnsError> {
637 let zone_origin =
638 Name::from_str(zone).map_err(|e| DnsError::InvalidName(format!("{zone}: {e}")))?;
639
640 let authority = Arc::new(InMemoryAuthority::empty(
643 zone_origin.clone(),
644 ZoneType::Primary,
645 false,
646 ));
647
648 Ok(Self {
649 listen_addr,
650 authority,
651 zone_origin,
652 serial: Arc::new(RwLock::new(1)),
653 upstreams,
654 })
655 }
656
657 pub fn from_config(config: &DnsConfig) -> Result<Self, DnsError> {
666 let listen_addr = SocketAddr::new(config.bind_addr, config.port);
667 let upstreams = resolve_upstreams(config, RESOLV_CONF_PATH);
668 Self::new_with_upstreams(listen_addr, &config.zone, upstreams)
669 }
670
671 #[must_use]
673 pub fn upstreams(&self) -> &[SocketAddr] {
674 &self.upstreams
675 }
676
677 fn build_catalog(
688 zone_origin: Name,
689 authority: Arc<InMemoryAuthority>,
690 upstreams: &[SocketAddr],
691 ) -> ForwardingCatalog {
692 let lower_origin = LowerName::from(zone_origin.clone());
693
694 let mut catalog = Catalog::new();
695 catalog.upsert(zone_origin.into(), Box::new(authority));
697
698 let resolver = if upstreams.is_empty() {
699 None
700 } else {
701 match build_forward_resolver(upstreams) {
702 Ok(r) => {
703 tracing::debug!(
704 upstreams = ?upstreams,
705 "overlay DNS forwarder ready for non-overlay queries",
706 );
707 Some(Arc::new(r))
708 }
709 Err(e) => {
710 tracing::error!(
711 error = %e,
712 "failed to build overlay DNS forwarder; non-overlay queries \
713 will be refused (overlay zone still served)",
714 );
715 None
716 }
717 }
718 };
719
720 ForwardingCatalog {
721 catalog,
722 zone_origin: lower_origin,
723 resolver,
724 }
725 }
726
727 #[must_use]
732 pub fn handle(&self) -> DnsHandle {
733 DnsHandle {
734 authority: Arc::clone(&self.authority),
735 zone_origin: self.zone_origin.clone(),
736 serial: Arc::clone(&self.serial),
737 }
738 }
739
740 pub async fn add_record(&self, hostname: &str, ip: IpAddr) -> Result<(), DnsError> {
748 self.handle().add_record(hostname, ip).await
749 }
750
751 pub async fn remove_record(&self, hostname: &str) -> Result<bool, DnsError> {
757 self.handle().remove_record(hostname).await
758 }
759
760 #[allow(clippy::unused_async)]
769 pub async fn start(self) -> Result<DnsHandle, DnsError> {
770 let handle = self.handle();
771 let listen_addr = self.listen_addr;
772 let zone_origin = self.zone_origin.clone();
773 let authority = Arc::clone(&self.authority);
774 let upstreams = self.upstreams.clone();
775
776 tokio::spawn(async move {
778 if let Err(e) = Self::run_server(listen_addr, zone_origin, authority, upstreams).await {
779 tracing::error!("DNS server error: {}", e);
780 }
781 });
782
783 Ok(handle)
784 }
785
786 #[allow(clippy::unused_async)]
796 pub async fn start_background(&self) -> Result<DnsHandle, DnsError> {
797 let handle = self.handle();
798 let listen_addr = self.listen_addr;
799 let zone_origin = self.zone_origin.clone();
800 let authority = Arc::clone(&self.authority);
801 let upstreams = self.upstreams.clone();
802
803 tokio::spawn(async move {
804 if let Err(e) = Self::run_server(listen_addr, zone_origin, authority, upstreams).await {
805 tracing::error!("DNS server error: {}", e);
806 }
807 });
808
809 Ok(handle)
810 }
811
812 #[allow(clippy::unused_async)]
841 pub async fn bind_windows_fallback(&self, bind_ip: IpAddr) -> Result<DnsHandle, DnsError> {
842 self.bind_secondary(SocketAddr::new(bind_ip, 53)).await
843 }
844
845 #[allow(clippy::unused_async)]
860 pub async fn bind_secondary(&self, listen_addr: SocketAddr) -> Result<DnsHandle, DnsError> {
861 let handle = self.handle();
862 let zone_origin = self.zone_origin.clone();
863 let authority = Arc::clone(&self.authority);
864 let upstreams = self.upstreams.clone();
865
866 let udp_socket = UdpSocket::bind(listen_addr).await?;
870 let tcp_listener = TcpListener::bind(listen_addr).await?;
871
872 tokio::spawn(async move {
873 let catalog = Self::build_catalog(zone_origin, authority, &upstreams);
874 let mut server = ServerFuture::new(catalog);
875 server.register_socket(udp_socket);
876 server.register_listener(tcp_listener, Duration::from_secs(30));
877 tracing::info!(
878 addr = %listen_addr,
879 "secondary DNS listener started",
880 );
881 if let Err(e) = server.block_until_done().await {
882 tracing::error!("secondary DNS listener error: {}", e);
883 }
884 });
885
886 Ok(handle)
887 }
888
889 async fn run_server(
891 listen_addr: SocketAddr,
892 zone_origin: Name,
893 authority: Arc<InMemoryAuthority>,
894 upstreams: Vec<SocketAddr>,
895 ) -> Result<(), DnsError> {
896 let catalog = Self::build_catalog(zone_origin, authority, &upstreams);
899
900 let mut server = ServerFuture::new(catalog);
902
903 let udp_socket = UdpSocket::bind(listen_addr).await?;
905 server.register_socket(udp_socket);
906
907 let tcp_listener = TcpListener::bind(listen_addr).await?;
909 server.register_listener(tcp_listener, Duration::from_secs(30));
910
911 tracing::info!(addr = %listen_addr, "DNS server listening");
912
913 server
915 .block_until_done()
916 .await
917 .map_err(|e| DnsError::Server(e.to_string()))?;
918
919 Ok(())
920 }
921
922 #[must_use]
924 pub fn listen_addr(&self) -> SocketAddr {
925 self.listen_addr
926 }
927
928 #[must_use]
930 pub fn zone_origin(&self) -> &Name {
931 &self.zone_origin
932 }
933}
934
935pub struct DnsClient {
937 server_addr: SocketAddr,
938}
939
940impl DnsClient {
941 #[must_use]
943 pub fn new(server_addr: SocketAddr) -> Self {
944 Self { server_addr }
945 }
946
947 pub fn query_a(&self, hostname: &str) -> Result<Option<Ipv4Addr>, DnsError> {
953 let name = Name::from_str(hostname)
954 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
955
956 let conn = UdpClientConnection::new(self.server_addr)
957 .map_err(|e| DnsError::Client(e.to_string()))?;
958
959 let client = SyncClient::new(conn);
960
961 let response = client
962 .query(&name, DNSClass::IN, RecordType::A)
963 .map_err(|e| DnsError::Client(e.to_string()))?;
964
965 for answer in response.answers() {
967 if let Some(RData::A(a_record)) = answer.data() {
968 return Ok(Some((*a_record).into()));
969 }
970 }
971
972 Ok(None)
973 }
974
975 pub fn query_aaaa(&self, hostname: &str) -> Result<Option<Ipv6Addr>, DnsError> {
981 let name = Name::from_str(hostname)
982 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
983
984 let conn = UdpClientConnection::new(self.server_addr)
985 .map_err(|e| DnsError::Client(e.to_string()))?;
986
987 let client = SyncClient::new(conn);
988
989 let response = client
990 .query(&name, DNSClass::IN, RecordType::AAAA)
991 .map_err(|e| DnsError::Client(e.to_string()))?;
992
993 for answer in response.answers() {
995 if let Some(RData::AAAA(aaaa_record)) = answer.data() {
996 return Ok(Some((*aaaa_record).into()));
997 }
998 }
999
1000 Ok(None)
1001 }
1002
1003 pub fn query_addr(&self, hostname: &str) -> Result<Option<IpAddr>, DnsError> {
1011 if let Ok(Some(v4)) = self.query_a(hostname) {
1013 return Ok(Some(IpAddr::V4(v4)));
1014 }
1015
1016 if let Ok(Some(v6)) = self.query_aaaa(hostname) {
1018 return Ok(Some(IpAddr::V6(v6)));
1019 }
1020
1021 Ok(None)
1022 }
1023}
1024
1025pub struct ServiceDiscovery {
1027 dns_server: SocketAddr,
1028 records: RwLock<HashMap<String, IpAddr>>,
1029}
1030
1031impl ServiceDiscovery {
1032 #[must_use]
1034 pub fn new(dns_server_addr: SocketAddr) -> Self {
1035 Self {
1036 dns_server: dns_server_addr,
1037 records: RwLock::new(HashMap::new()),
1038 }
1039 }
1040
1041 pub async fn register(&self, name: &str, ip: IpAddr) {
1043 let mut records = self.records.write().await;
1044 records.insert(name.to_string(), ip);
1045 }
1046
1047 pub async fn resolve(&self, name: &str) -> Option<IpAddr> {
1052 {
1054 let records = self.records.read().await;
1055 if let Some(ip) = records.get(name) {
1056 return Some(*ip);
1057 }
1058 }
1059
1060 let client = DnsClient::new(self.dns_server);
1062 if let Ok(Some(addr)) = client.query_addr(name) {
1063 return Some(addr);
1064 }
1065
1066 None
1067 }
1068
1069 pub async fn unregister(&self, name: &str) {
1071 let mut records = self.records.write().await;
1072 records.remove(name);
1073 }
1074
1075 pub async fn list_services(&self) -> Vec<String> {
1077 let records = self.records.read().await;
1078 records.keys().cloned().collect()
1079 }
1080
1081 pub fn dns_server(&self) -> SocketAddr {
1083 self.dns_server
1084 }
1085}
1086
1087#[cfg(test)]
1088mod tests {
1089 use super::*;
1090
1091 #[test]
1092 fn test_peer_hostname_v4() {
1093 assert_eq!(
1095 peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1))),
1096 "node-0-1"
1097 );
1098 assert_eq!(
1099 peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 5))),
1100 "node-0-5"
1101 );
1102 assert_eq!(
1103 peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 1, 100))),
1104 "node-1-100"
1105 );
1106 assert_eq!(
1107 peer_hostname(IpAddr::V4(Ipv4Addr::new(192, 168, 255, 254))),
1108 "node-255-254"
1109 );
1110 }
1111
1112 #[test]
1113 fn test_peer_hostname_v6() {
1114 assert_eq!(
1116 peer_hostname(IpAddr::V6("fd00::1".parse().unwrap())),
1117 "node-0001"
1118 );
1119 assert_eq!(
1120 peer_hostname(IpAddr::V6("fd00::abcd".parse().unwrap())),
1121 "node-abcd"
1122 );
1123 assert_eq!(
1124 peer_hostname(IpAddr::V6("fd00:200::ffff".parse().unwrap())),
1125 "node-ffff"
1126 );
1127 assert_eq!(
1129 peer_hostname(IpAddr::V6("fd00::1:0".parse().unwrap())),
1130 "node-0000"
1131 );
1132 }
1133
1134 #[test]
1135 fn test_dns_config() {
1136 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)));
1137 assert_eq!(config.zone, "overlay.local.");
1138 assert_eq!(config.port, DEFAULT_DNS_PORT);
1139 assert_eq!(config.bind_addr, IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)));
1140
1141 let config = config.with_port(5353);
1143 assert_eq!(config.port, 5353);
1144 }
1145
1146 #[test]
1147 fn test_dns_config_serialization() {
1148 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)))
1149 .with_port(15353);
1150
1151 let json = serde_json::to_string(&config).unwrap();
1152 let deserialized: DnsConfig = serde_json::from_str(&json).unwrap();
1153
1154 assert_eq!(deserialized.zone, config.zone);
1155 assert_eq!(deserialized.port, config.port);
1156 assert_eq!(deserialized.bind_addr, config.bind_addr);
1157 }
1158
1159 #[tokio::test]
1160 async fn test_service_discovery_local_cache() {
1161 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1163 let discovery = ServiceDiscovery::new(addr);
1164
1165 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
1166 discovery.register("test-service", ip).await;
1167
1168 let resolved = discovery.resolve("test-service").await;
1169 assert_eq!(resolved, Some(ip));
1170
1171 discovery.unregister("test-service").await;
1173 let services = discovery.list_services().await;
1174 assert!(services.is_empty());
1175 }
1176
1177 #[test]
1178 fn test_dns_server_creation() {
1179 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1180 let server = DnsServer::new(addr, "overlay.local.");
1181
1182 assert!(server.is_ok());
1183 let server = server.unwrap();
1184 assert_eq!(server.listen_addr(), addr);
1185 assert_eq!(server.zone_origin().to_string(), "overlay.local.");
1186 }
1187
1188 #[test]
1189 fn test_dns_server_from_config() {
1190 let config =
1191 DnsConfig::new("test.local.", IpAddr::V4(Ipv4Addr::LOCALHOST)).with_port(15353);
1192 let server = DnsServer::from_config(&config);
1193
1194 assert!(server.is_ok());
1195 let server = server.unwrap();
1196 assert_eq!(server.listen_addr().port(), 15353);
1197 assert_eq!(server.zone_origin().to_string(), "test.local.");
1198 }
1199
1200 #[test]
1201 fn test_dns_server_invalid_zone() {
1202 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1203 let server = DnsServer::new(addr, "overlay.local.");
1205 assert!(server.is_ok());
1206 }
1207
1208 #[tokio::test]
1209 async fn test_dns_server_add_record() {
1210 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1211 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1212
1213 let result = server
1214 .add_record("myservice", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))
1215 .await;
1216 assert!(result.is_ok());
1217 }
1218
1219 #[tokio::test]
1220 async fn test_dns_handle_add_record() {
1221 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1222 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1223
1224 let handle = server.handle();
1226
1227 let result = handle
1228 .add_record("service1", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
1229 .await;
1230 assert!(result.is_ok());
1231
1232 let result = handle
1233 .add_record("service2", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)))
1234 .await;
1235 assert!(result.is_ok());
1236
1237 assert_eq!(handle.zone_origin().to_string(), "overlay.local.");
1239 }
1240
1241 #[test]
1242 fn test_dns_client_creation() {
1243 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53);
1244 let client = DnsClient::new(addr);
1245 assert_eq!(client.server_addr, addr);
1246 }
1247
1248 #[tokio::test]
1249 async fn test_dns_handle_add_aaaa_record() {
1250 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1251 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1252 let handle = server.handle();
1253
1254 let ipv6: IpAddr = "fd00::1".parse().unwrap();
1256 let result = handle.add_record("service-v6", ipv6).await;
1257 assert!(result.is_ok());
1258
1259 let ipv6_2: IpAddr = "fd00::abcd".parse().unwrap();
1261 let result = handle.add_record("service-v6-2", ipv6_2).await;
1262 assert!(result.is_ok());
1263 }
1264
1265 #[tokio::test]
1266 async fn test_dns_server_add_aaaa_record() {
1267 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1268 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1269
1270 let ipv6: IpAddr = "fd00::42".parse().unwrap();
1272 let result = server.add_record("myservice-v6", ipv6).await;
1273 assert!(result.is_ok());
1274 }
1275
1276 #[tokio::test]
1277 async fn test_dns_handle_remove_record_covers_both_types() {
1278 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1279 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1280 let handle = server.handle();
1281
1282 let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
1284 handle.add_record("dual-service", ipv4).await.unwrap();
1285
1286 let removed = handle.remove_record("dual-service").await.unwrap();
1288 assert!(removed);
1289
1290 let ipv6: IpAddr = "fd00::1".parse().unwrap();
1292 handle.add_record("v6-service", ipv6).await.unwrap();
1293
1294 let removed = handle.remove_record("v6-service").await.unwrap();
1296 assert!(removed);
1297 }
1298
1299 #[tokio::test]
1300 async fn test_service_discovery_local_cache_ipv6() {
1301 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1302 let discovery = ServiceDiscovery::new(addr);
1303
1304 let ipv6: IpAddr = "fd00::beef".parse().unwrap();
1306 discovery.register("v6-service", ipv6).await;
1307
1308 let resolved = discovery.resolve("v6-service").await;
1310 assert_eq!(resolved, Some(ipv6));
1311
1312 discovery.unregister("v6-service").await;
1314 let services = discovery.list_services().await;
1315 assert!(services.is_empty());
1316 }
1317
1318 #[tokio::test]
1319 async fn test_service_discovery_mixed_v4_v6_cache() {
1320 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1321 let discovery = ServiceDiscovery::new(addr);
1322
1323 let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
1324 let ipv6: IpAddr = "fd00::1".parse().unwrap();
1325
1326 discovery.register("svc-v4", ipv4).await;
1327 discovery.register("svc-v6", ipv6).await;
1328
1329 assert_eq!(discovery.resolve("svc-v4").await, Some(ipv4));
1330 assert_eq!(discovery.resolve("svc-v6").await, Some(ipv6));
1331
1332 let mut services = discovery.list_services().await;
1333 services.sort();
1334 assert_eq!(services, vec!["svc-v4", "svc-v6"]);
1335 }
1336
1337 #[test]
1338 fn test_dns_config_with_ipv6_bind_addr() {
1339 let ipv6_bind: IpAddr = "fd00::1".parse().unwrap();
1340 let config = DnsConfig::new("overlay.local.", ipv6_bind);
1341 assert_eq!(config.bind_addr, ipv6_bind);
1342 assert_eq!(config.port, DEFAULT_DNS_PORT);
1343
1344 let json = serde_json::to_string(&config).unwrap();
1346 let deserialized: DnsConfig = serde_json::from_str(&json).unwrap();
1347 assert_eq!(deserialized.bind_addr, ipv6_bind);
1348 }
1349
1350 #[test]
1351 fn test_dns_server_creation_ipv6_bind() {
1352 let ipv6_addr: IpAddr = "::1".parse().unwrap();
1353 let addr = SocketAddr::new(ipv6_addr, 15353);
1354 let server = DnsServer::new(addr, "overlay.local.");
1355
1356 assert!(server.is_ok());
1357 let server = server.unwrap();
1358 assert_eq!(server.listen_addr(), addr);
1359 }
1360
1361 #[tokio::test]
1368 async fn test_bind_windows_fallback_errors_or_shares_authority() {
1369 let primary = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
1370 let server = DnsServer::new(primary, "overlay.local.").unwrap();
1371 let bind_ip: IpAddr = "127.0.0.2".parse().unwrap();
1372
1373 match server.bind_windows_fallback(bind_ip).await {
1374 Ok(handle) => {
1375 assert_eq!(handle.zone_origin().to_string(), "overlay.local.");
1379 handle
1380 .add_record("dual", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)))
1381 .await
1382 .expect("add_record via fallback handle");
1383 }
1384 Err(DnsError::Io(_)) => {
1385 }
1389 Err(other) => panic!("unexpected error from bind_windows_fallback: {other}"),
1390 }
1391 }
1392
1393 #[test]
1394 fn test_peer_hostname_uniqueness() {
1395 let v4_a = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
1397 let v4_b = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)));
1398 assert_ne!(v4_a, v4_b);
1399
1400 let v6_a = peer_hostname(IpAddr::V6("fd00::1".parse().unwrap()));
1401 let v6_b = peer_hostname(IpAddr::V6("fd00::2".parse().unwrap()));
1402 assert_ne!(v6_a, v6_b);
1403
1404 let v4 = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
1406 let v6 = peer_hostname(IpAddr::V6("fd00::1".parse().unwrap()));
1407 assert_ne!(v4, v6);
1408 }
1409
1410 #[test]
1413 fn test_parse_resolv_conf_filters_stub_and_loopback() {
1414 let contents = "\
1417 # generated by netbird\n\
1418 nameserver 127.0.0.53\n\
1419 nameserver 127.0.0.1\n\
1420 nameserver 192.168.1.1\n\
1421 search example.com\n\
1422 options edns0\n";
1423 let parsed = parse_resolv_conf(contents);
1424 assert_eq!(
1425 parsed,
1426 vec![SocketAddr::new(
1427 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
1428 53
1429 )],
1430 "127.0.0.53 stub and 127.0.0.1 loopback must be filtered out",
1431 );
1432 }
1433
1434 #[test]
1435 fn test_parse_resolv_conf_dedup_and_comments() {
1436 let contents = "\
1437 ; a comment\n\
1438 nameserver 8.8.8.8\n\
1439 nameserver 8.8.8.8\n\
1440 nameserver fe80::1%eth0\n\
1441 nameserver 0.0.0.0\n";
1442 let parsed = parse_resolv_conf(contents);
1443 assert_eq!(parsed.len(), 2);
1446 assert_eq!(
1447 parsed[0],
1448 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)
1449 );
1450 assert_eq!(parsed[1].ip(), "fe80::1".parse::<IpAddr>().unwrap());
1451 }
1452
1453 #[test]
1454 fn test_resolve_upstreams_config_override_wins() {
1455 let explicit = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 9, 9, 9)), 5300);
1458 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::LOCALHOST))
1459 .with_upstreams(vec![explicit]);
1460 let resolved = resolve_upstreams(&config, "/nonexistent/resolv.conf");
1461 assert_eq!(resolved, vec![explicit]);
1462 }
1463
1464 #[test]
1465 fn test_resolve_upstreams_falls_back_to_public_when_missing() {
1466 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::LOCALHOST));
1468 let resolved = resolve_upstreams(&config, "/definitely/not/a/real/resolv.conf");
1469 assert_eq!(
1470 resolved,
1471 vec![
1472 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53),
1473 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
1474 ],
1475 );
1476 }
1477
1478 async fn spawn_stub_upstream(answer_ip: Ipv4Addr) -> SocketAddr {
1486 use hickory_server::proto::op::{Message, MessageType, ResponseCode};
1487
1488 let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1489 .await
1490 .expect("bind stub upstream");
1491 let addr = sock.local_addr().expect("stub local_addr");
1492
1493 tokio::spawn(async move {
1494 let mut buf = vec![0u8; 1500];
1495 loop {
1496 let Ok((len, from)) = sock.recv_from(&mut buf).await else {
1497 break;
1498 };
1499 let Ok(request) = Message::from_vec(&buf[..len]) else {
1500 continue;
1501 };
1502 let mut resp = Message::new();
1503 resp.set_id(request.id());
1504 resp.set_message_type(MessageType::Response);
1505 resp.set_recursion_available(true);
1506 resp.set_response_code(ResponseCode::NoError);
1507 for q in request.queries() {
1508 resp.add_query(q.clone());
1509 if q.query_type() == RecordType::A {
1510 let rec =
1511 Record::from_rdata(q.name().clone(), 60, RData::A(A::from(answer_ip)));
1512 resp.add_answer(rec);
1513 }
1514 }
1515 if let Ok(bytes) = resp.to_vec() {
1516 let _ = sock.send_to(&bytes, from).await;
1517 }
1518 }
1519 });
1520
1521 addr
1522 }
1523
1524 async fn raw_query_a(
1528 server: SocketAddr,
1529 name: &str,
1530 ) -> Result<Option<Ipv4Addr>, hickory_server::proto::op::ResponseCode> {
1531 use hickory_server::proto::op::{Message, MessageType, Query, ResponseCode};
1532
1533 let client = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1534 .await
1535 .expect("bind client");
1536
1537 let qname = Name::from_str(name).expect("query name");
1538 let mut msg = Message::new();
1539 msg.set_id(0x1234);
1540 msg.set_message_type(MessageType::Query);
1541 msg.set_recursion_desired(true);
1542 msg.add_query(Query::query(qname, RecordType::A));
1543 let bytes = msg.to_vec().expect("encode query");
1544
1545 client.send_to(&bytes, server).await.expect("send query");
1546
1547 let mut buf = vec![0u8; 1500];
1548 let len = tokio::time::timeout(Duration::from_secs(12), client.recv(&mut buf))
1553 .await
1554 .expect("query timed out")
1555 .expect("recv response");
1556 let resp = Message::from_vec(&buf[..len]).expect("decode response");
1557
1558 if resp.response_code() != ResponseCode::NoError {
1559 return Err(resp.response_code());
1560 }
1561 for ans in resp.answers() {
1562 if let Some(RData::A(a)) = ans.data() {
1563 return Ok(Some((*a).into()));
1564 }
1565 }
1566 Ok(None)
1567 }
1568
1569 #[tokio::test]
1570 async fn test_forwarding_overlay_answered_and_nonoverlay_forwarded() {
1571 let upstream_answer = Ipv4Addr::new(203, 0, 113, 7);
1573 let upstream = spawn_stub_upstream(upstream_answer).await;
1574
1575 let bound = {
1579 let probe = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1580 .await
1581 .unwrap();
1582 let a = probe.local_addr().unwrap();
1583 drop(probe);
1584 a
1585 };
1586
1587 let overlay_ip = Ipv4Addr::new(10, 200, 0, 5);
1590 let server =
1591 DnsServer::new_with_upstreams(bound, "overlay.local.", vec![upstream]).unwrap();
1592 let handle = server.handle();
1593 handle
1594 .add_record("svc", IpAddr::V4(overlay_ip))
1595 .await
1596 .unwrap();
1597 let _running = server.start().await.unwrap();
1598
1599 tokio::time::sleep(Duration::from_millis(150)).await;
1601
1602 let overlay = raw_query_a(bound, "svc.overlay.local.")
1604 .await
1605 .expect("overlay query should not SERVFAIL");
1606 assert_eq!(
1607 overlay,
1608 Some(overlay_ip),
1609 "overlay name must be answered from InMemoryAuthority",
1610 );
1611
1612 let forwarded = raw_query_a(bound, "example.com.")
1614 .await
1615 .expect("forwarded query should not SERVFAIL");
1616 assert_eq!(
1617 forwarded,
1618 Some(upstream_answer),
1619 "non-overlay name must be forwarded to the upstream stub",
1620 );
1621 }
1622
1623 #[tokio::test]
1624 async fn test_forwarding_total_upstream_failure_is_servfail_not_panic() {
1625 use hickory_server::proto::op::ResponseCode;
1626
1627 let dead_upstream = {
1631 let s = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1633 .await
1634 .unwrap();
1635 let a = s.local_addr().unwrap();
1636 drop(s);
1637 a
1638 };
1639
1640 let bound = {
1641 let s = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1642 .await
1643 .unwrap();
1644 let a = s.local_addr().unwrap();
1645 drop(s);
1646 a
1647 };
1648
1649 let server =
1650 DnsServer::new_with_upstreams(bound, "overlay.local.", vec![dead_upstream]).unwrap();
1651 let handle = server.handle();
1652 handle
1653 .add_record("svc", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 9)))
1654 .await
1655 .unwrap();
1656 let _running = server.start().await.unwrap();
1657 tokio::time::sleep(Duration::from_millis(150)).await;
1658
1659 let overlay = raw_query_a(bound, "svc.overlay.local.")
1661 .await
1662 .expect("overlay query should still succeed");
1663 assert_eq!(overlay, Some(Ipv4Addr::new(10, 200, 0, 9)));
1664
1665 match raw_query_a(bound, "example.com.").await {
1668 Err(ResponseCode::ServFail) => {} Err(other) => panic!("expected SERVFAIL, got {other:?}"),
1670 Ok(answer) => panic!("expected SERVFAIL, got answer {answer:?}"),
1671 }
1672 }
1673}