1use simple_dns::{CLASS, Packet, PacketFlag, RCODE, ResourceRecord, rdata};
11
12use simple_dns::{QTYPE, TYPE};
13use std::net::{Ipv6Addr, SocketAddr};
14use tokio::net::UdpSocket;
15use tokio::sync::watch;
16use tracing::{debug, info, trace, warn};
17
18use super::pool::{PoolEvent, VirtualIpPool};
19use crate::NodeAddr;
20
21const UPSTREAM_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
23
24const MAX_DNS_SIZE: usize = 4096;
26
27#[derive(Debug)]
29pub struct DnsAllocation {
30 pub node_addr: NodeAddr,
31 pub virtual_ip: Ipv6Addr,
32 pub mesh_addr: Ipv6Addr,
33 pub is_new: bool,
34}
35
36fn extract_fips_name(packet: &Packet) -> Option<String> {
39 let question = packet.questions.first()?;
40 let name = question.qname.to_string();
41 let lower = name.to_ascii_lowercase();
42 if lower.ends_with(".fips") || lower.ends_with(".fips.") {
43 Some(lower.trim_end_matches('.').to_string())
44 } else {
45 None
46 }
47}
48
49fn extract_aaaa(packet: &Packet) -> Option<Ipv6Addr> {
51 for answer in &packet.answers {
52 if let rdata::RData::AAAA(aaaa) = &answer.rdata {
53 return Some(aaaa.address.into());
54 }
55 }
56 None
57}
58
59fn node_addr_from_mesh(mesh_addr: Ipv6Addr) -> NodeAddr {
62 let bytes = mesh_addr.octets();
63 let mut node_bytes = [0u8; 16];
68 node_bytes[..15].copy_from_slice(&bytes[1..16]);
69 NodeAddr::from_bytes(node_bytes)
70}
71
72fn build_refused(query: &Packet) -> Option<Vec<u8>> {
74 let mut response = Packet::new_reply(query.id());
75 response.set_flags(PacketFlag::RESPONSE | PacketFlag::RECURSION_AVAILABLE);
76 *response.rcode_mut() = RCODE::Refused;
77 response.questions.clone_from(&query.questions);
78 response.build_bytes_vec_compressed().ok()
79}
80
81fn build_servfail(query: &Packet) -> Option<Vec<u8>> {
83 let mut response = Packet::new_reply(query.id());
84 response.set_flags(PacketFlag::RESPONSE | PacketFlag::RECURSION_AVAILABLE);
85 *response.rcode_mut() = RCODE::ServerFailure;
86 response.questions.clone_from(&query.questions);
87 response.build_bytes_vec_compressed().ok()
88}
89
90fn build_nodata(query: &Packet, ttl: u32) -> Option<Vec<u8>> {
93 let mut response = Packet::new_reply(query.id());
94 response.set_flags(PacketFlag::RESPONSE | PacketFlag::RECURSION_AVAILABLE);
95 response.questions.clone_from(&query.questions);
96
97 let question = query.questions.first()?;
100 let soa = rdata::RData::SOA(rdata::SOA {
101 mname: simple_dns::Name::new_unchecked("gateway.fips"),
102 rname: simple_dns::Name::new_unchecked("nobody.fips"),
103 serial: std::time::SystemTime::now()
104 .duration_since(std::time::UNIX_EPOCH)
105 .map(|d| d.as_secs() as u32)
106 .unwrap_or(1),
107 refresh: ttl as i32,
108 retry: ttl as i32,
109 expire: ttl as i32,
110 minimum: ttl,
111 });
112 let soa_record = ResourceRecord::new(question.qname.clone(), CLASS::IN, ttl, soa);
113 response.name_servers.push(soa_record);
114
115 response.build_bytes_vec_compressed().ok()
116}
117
118fn build_aaaa_response(query: &Packet, virtual_ip: Ipv6Addr, ttl: u32) -> Option<Vec<u8>> {
120 let question = query.questions.first()?;
121 let mut response = Packet::new_reply(query.id());
122 response.set_flags(PacketFlag::RESPONSE | PacketFlag::RECURSION_AVAILABLE);
123
124 response.questions.push(question.clone());
126
127 let aaaa = rdata::RData::AAAA(rdata::AAAA {
128 address: virtual_ip.into(),
129 });
130 let record = ResourceRecord::new(question.qname.clone(), CLASS::IN, ttl, aaaa);
131 response.answers.push(record);
132
133 response.build_bytes_vec_compressed().ok()
134}
135
136pub async fn run_dns_resolver(
141 listen_addr: &str,
142 upstream_addr: &str,
143 ttl: u32,
144 pool: std::sync::Arc<tokio::sync::Mutex<VirtualIpPool>>,
145 event_tx: tokio::sync::mpsc::Sender<PoolEvent>,
146 mut shutdown: watch::Receiver<bool>,
147) -> Result<(), std::io::Error> {
148 let socket = UdpSocket::bind(listen_addr).await?;
149 info!(addr = %listen_addr, "Gateway DNS resolver listening");
150
151 let upstream: SocketAddr = upstream_addr
152 .parse()
153 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
154
155 let mut buf = vec![0u8; MAX_DNS_SIZE];
156
157 loop {
158 tokio::select! {
159 result = socket.recv_from(&mut buf) => {
160 let (len, client_addr) = result?;
161 let query_bytes = &buf[..len];
162
163 let response = match handle_query(
164 query_bytes,
165 upstream,
166 ttl,
167 &pool,
168 &event_tx,
169 ).await {
170 Some(resp) => resp,
171 None => continue,
172 };
173
174 if let Err(e) = socket.send_to(&response, client_addr).await {
175 debug!(error = %e, "Failed to send DNS response");
176 }
177 }
178 _ = shutdown.changed() => {
179 info!("DNS resolver shutting down");
180 break;
181 }
182 }
183 }
184
185 Ok(())
186}
187
188async fn handle_query(
190 query_bytes: &[u8],
191 upstream: SocketAddr,
192 ttl: u32,
193 pool: &std::sync::Arc<tokio::sync::Mutex<VirtualIpPool>>,
194 event_tx: &tokio::sync::mpsc::Sender<PoolEvent>,
195) -> Option<Vec<u8>> {
196 let query = Packet::parse(query_bytes).ok()?;
197
198 let fips_name = match extract_fips_name(&query) {
200 Some(name) => name,
201 None => {
202 trace!(id = query.id(), "Non-.fips query, returning REFUSED");
203 return build_refused(&query);
204 }
205 };
206
207 debug!(name = %fips_name, id = query.id(), "Forwarding .fips query to daemon");
208
209 let upstream_query_bytes = {
213 let question = query.questions.first()?;
214 let mut aaaa_query = Packet::new_query(query.id());
215 let aaaa_question = simple_dns::Question::new(
216 question.qname.clone(),
217 QTYPE::TYPE(TYPE::AAAA),
218 question.qclass,
219 question.unicast_response,
220 );
221 aaaa_query.questions.push(aaaa_question);
222 match aaaa_query.build_bytes_vec_compressed() {
223 Ok(bytes) => bytes,
224 Err(_) => return build_servfail(&query),
225 }
226 };
227
228 let bind_addr = if upstream.is_ipv4() {
232 "0.0.0.0:0"
233 } else {
234 "[::]:0"
235 };
236 let upstream_socket = match UdpSocket::bind(bind_addr).await {
237 Ok(s) => s,
238 Err(e) => {
239 warn!(error = %e, "Failed to bind upstream socket");
240 return build_servfail(&query);
241 }
242 };
243
244 if let Err(e) = upstream_socket
245 .send_to(&upstream_query_bytes, upstream)
246 .await
247 {
248 warn!(error = %e, "Failed to forward query to daemon");
249 return build_servfail(&query);
250 }
251
252 let mut resp_buf = vec![0u8; MAX_DNS_SIZE];
253 let resp_len =
254 match tokio::time::timeout(UPSTREAM_TIMEOUT, upstream_socket.recv(&mut resp_buf)).await {
255 Ok(Ok(len)) => len,
256 Ok(Err(e)) => {
257 warn!(error = %e, "Upstream recv error");
258 return build_servfail(&query);
259 }
260 Err(_) => {
261 warn!("Upstream DNS timeout");
262 return build_servfail(&query);
263 }
264 };
265
266 let upstream_response = match Packet::parse(&resp_buf[..resp_len]) {
267 Ok(p) => p,
268 Err(_) => return build_servfail(&query),
269 };
270
271 if upstream_response.rcode() != RCODE::NoError {
274 debug!(
275 name = %fips_name,
276 rcode = ?upstream_response.rcode(),
277 "Upstream returned non-success"
278 );
279 let mut err_resp = Packet::new_reply(query.id());
280 err_resp.set_flags(PacketFlag::RESPONSE | PacketFlag::RECURSION_AVAILABLE);
281 *err_resp.rcode_mut() = upstream_response.rcode();
282 err_resp.questions.clone_from(&query.questions);
283 return err_resp.build_bytes_vec_compressed().ok();
284 }
285
286 let mesh_addr = match extract_aaaa(&upstream_response) {
288 Some(addr) => addr,
289 None => {
290 debug!(name = %fips_name, "No AAAA record in upstream response");
291 return build_servfail(&query);
292 }
293 };
294
295 let node_addr = node_addr_from_mesh(mesh_addr);
297
298 let mut pool_guard = pool.lock().await;
300 let (virtual_ip, is_new) = match pool_guard.allocate(node_addr, mesh_addr, &fips_name) {
301 Ok(result) => result,
302 Err(e) => {
303 warn!(error = %e, "Pool allocation failed");
304 return build_servfail(&query);
305 }
306 };
307 drop(pool_guard);
308
309 if is_new {
311 let event = PoolEvent::MappingCreated {
312 virtual_ip,
313 mesh_addr,
314 };
315 if let Err(e) = event_tx.send(event).await {
316 warn!(error = %e, "Failed to send pool event");
317 }
318 }
319
320 debug!(
321 name = %fips_name,
322 virtual_ip = %virtual_ip,
323 mesh_addr = %mesh_addr,
324 is_new,
325 "Resolved .fips query"
326 );
327
328 let client_qtype = query
333 .questions
334 .first()
335 .map(|q| q.qtype)
336 .unwrap_or(QTYPE::TYPE(TYPE::AAAA));
337
338 match client_qtype {
339 QTYPE::TYPE(TYPE::AAAA) | QTYPE::ANY => build_aaaa_response(&query, virtual_ip, ttl),
340 _ => build_nodata(&query, ttl),
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn test_node_addr_from_mesh() {
352 let mesh: Ipv6Addr = "fd00::1".parse().unwrap();
354 let node = node_addr_from_mesh(mesh);
355 let bytes = node.as_bytes();
356 assert_eq!(bytes[14], 1);
359 assert_eq!(bytes[0], 0);
360 }
361
362 #[test]
363 fn test_extract_fips_name() {
364 let mut packet = Packet::new_query(1);
366 use simple_dns::{Name, Question};
367 let name = Name::new_unchecked("test.fips");
368 let question = Question::new(name, QTYPE::TYPE(TYPE::AAAA), CLASS::IN.into(), false);
369 packet.questions.push(question);
370
371 let result = extract_fips_name(&packet);
372 assert_eq!(result, Some("test.fips".to_string()));
373 }
374
375 #[test]
376 fn test_extract_non_fips_name() {
377 let mut packet = Packet::new_query(1);
378 use simple_dns::{Name, Question};
379 let name = Name::new_unchecked("example.com");
380 let question = Question::new(name, QTYPE::TYPE(TYPE::AAAA), CLASS::IN.into(), false);
381 packet.questions.push(question);
382
383 assert!(extract_fips_name(&packet).is_none());
384 }
385
386 #[test]
387 fn test_build_aaaa_response() {
388 let mut query = Packet::new_query(42);
389 use simple_dns::{Name, Question};
390 let name = Name::new_unchecked("test.fips");
391 let question = Question::new(name, QTYPE::TYPE(TYPE::AAAA), CLASS::IN.into(), false);
392 query.questions.push(question);
393
394 let vip: Ipv6Addr = "fd01::1".parse().unwrap();
395 let response_bytes = build_aaaa_response(&query, vip, 60).unwrap();
396 let response = Packet::parse(&response_bytes).unwrap();
397
398 assert_eq!(response.id(), 42);
399 assert_eq!(response.answers.len(), 1);
400 if let rdata::RData::AAAA(aaaa) = &response.answers[0].rdata {
401 assert_eq!(Ipv6Addr::from(aaaa.address), vip);
402 } else {
403 panic!("Expected AAAA record");
404 }
405 }
406}