Skip to main content

erbium/dns/
mod.rs

1/*   Copyright 2024 Perry Lorier
2 *
3 *  Licensed under the Apache License, Version 2.0 (the "License");
4 *  you may not use this file except in compliance with the License.
5 *  You may obtain a copy of the License at
6 *
7 *      http://www.apache.org/licenses/LICENSE-2.0
8 *
9 *  Unless required by applicable law or agreed to in writing, software
10 *  distributed under the License is distributed on an "AS IS" BASIS,
11 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 *  See the License for the specific language governing permissions and
13 *  limitations under the License.
14 *
15 *  SPDX-License-Identifier: Apache-2.0
16 *
17 *  Infrastructure for DNS services.
18 */
19use erbium_net::addr::NetAddr;
20use erbium_net::udp;
21
22const HOURS_24: std::time::Duration = std::time::Duration::from_secs(24 * 3600);
23const HOURS_36: std::time::Duration = std::time::Duration::from_secs(36 * 3600);
24
25type UdpSocket = udp::UdpSocket;
26
27mod acl;
28mod bucket;
29mod cache;
30pub(crate) mod config;
31pub mod dnspkt;
32mod outquery;
33#[cfg(fuzzing)]
34pub mod parse;
35#[cfg(not(fuzzing))]
36mod parse;
37mod router;
38
39use bytes::BytesMut;
40use tokio_util::codec::Decoder;
41
42type Key = [u8; 8];
43
44struct CookieKeys {
45    next_refresh: tokio::time::Instant,
46    current: Key,
47    previous: Key,
48}
49
50impl CookieKeys {
51    fn new() -> Self {
52        Self {
53            next_refresh: tokio::time::Instant::now(),
54            current: Default::default(),
55            previous: Default::default(),
56        }
57        .rotate()
58        .rotate()
59    }
60
61    fn rotate(&self) -> Self {
62        use rand::{RngExt as _, TryRng as _};
63        let mut rng = rand::rngs::SysRng;
64
65        let next_refresh =
66            tokio::time::Instant::now() + rand::rng().random_range(HOURS_24..HOURS_36);
67
68        let mut current: Key = Default::default();
69        rng.try_fill_bytes(&mut current).unwrap();
70
71        Self {
72            next_refresh,
73            current,
74            previous: self.current,
75        }
76    }
77
78    fn needs_rotation(&self) -> bool {
79        self.next_refresh < tokio::time::Instant::now()
80    }
81
82    // Gets the current and previous cookie keys, rotating them if they've expired.
83    async fn get_keys(s: &tokio::sync::RwLock<Self>) -> (Key, Key) {
84        if s.read().await.needs_rotation() {
85            // TODO: This only does one rotation, it's possible both keys have expired, in which
86            // case we should rotate both.
87            let mut cookies = s.write().await;
88            *cookies = cookies.rotate();
89        }
90
91        let cookies = s.read().await;
92        (cookies.current, cookies.previous)
93    }
94
95    async fn get_current_key(s: &tokio::sync::RwLock<Self>) -> Key {
96        Self::get_keys(s).await.0
97    }
98}
99
100impl Default for CookieKeys {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106lazy_static::lazy_static! {
107    static ref IN_QUERY_LATENCY: prometheus::HistogramVec =
108        prometheus::register_histogram_vec!("dns_in_query_latency",
109            "DNS latency for in queries",
110            &["protocol"])
111        .unwrap();
112
113    /* Result is "RCode" or "RCode (EdeCode)" */
114    static ref IN_QUERY_RESULT: prometheus::IntCounterVec =
115        prometheus::register_int_counter_vec!("dns_in_query_result",
116            "DNS response codes for in queries",
117            &["protocol", "result"])
118        .unwrap();
119
120    static ref IN_QUERY_DROPPED: prometheus::IntCounter =
121        prometheus::register_int_counter!("dns_in_query_dropped",
122            "DNS queries dropped")
123        .unwrap();
124
125    static ref COOKIE_KEYS: tokio::sync::RwLock<CookieKeys> = Default::default();
126}
127
128#[cfg_attr(test, derive(Debug))]
129pub enum Error {
130    ListenError(std::io::Error, Box<erbium_net::addr::NetAddr>),
131    AcceptError(std::io::Error),
132    RecvError(std::io::Error),
133    ParseError(String),
134    RefusedByAcl(crate::acl::AclError),
135    Denied(String),
136    Blocked,
137    NoRouteConfigured,
138    NotAuthoritative,
139    OutReply(outquery::Error),
140}
141
142impl std::fmt::Display for Error {
143    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
144        use Error::*;
145        match self {
146            ListenError(io, addr) => write!(f, "Failed to listen for DNS on {}: {}", addr, io),
147            AcceptError(io) => write!(f, "Failed to accept new TCP connection for DNS: {}", io),
148            RecvError(io) => write!(f, "Failed to receive DNS in query: {}", io),
149            ParseError(msg) => write!(f, "Failed to parse DNS in query: {}", msg),
150            RefusedByAcl(why) => write!(f, "Query refused by policy: {}", why),
151            NotAuthoritative => write!(f, "Not Authoritative"),
152            Blocked => write!(f, "Blocked by configuration"),
153            NoRouteConfigured => write!(f, "No route configured"),
154            Denied(msg) => write!(f, "Denied: {}", msg),
155            OutReply(err) => write!(f, "{}", err),
156        }
157    }
158}
159
160// We want to rate limit some error codes (like REFUSED) to prevent being used in reflection
161// attacks.  We don't want to keep track of a whole bunch of IP addresses tho, so we do a variation
162// on a bloom filter.  We have N token buckets, we hash the IP into *two* of those buckets, and
163// then we try and take some tokens from which ever has more tokens available.  If neither bucket
164// has sufficient tokens available, then we fail.  This means for small amounts of fixed memory
165// we can have a pretty low false positive rate.
166type Bucket = tokio::sync::RwLock<bucket::GenericTokenBucket>;
167struct IpRateLimiter([Bucket; 256]);
168
169impl IpRateLimiter {
170    fn new() -> Self {
171        Self(std::array::from_fn(|_| {
172            Bucket::new(bucket::GenericTokenBucket::new())
173        }))
174    }
175
176    fn hash_ip(seed: u64, ip: std::net::IpAddr) -> usize {
177        use std::hash::Hash as _;
178        use std::hash::Hasher as _;
179        let mut hasher = std::collections::hash_map::DefaultHasher::new();
180        seed.hash(&mut hasher);
181        ip.hash(&mut hasher);
182        hasher.finish() as usize
183    }
184
185    async fn check(&self, ip: std::net::IpAddr, bytes: usize) -> bool {
186        // TODO: Base seeds on time, rotating every 60s or something.
187        // They probably should also be unique per process.
188        // Maybe each seed should be staggered in time.
189        const SEED1: u64 = 0x1234_5678_9ABC_DEF0;
190        const SEED2: u64 = 0x2345_6789_ABCD_EF01;
191
192        let hash1 = Self::hash_ip(SEED1, ip);
193        let hash2 = Self::hash_ip(SEED2, ip);
194
195        let bucket1 = hash1 % self.0.len();
196
197        /* Normally a read() lock like this, when converted to a write() should be tested again,
198         * however since the writes are commutative, and we're more worried about speed than exact
199         * precision this should be fine.
200         */
201        if self.0[bucket1]
202            .read()
203            .await
204            .check::<bucket::RealTimeClock>(bytes as u32)
205        {
206            self.0[bucket1]
207                .write()
208                .await
209                .deplete::<bucket::RealTimeClock>(bytes as u32);
210            true
211        } else {
212            let mut bucket2 = hash2 % (self.0.len() - 1);
213            if bucket2 == bucket1 {
214                bucket2 = self.0.len() - 1;
215            }
216
217            if self.0[bucket2]
218                .read()
219                .await
220                .check::<bucket::RealTimeClock>(bytes as u32)
221            {
222                self.0[bucket2]
223                    .write()
224                    .await
225                    .deplete::<bucket::RealTimeClock>(bytes as u32);
226                true
227            } else {
228                false
229            }
230        }
231    }
232}
233
234struct DnsCodec {}
235
236impl Decoder for DnsCodec {
237    type Item = dnspkt::DNSPkt;
238    type Error = std::io::Error;
239    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
240        let in_query = parse::PktParser::new(&src[..]).get_dns();
241        match in_query {
242            Ok(p) => Ok(Some(p)),
243            Err(e) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e)),
244        }
245    }
246}
247
248pub enum Protocol {
249    Udp,
250    Tcp,
251}
252
253impl std::fmt::Display for Protocol {
254    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
255        match &self {
256            Protocol::Udp => write!(f, "UDP"),
257            Protocol::Tcp => write!(f, "TCP"),
258        }
259    }
260}
261
262#[derive(Eq, PartialEq)]
263enum CookieStatus {
264    Missing,
265    Bad,
266    Good,
267}
268
269pub struct DnsMessage {
270    pub in_query: dnspkt::DNSPkt,
271    pub in_size: usize,
272    pub local_ip: std::net::IpAddr,
273    pub remote_addr: NetAddr,
274    pub protocol: Protocol,
275}
276
277type CookieDigest = hmac::Hmac<sha2::Sha256>;
278
279impl DnsMessage {
280    // Calculate the value of the cookie based on a key.
281    // This uses the client cookie, the source and dest ip addresses for generating the cookie.
282    fn calculate_cookie(&self, client: &[u8], key: &[u8]) -> CookieDigest {
283        use hmac::Mac as _;
284        // I'm not a crypto expert, but I am to understand that Hmac is the safest way to use a
285        // hash function to avoid length extension attacks.
286        let mut hasher =
287            CookieDigest::new_from_slice(key).expect("should always be able to create a key");
288        hasher.update(client);
289        match self.local_ip {
290            std::net::IpAddr::V4(v4) => hasher.update(&v4.octets()),
291            std::net::IpAddr::V6(v6) => hasher.update(&v6.octets()),
292        }
293        use erbium_net::addr::NetAddrExt as _;
294        match self.remote_addr.ip() {
295            Some(std::net::IpAddr::V4(v4)) => hasher.update(&v4.octets()),
296            Some(std::net::IpAddr::V6(v6)) => hasher.update(&v6.octets()),
297            _ => unreachable!(),
298        };
299
300        hasher
301    }
302
303    fn validate_cookie_key(&self, key: &[u8]) -> CookieStatus {
304        if let Some((client, Some(server))) = self
305            .in_query
306            .edns
307            .as_ref()
308            .and_then(|edns| edns.get_cookie())
309        {
310            use hmac::Mac as _;
311            let myserver = self.calculate_cookie(client, key);
312            if myserver.verify_slice(server).is_ok() {
313                CookieStatus::Good
314            } else {
315                CookieStatus::Bad
316            }
317        } else {
318            CookieStatus::Missing
319        }
320    }
321
322    // To support key rotation, we provide a new key, and an old key, we first
323    // check if they match using the new key, if so we accept it, if not, then
324    // we try again with the older key.
325    fn validate_cookie_keys(&self, key: &[u8], oldkey: &[u8]) -> CookieStatus {
326        match self.validate_cookie_key(key) {
327            CookieStatus::Bad => self.validate_cookie_key(oldkey),
328            status => status,
329        }
330    }
331
332    async fn calculate_current_cookie(&self, client: &[u8]) -> [u8; 32] {
333        use hmac::Mac as _;
334        let key = CookieKeys::get_current_key(&COOKIE_KEYS).await;
335        self.calculate_cookie(client, &key)
336            .finalize()
337            .into_bytes()
338            .as_slice()
339            .try_into()
340            .unwrap()
341    }
342
343    async fn validate_cookie(&self) -> CookieStatus {
344        let keys = CookieKeys::get_keys(&COOKIE_KEYS).await;
345        self.validate_cookie_keys(&keys.0, &keys.1)
346    }
347}
348
349struct DnsListenerHandler {
350    next: acl::DnsAclHandler,
351    udp_listeners: Vec<UdpSocket>,
352    tcp_listeners: Vec<tokio::net::TcpListener>,
353    rate_limiter: std::sync::Arc<IpRateLimiter>,
354}
355
356impl DnsListenerHandler {
357    async fn listen_udp(
358        _conf: &crate::config::SharedConfig,
359        addr: &erbium_net::addr::NetAddr,
360    ) -> Result<UdpSocket, Error> {
361        let mut count: i32 = 0;
362        let udp = loop {
363            match UdpSocket::bind(&[*addr]).await {
364                Ok(sock) => break sock,
365                Err(e) if e.kind() == std::io::ErrorKind::AddrNotAvailable => {
366                    // Due to duplicate address detection, the IPv6 address we're binding to might
367                    // still be in the "tentative" state, which prevents binding.  Retry a few
368                    // times with exponential backoff to see if it will become ready.
369                    //
370                    // Ideally we would just not bind to it, and get a signal later from netinfo
371                    // when it becomes ready and bind to it then, but that would require a massive
372                    // restructuring of netinfo.
373                    if count > 2 {
374                        return Err(Error::ListenError(e, Box::new(*addr)));
375                    }
376                    log::warn!(
377                        "Failed to bind DNS UDP to {} ({}): Retrying after {}s",
378                        addr,
379                        e,
380                        1 << count
381                    );
382                    tokio::time::sleep(std::time::Duration::from_secs(1 << count)).await;
383                    count += 1;
384                    continue;
385                }
386                Err(e) => return Err(Error::ListenError(e, Box::new(*addr))),
387            }
388        };
389
390        if addr.as_sockaddr_in6().is_some() {
391            udp.set_opt_ipv6_packet_info(true)
392                .map_err(|e| Error::ListenError(e, Box::new(*addr)))?
393        } else {
394            udp.set_opt_ipv4_packet_info(true)
395                .map_err(|e| Error::ListenError(e, Box::new(*addr)))?
396        }
397
398        log::info!(
399            "Listening for DNS on UDP {}",
400            udp.local_addr()
401                .map(|name| format!("{}", name))
402                .unwrap_or_else(|_| "Unknown".into())
403        );
404
405        Ok(udp)
406    }
407
408    async fn listen_tcp(
409        _conf: &crate::config::SharedConfig,
410        addr: &erbium_net::addr::NetAddr,
411    ) -> Result<tokio::net::TcpListener, Error> {
412        use erbium_net::addr::NetAddrExt as _;
413        let tcp = tokio::net::TcpListener::bind(addr.to_std_socket_addr().ok_or_else(|| {
414            Error::ListenError(std::io::ErrorKind::Unsupported.into(), Box::new(*addr))
415        })?)
416        .await
417        .map_err(|e| Error::ListenError(e, Box::new(*addr)))?;
418
419        log::info!(
420            "Listening for DNS on TCP {}",
421            tcp.local_addr()
422                .map(|name| format!("{}", name))
423                .unwrap_or_else(|_| "Unknown".into())
424        );
425
426        Ok(tcp)
427    }
428
429    async fn new(
430        conf: crate::config::SharedConfig,
431        netinfo: &erbium_net::netinfo::SharedNetInfo,
432    ) -> Result<Self, Error> {
433        let mut udp_listeners = vec![];
434        let mut tcp_listeners = vec![];
435        {
436            let roconf = conf.read().await;
437            for addr in &roconf
438                .dns_listeners
439                .as_sockaddrs(&roconf.addresses, netinfo, 53)
440                .await
441            {
442                udp_listeners.push(Self::listen_udp(&conf, addr).await?);
443                tcp_listeners.push(Self::listen_tcp(&conf, addr).await?);
444            }
445        }
446        let rate_limiter = IpRateLimiter::new().into();
447
448        Ok(Self {
449            next: acl::DnsAclHandler::new(conf).await,
450            udp_listeners,
451            tcp_listeners,
452            rate_limiter,
453        })
454    }
455
456    async fn add_edns(edns: &mut dnspkt::EdnsData, msg: &DnsMessage) {
457        // If they requested NSID, then return it.
458        if msg
459            .in_query
460            .edns
461            .as_ref()
462            .map(|edns| edns.get_nsid().is_some())
463            .unwrap_or(false)
464        {
465            // We fill in NSID with the receiving interface IP.
466            // TODO: This might not be particularly interesting if this is a VIP.  We might want to
467            // find some more useful information to put in here.
468            edns.set_nsid(format!("{}", msg.local_ip).as_bytes());
469        }
470
471        // Handle DNS COOKIE (RFC7873)
472        if let Some((client, _server)) = msg
473            .in_query
474            .edns
475            .as_ref()
476            .and_then(|edns| edns.get_cookie())
477        {
478            let server = msg.calculate_current_cookie(client).await;
479            edns.set_cookie(client, &server);
480        }
481    }
482
483    async fn create_in_reply(msg: &DnsMessage, outr: &dnspkt::DNSPkt) -> dnspkt::DNSPkt {
484        let mut edns: dnspkt::EdnsData = Default::default();
485        Self::add_edns(&mut edns, msg).await;
486        dnspkt::DNSPkt {
487            qid: msg.in_query.qid,
488            rd: false,
489            tc: outr.tc,
490            aa: outr.aa,
491            qr: true,
492            opcode: dnspkt::OPCODE_QUERY,
493
494            cd: outr.cd,
495            ad: outr.ad,
496            ra: outr.ra,
497
498            rcode: outr.rcode,
499
500            bufsize: 4096,
501
502            edns_ver: msg.in_query.edns_ver.map(|_| 0),
503            edns_do: false,
504
505            question: msg.in_query.question.clone(),
506            answer: outr.answer.clone(),
507            nameserver: outr.answer.clone(),
508            additional: outr.additional.clone(),
509            edns: Some(edns),
510        }
511    }
512
513    async fn create_in_error(msg: &DnsMessage, err: Error) -> dnspkt::DNSPkt {
514        use Error::*;
515        use dnspkt::*;
516        let mut edns: EdnsData = Default::default();
517        Self::add_edns(&mut edns, msg).await;
518        let rcode;
519        match err {
520            /* These errors mean we never get a packet to reply to. */
521            ListenError(..) => unreachable!(),
522            AcceptError(..) => unreachable!(),
523            RecvError(_) => unreachable!(),
524            ParseError(_) => unreachable!(),
525            RefusedByAcl(why) => {
526                rcode = REFUSED;
527                edns.set_extended_dns_error(EDE_PROHIBITED, &why.to_string());
528            }
529            Denied(why) => {
530                rcode = REFUSED;
531                edns.set_extended_dns_error(EDE_PROHIBITED, &why);
532            }
533            Blocked => {
534                rcode = NXDOMAIN;
535                edns.set_extended_dns_error(
536                    EDE_BLOCKED,
537                    "Server is configured to block these queries",
538                );
539            }
540            NotAuthoritative => {
541                rcode = REFUSED;
542                edns.set_extended_dns_error(EDE_NOT_AUTHORITATIVE, "Not Authoritative");
543            }
544            NoRouteConfigured => {
545                rcode = SERVFAIL;
546                edns.set_extended_dns_error(EDE_NOT_SUPPORTED, "No route configured for suffix");
547            }
548            OutReply(outquery::Error::Timeout) => {
549                rcode = SERVFAIL;
550                edns.set_extended_dns_error(
551                    EDE_NO_REACHABLE_AUTHORITY,
552                    "Timed out talking to upstream server",
553                );
554            }
555            OutReply(outquery::Error::FailedToSend(io)) => {
556                rcode = SERVFAIL;
557                edns.set_extended_dns_error(EDE_NETWORK_ERROR, &io.to_string());
558            }
559            OutReply(outquery::Error::FailedToSendMsg(msg)) => {
560                rcode = SERVFAIL;
561                edns.set_extended_dns_error(EDE_NETWORK_ERROR, &msg);
562            }
563            OutReply(outquery::Error::FailedToRecv(io)) => {
564                rcode = SERVFAIL;
565                edns.set_extended_dns_error(EDE_NETWORK_ERROR, &io.to_string());
566            }
567            OutReply(outquery::Error::FailedToRecvMsg(msg)) => {
568                rcode = SERVFAIL;
569                edns.set_extended_dns_error(EDE_NETWORK_ERROR, &msg);
570            }
571            OutReply(outquery::Error::TcpConnection(msg)) => {
572                rcode = SERVFAIL;
573                edns.set_extended_dns_error(EDE_NETWORK_ERROR, &msg);
574            }
575            OutReply(outquery::Error::Parse(msg)) => {
576                rcode = SERVFAIL;
577                edns.set_extended_dns_error(EDE_NETWORK_ERROR, &msg);
578            }
579            OutReply(outquery::Error::Internal(_)) => {
580                rcode = SERVFAIL;
581                edns.set_extended_dns_error(EDE_OTHER, "Internal Error");
582            }
583        }
584        dnspkt::DNSPkt {
585            qid: msg.in_query.qid,
586            rd: false,
587            tc: false,
588            aa: false,
589            qr: true,
590            opcode: dnspkt::OPCODE_QUERY,
591            cd: false,
592            ad: false,
593            ra: true,
594            rcode,
595            bufsize: 4096,
596            edns_ver: msg.in_query.edns_ver.map(|_| 0),
597            edns_do: false,
598
599            question: msg.in_query.question.clone(),
600            answer: vec![],
601            additional: vec![],
602            nameserver: vec![],
603            edns: Some(edns),
604        }
605    }
606
607    fn build_dns_message(
608        pkt: &[u8],
609        local_ip: std::net::IpAddr,
610        remote_addr: NetAddr,
611        protocol: Protocol,
612    ) -> Result<DnsMessage, Error> {
613        let in_query = parse::PktParser::new(pkt)
614            .get_dns()
615            .map_err(Error::ParseError)?;
616        Ok(DnsMessage {
617            in_query,
618            local_ip,
619            remote_addr,
620            protocol,
621            in_size: pkt.len(),
622        })
623    }
624
625    async fn recv_in_query(
626        s: &std::sync::Arc<tokio::sync::RwLock<Self>>,
627        msg: &DnsMessage,
628    ) -> Result<dnspkt::DNSPkt, std::convert::Infallible> {
629        log::trace!(
630            "[{:x}] In Query {}: {} ⇐ {}: {:?}",
631            msg.in_query.qid,
632            msg.protocol,
633            msg.local_ip,
634            msg.remote_addr,
635            msg.in_query
636        );
637        let next = &s.read().await.next;
638        let in_reply;
639        match next.handle_query(msg).await {
640            Ok(out_reply) => {
641                in_reply = Self::create_in_reply(msg, &out_reply).await;
642                IN_QUERY_RESULT
643                    .with_label_values(&[&msg.protocol.to_string(), &in_reply.status()])
644                    .inc();
645            }
646            Err(err) => {
647                in_reply = Self::create_in_error(msg, err).await;
648                IN_QUERY_RESULT
649                    .with_label_values(&[&msg.protocol.to_string(), &in_reply.status()])
650                    .inc();
651            }
652        }
653        log::trace!("[{:x}] In Reply: {:?}", msg.in_query.qid, in_reply);
654        Ok(in_reply)
655    }
656
657    async fn should_ratelimit(
658        msg: &DnsMessage,
659        in_reply: &dnspkt::DNSPkt,
660        in_reply_serialised: &[u8],
661        rate_limiter: &IpRateLimiter,
662    ) -> bool {
663        // Currently we only ratelimit REFUSEDs.
664        if in_reply.rcode != dnspkt::REFUSED {
665            return false;
666        }
667
668        match msg.validate_cookie().await {
669            CookieStatus::Good => {
670                // If we can tell it's not spoofed, don't ratelimit.
671                log::trace!("[{:x}] Cookie status: Good", msg.in_query.qid);
672                return false;
673            }
674            CookieStatus::Bad => {
675                log::trace!("[{:x}] Cookie status: Bad", msg.in_query.qid);
676            }
677            CookieStatus::Missing => {
678                log::trace!("[{:x}] Cookie status: Missing", msg.in_query.qid);
679            }
680        }
681
682        // For each byte larger than the incoming request, we charge it at 2× the cost.
683        // For each byte smaller or equal than the incoming request, we charge it at 1× the cost.
684        // But always charge at least 200.
685        let cost = std::cmp::max(
686            (in_reply_serialised.len() * 2).saturating_sub(msg.in_size),
687            200,
688        );
689
690        use erbium_net::addr::NetAddrExt as _;
691
692        // We bill this to the remote address.
693        // TODO: Should we bill this to the subnet?  Eg, /56 for v6 and /24 for v4?
694        !rate_limiter
695            .check(msg.remote_addr.ip().unwrap(), cost)
696            .await
697    }
698
699    async fn run_udp(
700        listener: &std::sync::Arc<UdpSocket>,
701        s: &std::sync::Arc<tokio::sync::RwLock<Self>>,
702    ) -> Result<(), Error> {
703        let local_rate_limiter;
704        {
705            let local_self = s.read().await;
706            local_rate_limiter = local_self.rate_limiter.clone();
707        }
708        let rm = match listener.recv_msg(4096, udp::MsgFlags::empty()).await {
709            Ok(rm) => rm,
710            Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
711            Err(err) if err.kind() == std::io::ErrorKind::Interrupted => return Ok(()),
712            Err(err) => return Err(Error::RecvError(err)),
713        };
714        let timer = IN_QUERY_LATENCY.with_label_values(&["UDP"]).start_timer();
715
716        let q = s.clone();
717        let local_listener = listener.clone();
718
719        log::trace!(
720            "Received UDP {:?} ⇒ {:?} ({})",
721            rm.address,
722            rm.local_ip(),
723            rm.buffer.len()
724        );
725
726        tokio::spawn(async move {
727            match Self::build_dns_message(
728                &rm.buffer,
729                rm.local_ip().unwrap(), /* TODO: Error? */
730                rm.address.unwrap(),    /* TODO: Error? */
731                Protocol::Udp,
732            ) {
733                Ok(msg) => {
734                    let in_reply = Self::recv_in_query(&q, &msg).await.unwrap();
735                    let in_reply_bytes = in_reply.serialise();
736                    if !Self::should_ratelimit(
737                        &msg,
738                        &in_reply,
739                        &in_reply_bytes,
740                        &local_rate_limiter,
741                    )
742                    .await
743                    {
744                        let cmsg = udp::ControlMessage::new().set_send_from(rm.local_ip());
745                        local_listener
746                            .send_msg(
747                                in_reply_bytes.as_slice(),
748                                &cmsg,
749                                udp::MsgFlags::empty(),
750                                Some(&rm.address.unwrap()), /* TODO: Error? */
751                            )
752                            .await
753                            .expect("Failed to send reply"); // TODO: Better error handling
754                    } else {
755                        IN_QUERY_DROPPED.inc();
756                        log::warn!("[{:x}] Not Sending Reply: Rate Limit", msg.in_query.qid);
757                    }
758                }
759                Err(err) => {
760                    log::warn!("Failed to handle request: {}", err);
761                    IN_QUERY_RESULT
762                        .with_label_values(&["UDP", "parse fail"])
763                        .inc();
764                }
765            }
766            drop(timer);
767        });
768        Ok(())
769    }
770
771    fn prepare_to_send(pkt: &dnspkt::DNSPkt, size: usize) -> Vec<u8> {
772        let size = std::cmp::max(size, 512);
773        pkt.serialise_with_size(size)
774    }
775
776    async fn run_tcp(
777        s: &std::sync::Arc<tokio::sync::RwLock<Self>>,
778        mut sock: tokio::net::TcpStream,
779        sock_addr: NetAddr,
780    ) -> Result<(), Error> {
781        use tokio::io::AsyncReadExt as _;
782
783        log::trace!(
784            "Received TCP connection {:?} ⇒ {:?}",
785            sock_addr,
786            sock.local_addr().unwrap(), /* TODO: Error? */
787        );
788
789        let mut lbytes = [0u8; 2];
790
791        if sock.read(&mut lbytes).await.map_err(Error::RecvError)? != lbytes.len() {
792            return Err(Error::ParseError("Failed to read length".into()));
793        }
794
795        let l = u16::from_be_bytes(lbytes) as usize;
796        let mut buffer = vec![0u8; l];
797
798        sock.read_exact(&mut buffer[..])
799            .await
800            .map_err(Error::RecvError)?;
801        let timer = IN_QUERY_LATENCY.with_label_values(&["TCP"]).start_timer();
802
803        let q = s.clone();
804
805        log::trace!(
806            "Received TCP {:?} ⇒ {:?} ({})",
807            sock_addr,
808            sock.local_addr(),
809            buffer.len()
810        );
811
812        tokio::spawn(async move {
813            use tokio::io::AsyncWriteExt as _;
814            match Self::build_dns_message(
815                &buffer,
816                sock.local_addr().ok().map(|addr| addr.ip()).unwrap(), /* TODO: Error? */
817                sock_addr,
818                Protocol::Tcp,
819            ) {
820                Ok(msg) => {
821                    let in_reply = Self::recv_in_query(&q, &msg).await.unwrap();
822                    let serialised =
823                        Self::prepare_to_send(&in_reply, msg.in_query.bufsize as usize);
824                    let mut in_reply_bytes = Vec::with_capacity(2 + serialised.len());
825                    in_reply_bytes.extend((serialised.len() as u16).to_be_bytes().iter());
826                    in_reply_bytes.extend(serialised);
827                    if let Err(io) = sock.write(&in_reply_bytes).await {
828                        log::warn!("[{:x}] Failed to send DNS reply: {}", msg.in_query.qid, io);
829                        IN_QUERY_RESULT
830                            .with_label_values(&["TCP", "send fail"])
831                            .inc();
832                    }
833                    drop(timer);
834                }
835                Err(err) => {
836                    IN_QUERY_RESULT
837                        .with_label_values(&["TCP", "parse fail"])
838                        .inc();
839                    log::warn!("Failed to handle request: {}", err);
840                }
841            }
842        });
843
844        Ok(())
845    }
846
847    async fn run_tcp_listener(
848        tcp: &tokio::net::TcpListener,
849        s: &std::sync::Arc<tokio::sync::RwLock<Self>>,
850    ) -> Result<(), Error> {
851        let (sock, sock_addr) = tcp.accept().await.map_err(Error::AcceptError)?;
852        let local_s = s.clone();
853
854        tokio::spawn(async move { Self::run_tcp(&local_s, sock, sock_addr.into()).await });
855
856        Ok(())
857    }
858
859    async fn run(s: &std::sync::Arc<tokio::sync::RwLock<Self>>) -> Result<(), Error> {
860        use futures::StreamExt as _;
861        let mut services = futures::stream::FuturesUnordered::new();
862        let mut my_self = s.write().await;
863        for listener in my_self.udp_listeners.drain(..) {
864            let s_clone = s.clone();
865            services.push(tokio::spawn(async move {
866                let shared_listener = listener.into();
867                loop {
868                    match Self::run_udp(&shared_listener, &s_clone).await {
869                        Ok(()) => (),
870                        Err(err) => {
871                            log::warn!(
872                                "{}: {}",
873                                shared_listener
874                                    .local_addr()
875                                    .map(|a| format!("{}", a))
876                                    .unwrap_or_else(|e| format!("<unknown: {}>", e)),
877                                err
878                            )
879                        }
880                    }
881                }
882            }));
883        }
884        for listener in my_self.tcp_listeners.drain(..) {
885            let s_clone = s.clone();
886            services.push(tokio::spawn(async move {
887                loop {
888                    match Self::run_tcp_listener(&listener, &s_clone).await {
889                        Ok(()) => (),
890                        Err(err) => {
891                            log::warn!(
892                                "{}: {}",
893                                listener
894                                    .local_addr()
895                                    .map(|a| format!("{}", a))
896                                    .unwrap_or_else(|e| format!("<unknown: {}>", e)),
897                                err
898                            )
899                        }
900                    }
901                }
902            }));
903        }
904
905        drop(my_self);
906
907        services.next().await.unwrap().unwrap()
908    }
909}
910
911pub struct DnsService {
912    next: std::sync::Arc<tokio::sync::RwLock<DnsListenerHandler>>,
913}
914
915impl DnsService {
916    pub async fn run(self) -> Result<(), Error> {
917        loop {
918            DnsListenerHandler::run(&self.next).await?;
919        }
920    }
921
922    pub async fn new(
923        conf: crate::config::SharedConfig,
924        netinfo: &erbium_net::netinfo::SharedNetInfo,
925    ) -> Result<Self, Error> {
926        Ok(Self {
927            next: tokio::sync::RwLock::new(DnsListenerHandler::new(conf, netinfo).await?).into(),
928        })
929    }
930}