adns_server/server/
mod.rs

1use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
2
3use adns_zone::Zone;
4use arc_swap::{ArcSwap, Guard};
5use log::{debug, error, info};
6use tokio::{
7    io::{AsyncReadExt, AsyncWriteExt},
8    net::{TcpListener, TcpStream, UdpSocket},
9    sync::mpsc,
10    task::JoinHandle,
11};
12
13use crate::{metrics, ZoneProvider, ZoneProviderUpdate};
14
15pub struct Server {
16    udp_bind: SocketAddr,
17    tcp_bind: SocketAddr,
18    receiver: mpsc::Receiver<Zone>,
19    update_sender: mpsc::Sender<ZoneProviderUpdate>,
20    current_zone: Arc<ArcSwap<Zone>>,
21}
22
23mod respond;
24mod respond_update;
25
26async fn tcp_transaction(
27    client: &mut TcpStream,
28    updater: &mpsc::Sender<ZoneProviderUpdate>,
29    from: &str,
30    zone: &Zone,
31) -> Result<(), std::io::Error> {
32    let len = client.read_u16().await?;
33    let mut response = vec![0u8; len as usize];
34    client.read_exact(&mut response).await?;
35    if let Some(response) = respond::respond(true, zone, updater, from, &response).await {
36        let response = response.serialize(zone, u16::MAX as usize);
37        for response in response {
38            client.write_u16(response.len() as u16).await?;
39            client.write_all(&response).await?;
40        }
41    }
42    Ok(())
43}
44
45async fn tcp_connection(
46    mut client: TcpStream,
47    updater: mpsc::Sender<ZoneProviderUpdate>,
48    from: &str,
49    zone: Guard<Arc<Zone>>,
50) -> Result<(), std::io::Error> {
51    metrics::TCP_CONNECTIONS.with_label_values(&[from]).inc();
52    defer_lite::defer! {
53        metrics::TCP_CONNECTIONS.with_label_values(&[from]).dec();
54    };
55    loop {
56        match tokio::time::timeout(
57            Duration::from_secs(30),
58            tcp_transaction(&mut client, &updater, from, &zone),
59        )
60        .await
61        {
62            Ok(Ok(())) => (),
63            Ok(Err(e)) => return Err(e),
64            Err(_) => {
65                return Err(std::io::Error::new(
66                    ErrorKind::TimedOut,
67                    "dns transaction timed out",
68                ))
69            }
70        }
71    }
72}
73
74impl Server {
75    pub fn new(
76        udp_bind: SocketAddr,
77        tcp_bind: SocketAddr,
78        mut zone_provider: impl ZoneProvider,
79    ) -> Self {
80        let (sender, receiver) = mpsc::channel(2);
81        let (update_sender, update_receiver) = mpsc::channel(2);
82        tokio::spawn(async move { zone_provider.run(sender, update_receiver).await });
83        Self {
84            udp_bind,
85            tcp_bind,
86            receiver,
87            update_sender,
88            current_zone: Arc::new(ArcSwap::new(Arc::new(Zone::default()))),
89        }
90    }
91
92    pub async fn run(mut self) {
93        info!("Waiting for initial zone load...");
94        match self.receiver.recv().await {
95            Some(zone) => {
96                self.current_zone.store(Arc::new(zone));
97            }
98            None => {
99                error!("Zone provider died before giving us an initial zone");
100                return;
101            }
102        }
103        info!("Initial zone loaded");
104        let udp = match UdpSocket::bind(self.udp_bind).await {
105            Ok(x) => Arc::new(x),
106            Err(e) => {
107                error!("failed to bind to UDP port: {e}");
108                return;
109            }
110        };
111        info!("Listening on {} (UDP)", self.udp_bind);
112        let mut futures: Vec<JoinHandle<()>> = vec![];
113        let current_zone = self.current_zone.clone();
114        let mut receiver = self.receiver;
115        futures.push(tokio::spawn(async move {
116            while let Some(zone) = receiver.recv().await {
117                info!("updating zone...");
118                current_zone.store(Arc::new(zone));
119            }
120        }));
121        let current_zone = self.current_zone.clone();
122        let updater = self.update_sender.clone();
123        futures.push(tokio::spawn(async move {
124            loop {
125                let mut recv_buf = vec![0u8; 512];
126                let (size, from) = match udp.recv_from(&mut recv_buf[..]).await {
127                    Ok(x) => x,
128                    Err(e) => {
129                        error!("udp server failure: {e}");
130                        break;
131                    }
132                };
133                recv_buf.truncate(size);
134                let zone = current_zone.load();
135                let udp = udp.clone();
136                let updater = updater.clone();
137                tokio::spawn(async move {
138                    match respond::respond(false, &zone, &updater, &from.to_string(), &recv_buf)
139                        .await
140                    {
141                        Some(packet) => {
142                            let serialized = packet.serialize(&zone, 512);
143                            if serialized.len() != 1 {
144                                error!("cannot send more than one packet for udp!");
145                                return;
146                            }
147                            if let Err(e) = udp.send_to(&serialized[0], from).await {
148                                debug!("UDP send_to error: {e}");
149                            }
150                        }
151                        None => {
152                            debug!("packet had no response issued");
153                        }
154                    }
155                });
156            }
157        }));
158        let tcp = match TcpListener::bind(self.tcp_bind).await {
159            Ok(x) => x,
160            Err(e) => {
161                error!("failed to bind to TCP port: {e}");
162                return;
163            }
164        };
165        info!("Listening on {} (TCP)", self.tcp_bind);
166        let current_zone = self.current_zone.clone();
167        let updater = self.update_sender.clone();
168        futures.push(tokio::spawn(async move {
169            while let Ok((client, from)) = tcp.accept().await {
170                let zone = current_zone.load();
171                let updater = updater.clone();
172                tokio::spawn(async move {
173                    if let Err(e) = tcp_connection(client, updater, &from.to_string(), zone).await {
174                        debug!("TCP connection error: {e}");
175                    }
176                });
177            }
178        }));
179        let _ = futures::future::select_all(&mut futures).await;
180    }
181}