1
2use std::collections::HashMap;
3use std::fmt;
4use std::hash::Hash;
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use error::Result;
10use simple_dns::{Name, PacketBuf, PacketHeader, QCLASS, QTYPE, Question};
11use socket2::{Domain, Protocol, SockAddr, Socket, Type};
12use tokio::net::UdpSocket;
13use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
14use tokio::sync::oneshot;
15use tracing::{debug, warn};
16use lazy_static::lazy_static;
17
18mod error;
19
20pub use error::MdnsError;
21
22const MULTICAST_PORT: u16 = 5353;
23const MULTICAST_ADDR_IPV4: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
24const MULTICAST_ADDR_IPV6: Ipv6Addr = Ipv6Addr::new(0xFF02, 0, 0, 0, 0, 0, 0, 0xFB);
25
26lazy_static! {
27 pub(crate) static ref MULTICAST_IPV4_SOCKET: SocketAddr =
28 SocketAddr::new(IpAddr::V4(MULTICAST_ADDR_IPV4), MULTICAST_PORT);
29 pub(crate) static ref MULTICAST_IPV6_SOCKET: SocketAddr =
30 SocketAddr::new(IpAddr::V6(MULTICAST_ADDR_IPV6), MULTICAST_PORT);
31}
32
33fn create_socket(addr: &SocketAddr) -> std::io::Result<Socket> {
34 let domain = if addr.is_ipv4() {
35 Domain::IPV4
36 } else {
37 Domain::IPV6
38 };
39
40 let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
41 socket.set_read_timeout(Some(Duration::from_millis(100)))?;
42 socket.set_reuse_address(true)?;
43
44 #[cfg(not(windows))]
45 socket.set_reuse_port(true)?;
46
47 Ok(socket)
48}
49
50fn sender_socket(addr: &SocketAddr) -> std::io::Result<std::net::UdpSocket> {
51 let sock_addr = if addr.is_ipv4() {
52 SockAddr::from(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0))
53 } else {
54 SockAddr::from(SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0))
55 };
56
57 let socket = create_socket(addr)?;
58 socket.bind(&sock_addr)?;
59
60 Ok(socket.into())
61}
62
63async fn ingest_packets(socket: Arc<UdpSocket>, tx: UnboundedSender<PacketBuf>) {
66 let mut buf = [0u8; 4096];
67
68 loop {
69 match socket.recv_from(&mut buf[..]).await {
70 Ok((count, _)) => {
71 if let Ok(header) = PacketHeader::parse(&buf[0..12]) {
72 if header.query || header.answers_count == 0 {
75 continue;
76 }
77
78 let buf = PacketBuf::from(&buf[..count]);
79 if let Err(e) = tx.send(buf) {
80 warn!("failed to send parsed packet: {}", e);
81 break;
82 }
83 }
84 },
85 Err(e) => {
86 warn!("error receiving packet: {}", e);
87 continue;
88 }
89 }
90 }
91}
92
93pub struct Query {
94 packet_id: u16,
95 query_name: String,
96 completion: Option<oneshot::Sender<Result<PacketBuf>>>,
97 packet: PacketBuf,
98
99 started: Instant,
100 timeout: Duration,
101}
102
103impl Eq for Query {}
104
105impl PartialEq for Query {
106 fn eq(&self, other: &Self) -> bool {
107 self.packet_id == other.packet_id && self.query_name == other.query_name
108 }
109}
110
111impl Ord for Query {
112 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
113 (self.packet_id, &self.query_name).cmp(&(other.packet_id, &other.query_name))
114 }
115}
116
117impl PartialOrd for Query {
118 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
119 Some(self.cmp(other))
120 }
121}
122
123impl Hash for Query {
124 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
125 self.packet_id.hash(state);
126 self.query_name.hash(state);
127 }
128}
129
130impl fmt::Debug for Query {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 f.debug_struct("Query")
133 .field("packet_id", &self.packet_id)
134 .field("query_name", &self.query_name)
135 .finish()
136 }
137}
138
139#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
140pub struct QueryKey {
141 packet_id: u16,
142 query_name: String,
143}
144
145impl From<&Query> for QueryKey {
146 fn from(q: &Query) -> Self {
147 QueryKey {
148 packet_id: q.packet_id,
149 query_name: q.query_name.clone(),
150 }
151 }
152}
153
154async fn process_packets(
159 mut query_rx: UnboundedReceiver<Query>,
160 mut packet_rx: UnboundedReceiver<PacketBuf>,
161 socket: Arc<UdpSocket>,
162) {
163 let mut queries = HashMap::new();
164 let mut cleanup = tokio::time::interval(Duration::from_secs(1));
165
166 loop {
167 tokio::select! {
168 query = query_rx.recv() => {
169 let mut query = match query {
170 Some(query) => query,
171 None => continue
172 };
173 match socket.send_to(&query.packet, *MULTICAST_IPV4_SOCKET).await {
174 Ok(_) => {
175 debug!("inserting query: {:?}", query);
176 queries.insert(QueryKey::from(&query), query);
177 },
178 Err(e) => {
179 if let Some(completion) = query.completion.take() {
182 completion.send(Err(MdnsError::from(e))).ok();
183 }
184 },
185 }
186 },
187
188 packet = packet_rx.recv() => {
189 let packet = match packet {
190 Some(packet) => packet,
191 None => continue
192 };
193
194 let parsed_packet = match packet.to_packet() {
195 Ok(packet) => packet,
196 Err(_) => continue,
197 };
198
199 for answer in parsed_packet.answers {
200 let query_key = QueryKey {
201 packet_id: packet.packet_id(),
202 query_name: answer.name.to_string(),
203 };
204
205 if let Some(mut query) = queries.remove(&query_key) {
206 if let Some(completion) = query.completion.take() {
207 let cloned = PacketBuf::from(&packet[..]);
208 completion.send(Ok(cloned)).ok();
209 }
210 debug!("completed {:?}", query);
211 }
212 }
213 },
214
215 _ = cleanup.tick() => {
216 let mut to_remove = Vec::new();
217 for (key, query) in queries.iter_mut() {
218 if query.started.elapsed() > query.timeout {
219 to_remove.push(key.clone());
220
221 if let Some(completion) = query.completion.take() {
222 completion.send(Err(MdnsError::TimedOut(key.clone()))).ok();
223 }
224
225 debug!("removing timed out query: {:?}", query);
226 }
227 }
228
229 for key in to_remove {
230 queries.remove(&key);
231 }
232 }
233 };
234 }
235}
236
237#[derive(Clone)]
238pub struct MdnsResolver {
239 query_tx: UnboundedSender<Query>
240}
241
242impl MdnsResolver {
243 pub async fn new() -> Result<Self> {
246 let tx_socket = Arc::new(UdpSocket::from_std(sender_socket(&MULTICAST_IPV4_SOCKET)?)?);
247 let tx_socket_clone = Arc::clone(&tx_socket);
248
249 let (query_tx, query_rx) = mpsc::unbounded_channel();
250 let (packet_tx, packet_rx) = mpsc::unbounded_channel();
251
252 tokio::spawn(async move {
253 ingest_packets(tx_socket_clone, packet_tx).await
254 });
255
256 tokio::spawn(async move {
257 process_packets(query_rx, packet_rx, tx_socket).await;
258 });
259
260 Ok(MdnsResolver {
261 query_tx,
262 })
263 }
264
265 pub async fn query_timeout(&self, q: impl AsRef<str>, timeout: Duration) -> Result<PacketBuf> {
268 let packet_id = rand::random();
269 let mut packet = PacketBuf::new(PacketHeader::new_query(packet_id, false), true);
270 let service_name = Name::new(q.as_ref())?;
271 packet.add_question(&Question::new(
272 service_name.clone(),
273 QTYPE::A,
274 QCLASS::IN,
275 true
276 ))?;
277
278 let (tx, rx) = oneshot::channel();
279 self.query_tx.send(Query {
280 packet_id,
281 query_name: q.as_ref().to_string(),
282 completion: Some(tx),
283 packet,
284
285 started: Instant::now(),
286 timeout,
287 })?;
288
289 rx.await?
290 }
291}