1use hickory_client::client::{Client, SyncClient};
4use hickory_client::udp::UdpClientConnection;
5use hickory_server::authority::{Catalog, ZoneType};
6use hickory_server::proto::rr::rdata::{A, AAAA};
7use hickory_server::proto::rr::{DNSClass, Name, RData, Record, RecordType};
8use hickory_server::server::ServerFuture;
9use hickory_server::store::in_memory::InMemoryAuthority;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
13use std::str::FromStr;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::net::{TcpListener, UdpSocket};
17use tokio::sync::RwLock;
18
19pub const DEFAULT_DNS_PORT: u16 = 15353;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct DnsConfig {
25 pub zone: String,
27 pub port: u16,
29 pub bind_addr: IpAddr,
31}
32
33impl DnsConfig {
34 #[must_use]
36 pub fn new(zone: &str, bind_addr: IpAddr) -> Self {
37 Self {
38 zone: zone.to_string(),
39 port: DEFAULT_DNS_PORT,
40 bind_addr,
41 }
42 }
43
44 #[must_use]
46 pub fn with_port(mut self, port: u16) -> Self {
47 self.port = port;
48 self
49 }
50}
51
52#[must_use]
57pub fn peer_hostname(ip: IpAddr) -> String {
58 match ip {
59 IpAddr::V4(v4) => {
60 let octets = v4.octets();
61 format!("node-{}-{}", octets[2], octets[3])
62 }
63 IpAddr::V6(v6) => {
64 let segments = v6.segments();
65 let last_segment = segments[7];
66 format!("node-{last_segment:04x}")
67 }
68 }
69}
70
71#[derive(Debug, thiserror::Error)]
73pub enum DnsError {
74 #[error("Invalid domain name: {0}")]
75 InvalidName(String),
76
77 #[error("DNS server error: {0}")]
78 Server(String),
79
80 #[error("DNS client error: {0}")]
81 Client(String),
82
83 #[error("IO error: {0}")]
84 Io(#[from] std::io::Error),
85
86 #[error("Record not found: {0}")]
87 NotFound(String),
88}
89
90#[derive(Clone)]
94pub struct DnsHandle {
95 authority: Arc<InMemoryAuthority>,
96 zone_origin: Name,
97 serial: Arc<RwLock<u32>>,
98}
99
100impl DnsHandle {
101 pub async fn add_record(&self, hostname: &str, ip: IpAddr) -> Result<(), DnsError> {
109 let fqdn = if hostname.ends_with('.') {
111 Name::from_str(hostname)
112 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?
113 } else {
114 let name = Name::from_str(hostname)
116 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
117 name.append_domain(&self.zone_origin)
118 .map_err(|e| DnsError::InvalidName(format!("Failed to append zone: {e}")))?
119 };
120
121 let rdata = match ip {
123 IpAddr::V4(v4) => RData::A(A::from(v4)),
124 IpAddr::V6(v6) => RData::AAAA(AAAA::from(v6)),
125 };
126 let record = Record::from_rdata(fqdn, 300, rdata); let serial = {
130 let mut s = self.serial.write().await;
131 let current = *s;
132 *s = s.wrapping_add(1);
133 current
134 };
135
136 self.authority.upsert(record, serial).await;
138
139 Ok(())
140 }
141
142 pub async fn remove_record(&self, hostname: &str) -> Result<bool, DnsError> {
150 let fqdn = if hostname.ends_with('.') {
151 Name::from_str(hostname)
152 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?
153 } else {
154 let name = Name::from_str(hostname)
155 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
156 name.append_domain(&self.zone_origin)
157 .map_err(|e| DnsError::InvalidName(format!("Failed to append zone: {e}")))?
158 };
159
160 let serial = {
161 let mut s = self.serial.write().await;
162 let current = *s;
163 *s = s.wrapping_add(1);
164 current
165 };
166
167 let a_record = Record::with(fqdn.clone(), RecordType::A, 0);
171 self.authority.upsert(a_record, serial).await;
172
173 let aaaa_record = Record::with(fqdn.clone(), RecordType::AAAA, 0);
174 self.authority.upsert(aaaa_record, serial).await;
175
176 Ok(true)
177 }
178
179 #[must_use]
181 pub fn zone_origin(&self) -> &Name {
182 &self.zone_origin
183 }
184}
185
186pub struct DnsServer {
188 listen_addr: SocketAddr,
189 authority: Arc<InMemoryAuthority>,
190 zone_origin: Name,
191 serial: Arc<RwLock<u32>>,
192}
193
194impl DnsServer {
195 pub fn new(listen_addr: SocketAddr, zone: &str) -> Result<Self, DnsError> {
201 let zone_origin =
202 Name::from_str(zone).map_err(|e| DnsError::InvalidName(format!("{zone}: {e}")))?;
203
204 let authority = Arc::new(InMemoryAuthority::empty(
207 zone_origin.clone(),
208 ZoneType::Primary,
209 false,
210 ));
211
212 Ok(Self {
213 listen_addr,
214 authority,
215 zone_origin,
216 serial: Arc::new(RwLock::new(1)),
217 })
218 }
219
220 pub fn from_config(config: &DnsConfig) -> Result<Self, DnsError> {
226 let listen_addr = SocketAddr::new(config.bind_addr, config.port);
227 Self::new(listen_addr, &config.zone)
228 }
229
230 #[must_use]
235 pub fn handle(&self) -> DnsHandle {
236 DnsHandle {
237 authority: Arc::clone(&self.authority),
238 zone_origin: self.zone_origin.clone(),
239 serial: Arc::clone(&self.serial),
240 }
241 }
242
243 pub async fn add_record(&self, hostname: &str, ip: IpAddr) -> Result<(), DnsError> {
251 self.handle().add_record(hostname, ip).await
252 }
253
254 pub async fn remove_record(&self, hostname: &str) -> Result<bool, DnsError> {
260 self.handle().remove_record(hostname).await
261 }
262
263 #[allow(clippy::unused_async)]
272 pub async fn start(self) -> Result<DnsHandle, DnsError> {
273 let handle = self.handle();
274 let listen_addr = self.listen_addr;
275 let zone_origin = self.zone_origin.clone();
276 let authority = Arc::clone(&self.authority);
277
278 tokio::spawn(async move {
280 if let Err(e) = Self::run_server(listen_addr, zone_origin, authority).await {
281 tracing::error!("DNS server error: {}", e);
282 }
283 });
284
285 Ok(handle)
286 }
287
288 #[allow(clippy::unused_async)]
298 pub async fn start_background(&self) -> Result<DnsHandle, DnsError> {
299 let handle = self.handle();
300 let listen_addr = self.listen_addr;
301 let zone_origin = self.zone_origin.clone();
302 let authority = Arc::clone(&self.authority);
303
304 tokio::spawn(async move {
305 if let Err(e) = Self::run_server(listen_addr, zone_origin, authority).await {
306 tracing::error!("DNS server error: {}", e);
307 }
308 });
309
310 Ok(handle)
311 }
312
313 async fn run_server(
315 listen_addr: SocketAddr,
316 zone_origin: Name,
317 authority: Arc<InMemoryAuthority>,
318 ) -> Result<(), DnsError> {
319 let mut catalog = Catalog::new();
321
322 catalog.upsert(zone_origin.into(), Box::new(authority));
324
325 let mut server = ServerFuture::new(catalog);
327
328 let udp_socket = UdpSocket::bind(listen_addr).await?;
330 server.register_socket(udp_socket);
331
332 let tcp_listener = TcpListener::bind(listen_addr).await?;
334 server.register_listener(tcp_listener, Duration::from_secs(30));
335
336 tracing::info!(addr = %listen_addr, "DNS server listening");
337
338 server
340 .block_until_done()
341 .await
342 .map_err(|e| DnsError::Server(e.to_string()))?;
343
344 Ok(())
345 }
346
347 #[must_use]
349 pub fn listen_addr(&self) -> SocketAddr {
350 self.listen_addr
351 }
352
353 #[must_use]
355 pub fn zone_origin(&self) -> &Name {
356 &self.zone_origin
357 }
358}
359
360pub struct DnsClient {
362 server_addr: SocketAddr,
363}
364
365impl DnsClient {
366 #[must_use]
368 pub fn new(server_addr: SocketAddr) -> Self {
369 Self { server_addr }
370 }
371
372 pub fn query_a(&self, hostname: &str) -> Result<Option<Ipv4Addr>, DnsError> {
378 let name = Name::from_str(hostname)
379 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
380
381 let conn = UdpClientConnection::new(self.server_addr)
382 .map_err(|e| DnsError::Client(e.to_string()))?;
383
384 let client = SyncClient::new(conn);
385
386 let response = client
387 .query(&name, DNSClass::IN, RecordType::A)
388 .map_err(|e| DnsError::Client(e.to_string()))?;
389
390 for answer in response.answers() {
392 if let Some(RData::A(a_record)) = answer.data() {
393 return Ok(Some((*a_record).into()));
394 }
395 }
396
397 Ok(None)
398 }
399
400 pub fn query_aaaa(&self, hostname: &str) -> Result<Option<Ipv6Addr>, DnsError> {
406 let name = Name::from_str(hostname)
407 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
408
409 let conn = UdpClientConnection::new(self.server_addr)
410 .map_err(|e| DnsError::Client(e.to_string()))?;
411
412 let client = SyncClient::new(conn);
413
414 let response = client
415 .query(&name, DNSClass::IN, RecordType::AAAA)
416 .map_err(|e| DnsError::Client(e.to_string()))?;
417
418 for answer in response.answers() {
420 if let Some(RData::AAAA(aaaa_record)) = answer.data() {
421 return Ok(Some((*aaaa_record).into()));
422 }
423 }
424
425 Ok(None)
426 }
427
428 pub fn query_addr(&self, hostname: &str) -> Result<Option<IpAddr>, DnsError> {
436 if let Ok(Some(v4)) = self.query_a(hostname) {
438 return Ok(Some(IpAddr::V4(v4)));
439 }
440
441 if let Ok(Some(v6)) = self.query_aaaa(hostname) {
443 return Ok(Some(IpAddr::V6(v6)));
444 }
445
446 Ok(None)
447 }
448}
449
450pub struct ServiceDiscovery {
452 dns_server: SocketAddr,
453 records: RwLock<HashMap<String, IpAddr>>,
454}
455
456impl ServiceDiscovery {
457 #[must_use]
459 pub fn new(dns_server_addr: SocketAddr) -> Self {
460 Self {
461 dns_server: dns_server_addr,
462 records: RwLock::new(HashMap::new()),
463 }
464 }
465
466 pub async fn register(&self, name: &str, ip: IpAddr) {
468 let mut records = self.records.write().await;
469 records.insert(name.to_string(), ip);
470 }
471
472 pub async fn resolve(&self, name: &str) -> Option<IpAddr> {
477 {
479 let records = self.records.read().await;
480 if let Some(ip) = records.get(name) {
481 return Some(*ip);
482 }
483 }
484
485 let client = DnsClient::new(self.dns_server);
487 if let Ok(Some(addr)) = client.query_addr(name) {
488 return Some(addr);
489 }
490
491 None
492 }
493
494 pub async fn unregister(&self, name: &str) {
496 let mut records = self.records.write().await;
497 records.remove(name);
498 }
499
500 pub async fn list_services(&self) -> Vec<String> {
502 let records = self.records.read().await;
503 records.keys().cloned().collect()
504 }
505
506 pub fn dns_server(&self) -> SocketAddr {
508 self.dns_server
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 #[test]
517 fn test_peer_hostname_v4() {
518 assert_eq!(
520 peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1))),
521 "node-0-1"
522 );
523 assert_eq!(
524 peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 5))),
525 "node-0-5"
526 );
527 assert_eq!(
528 peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 1, 100))),
529 "node-1-100"
530 );
531 assert_eq!(
532 peer_hostname(IpAddr::V4(Ipv4Addr::new(192, 168, 255, 254))),
533 "node-255-254"
534 );
535 }
536
537 #[test]
538 fn test_peer_hostname_v6() {
539 assert_eq!(
541 peer_hostname(IpAddr::V6("fd00::1".parse().unwrap())),
542 "node-0001"
543 );
544 assert_eq!(
545 peer_hostname(IpAddr::V6("fd00::abcd".parse().unwrap())),
546 "node-abcd"
547 );
548 assert_eq!(
549 peer_hostname(IpAddr::V6("fd00:200::ffff".parse().unwrap())),
550 "node-ffff"
551 );
552 assert_eq!(
554 peer_hostname(IpAddr::V6("fd00::1:0".parse().unwrap())),
555 "node-0000"
556 );
557 }
558
559 #[test]
560 fn test_dns_config() {
561 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)));
562 assert_eq!(config.zone, "overlay.local.");
563 assert_eq!(config.port, DEFAULT_DNS_PORT);
564 assert_eq!(config.bind_addr, IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)));
565
566 let config = config.with_port(5353);
568 assert_eq!(config.port, 5353);
569 }
570
571 #[test]
572 fn test_dns_config_serialization() {
573 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)))
574 .with_port(15353);
575
576 let json = serde_json::to_string(&config).unwrap();
577 let deserialized: DnsConfig = serde_json::from_str(&json).unwrap();
578
579 assert_eq!(deserialized.zone, config.zone);
580 assert_eq!(deserialized.port, config.port);
581 assert_eq!(deserialized.bind_addr, config.bind_addr);
582 }
583
584 #[tokio::test]
585 async fn test_service_discovery_local_cache() {
586 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
588 let discovery = ServiceDiscovery::new(addr);
589
590 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
591 discovery.register("test-service", ip).await;
592
593 let resolved = discovery.resolve("test-service").await;
594 assert_eq!(resolved, Some(ip));
595
596 discovery.unregister("test-service").await;
598 let services = discovery.list_services().await;
599 assert!(services.is_empty());
600 }
601
602 #[test]
603 fn test_dns_server_creation() {
604 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
605 let server = DnsServer::new(addr, "overlay.local.");
606
607 assert!(server.is_ok());
608 let server = server.unwrap();
609 assert_eq!(server.listen_addr(), addr);
610 assert_eq!(server.zone_origin().to_string(), "overlay.local.");
611 }
612
613 #[test]
614 fn test_dns_server_from_config() {
615 let config =
616 DnsConfig::new("test.local.", IpAddr::V4(Ipv4Addr::LOCALHOST)).with_port(15353);
617 let server = DnsServer::from_config(&config);
618
619 assert!(server.is_ok());
620 let server = server.unwrap();
621 assert_eq!(server.listen_addr().port(), 15353);
622 assert_eq!(server.zone_origin().to_string(), "test.local.");
623 }
624
625 #[test]
626 fn test_dns_server_invalid_zone() {
627 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
628 let server = DnsServer::new(addr, "overlay.local.");
630 assert!(server.is_ok());
631 }
632
633 #[tokio::test]
634 async fn test_dns_server_add_record() {
635 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
636 let server = DnsServer::new(addr, "overlay.local.").unwrap();
637
638 let result = server
639 .add_record("myservice", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))
640 .await;
641 assert!(result.is_ok());
642 }
643
644 #[tokio::test]
645 async fn test_dns_handle_add_record() {
646 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
647 let server = DnsServer::new(addr, "overlay.local.").unwrap();
648
649 let handle = server.handle();
651
652 let result = handle
653 .add_record("service1", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
654 .await;
655 assert!(result.is_ok());
656
657 let result = handle
658 .add_record("service2", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)))
659 .await;
660 assert!(result.is_ok());
661
662 assert_eq!(handle.zone_origin().to_string(), "overlay.local.");
664 }
665
666 #[test]
667 fn test_dns_client_creation() {
668 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53);
669 let client = DnsClient::new(addr);
670 assert_eq!(client.server_addr, addr);
671 }
672
673 #[tokio::test]
674 async fn test_dns_handle_add_aaaa_record() {
675 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
676 let server = DnsServer::new(addr, "overlay.local.").unwrap();
677 let handle = server.handle();
678
679 let ipv6: IpAddr = "fd00::1".parse().unwrap();
681 let result = handle.add_record("service-v6", ipv6).await;
682 assert!(result.is_ok());
683
684 let ipv6_2: IpAddr = "fd00::abcd".parse().unwrap();
686 let result = handle.add_record("service-v6-2", ipv6_2).await;
687 assert!(result.is_ok());
688 }
689
690 #[tokio::test]
691 async fn test_dns_server_add_aaaa_record() {
692 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
693 let server = DnsServer::new(addr, "overlay.local.").unwrap();
694
695 let ipv6: IpAddr = "fd00::42".parse().unwrap();
697 let result = server.add_record("myservice-v6", ipv6).await;
698 assert!(result.is_ok());
699 }
700
701 #[tokio::test]
702 async fn test_dns_handle_remove_record_covers_both_types() {
703 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
704 let server = DnsServer::new(addr, "overlay.local.").unwrap();
705 let handle = server.handle();
706
707 let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
709 handle.add_record("dual-service", ipv4).await.unwrap();
710
711 let removed = handle.remove_record("dual-service").await.unwrap();
713 assert!(removed);
714
715 let ipv6: IpAddr = "fd00::1".parse().unwrap();
717 handle.add_record("v6-service", ipv6).await.unwrap();
718
719 let removed = handle.remove_record("v6-service").await.unwrap();
721 assert!(removed);
722 }
723
724 #[tokio::test]
725 async fn test_service_discovery_local_cache_ipv6() {
726 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
727 let discovery = ServiceDiscovery::new(addr);
728
729 let ipv6: IpAddr = "fd00::beef".parse().unwrap();
731 discovery.register("v6-service", ipv6).await;
732
733 let resolved = discovery.resolve("v6-service").await;
735 assert_eq!(resolved, Some(ipv6));
736
737 discovery.unregister("v6-service").await;
739 let services = discovery.list_services().await;
740 assert!(services.is_empty());
741 }
742
743 #[tokio::test]
744 async fn test_service_discovery_mixed_v4_v6_cache() {
745 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
746 let discovery = ServiceDiscovery::new(addr);
747
748 let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
749 let ipv6: IpAddr = "fd00::1".parse().unwrap();
750
751 discovery.register("svc-v4", ipv4).await;
752 discovery.register("svc-v6", ipv6).await;
753
754 assert_eq!(discovery.resolve("svc-v4").await, Some(ipv4));
755 assert_eq!(discovery.resolve("svc-v6").await, Some(ipv6));
756
757 let mut services = discovery.list_services().await;
758 services.sort();
759 assert_eq!(services, vec!["svc-v4", "svc-v6"]);
760 }
761
762 #[test]
763 fn test_dns_config_with_ipv6_bind_addr() {
764 let ipv6_bind: IpAddr = "fd00::1".parse().unwrap();
765 let config = DnsConfig::new("overlay.local.", ipv6_bind);
766 assert_eq!(config.bind_addr, ipv6_bind);
767 assert_eq!(config.port, DEFAULT_DNS_PORT);
768
769 let json = serde_json::to_string(&config).unwrap();
771 let deserialized: DnsConfig = serde_json::from_str(&json).unwrap();
772 assert_eq!(deserialized.bind_addr, ipv6_bind);
773 }
774
775 #[test]
776 fn test_dns_server_creation_ipv6_bind() {
777 let ipv6_addr: IpAddr = "::1".parse().unwrap();
778 let addr = SocketAddr::new(ipv6_addr, 15353);
779 let server = DnsServer::new(addr, "overlay.local.");
780
781 assert!(server.is_ok());
782 let server = server.unwrap();
783 assert_eq!(server.listen_addr(), addr);
784 }
785
786 #[test]
787 fn test_peer_hostname_uniqueness() {
788 let v4_a = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
790 let v4_b = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)));
791 assert_ne!(v4_a, v4_b);
792
793 let v6_a = peer_hostname(IpAddr::V6("fd00::1".parse().unwrap()));
794 let v6_b = peer_hostname(IpAddr::V6("fd00::2".parse().unwrap()));
795 assert_ne!(v6_a, v6_b);
796
797 let v4 = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
799 let v6 = peer_hostname(IpAddr::V6("fd00::1".parse().unwrap()));
800 assert_ne!(v4, v6);
801 }
802}