1pub mod dns_parser;
2pub mod mac_addr;
3pub mod net_utils;
4
5pub use dns_parser::QueryType;
6use dns_parser::{Builder, Packet, QueryClass, RData, ResourceRecord};
7use smol::channel::{bounded, Receiver};
8use smol::net::UdpSocket;
9use smol::prelude::*;
10use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
11use std::time::{Duration, Instant};
12use thiserror::*;
13
14const MULTICAST_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
15const MULTICAST_PORT: u16 = 5353;
16
17#[derive(Debug, Error)]
19pub enum Error {
20 #[error(transparent)]
21 Io(#[from] std::io::Error),
22 #[error(transparent)]
23 ChanRecv(#[from] smol::channel::RecvError),
24 #[error(transparent)]
25 ChanSend(#[from] smol::channel::SendError<Response>),
26 #[error("failed to build DNS packet")]
27 DnsPacketBuildError,
28 #[error("Timed out")]
29 Timeout,
30 #[error("The QueryParameters were invalid")]
31 InvalidQueryParams,
32 #[error("Unable to determine local interface")]
33 LocalInterfaceUnknown,
34}
35
36pub type Result<T> = std::result::Result<T, Error>;
37
38fn sockaddr(ip: Ipv4Addr, port: u16) -> SocketAddr {
39 let addr = std::net::SocketAddrV4::new(ip, port);
40 addr.into()
41}
42
43async fn create_socket() -> Result<UdpSocket> {
44 let socket = socket2::Socket::new(
45 socket2::Domain::IPV4,
46 socket2::Type::DGRAM,
47 Some(socket2::Protocol::UDP),
48 )?;
49 socket.set_reuse_address(true)?;
50 let _ = socket.set_reuse_port(true);
51
52 let addr = sockaddr(Ipv4Addr::UNSPECIFIED, MULTICAST_PORT);
53 socket.bind(&addr.into())?;
54
55 let socket = UdpSocket::from(smol::Async::new(socket.into())?);
56 socket.set_multicast_loop_v4(false)?;
57 socket.join_multicast_v4(MULTICAST_ADDR, Ipv4Addr::UNSPECIFIED)?;
58 Ok(socket)
59}
60
61#[derive(Debug)]
63pub struct Response {
64 pub answers: Vec<Record>,
65 pub nameservers: Vec<Record>,
66 pub additional: Vec<Record>,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct Host {
72 pub name: String,
74 pub host_name: Option<String>,
77 pub ip_address: Vec<IpAddr>,
79 pub socket_address: Vec<SocketAddr>,
82 pub expires: Instant,
84}
85
86impl Host {
87 pub fn valid(&self) -> bool {
90 Instant::now() < self.expires
91 }
92}
93
94impl Response {
95 fn new(packet: &Packet) -> Self {
96 Self {
97 answers: packet.answers.iter().map(Record::new).collect(),
98 nameservers: packet.nameservers.iter().map(Record::new).collect(),
99 additional: packet.additional.iter().map(Record::new).collect(),
100 }
101 }
102
103 fn all_records(&self) -> impl Iterator<Item = &Record> {
104 self.answers
105 .iter()
106 .chain(self.additional.iter())
107 .chain(self.nameservers.iter())
108 }
109
110 pub fn hosts(&self) -> Vec<Host> {
112 let mut result = vec![];
113
114 for ans in &self.answers {
115 match &ans.kind {
116 RecordKind::A(addr) => {
117 result.push(Host {
118 name: ans.name.clone(),
119 host_name: Some(ans.name.clone()),
120 ip_address: vec![(*addr).into()],
121 socket_address: vec![],
122 expires: Instant::now() + Duration::from_secs(ans.ttl.into()),
123 });
124 }
125 RecordKind::AAAA(addr) => {
126 result.push(Host {
127 name: ans.name.clone(),
128 host_name: Some(ans.name.clone()),
129 ip_address: vec![(*addr).into()],
130 socket_address: vec![],
131 expires: Instant::now() + Duration::from_secs(ans.ttl.into()),
132 });
133 }
134 RecordKind::PTR(name) => {
135 let name = name.clone();
136 let mut found_port = None;
137 let mut host_name = None;
138 let mut ip_address = vec![];
139 let mut socket_address = vec![];
140
141 for r in self.all_records() {
142 if r.name != name {
143 continue;
144 }
145
146 match &r.kind {
147 RecordKind::SRV { port, target, .. } => {
148 found_port.replace(*port);
149 host_name.replace(target.clone());
150 }
151 _ => {}
152 }
153 }
154
155 if let Some(host_name) = host_name.as_ref() {
156 for r in self.all_records() {
157 if &r.name != host_name {
158 continue;
159 }
160
161 match &r.kind {
162 RecordKind::A(addr) => {
163 ip_address.push(addr.clone().into());
164 }
165 RecordKind::AAAA(addr) => {
166 ip_address.push(addr.clone().into());
167 }
168 _ => {}
169 }
170 }
171 }
172
173 if let Some(port) = found_port {
174 for addr in &ip_address {
175 socket_address.push(SocketAddr::new(*addr, port));
176 }
177 }
178
179 result.push(Host {
180 name,
181 host_name,
182 ip_address,
183 socket_address,
184 expires: Instant::now() + Duration::from_secs(ans.ttl.into()),
185 });
186 }
187 _ => {}
188 }
189 }
190 result
191 }
192}
193
194#[derive(Debug, Clone, PartialEq, Eq)]
196pub struct Record {
197 pub name: String,
198 pub class: dns_parser::Class,
199 pub ttl: u32,
200 pub kind: RecordKind,
201}
202
203impl Record {
204 fn new(rr: &ResourceRecord) -> Self {
205 Self {
206 name: rr.name.to_string(),
207 class: rr.cls,
208 ttl: rr.ttl,
209 kind: RecordKind::new(&rr.data),
210 }
211 }
212}
213
214#[derive(Debug, Clone, PartialEq, Eq)]
216pub enum RecordKind {
217 A(Ipv4Addr),
218 AAAA(Ipv6Addr),
219 CNAME(String),
220 PTR(String),
221 NS(String),
222 MX {
223 preference: u16,
224 exchange: String,
225 },
226 SRV {
227 priority: u16,
228 weight: u16,
229 port: u16,
230 target: String,
231 },
232 SOA {
233 primary_ns: String,
234 mailbox: String,
235 serial: u32,
236 refresh: u32,
237 retry: u32,
238 expire: u32,
239 minimum_ttl: u32,
240 },
241 TXT(Vec<String>),
242 Unimplemented {
243 kind: dns_parser::Type,
244 data: Vec<u8>,
245 },
246}
247
248impl RecordKind {
249 fn new(data: &RData) -> Self {
250 match data {
251 RData::A(dns_parser::rdata::a::Record(addr)) => Self::A(*addr),
252 RData::AAAA(dns_parser::rdata::aaaa::Record(addr)) => Self::AAAA(*addr),
253 RData::CNAME(name) => Self::CNAME(name.to_string()),
254 RData::NS(name) => Self::NS(name.to_string()),
255 RData::PTR(name) => Self::PTR(name.to_string()),
256 RData::MX(dns_parser::rdata::mx::Record {
257 preference,
258 exchange,
259 }) => Self::MX {
260 preference: *preference,
261 exchange: exchange.to_string(),
262 },
263 RData::SRV(dns_parser::rdata::srv::Record {
264 priority,
265 weight,
266 port,
267 target,
268 }) => Self::SRV {
269 priority: *priority,
270 weight: *weight,
271 port: *port,
272 target: target.to_string(),
273 },
274 RData::TXT(txt) => Self::TXT(
275 txt.iter()
276 .map(|b| String::from_utf8_lossy(b).into_owned())
277 .collect(),
278 ),
279 RData::SOA(dns_parser::rdata::soa::Record {
280 primary_ns,
281 mailbox,
282 serial,
283 refresh,
284 retry,
285 expire,
286 minimum_ttl,
287 }) => Self::SOA {
288 primary_ns: primary_ns.to_string(),
289 mailbox: mailbox.to_string(),
290 serial: *serial,
291 refresh: *refresh,
292 retry: *retry,
293 expire: *expire,
294 minimum_ttl: *minimum_ttl,
295 },
296 RData::Unknown(kind, data) => Self::Unimplemented {
297 kind: *kind,
298 data: data.to_vec(),
299 },
300 }
301 }
302}
303
304pub async fn resolve_one<S: AsRef<str>>(
308 service_name: S,
309 params: QueryParameters,
310) -> Result<Response> {
311 let responses = resolve(service_name, params).await?;
312 let response = responses.recv().await?;
313 Ok(response)
314}
315
316#[derive(Clone, Copy, Debug, PartialEq, Eq)]
322pub struct QueryParameters {
323 pub query_type: QueryType,
324 pub base_repeat_interval: Option<Duration>,
326 pub max_repeat_interval: Option<Duration>,
328 pub exponential_backoff: bool,
333 pub timeout_after: Option<Duration>,
338}
339
340impl QueryParameters {
341 pub const DISCOVERY: QueryParameters = QueryParameters {
345 query_type: QueryType::PTR,
346 base_repeat_interval: Some(Duration::from_secs(2)),
347 exponential_backoff: true,
348 max_repeat_interval: Some(Duration::from_secs(300)),
349 timeout_after: None,
350 };
351
352 pub const SERVICE_LOOKUP: QueryParameters = QueryParameters {
356 query_type: QueryType::PTR,
357 base_repeat_interval: Some(Duration::from_secs(2)),
358 exponential_backoff: true,
359 max_repeat_interval: None,
360 timeout_after: Some(Duration::from_secs(60)),
361 };
362
363 pub const HOST_LOOKUP: QueryParameters = QueryParameters {
367 query_type: QueryType::A,
368 base_repeat_interval: Some(Duration::from_secs(1)),
369 exponential_backoff: true,
370 max_repeat_interval: None,
371 timeout_after: Some(Duration::from_secs(60)),
372 };
373
374 pub fn with_timeout(mut self, timeout: Duration) -> Self {
375 self.timeout_after.replace(timeout);
376 self
377 }
378}
379
380fn make_query(service_name: &str, query_type: QueryType) -> Result<Vec<u8>> {
381 let mut builder = Builder::new_query(rand::random(), false);
382 let prefer_unicast = false;
383 builder.add_question(&service_name, prefer_unicast, query_type, QueryClass::IN);
384 Ok(builder.build().map_err(|_| Error::DnsPacketBuildError)?)
385}
386
387fn valid_source_address(addr: SocketAddr) -> bool {
396 if addr.port() != MULTICAST_PORT {
397 false
398 } else {
399 fn masked(addr: &[u8], mask: &[u8]) -> Vec<u8> {
401 assert_eq!(addr.len(), mask.len());
402 addr.iter().zip(mask.iter()).map(|(a, m)| a & m).collect()
403 }
404
405 let ifaces = match crate::net_utils::get_if_addrs() {
406 Ok(i) => i,
407 Err(err) => {
408 log::error!("error while listing local interfaces: {}", err);
409 return false;
410 }
411 };
412
413 for iface in ifaces {
414 let matches_iface = match (&iface.addr, addr.ip()) {
415 (crate::net_utils::IfAddr::V4(a), IpAddr::V4(source)) => {
416 masked(&a.ip.octets(), &a.netmask.octets())
417 == masked(&source.octets(), &a.netmask.octets())
418 }
419
420 (crate::net_utils::IfAddr::V6(a), IpAddr::V6(source)) => {
421 masked(&a.ip.octets(), &a.netmask.octets())
422 == masked(&source.octets(), &a.netmask.octets())
423 }
424 _ => false,
425 };
426
427 if matches_iface {
428 return true;
429 }
430 }
431
432 false
433 }
434}
435
436pub async fn resolve<S: AsRef<str>>(
441 service_name: S,
442 params: QueryParameters,
443) -> Result<Receiver<Response>> {
444 if params.base_repeat_interval.is_none() && params.timeout_after.is_none() {
445 return Err(Error::InvalidQueryParams);
446 }
447
448 let service_name = service_name.as_ref().to_string();
449 let deadline = params.timeout_after.map(|d| Instant::now() + d);
450
451 let data = make_query(&service_name, params.query_type)?;
452
453 let socket = create_socket().await?;
454 let addr = sockaddr(MULTICAST_ADDR, MULTICAST_PORT);
455
456 socket.send_to(&data, addr).await?;
457
458 let (tx, rx) = bounded(8);
459
460 smol::spawn(async move {
461 let mut retry_interval = params.base_repeat_interval;
462 let mut last_send = Instant::now();
463
464 loop {
465 let now = Instant::now();
466
467 if let Some(deadline) = deadline {
468 if now >= deadline {
469 log::trace!("resolve loop completing because {now:?} >= {deadline:?}");
470 break;
471 }
472 }
473
474 let recv_deadline = match retry_interval {
475 Some(retry) => match deadline {
476 Some(overall) => (last_send + retry).min(overall),
477 None => last_send + retry,
478 },
479 None => match deadline {
480 Some(overall) => overall,
481 None => {
482 log::error!("resolve loop aborting because params are invalid");
486 return Err(Error::InvalidQueryParams);
487 }
488 },
489 };
490
491 let mut buf = [0u8; 4096];
492
493 let recv = async {
494 let (len, addr) = socket.recv_from(&mut buf).await?;
495 Result::Ok(Some((len, addr)))
496 };
497
498 let timer = async {
499 let timer = smol::Timer::at(recv_deadline);
500 timer.await;
501 Result::Ok(None)
502 };
503
504 if let Some((len, addr)) = recv.or(timer).await? {
505 match Packet::parse(&buf[..len]) {
506 Ok(dns) => {
507 let response = Response::new(&dns);
508 if !valid_source_address(addr) {
509 log::trace!(
510 "ignoring response {response:?} from {addr:?} which is not local",
511 );
512 } else {
513 let matched = response
514 .answers
515 .iter()
516 .any(|answer| answer.name == service_name);
517 if matched {
518 tx.send(response).await?;
519 }
520 }
521 }
522 Err(e) => {
523 log::trace!("failed to parse packet: {e:?} received from {addr:?}");
524 }
525 }
526 } else {
527 log::trace!("resolve loop read timeout; send another query");
528 let data = make_query(&service_name, params.query_type)?;
530 socket.send_to(&data, addr).await?;
531 last_send = Instant::now();
532
533 match retry_interval.take() {
535 None => {
536 break;
538 }
539 Some(retry) => {
540 let base = params.base_repeat_interval.unwrap();
541
542 let retry = if params.exponential_backoff {
543 retry + retry
544 } else {
545 retry + base
546 };
547
548 let retry = params
549 .max_repeat_interval
550 .map(|max| retry.min(max))
551 .unwrap_or(retry);
552
553 retry_interval.replace(retry);
554 }
555 }
556 log::trace!("updated retry_interval is now {retry_interval:?}");
557 }
558 }
559
560 log::trace!("resolve loop completing OK");
561 Result::Ok(())
562 })
563 .detach();
564
565 Ok(rx)
566}