mdns_resolver/
lib.rs

1
2use std::collections::HashMap;
3use std::fmt;
4use std::hash::Hash;
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use error::Result;
10use simple_dns::{Name, PacketBuf, PacketHeader, QCLASS, QTYPE, Question};
11use socket2::{Domain, Protocol, SockAddr, Socket, Type};
12use tokio::net::UdpSocket;
13use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
14use tokio::sync::oneshot;
15use tracing::{debug, warn};
16use lazy_static::lazy_static;
17
18mod error;
19
20pub use error::MdnsError;
21
22const MULTICAST_PORT: u16 = 5353;
23const MULTICAST_ADDR_IPV4: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
24const MULTICAST_ADDR_IPV6: Ipv6Addr = Ipv6Addr::new(0xFF02, 0, 0, 0, 0, 0, 0, 0xFB);
25
26lazy_static! {
27  pub(crate) static ref MULTICAST_IPV4_SOCKET: SocketAddr =
28      SocketAddr::new(IpAddr::V4(MULTICAST_ADDR_IPV4), MULTICAST_PORT);
29  pub(crate) static ref MULTICAST_IPV6_SOCKET: SocketAddr =
30      SocketAddr::new(IpAddr::V6(MULTICAST_ADDR_IPV6), MULTICAST_PORT);
31}
32
33fn create_socket(addr: &SocketAddr) -> std::io::Result<Socket> {
34  let domain = if addr.is_ipv4() {
35      Domain::IPV4
36  } else {
37      Domain::IPV6
38  };
39
40  let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
41  socket.set_read_timeout(Some(Duration::from_millis(100)))?;
42  socket.set_reuse_address(true)?;
43
44  #[cfg(not(windows))]
45  socket.set_reuse_port(true)?;
46
47  Ok(socket)
48}
49
50fn sender_socket(addr: &SocketAddr) -> std::io::Result<std::net::UdpSocket> {
51  let sock_addr = if addr.is_ipv4() {
52    SockAddr::from(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0))
53  } else {
54    SockAddr::from(SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0))
55  };
56
57  let socket = create_socket(addr)?;
58  socket.bind(&sock_addr)?;
59
60  Ok(socket.into())
61}
62
63/// Continuously reads packets from the given socket and publishes them to the
64/// provided channel.
65async fn ingest_packets(socket: Arc<UdpSocket>, tx: UnboundedSender<PacketBuf>) {
66  let mut buf = [0u8; 4096];
67
68  loop {
69    match socket.recv_from(&mut buf[..]).await {
70      Ok((count, _)) => {
71        if let Ok(header) = PacketHeader::parse(&buf[0..12]) {
72          // filter out some obvious noise (e.g. other queries, particularly
73          // our own)
74          if header.query || header.answers_count == 0 {
75            continue;
76          }
77
78          let buf = PacketBuf::from(&buf[..count]);
79          if let Err(e) = tx.send(buf) {
80            warn!("failed to send parsed packet: {}", e);
81            break;
82          }
83        }
84      },
85      Err(e) => {
86        warn!("error receiving packet: {}", e);
87        continue;
88      }
89    }
90  }
91}
92
93pub struct Query {
94  packet_id: u16,
95  query_name: String,
96  completion: Option<oneshot::Sender<Result<PacketBuf>>>,
97  packet: PacketBuf,
98
99  started: Instant,
100  timeout: Duration,
101}
102
103impl Eq for Query {}
104
105impl PartialEq for Query {
106  fn eq(&self, other: &Self) -> bool {
107    self.packet_id == other.packet_id && self.query_name == other.query_name
108  }
109}
110
111impl Ord for Query {
112  fn cmp(&self, other: &Self) -> std::cmp::Ordering {
113    (self.packet_id, &self.query_name).cmp(&(other.packet_id, &other.query_name))
114  }
115}
116
117impl PartialOrd for Query {
118  fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
119    Some(self.cmp(other))
120  }
121}
122
123impl Hash for Query {
124  fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
125    self.packet_id.hash(state);
126    self.query_name.hash(state);
127  }
128}
129
130impl fmt::Debug for Query {
131  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132    f.debug_struct("Query")
133      .field("packet_id", &self.packet_id)
134      .field("query_name", &self.query_name)
135      .finish()
136  }
137}
138
139#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
140pub struct QueryKey {
141  packet_id: u16,
142  query_name: String,
143}
144
145impl From<&Query> for QueryKey {
146  fn from(q: &Query) -> Self {
147    QueryKey {
148      packet_id: q.packet_id,
149      query_name: q.query_name.clone(),
150    }
151  }
152}
153
154/// An async task to coordinate the sending of queries and processing of
155/// responses.
156/// Note that this function never terminates; it should be executed in the
157/// background using `tokio::spawn`.
158async fn process_packets(
159  mut query_rx: UnboundedReceiver<Query>,
160  mut packet_rx: UnboundedReceiver<PacketBuf>,
161  socket: Arc<UdpSocket>,
162) {
163  let mut queries = HashMap::new();
164  let mut cleanup = tokio::time::interval(Duration::from_secs(1));
165
166  loop {
167    tokio::select! {
168      query = query_rx.recv() => {
169        let mut query = match query {
170          Some(query) => query,
171          None => continue
172        };
173        match socket.send_to(&query.packet, *MULTICAST_IPV4_SOCKET).await {
174          Ok(_) => {
175            debug!("inserting query: {:?}", query);
176            queries.insert(QueryKey::from(&query), query);
177          },
178          Err(e) => {
179            // we couldn't send the query, send off the completion immediately
180            // with an error (ignoring any errors in the send).
181            if let Some(completion) = query.completion.take() {
182              completion.send(Err(MdnsError::from(e))).ok();
183            }
184          },
185        }
186      },
187
188      packet = packet_rx.recv() => {
189        let packet = match packet {
190          Some(packet) => packet,
191          None => continue
192        };
193
194        let parsed_packet = match packet.to_packet() {
195          Ok(packet) => packet,
196          Err(_) => continue,
197        };
198
199        for answer in parsed_packet.answers {
200          let query_key = QueryKey {
201            packet_id: packet.packet_id(),
202            query_name: answer.name.to_string(),
203          };
204
205          if let Some(mut query) = queries.remove(&query_key) {
206            if let Some(completion) = query.completion.take() {
207              let cloned = PacketBuf::from(&packet[..]);
208              completion.send(Ok(cloned)).ok();
209            }
210            debug!("completed {:?}", query);
211          }
212        }
213      },
214
215      _ = cleanup.tick() => {
216        let mut to_remove = Vec::new();
217        for (key, query) in queries.iter_mut() {
218          if query.started.elapsed() > query.timeout {
219            to_remove.push(key.clone());
220
221            if let Some(completion) = query.completion.take() {
222              completion.send(Err(MdnsError::TimedOut(key.clone()))).ok();
223            }
224
225            debug!("removing timed out query: {:?}", query);
226          }
227        }
228
229        for key in to_remove {
230          queries.remove(&key);
231        }
232      }
233    };
234  }
235}
236
237#[derive(Clone)]
238pub struct MdnsResolver {
239  query_tx: UnboundedSender<Query>
240}
241
242impl MdnsResolver {
243  /// Attempts to create a new MdnsResolver and begins listening for packets on
244  /// the necessary UDP sockets.
245  pub async fn new() -> Result<Self> {
246    let tx_socket = Arc::new(UdpSocket::from_std(sender_socket(&MULTICAST_IPV4_SOCKET)?)?);
247    let tx_socket_clone = Arc::clone(&tx_socket);
248
249    let (query_tx, query_rx) = mpsc::unbounded_channel();
250    let (packet_tx, packet_rx) = mpsc::unbounded_channel();
251
252    tokio::spawn(async move {
253      ingest_packets(tx_socket_clone, packet_tx).await
254    });
255
256    tokio::spawn(async move {
257      process_packets(query_rx, packet_rx, tx_socket).await;
258    });
259
260    Ok(MdnsResolver {
261      query_tx,
262    })
263  }
264
265  /// Submit a query with the given timeout.
266  /// Note that timeouts are processed at 1s intervals.
267  pub async fn query_timeout(&self, q: impl AsRef<str>, timeout: Duration) -> Result<PacketBuf> {
268    let packet_id = rand::random();
269    let mut packet = PacketBuf::new(PacketHeader::new_query(packet_id, false), true);
270    let service_name = Name::new(q.as_ref())?;
271    packet.add_question(&Question::new(
272      service_name.clone(),
273      QTYPE::A,
274      QCLASS::IN,
275      true
276    ))?;
277
278    let (tx, rx) = oneshot::channel();
279    self.query_tx.send(Query {
280      packet_id,
281      query_name: q.as_ref().to_string(),
282      completion: Some(tx),
283      packet,
284
285      started: Instant::now(),
286      timeout,
287    })?;
288
289    rx.await?
290  }
291}