1use erbium_net::addr::NetAddr;
20use erbium_net::udp;
21
22const HOURS_24: std::time::Duration = std::time::Duration::from_secs(24 * 3600);
23const HOURS_36: std::time::Duration = std::time::Duration::from_secs(36 * 3600);
24
25type UdpSocket = udp::UdpSocket;
26
27mod acl;
28mod bucket;
29mod cache;
30pub(crate) mod config;
31pub mod dnspkt;
32mod outquery;
33#[cfg(fuzzing)]
34pub mod parse;
35#[cfg(not(fuzzing))]
36mod parse;
37mod router;
38
39use bytes::BytesMut;
40use tokio_util::codec::Decoder;
41
42type Key = [u8; 8];
43
44struct CookieKeys {
45 next_refresh: tokio::time::Instant,
46 current: Key,
47 previous: Key,
48}
49
50impl CookieKeys {
51 fn new() -> Self {
52 Self {
53 next_refresh: tokio::time::Instant::now(),
54 current: Default::default(),
55 previous: Default::default(),
56 }
57 .rotate()
58 .rotate()
59 }
60
61 fn rotate(&self) -> Self {
62 use rand::{RngExt as _, TryRng as _};
63 let mut rng = rand::rngs::SysRng;
64
65 let next_refresh =
66 tokio::time::Instant::now() + rand::rng().random_range(HOURS_24..HOURS_36);
67
68 let mut current: Key = Default::default();
69 rng.try_fill_bytes(&mut current).unwrap();
70
71 Self {
72 next_refresh,
73 current,
74 previous: self.current,
75 }
76 }
77
78 fn needs_rotation(&self) -> bool {
79 self.next_refresh < tokio::time::Instant::now()
80 }
81
82 async fn get_keys(s: &tokio::sync::RwLock<Self>) -> (Key, Key) {
84 if s.read().await.needs_rotation() {
85 let mut cookies = s.write().await;
88 *cookies = cookies.rotate();
89 }
90
91 let cookies = s.read().await;
92 (cookies.current, cookies.previous)
93 }
94
95 async fn get_current_key(s: &tokio::sync::RwLock<Self>) -> Key {
96 Self::get_keys(s).await.0
97 }
98}
99
100impl Default for CookieKeys {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106lazy_static::lazy_static! {
107 static ref IN_QUERY_LATENCY: prometheus::HistogramVec =
108 prometheus::register_histogram_vec!("dns_in_query_latency",
109 "DNS latency for in queries",
110 &["protocol"])
111 .unwrap();
112
113 static ref IN_QUERY_RESULT: prometheus::IntCounterVec =
115 prometheus::register_int_counter_vec!("dns_in_query_result",
116 "DNS response codes for in queries",
117 &["protocol", "result"])
118 .unwrap();
119
120 static ref IN_QUERY_DROPPED: prometheus::IntCounter =
121 prometheus::register_int_counter!("dns_in_query_dropped",
122 "DNS queries dropped")
123 .unwrap();
124
125 static ref COOKIE_KEYS: tokio::sync::RwLock<CookieKeys> = Default::default();
126}
127
128#[cfg_attr(test, derive(Debug))]
129pub enum Error {
130 ListenError(std::io::Error, Box<erbium_net::addr::NetAddr>),
131 AcceptError(std::io::Error),
132 RecvError(std::io::Error),
133 ParseError(String),
134 RefusedByAcl(crate::acl::AclError),
135 Denied(String),
136 Blocked,
137 NoRouteConfigured,
138 NotAuthoritative,
139 OutReply(outquery::Error),
140}
141
142impl std::fmt::Display for Error {
143 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
144 use Error::*;
145 match self {
146 ListenError(io, addr) => write!(f, "Failed to listen for DNS on {}: {}", addr, io),
147 AcceptError(io) => write!(f, "Failed to accept new TCP connection for DNS: {}", io),
148 RecvError(io) => write!(f, "Failed to receive DNS in query: {}", io),
149 ParseError(msg) => write!(f, "Failed to parse DNS in query: {}", msg),
150 RefusedByAcl(why) => write!(f, "Query refused by policy: {}", why),
151 NotAuthoritative => write!(f, "Not Authoritative"),
152 Blocked => write!(f, "Blocked by configuration"),
153 NoRouteConfigured => write!(f, "No route configured"),
154 Denied(msg) => write!(f, "Denied: {}", msg),
155 OutReply(err) => write!(f, "{}", err),
156 }
157 }
158}
159
160type Bucket = tokio::sync::RwLock<bucket::GenericTokenBucket>;
167struct IpRateLimiter([Bucket; 256]);
168
169impl IpRateLimiter {
170 fn new() -> Self {
171 Self(std::array::from_fn(|_| {
172 Bucket::new(bucket::GenericTokenBucket::new())
173 }))
174 }
175
176 fn hash_ip(seed: u64, ip: std::net::IpAddr) -> usize {
177 use std::hash::Hash as _;
178 use std::hash::Hasher as _;
179 let mut hasher = std::collections::hash_map::DefaultHasher::new();
180 seed.hash(&mut hasher);
181 ip.hash(&mut hasher);
182 hasher.finish() as usize
183 }
184
185 async fn check(&self, ip: std::net::IpAddr, bytes: usize) -> bool {
186 const SEED1: u64 = 0x1234_5678_9ABC_DEF0;
190 const SEED2: u64 = 0x2345_6789_ABCD_EF01;
191
192 let hash1 = Self::hash_ip(SEED1, ip);
193 let hash2 = Self::hash_ip(SEED2, ip);
194
195 let bucket1 = hash1 % self.0.len();
196
197 if self.0[bucket1]
202 .read()
203 .await
204 .check::<bucket::RealTimeClock>(bytes as u32)
205 {
206 self.0[bucket1]
207 .write()
208 .await
209 .deplete::<bucket::RealTimeClock>(bytes as u32);
210 true
211 } else {
212 let mut bucket2 = hash2 % (self.0.len() - 1);
213 if bucket2 == bucket1 {
214 bucket2 = self.0.len() - 1;
215 }
216
217 if self.0[bucket2]
218 .read()
219 .await
220 .check::<bucket::RealTimeClock>(bytes as u32)
221 {
222 self.0[bucket2]
223 .write()
224 .await
225 .deplete::<bucket::RealTimeClock>(bytes as u32);
226 true
227 } else {
228 false
229 }
230 }
231 }
232}
233
234struct DnsCodec {}
235
236impl Decoder for DnsCodec {
237 type Item = dnspkt::DNSPkt;
238 type Error = std::io::Error;
239 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
240 let in_query = parse::PktParser::new(&src[..]).get_dns();
241 match in_query {
242 Ok(p) => Ok(Some(p)),
243 Err(e) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e)),
244 }
245 }
246}
247
248pub enum Protocol {
249 Udp,
250 Tcp,
251}
252
253impl std::fmt::Display for Protocol {
254 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
255 match &self {
256 Protocol::Udp => write!(f, "UDP"),
257 Protocol::Tcp => write!(f, "TCP"),
258 }
259 }
260}
261
262#[derive(Eq, PartialEq)]
263enum CookieStatus {
264 Missing,
265 Bad,
266 Good,
267}
268
269pub struct DnsMessage {
270 pub in_query: dnspkt::DNSPkt,
271 pub in_size: usize,
272 pub local_ip: std::net::IpAddr,
273 pub remote_addr: NetAddr,
274 pub protocol: Protocol,
275}
276
277type CookieDigest = hmac::Hmac<sha2::Sha256>;
278
279impl DnsMessage {
280 fn calculate_cookie(&self, client: &[u8], key: &[u8]) -> CookieDigest {
283 use hmac::Mac as _;
284 let mut hasher =
287 CookieDigest::new_from_slice(key).expect("should always be able to create a key");
288 hasher.update(client);
289 match self.local_ip {
290 std::net::IpAddr::V4(v4) => hasher.update(&v4.octets()),
291 std::net::IpAddr::V6(v6) => hasher.update(&v6.octets()),
292 }
293 use erbium_net::addr::NetAddrExt as _;
294 match self.remote_addr.ip() {
295 Some(std::net::IpAddr::V4(v4)) => hasher.update(&v4.octets()),
296 Some(std::net::IpAddr::V6(v6)) => hasher.update(&v6.octets()),
297 _ => unreachable!(),
298 };
299
300 hasher
301 }
302
303 fn validate_cookie_key(&self, key: &[u8]) -> CookieStatus {
304 if let Some((client, Some(server))) = self
305 .in_query
306 .edns
307 .as_ref()
308 .and_then(|edns| edns.get_cookie())
309 {
310 use hmac::Mac as _;
311 let myserver = self.calculate_cookie(client, key);
312 if myserver.verify_slice(server).is_ok() {
313 CookieStatus::Good
314 } else {
315 CookieStatus::Bad
316 }
317 } else {
318 CookieStatus::Missing
319 }
320 }
321
322 fn validate_cookie_keys(&self, key: &[u8], oldkey: &[u8]) -> CookieStatus {
326 match self.validate_cookie_key(key) {
327 CookieStatus::Bad => self.validate_cookie_key(oldkey),
328 status => status,
329 }
330 }
331
332 async fn calculate_current_cookie(&self, client: &[u8]) -> [u8; 32] {
333 use hmac::Mac as _;
334 let key = CookieKeys::get_current_key(&COOKIE_KEYS).await;
335 self.calculate_cookie(client, &key)
336 .finalize()
337 .into_bytes()
338 .as_slice()
339 .try_into()
340 .unwrap()
341 }
342
343 async fn validate_cookie(&self) -> CookieStatus {
344 let keys = CookieKeys::get_keys(&COOKIE_KEYS).await;
345 self.validate_cookie_keys(&keys.0, &keys.1)
346 }
347}
348
349struct DnsListenerHandler {
350 next: acl::DnsAclHandler,
351 udp_listeners: Vec<UdpSocket>,
352 tcp_listeners: Vec<tokio::net::TcpListener>,
353 rate_limiter: std::sync::Arc<IpRateLimiter>,
354}
355
356impl DnsListenerHandler {
357 async fn listen_udp(
358 _conf: &crate::config::SharedConfig,
359 addr: &erbium_net::addr::NetAddr,
360 ) -> Result<UdpSocket, Error> {
361 let mut count: i32 = 0;
362 let udp = loop {
363 match UdpSocket::bind(&[*addr]).await {
364 Ok(sock) => break sock,
365 Err(e) if e.kind() == std::io::ErrorKind::AddrNotAvailable => {
366 if count > 2 {
374 return Err(Error::ListenError(e, Box::new(*addr)));
375 }
376 log::warn!(
377 "Failed to bind DNS UDP to {} ({}): Retrying after {}s",
378 addr,
379 e,
380 1 << count
381 );
382 tokio::time::sleep(std::time::Duration::from_secs(1 << count)).await;
383 count += 1;
384 continue;
385 }
386 Err(e) => return Err(Error::ListenError(e, Box::new(*addr))),
387 }
388 };
389
390 if addr.as_sockaddr_in6().is_some() {
391 udp.set_opt_ipv6_packet_info(true)
392 .map_err(|e| Error::ListenError(e, Box::new(*addr)))?
393 } else {
394 udp.set_opt_ipv4_packet_info(true)
395 .map_err(|e| Error::ListenError(e, Box::new(*addr)))?
396 }
397
398 log::info!(
399 "Listening for DNS on UDP {}",
400 udp.local_addr()
401 .map(|name| format!("{}", name))
402 .unwrap_or_else(|_| "Unknown".into())
403 );
404
405 Ok(udp)
406 }
407
408 async fn listen_tcp(
409 _conf: &crate::config::SharedConfig,
410 addr: &erbium_net::addr::NetAddr,
411 ) -> Result<tokio::net::TcpListener, Error> {
412 use erbium_net::addr::NetAddrExt as _;
413 let tcp = tokio::net::TcpListener::bind(addr.to_std_socket_addr().ok_or_else(|| {
414 Error::ListenError(std::io::ErrorKind::Unsupported.into(), Box::new(*addr))
415 })?)
416 .await
417 .map_err(|e| Error::ListenError(e, Box::new(*addr)))?;
418
419 log::info!(
420 "Listening for DNS on TCP {}",
421 tcp.local_addr()
422 .map(|name| format!("{}", name))
423 .unwrap_or_else(|_| "Unknown".into())
424 );
425
426 Ok(tcp)
427 }
428
429 async fn new(
430 conf: crate::config::SharedConfig,
431 netinfo: &erbium_net::netinfo::SharedNetInfo,
432 ) -> Result<Self, Error> {
433 let mut udp_listeners = vec![];
434 let mut tcp_listeners = vec![];
435 {
436 let roconf = conf.read().await;
437 for addr in &roconf
438 .dns_listeners
439 .as_sockaddrs(&roconf.addresses, netinfo, 53)
440 .await
441 {
442 udp_listeners.push(Self::listen_udp(&conf, addr).await?);
443 tcp_listeners.push(Self::listen_tcp(&conf, addr).await?);
444 }
445 }
446 let rate_limiter = IpRateLimiter::new().into();
447
448 Ok(Self {
449 next: acl::DnsAclHandler::new(conf).await,
450 udp_listeners,
451 tcp_listeners,
452 rate_limiter,
453 })
454 }
455
456 async fn add_edns(edns: &mut dnspkt::EdnsData, msg: &DnsMessage) {
457 if msg
459 .in_query
460 .edns
461 .as_ref()
462 .map(|edns| edns.get_nsid().is_some())
463 .unwrap_or(false)
464 {
465 edns.set_nsid(format!("{}", msg.local_ip).as_bytes());
469 }
470
471 if let Some((client, _server)) = msg
473 .in_query
474 .edns
475 .as_ref()
476 .and_then(|edns| edns.get_cookie())
477 {
478 let server = msg.calculate_current_cookie(client).await;
479 edns.set_cookie(client, &server);
480 }
481 }
482
483 async fn create_in_reply(msg: &DnsMessage, outr: &dnspkt::DNSPkt) -> dnspkt::DNSPkt {
484 let mut edns: dnspkt::EdnsData = Default::default();
485 Self::add_edns(&mut edns, msg).await;
486 dnspkt::DNSPkt {
487 qid: msg.in_query.qid,
488 rd: false,
489 tc: outr.tc,
490 aa: outr.aa,
491 qr: true,
492 opcode: dnspkt::OPCODE_QUERY,
493
494 cd: outr.cd,
495 ad: outr.ad,
496 ra: outr.ra,
497
498 rcode: outr.rcode,
499
500 bufsize: 4096,
501
502 edns_ver: msg.in_query.edns_ver.map(|_| 0),
503 edns_do: false,
504
505 question: msg.in_query.question.clone(),
506 answer: outr.answer.clone(),
507 nameserver: outr.answer.clone(),
508 additional: outr.additional.clone(),
509 edns: Some(edns),
510 }
511 }
512
513 async fn create_in_error(msg: &DnsMessage, err: Error) -> dnspkt::DNSPkt {
514 use Error::*;
515 use dnspkt::*;
516 let mut edns: EdnsData = Default::default();
517 Self::add_edns(&mut edns, msg).await;
518 let rcode;
519 match err {
520 ListenError(..) => unreachable!(),
522 AcceptError(..) => unreachable!(),
523 RecvError(_) => unreachable!(),
524 ParseError(_) => unreachable!(),
525 RefusedByAcl(why) => {
526 rcode = REFUSED;
527 edns.set_extended_dns_error(EDE_PROHIBITED, &why.to_string());
528 }
529 Denied(why) => {
530 rcode = REFUSED;
531 edns.set_extended_dns_error(EDE_PROHIBITED, &why);
532 }
533 Blocked => {
534 rcode = NXDOMAIN;
535 edns.set_extended_dns_error(
536 EDE_BLOCKED,
537 "Server is configured to block these queries",
538 );
539 }
540 NotAuthoritative => {
541 rcode = REFUSED;
542 edns.set_extended_dns_error(EDE_NOT_AUTHORITATIVE, "Not Authoritative");
543 }
544 NoRouteConfigured => {
545 rcode = SERVFAIL;
546 edns.set_extended_dns_error(EDE_NOT_SUPPORTED, "No route configured for suffix");
547 }
548 OutReply(outquery::Error::Timeout) => {
549 rcode = SERVFAIL;
550 edns.set_extended_dns_error(
551 EDE_NO_REACHABLE_AUTHORITY,
552 "Timed out talking to upstream server",
553 );
554 }
555 OutReply(outquery::Error::FailedToSend(io)) => {
556 rcode = SERVFAIL;
557 edns.set_extended_dns_error(EDE_NETWORK_ERROR, &io.to_string());
558 }
559 OutReply(outquery::Error::FailedToSendMsg(msg)) => {
560 rcode = SERVFAIL;
561 edns.set_extended_dns_error(EDE_NETWORK_ERROR, &msg);
562 }
563 OutReply(outquery::Error::FailedToRecv(io)) => {
564 rcode = SERVFAIL;
565 edns.set_extended_dns_error(EDE_NETWORK_ERROR, &io.to_string());
566 }
567 OutReply(outquery::Error::FailedToRecvMsg(msg)) => {
568 rcode = SERVFAIL;
569 edns.set_extended_dns_error(EDE_NETWORK_ERROR, &msg);
570 }
571 OutReply(outquery::Error::TcpConnection(msg)) => {
572 rcode = SERVFAIL;
573 edns.set_extended_dns_error(EDE_NETWORK_ERROR, &msg);
574 }
575 OutReply(outquery::Error::Parse(msg)) => {
576 rcode = SERVFAIL;
577 edns.set_extended_dns_error(EDE_NETWORK_ERROR, &msg);
578 }
579 OutReply(outquery::Error::Internal(_)) => {
580 rcode = SERVFAIL;
581 edns.set_extended_dns_error(EDE_OTHER, "Internal Error");
582 }
583 }
584 dnspkt::DNSPkt {
585 qid: msg.in_query.qid,
586 rd: false,
587 tc: false,
588 aa: false,
589 qr: true,
590 opcode: dnspkt::OPCODE_QUERY,
591 cd: false,
592 ad: false,
593 ra: true,
594 rcode,
595 bufsize: 4096,
596 edns_ver: msg.in_query.edns_ver.map(|_| 0),
597 edns_do: false,
598
599 question: msg.in_query.question.clone(),
600 answer: vec![],
601 additional: vec![],
602 nameserver: vec![],
603 edns: Some(edns),
604 }
605 }
606
607 fn build_dns_message(
608 pkt: &[u8],
609 local_ip: std::net::IpAddr,
610 remote_addr: NetAddr,
611 protocol: Protocol,
612 ) -> Result<DnsMessage, Error> {
613 let in_query = parse::PktParser::new(pkt)
614 .get_dns()
615 .map_err(Error::ParseError)?;
616 Ok(DnsMessage {
617 in_query,
618 local_ip,
619 remote_addr,
620 protocol,
621 in_size: pkt.len(),
622 })
623 }
624
625 async fn recv_in_query(
626 s: &std::sync::Arc<tokio::sync::RwLock<Self>>,
627 msg: &DnsMessage,
628 ) -> Result<dnspkt::DNSPkt, std::convert::Infallible> {
629 log::trace!(
630 "[{:x}] In Query {}: {} ⇐ {}: {:?}",
631 msg.in_query.qid,
632 msg.protocol,
633 msg.local_ip,
634 msg.remote_addr,
635 msg.in_query
636 );
637 let next = &s.read().await.next;
638 let in_reply;
639 match next.handle_query(msg).await {
640 Ok(out_reply) => {
641 in_reply = Self::create_in_reply(msg, &out_reply).await;
642 IN_QUERY_RESULT
643 .with_label_values(&[&msg.protocol.to_string(), &in_reply.status()])
644 .inc();
645 }
646 Err(err) => {
647 in_reply = Self::create_in_error(msg, err).await;
648 IN_QUERY_RESULT
649 .with_label_values(&[&msg.protocol.to_string(), &in_reply.status()])
650 .inc();
651 }
652 }
653 log::trace!("[{:x}] In Reply: {:?}", msg.in_query.qid, in_reply);
654 Ok(in_reply)
655 }
656
657 async fn should_ratelimit(
658 msg: &DnsMessage,
659 in_reply: &dnspkt::DNSPkt,
660 in_reply_serialised: &[u8],
661 rate_limiter: &IpRateLimiter,
662 ) -> bool {
663 if in_reply.rcode != dnspkt::REFUSED {
665 return false;
666 }
667
668 match msg.validate_cookie().await {
669 CookieStatus::Good => {
670 log::trace!("[{:x}] Cookie status: Good", msg.in_query.qid);
672 return false;
673 }
674 CookieStatus::Bad => {
675 log::trace!("[{:x}] Cookie status: Bad", msg.in_query.qid);
676 }
677 CookieStatus::Missing => {
678 log::trace!("[{:x}] Cookie status: Missing", msg.in_query.qid);
679 }
680 }
681
682 let cost = std::cmp::max(
686 (in_reply_serialised.len() * 2).saturating_sub(msg.in_size),
687 200,
688 );
689
690 use erbium_net::addr::NetAddrExt as _;
691
692 !rate_limiter
695 .check(msg.remote_addr.ip().unwrap(), cost)
696 .await
697 }
698
699 async fn run_udp(
700 listener: &std::sync::Arc<UdpSocket>,
701 s: &std::sync::Arc<tokio::sync::RwLock<Self>>,
702 ) -> Result<(), Error> {
703 let local_rate_limiter;
704 {
705 let local_self = s.read().await;
706 local_rate_limiter = local_self.rate_limiter.clone();
707 }
708 let rm = match listener.recv_msg(4096, udp::MsgFlags::empty()).await {
709 Ok(rm) => rm,
710 Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
711 Err(err) if err.kind() == std::io::ErrorKind::Interrupted => return Ok(()),
712 Err(err) => return Err(Error::RecvError(err)),
713 };
714 let timer = IN_QUERY_LATENCY.with_label_values(&["UDP"]).start_timer();
715
716 let q = s.clone();
717 let local_listener = listener.clone();
718
719 log::trace!(
720 "Received UDP {:?} ⇒ {:?} ({})",
721 rm.address,
722 rm.local_ip(),
723 rm.buffer.len()
724 );
725
726 tokio::spawn(async move {
727 match Self::build_dns_message(
728 &rm.buffer,
729 rm.local_ip().unwrap(), rm.address.unwrap(), Protocol::Udp,
732 ) {
733 Ok(msg) => {
734 let in_reply = Self::recv_in_query(&q, &msg).await.unwrap();
735 let in_reply_bytes = in_reply.serialise();
736 if !Self::should_ratelimit(
737 &msg,
738 &in_reply,
739 &in_reply_bytes,
740 &local_rate_limiter,
741 )
742 .await
743 {
744 let cmsg = udp::ControlMessage::new().set_send_from(rm.local_ip());
745 local_listener
746 .send_msg(
747 in_reply_bytes.as_slice(),
748 &cmsg,
749 udp::MsgFlags::empty(),
750 Some(&rm.address.unwrap()), )
752 .await
753 .expect("Failed to send reply"); } else {
755 IN_QUERY_DROPPED.inc();
756 log::warn!("[{:x}] Not Sending Reply: Rate Limit", msg.in_query.qid);
757 }
758 }
759 Err(err) => {
760 log::warn!("Failed to handle request: {}", err);
761 IN_QUERY_RESULT
762 .with_label_values(&["UDP", "parse fail"])
763 .inc();
764 }
765 }
766 drop(timer);
767 });
768 Ok(())
769 }
770
771 fn prepare_to_send(pkt: &dnspkt::DNSPkt, size: usize) -> Vec<u8> {
772 let size = std::cmp::max(size, 512);
773 pkt.serialise_with_size(size)
774 }
775
776 async fn run_tcp(
777 s: &std::sync::Arc<tokio::sync::RwLock<Self>>,
778 mut sock: tokio::net::TcpStream,
779 sock_addr: NetAddr,
780 ) -> Result<(), Error> {
781 use tokio::io::AsyncReadExt as _;
782
783 log::trace!(
784 "Received TCP connection {:?} ⇒ {:?}",
785 sock_addr,
786 sock.local_addr().unwrap(), );
788
789 let mut lbytes = [0u8; 2];
790
791 if sock.read(&mut lbytes).await.map_err(Error::RecvError)? != lbytes.len() {
792 return Err(Error::ParseError("Failed to read length".into()));
793 }
794
795 let l = u16::from_be_bytes(lbytes) as usize;
796 let mut buffer = vec![0u8; l];
797
798 sock.read_exact(&mut buffer[..])
799 .await
800 .map_err(Error::RecvError)?;
801 let timer = IN_QUERY_LATENCY.with_label_values(&["TCP"]).start_timer();
802
803 let q = s.clone();
804
805 log::trace!(
806 "Received TCP {:?} ⇒ {:?} ({})",
807 sock_addr,
808 sock.local_addr(),
809 buffer.len()
810 );
811
812 tokio::spawn(async move {
813 use tokio::io::AsyncWriteExt as _;
814 match Self::build_dns_message(
815 &buffer,
816 sock.local_addr().ok().map(|addr| addr.ip()).unwrap(), sock_addr,
818 Protocol::Tcp,
819 ) {
820 Ok(msg) => {
821 let in_reply = Self::recv_in_query(&q, &msg).await.unwrap();
822 let serialised =
823 Self::prepare_to_send(&in_reply, msg.in_query.bufsize as usize);
824 let mut in_reply_bytes = Vec::with_capacity(2 + serialised.len());
825 in_reply_bytes.extend((serialised.len() as u16).to_be_bytes().iter());
826 in_reply_bytes.extend(serialised);
827 if let Err(io) = sock.write(&in_reply_bytes).await {
828 log::warn!("[{:x}] Failed to send DNS reply: {}", msg.in_query.qid, io);
829 IN_QUERY_RESULT
830 .with_label_values(&["TCP", "send fail"])
831 .inc();
832 }
833 drop(timer);
834 }
835 Err(err) => {
836 IN_QUERY_RESULT
837 .with_label_values(&["TCP", "parse fail"])
838 .inc();
839 log::warn!("Failed to handle request: {}", err);
840 }
841 }
842 });
843
844 Ok(())
845 }
846
847 async fn run_tcp_listener(
848 tcp: &tokio::net::TcpListener,
849 s: &std::sync::Arc<tokio::sync::RwLock<Self>>,
850 ) -> Result<(), Error> {
851 let (sock, sock_addr) = tcp.accept().await.map_err(Error::AcceptError)?;
852 let local_s = s.clone();
853
854 tokio::spawn(async move { Self::run_tcp(&local_s, sock, sock_addr.into()).await });
855
856 Ok(())
857 }
858
859 async fn run(s: &std::sync::Arc<tokio::sync::RwLock<Self>>) -> Result<(), Error> {
860 use futures::StreamExt as _;
861 let mut services = futures::stream::FuturesUnordered::new();
862 let mut my_self = s.write().await;
863 for listener in my_self.udp_listeners.drain(..) {
864 let s_clone = s.clone();
865 services.push(tokio::spawn(async move {
866 let shared_listener = listener.into();
867 loop {
868 match Self::run_udp(&shared_listener, &s_clone).await {
869 Ok(()) => (),
870 Err(err) => {
871 log::warn!(
872 "{}: {}",
873 shared_listener
874 .local_addr()
875 .map(|a| format!("{}", a))
876 .unwrap_or_else(|e| format!("<unknown: {}>", e)),
877 err
878 )
879 }
880 }
881 }
882 }));
883 }
884 for listener in my_self.tcp_listeners.drain(..) {
885 let s_clone = s.clone();
886 services.push(tokio::spawn(async move {
887 loop {
888 match Self::run_tcp_listener(&listener, &s_clone).await {
889 Ok(()) => (),
890 Err(err) => {
891 log::warn!(
892 "{}: {}",
893 listener
894 .local_addr()
895 .map(|a| format!("{}", a))
896 .unwrap_or_else(|e| format!("<unknown: {}>", e)),
897 err
898 )
899 }
900 }
901 }
902 }));
903 }
904
905 drop(my_self);
906
907 services.next().await.unwrap().unwrap()
908 }
909}
910
911pub struct DnsService {
912 next: std::sync::Arc<tokio::sync::RwLock<DnsListenerHandler>>,
913}
914
915impl DnsService {
916 pub async fn run(self) -> Result<(), Error> {
917 loop {
918 DnsListenerHandler::run(&self.next).await?;
919 }
920 }
921
922 pub async fn new(
923 conf: crate::config::SharedConfig,
924 netinfo: &erbium_net::netinfo::SharedNetInfo,
925 ) -> Result<Self, Error> {
926 Ok(Self {
927 next: tokio::sync::RwLock::new(DnsListenerHandler::new(conf, netinfo).await?).into(),
928 })
929 }
930}