1use core::sync::atomic;
2use std::net::{IpAddr, Ipv4Addr, SocketAddr};
3use std::sync::Arc;
4use std::time::Duration;
5
6use socket2::SockAddr;
7use tokio::net::{ToSocketAddrs, UdpSocket};
8use tokio::sync::{mpsc, Mutex};
9use util::ifaces;
10
11use crate::config::*;
12use crate::error::*;
13use crate::message::header::*;
14use crate::message::name::*;
15use crate::message::parser::*;
16use crate::message::question::*;
17use crate::message::resource::a::*;
18use crate::message::resource::*;
19use crate::message::*;
20
21mod conn_test;
22
23pub const DEFAULT_DEST_ADDR: &str = "224.0.0.251:5353";
24
25const INBOUND_BUFFER_SIZE: usize = 65535;
26const DEFAULT_QUERY_INTERVAL: Duration = Duration::from_secs(1);
27const MAX_MESSAGE_RECORDS: usize = 3;
28const RESPONSE_TTL: u32 = 120;
29
30pub struct DnsConn {
32 socket: Arc<UdpSocket>,
33 dst_addr: SocketAddr,
34
35 query_interval: Duration,
36 queries: Arc<Mutex<Vec<Query>>>,
37
38 is_server_closed: Arc<atomic::AtomicBool>,
39 close_server: mpsc::Sender<()>,
40}
41
42struct Query {
43 name_with_suffix: String,
44 query_result_chan: mpsc::Sender<QueryResult>,
45}
46
47struct QueryResult {
48 answer: ResourceHeader,
49 addr: SocketAddr,
50}
51
52impl DnsConn {
53 pub fn server(addr: SocketAddr, config: Config) -> Result<Self> {
55 let socket = socket2::Socket::new(
56 socket2::Domain::IPV4,
57 socket2::Type::DGRAM,
58 Some(socket2::Protocol::UDP),
59 )?;
60
61 #[cfg(all(target_family = "unix", feature = "reuse_port"))]
62 socket.set_reuse_port(true)?;
63
64 socket.set_reuse_address(true)?;
65 socket.set_broadcast(true)?;
66 socket.set_nonblocking(true)?;
67
68 socket.bind(&SockAddr::from(addr))?;
69 {
70 let mut join_error_count = 0;
71 let interfaces = match ifaces::ifaces() {
72 Ok(e) => e,
73 Err(e) => {
74 log::error!("Error getting interfaces: {e:?}");
75 return Err(Error::Other(e.to_string()));
76 }
77 };
78
79 for interface in &interfaces {
80 if let Some(SocketAddr::V4(e)) = interface.addr {
81 if let Err(e) = socket.join_multicast_v4(&Ipv4Addr::new(224, 0, 0, 251), e.ip())
82 {
83 log::trace!("Error connecting multicast, error: {e:?}");
84 join_error_count += 1;
85 continue;
86 }
87
88 log::trace!("Connected to interface address {e:?}");
89 }
90 }
91
92 if join_error_count >= interfaces.len() {
93 return Err(Error::ErrJoiningMulticastGroup);
94 }
95 }
96
97 let socket = UdpSocket::from_std(socket.into())?;
98
99 let local_names = config
100 .local_names
101 .iter()
102 .map(|l| l.to_string() + ".")
103 .collect();
104
105 let dst_addr: SocketAddr = DEFAULT_DEST_ADDR.parse()?;
106
107 let is_server_closed = Arc::new(atomic::AtomicBool::new(false));
108
109 let (close_server_send, close_server_rcv) = mpsc::channel(1);
110
111 let c = DnsConn {
112 query_interval: if config.query_interval != Duration::from_secs(0) {
113 config.query_interval
114 } else {
115 DEFAULT_QUERY_INTERVAL
116 },
117
118 queries: Arc::new(Mutex::new(vec![])),
119 socket: Arc::new(socket),
120 dst_addr,
121 is_server_closed: Arc::clone(&is_server_closed),
122 close_server: close_server_send,
123 };
124
125 let queries = c.queries.clone();
126 let socket = Arc::clone(&c.socket);
127
128 tokio::spawn(async move {
129 DnsConn::start(
130 close_server_rcv,
131 is_server_closed,
132 socket,
133 local_names,
134 dst_addr,
135 queries,
136 )
137 .await
138 });
139
140 Ok(c)
141 }
142
143 pub async fn close(&self) -> Result<()> {
145 log::info!("Closing connection");
146 if self.is_server_closed.load(atomic::Ordering::SeqCst) {
147 return Err(Error::ErrConnectionClosed);
148 }
149
150 log::trace!("Sending close command to server");
151 match self.close_server.send(()).await {
152 Ok(_) => {
153 log::trace!("Close command sent");
154 Ok(())
155 }
156 Err(e) => {
157 log::warn!("Error sending close command to server: {e:?}");
158 Err(Error::ErrConnectionClosed)
159 }
160 }
161 }
162
163 pub async fn query(
166 &self,
167 name: &str,
168 mut close_query_signal: mpsc::Receiver<()>,
169 ) -> Result<(ResourceHeader, SocketAddr)> {
170 if self.is_server_closed.load(atomic::Ordering::SeqCst) {
171 return Err(Error::ErrConnectionClosed);
172 }
173
174 let name_with_suffix = name.to_owned() + ".";
175
176 let (query_tx, mut query_rx) = mpsc::channel(1);
177 {
178 let mut queries = self.queries.lock().await;
179 queries.push(Query {
180 name_with_suffix: name_with_suffix.clone(),
181 query_result_chan: query_tx,
182 });
183 }
184
185 log::trace!("Sending query");
186 self.send_question(&name_with_suffix).await;
187
188 loop {
189 tokio::select! {
190 _ = tokio::time::sleep(self.query_interval) => {
191 log::trace!("Sending query");
192 self.send_question(&name_with_suffix).await
193 },
194
195 _ = close_query_signal.recv() => {
196 log::info!("Query close signal received.");
197 return Err(Error::ErrConnectionClosed)
198 },
199
200 res_opt = query_rx.recv() =>{
201 log::info!("Received query result");
202 if let Some(res) = res_opt{
203 return Ok((res.answer, res.addr));
204 }
205 }
206 }
207 }
208 }
209
210 async fn send_question(&self, name: &str) {
211 let packed_name = match Name::new(name) {
212 Ok(pn) => pn,
213 Err(err) => {
214 log::warn!("Failed to construct mDNS packet: {err}");
215 return;
216 }
217 };
218
219 let raw_query = {
220 let mut msg = Message {
221 header: Header::default(),
222 questions: vec![Question {
223 typ: DnsType::A,
224 class: DNSCLASS_INET,
225 name: packed_name,
226 }],
227 ..Default::default()
228 };
229
230 match msg.pack() {
231 Ok(v) => v,
232 Err(err) => {
233 log::error!("Failed to construct mDNS packet {err}");
234 return;
235 }
236 }
237 };
238
239 log::trace!("{:?} sending {:?}...", self.socket.local_addr(), raw_query);
240 if let Err(err) = self.socket.send_to(&raw_query, self.dst_addr).await {
241 log::error!("Failed to send mDNS packet {err}");
242 }
243 }
244
245 async fn start(
246 mut closed_rx: mpsc::Receiver<()>,
247 close_server: Arc<atomic::AtomicBool>,
248 socket: Arc<UdpSocket>,
249 local_names: Vec<String>,
250 dst_addr: SocketAddr,
251 queries: Arc<Mutex<Vec<Query>>>,
252 ) -> Result<()> {
253 log::info!("Looping and listening {:?}", socket.local_addr());
254
255 let mut b = vec![0u8; INBOUND_BUFFER_SIZE];
256 let (mut n, mut src);
257
258 loop {
259 tokio::select! {
260 _ = closed_rx.recv() => {
261 log::info!("Closing server connection");
262 close_server.store(true, atomic::Ordering::SeqCst);
263
264 return Ok(());
265 }
266
267 result = socket.recv_from(&mut b) => {
268 match result{
269 Ok((len, addr)) => {
270 n = len;
271 src = addr;
272 log::trace!("Received new connection from {addr:?}");
273 },
274
275 Err(err) => {
276 log::error!("Error receiving from socket connection: {err:?}");
277 continue;
278 },
279 }
280 }
281 }
282
283 let mut p = Parser::default();
284 if let Err(err) = p.start(&b[..n]) {
285 log::error!("Failed to parse mDNS packet {err}");
286 continue;
287 }
288
289 run(&mut p, &socket, &local_names, src, dst_addr, &queries).await
290 }
291 }
292}
293
294async fn run(
295 p: &mut Parser<'_>,
296 socket: &Arc<UdpSocket>,
297 local_names: &[String],
298 src: SocketAddr,
299 dst_addr: SocketAddr,
300 queries: &Arc<Mutex<Vec<Query>>>,
301) {
302 let mut interface_addr = None;
303 for _ in 0..=MAX_MESSAGE_RECORDS {
304 let q = match p.question() {
305 Ok(q) => q,
306 Err(err) => {
307 if Error::ErrSectionDone == err {
308 log::trace!("Parsing has completed");
309 break;
310 } else {
311 log::error!("Failed to parse mDNS packet {err}");
312 return;
313 }
314 }
315 };
316
317 for local_name in local_names {
318 if *local_name == q.name.data {
319 let interface_addr = match interface_addr {
320 Some(addr) => addr,
321 None => match get_interface_addr_for_ip(src).await {
322 Ok(addr) => {
323 interface_addr.replace(addr);
324 addr
325 }
326 Err(e) => {
327 log::warn!(
328 "Failed to get local interface to communicate with {}: {:?}",
329 &src,
330 e
331 );
332 continue;
333 }
334 },
335 };
336
337 log::trace!(
338 "Found local name: {} to send answer, IP {}, interface addr {}",
339 local_name,
340 src.ip(),
341 interface_addr
342 );
343 if let Err(e) =
344 send_answer(socket, &interface_addr, &q.name.data, src.ip(), dst_addr).await
345 {
346 log::error!("Error sending answer to client: {e:?}");
347 continue;
348 };
349 }
350 }
351 }
352
353 let _ = p.skip_all_questions();
355
356 for _ in 0..=MAX_MESSAGE_RECORDS {
357 let a = match p.answer_header() {
358 Ok(a) => a,
359 Err(err) => {
360 if Error::ErrSectionDone != err {
361 log::warn!("Failed to parse mDNS packet {err}");
362 }
363 return;
364 }
365 };
366
367 if a.typ != DnsType::A && a.typ != DnsType::Aaaa {
368 continue;
369 }
370
371 let mut qs = queries.lock().await;
372 for j in (0..qs.len()).rev() {
373 if qs[j].name_with_suffix == a.name.data {
374 let _ = qs[j]
375 .query_result_chan
376 .send(QueryResult {
377 answer: a.clone(),
378 addr: src,
379 })
380 .await;
381 qs.remove(j);
382 }
383 }
384 }
385}
386
387async fn send_answer(
388 socket: &Arc<UdpSocket>,
389 interface_addr: &SocketAddr,
390 name: &str,
391 dst: IpAddr,
392 dst_addr: SocketAddr,
393) -> Result<()> {
394 let raw_answer = {
395 let mut msg = Message {
396 header: Header {
397 response: true,
398 authoritative: true,
399 ..Default::default()
400 },
401
402 answers: vec![Resource {
403 header: ResourceHeader {
404 typ: DnsType::A,
405 class: DNSCLASS_INET,
406 name: Name::new(name)?,
407 ttl: RESPONSE_TTL,
408 ..Default::default()
409 },
410 body: Some(Box::new(AResource {
411 a: match interface_addr.ip() {
412 IpAddr::V4(ip) => ip.octets(),
413 IpAddr::V6(_) => {
414 return Err(Error::Other("Unexpected IpV6 addr".to_owned()))
415 }
416 },
417 })),
418 }],
419 ..Default::default()
420 };
421
422 msg.pack()?
423 };
424
425 socket.send_to(&raw_answer, dst_addr).await?;
426 log::trace!("Sent answer to IP {dst}");
427
428 Ok(())
429}
430
431async fn get_interface_addr_for_ip(addr: impl ToSocketAddrs) -> std::io::Result<SocketAddr> {
432 let socket = UdpSocket::bind("0.0.0.0:0").await?;
433 socket.connect(addr).await?;
434 socket.local_addr()
435}