adns_server/server/
mod.rs1use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
2
3use adns_zone::Zone;
4use arc_swap::{ArcSwap, Guard};
5use log::{debug, error, info};
6use tokio::{
7 io::{AsyncReadExt, AsyncWriteExt},
8 net::{TcpListener, TcpStream, UdpSocket},
9 sync::mpsc,
10 task::JoinHandle,
11};
12
13use crate::{metrics, ZoneProvider, ZoneProviderUpdate};
14
15pub struct Server {
16 udp_bind: SocketAddr,
17 tcp_bind: SocketAddr,
18 receiver: mpsc::Receiver<Zone>,
19 update_sender: mpsc::Sender<ZoneProviderUpdate>,
20 current_zone: Arc<ArcSwap<Zone>>,
21}
22
23mod respond;
24mod respond_update;
25
26async fn tcp_transaction(
27 client: &mut TcpStream,
28 updater: &mpsc::Sender<ZoneProviderUpdate>,
29 from: &str,
30 zone: &Zone,
31) -> Result<(), std::io::Error> {
32 let len = client.read_u16().await?;
33 let mut response = vec![0u8; len as usize];
34 client.read_exact(&mut response).await?;
35 if let Some(response) = respond::respond(true, zone, updater, from, &response).await {
36 let response = response.serialize(zone, u16::MAX as usize);
37 for response in response {
38 client.write_u16(response.len() as u16).await?;
39 client.write_all(&response).await?;
40 }
41 }
42 Ok(())
43}
44
45async fn tcp_connection(
46 mut client: TcpStream,
47 updater: mpsc::Sender<ZoneProviderUpdate>,
48 from: &str,
49 zone: Guard<Arc<Zone>>,
50) -> Result<(), std::io::Error> {
51 metrics::TCP_CONNECTIONS.with_label_values(&[from]).inc();
52 defer_lite::defer! {
53 metrics::TCP_CONNECTIONS.with_label_values(&[from]).dec();
54 };
55 loop {
56 match tokio::time::timeout(
57 Duration::from_secs(30),
58 tcp_transaction(&mut client, &updater, from, &zone),
59 )
60 .await
61 {
62 Ok(Ok(())) => (),
63 Ok(Err(e)) => return Err(e),
64 Err(_) => {
65 return Err(std::io::Error::new(
66 ErrorKind::TimedOut,
67 "dns transaction timed out",
68 ))
69 }
70 }
71 }
72}
73
74impl Server {
75 pub fn new(
76 udp_bind: SocketAddr,
77 tcp_bind: SocketAddr,
78 mut zone_provider: impl ZoneProvider,
79 ) -> Self {
80 let (sender, receiver) = mpsc::channel(2);
81 let (update_sender, update_receiver) = mpsc::channel(2);
82 tokio::spawn(async move { zone_provider.run(sender, update_receiver).await });
83 Self {
84 udp_bind,
85 tcp_bind,
86 receiver,
87 update_sender,
88 current_zone: Arc::new(ArcSwap::new(Arc::new(Zone::default()))),
89 }
90 }
91
92 pub async fn run(mut self) {
93 info!("Waiting for initial zone load...");
94 match self.receiver.recv().await {
95 Some(zone) => {
96 self.current_zone.store(Arc::new(zone));
97 }
98 None => {
99 error!("Zone provider died before giving us an initial zone");
100 return;
101 }
102 }
103 info!("Initial zone loaded");
104 let udp = match UdpSocket::bind(self.udp_bind).await {
105 Ok(x) => Arc::new(x),
106 Err(e) => {
107 error!("failed to bind to UDP port: {e}");
108 return;
109 }
110 };
111 info!("Listening on {} (UDP)", self.udp_bind);
112 let mut futures: Vec<JoinHandle<()>> = vec![];
113 let current_zone = self.current_zone.clone();
114 let mut receiver = self.receiver;
115 futures.push(tokio::spawn(async move {
116 while let Some(zone) = receiver.recv().await {
117 info!("updating zone...");
118 current_zone.store(Arc::new(zone));
119 }
120 }));
121 let current_zone = self.current_zone.clone();
122 let updater = self.update_sender.clone();
123 futures.push(tokio::spawn(async move {
124 loop {
125 let mut recv_buf = vec![0u8; 512];
126 let (size, from) = match udp.recv_from(&mut recv_buf[..]).await {
127 Ok(x) => x,
128 Err(e) => {
129 error!("udp server failure: {e}");
130 break;
131 }
132 };
133 recv_buf.truncate(size);
134 let zone = current_zone.load();
135 let udp = udp.clone();
136 let updater = updater.clone();
137 tokio::spawn(async move {
138 match respond::respond(false, &zone, &updater, &from.to_string(), &recv_buf)
139 .await
140 {
141 Some(packet) => {
142 let serialized = packet.serialize(&zone, 512);
143 if serialized.len() != 1 {
144 error!("cannot send more than one packet for udp!");
145 return;
146 }
147 if let Err(e) = udp.send_to(&serialized[0], from).await {
148 debug!("UDP send_to error: {e}");
149 }
150 }
151 None => {
152 debug!("packet had no response issued");
153 }
154 }
155 });
156 }
157 }));
158 let tcp = match TcpListener::bind(self.tcp_bind).await {
159 Ok(x) => x,
160 Err(e) => {
161 error!("failed to bind to TCP port: {e}");
162 return;
163 }
164 };
165 info!("Listening on {} (TCP)", self.tcp_bind);
166 let current_zone = self.current_zone.clone();
167 let updater = self.update_sender.clone();
168 futures.push(tokio::spawn(async move {
169 while let Ok((client, from)) = tcp.accept().await {
170 let zone = current_zone.load();
171 let updater = updater.clone();
172 tokio::spawn(async move {
173 if let Err(e) = tcp_connection(client, updater, &from.to_string(), zone).await {
174 debug!("TCP connection error: {e}");
175 }
176 });
177 }
178 }));
179 let _ = futures::future::select_all(&mut futures).await;
180 }
181}