Skip to main content

fips_core/gateway/
dns.rs

1//! Gateway DNS resolver.
2//!
3//! Forwarding proxy that handles `.fips` queries from LAN hosts,
4//! forwards them to the FIPS daemon resolver (localhost:5354),
5//! and returns virtual IP addresses from the pool.
6//!
7//! The daemon resolver populates its identity cache as a side effect
8//! of resolution, which is required for fips0 routing to work.
9
10use simple_dns::{CLASS, Packet, PacketFlag, RCODE, ResourceRecord, rdata};
11
12use simple_dns::{QTYPE, TYPE};
13use std::net::{Ipv6Addr, SocketAddr};
14use tokio::net::UdpSocket;
15use tokio::sync::watch;
16use tracing::{debug, info, trace, warn};
17
18use super::pool::{PoolEvent, VirtualIpPool};
19use crate::NodeAddr;
20
21/// Timeout for upstream DNS queries.
22const UPSTREAM_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
23
24/// Maximum DNS packet size.
25const MAX_DNS_SIZE: usize = 4096;
26
27/// Events emitted by the DNS resolver.
28#[derive(Debug)]
29pub struct DnsAllocation {
30    pub node_addr: NodeAddr,
31    pub virtual_ip: Ipv6Addr,
32    pub mesh_addr: Ipv6Addr,
33    pub is_new: bool,
34}
35
36/// Extract the `.fips` query name from a DNS packet.
37/// Returns Some(name) if the query is for a `.fips` domain, None otherwise.
38fn extract_fips_name(packet: &Packet) -> Option<String> {
39    let question = packet.questions.first()?;
40    let name = question.qname.to_string();
41    let lower = name.to_ascii_lowercase();
42    if lower.ends_with(".fips") || lower.ends_with(".fips.") {
43        Some(lower.trim_end_matches('.').to_string())
44    } else {
45        None
46    }
47}
48
49/// Extract the AAAA (IPv6) address from a DNS response.
50fn extract_aaaa(packet: &Packet) -> Option<Ipv6Addr> {
51    for answer in &packet.answers {
52        if let rdata::RData::AAAA(aaaa) = &answer.rdata {
53            return Some(aaaa.address.into());
54        }
55    }
56    None
57}
58
59/// Derive NodeAddr from a FIPS mesh address (fd00::/8).
60/// The NodeAddr is bytes 1-15 of the IPv6 address prepended with the first byte.
61fn node_addr_from_mesh(mesh_addr: Ipv6Addr) -> NodeAddr {
62    let bytes = mesh_addr.octets();
63    // NodeAddr = first 16 bytes of SHA-256(pubkey), which maps to
64    // FipsAddress = fd + NodeAddr[1..16]. So NodeAddr[0] = bytes[1].
65    // Actually, FipsAddress = [0xfd, nodeaddr[0..15]]
66    // So nodeaddr[0..15] = bytes[1..16]
67    let mut node_bytes = [0u8; 16];
68    node_bytes[..15].copy_from_slice(&bytes[1..16]);
69    NodeAddr::from_bytes(node_bytes)
70}
71
72/// Build a REFUSED DNS response.
73fn build_refused(query: &Packet) -> Option<Vec<u8>> {
74    let mut response = Packet::new_reply(query.id());
75    response.set_flags(PacketFlag::RESPONSE | PacketFlag::RECURSION_AVAILABLE);
76    *response.rcode_mut() = RCODE::Refused;
77    response.questions.clone_from(&query.questions);
78    response.build_bytes_vec_compressed().ok()
79}
80
81/// Build a SERVFAIL DNS response.
82fn build_servfail(query: &Packet) -> Option<Vec<u8>> {
83    let mut response = Packet::new_reply(query.id());
84    response.set_flags(PacketFlag::RESPONSE | PacketFlag::RECURSION_AVAILABLE);
85    *response.rcode_mut() = RCODE::ServerFailure;
86    response.questions.clone_from(&query.questions);
87    response.build_bytes_vec_compressed().ok()
88}
89
90/// Build a NODATA response (NOERROR with no answer records).
91/// Signals "this name exists but has no records of the requested type".
92fn build_nodata(query: &Packet, ttl: u32) -> Option<Vec<u8>> {
93    let mut response = Packet::new_reply(query.id());
94    response.set_flags(PacketFlag::RESPONSE | PacketFlag::RECURSION_AVAILABLE);
95    response.questions.clone_from(&query.questions);
96
97    // Add a minimal SOA in the authority section (RFC 2308 §2.2).
98    // This tells the client how long to cache the negative answer.
99    let question = query.questions.first()?;
100    let soa = rdata::RData::SOA(rdata::SOA {
101        mname: simple_dns::Name::new_unchecked("gateway.fips"),
102        rname: simple_dns::Name::new_unchecked("nobody.fips"),
103        serial: std::time::SystemTime::now()
104            .duration_since(std::time::UNIX_EPOCH)
105            .map(|d| d.as_secs() as u32)
106            .unwrap_or(1),
107        refresh: ttl as i32,
108        retry: ttl as i32,
109        expire: ttl as i32,
110        minimum: ttl,
111    });
112    let soa_record = ResourceRecord::new(question.qname.clone(), CLASS::IN, ttl, soa);
113    response.name_servers.push(soa_record);
114
115    response.build_bytes_vec_compressed().ok()
116}
117
118/// Build an AAAA response with the given virtual IP.
119fn build_aaaa_response(query: &Packet, virtual_ip: Ipv6Addr, ttl: u32) -> Option<Vec<u8>> {
120    let question = query.questions.first()?;
121    let mut response = Packet::new_reply(query.id());
122    response.set_flags(PacketFlag::RESPONSE | PacketFlag::RECURSION_AVAILABLE);
123
124    // Echo the question section (required by RFC 1035 §4.1.1)
125    response.questions.push(question.clone());
126
127    let aaaa = rdata::RData::AAAA(rdata::AAAA {
128        address: virtual_ip.into(),
129    });
130    let record = ResourceRecord::new(question.qname.clone(), CLASS::IN, ttl, aaaa);
131    response.answers.push(record);
132
133    response.build_bytes_vec_compressed().ok()
134}
135
136/// Run the gateway DNS resolver.
137///
138/// Listens for DNS queries, forwards `.fips` queries to the upstream
139/// daemon resolver, allocates virtual IPs, and returns them to clients.
140pub async fn run_dns_resolver(
141    listen_addr: &str,
142    upstream_addr: &str,
143    ttl: u32,
144    pool: std::sync::Arc<tokio::sync::Mutex<VirtualIpPool>>,
145    event_tx: tokio::sync::mpsc::Sender<PoolEvent>,
146    mut shutdown: watch::Receiver<bool>,
147) -> Result<(), std::io::Error> {
148    let socket = UdpSocket::bind(listen_addr).await?;
149    info!(addr = %listen_addr, "Gateway DNS resolver listening");
150
151    let upstream: SocketAddr = upstream_addr
152        .parse()
153        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
154
155    let mut buf = vec![0u8; MAX_DNS_SIZE];
156
157    loop {
158        tokio::select! {
159            result = socket.recv_from(&mut buf) => {
160                let (len, client_addr) = result?;
161                let query_bytes = &buf[..len];
162
163                let response = match handle_query(
164                    query_bytes,
165                    upstream,
166                    ttl,
167                    &pool,
168                    &event_tx,
169                ).await {
170                    Some(resp) => resp,
171                    None => continue,
172                };
173
174                if let Err(e) = socket.send_to(&response, client_addr).await {
175                    debug!(error = %e, "Failed to send DNS response");
176                }
177            }
178            _ = shutdown.changed() => {
179                info!("DNS resolver shutting down");
180                break;
181            }
182        }
183    }
184
185    Ok(())
186}
187
188/// Handle a single DNS query. Returns the response bytes to send back.
189async fn handle_query(
190    query_bytes: &[u8],
191    upstream: SocketAddr,
192    ttl: u32,
193    pool: &std::sync::Arc<tokio::sync::Mutex<VirtualIpPool>>,
194    event_tx: &tokio::sync::mpsc::Sender<PoolEvent>,
195) -> Option<Vec<u8>> {
196    let query = Packet::parse(query_bytes).ok()?;
197
198    // Check if this is a .fips query
199    let fips_name = match extract_fips_name(&query) {
200        Some(name) => name,
201        None => {
202            trace!(id = query.id(), "Non-.fips query, returning REFUSED");
203            return build_refused(&query);
204        }
205    };
206
207    debug!(name = %fips_name, id = query.id(), "Forwarding .fips query to daemon");
208
209    // Build an AAAA query for the daemon regardless of what the client asked
210    // (A, AAAA, ANY, etc.).  Mesh addresses are always IPv6, so the daemon
211    // only returns useful answers for AAAA queries.
212    let upstream_query_bytes = {
213        let question = query.questions.first()?;
214        let mut aaaa_query = Packet::new_query(query.id());
215        let aaaa_question = simple_dns::Question::new(
216            question.qname.clone(),
217            QTYPE::TYPE(TYPE::AAAA),
218            question.qclass,
219            question.unicast_response,
220        );
221        aaaa_query.questions.push(aaaa_question);
222        match aaaa_query.build_bytes_vec_compressed() {
223            Ok(bytes) => bytes,
224            Err(_) => return build_servfail(&query),
225        }
226    };
227
228    // Forward to upstream daemon resolver.
229    // Bind to the same address family as the upstream to avoid dual-stack issues
230    // (OpenWrt often has net.ipv6.bindv6only=1).
231    let bind_addr = if upstream.is_ipv4() {
232        "0.0.0.0:0"
233    } else {
234        "[::]:0"
235    };
236    let upstream_socket = match UdpSocket::bind(bind_addr).await {
237        Ok(s) => s,
238        Err(e) => {
239            warn!(error = %e, "Failed to bind upstream socket");
240            return build_servfail(&query);
241        }
242    };
243
244    if let Err(e) = upstream_socket
245        .send_to(&upstream_query_bytes, upstream)
246        .await
247    {
248        warn!(error = %e, "Failed to forward query to daemon");
249        return build_servfail(&query);
250    }
251
252    let mut resp_buf = vec![0u8; MAX_DNS_SIZE];
253    let resp_len =
254        match tokio::time::timeout(UPSTREAM_TIMEOUT, upstream_socket.recv(&mut resp_buf)).await {
255            Ok(Ok(len)) => len,
256            Ok(Err(e)) => {
257                warn!(error = %e, "Upstream recv error");
258                return build_servfail(&query);
259            }
260            Err(_) => {
261                warn!("Upstream DNS timeout");
262                return build_servfail(&query);
263            }
264        };
265
266    let upstream_response = match Packet::parse(&resp_buf[..resp_len]) {
267        Ok(p) => p,
268        Err(_) => return build_servfail(&query),
269    };
270
271    // If upstream returned NXDOMAIN or error, rebuild the response with the
272    // client's original question section (not the AAAA question we sent upstream).
273    if upstream_response.rcode() != RCODE::NoError {
274        debug!(
275            name = %fips_name,
276            rcode = ?upstream_response.rcode(),
277            "Upstream returned non-success"
278        );
279        let mut err_resp = Packet::new_reply(query.id());
280        err_resp.set_flags(PacketFlag::RESPONSE | PacketFlag::RECURSION_AVAILABLE);
281        *err_resp.rcode_mut() = upstream_response.rcode();
282        err_resp.questions.clone_from(&query.questions);
283        return err_resp.build_bytes_vec_compressed().ok();
284    }
285
286    // Extract the fd00:: mesh address from the AAAA response
287    let mesh_addr = match extract_aaaa(&upstream_response) {
288        Some(addr) => addr,
289        None => {
290            debug!(name = %fips_name, "No AAAA record in upstream response");
291            return build_servfail(&query);
292        }
293    };
294
295    // Derive NodeAddr from mesh address
296    let node_addr = node_addr_from_mesh(mesh_addr);
297
298    // Allocate virtual IP from pool
299    let mut pool_guard = pool.lock().await;
300    let (virtual_ip, is_new) = match pool_guard.allocate(node_addr, mesh_addr, &fips_name) {
301        Ok(result) => result,
302        Err(e) => {
303            warn!(error = %e, "Pool allocation failed");
304            return build_servfail(&query);
305        }
306    };
307    drop(pool_guard);
308
309    // Notify NAT module of new mapping
310    if is_new {
311        let event = PoolEvent::MappingCreated {
312            virtual_ip,
313            mesh_addr,
314        };
315        if let Err(e) = event_tx.send(event).await {
316            warn!(error = %e, "Failed to send pool event");
317        }
318    }
319
320    debug!(
321        name = %fips_name,
322        virtual_ip = %virtual_ip,
323        mesh_addr = %mesh_addr,
324        is_new,
325        "Resolved .fips query"
326    );
327
328    // Check what the client originally asked for.
329    // Only return an AAAA record if the client asked for AAAA (or ANY).
330    // For A queries, return an empty NOERROR — the client's resolver will
331    // use the AAAA answer from its parallel AAAA query instead.
332    let client_qtype = query
333        .questions
334        .first()
335        .map(|q| q.qtype)
336        .unwrap_or(QTYPE::TYPE(TYPE::AAAA));
337
338    match client_qtype {
339        QTYPE::TYPE(TYPE::AAAA) | QTYPE::ANY => build_aaaa_response(&query, virtual_ip, ttl),
340        // All other types (A, HTTPS, etc.): return NODATA — the name exists
341        // but has no records of the requested type.
342        _ => build_nodata(&query, ttl),
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn test_node_addr_from_mesh() {
352        // fd00::1 → node_addr bytes should be [0, 0, ..., 0, 1] in positions 0..15
353        let mesh: Ipv6Addr = "fd00::1".parse().unwrap();
354        let node = node_addr_from_mesh(mesh);
355        let bytes = node.as_bytes();
356        // mesh = [0xfd, 0, 0, ..., 0, 1]
357        // node = bytes[1..16] of mesh = [0, 0, ..., 0, 1] in first 15 bytes
358        assert_eq!(bytes[14], 1);
359        assert_eq!(bytes[0], 0);
360    }
361
362    #[test]
363    fn test_extract_fips_name() {
364        // Build a simple AAAA query for test.fips
365        let mut packet = Packet::new_query(1);
366        use simple_dns::{Name, Question};
367        let name = Name::new_unchecked("test.fips");
368        let question = Question::new(name, QTYPE::TYPE(TYPE::AAAA), CLASS::IN.into(), false);
369        packet.questions.push(question);
370
371        let result = extract_fips_name(&packet);
372        assert_eq!(result, Some("test.fips".to_string()));
373    }
374
375    #[test]
376    fn test_extract_non_fips_name() {
377        let mut packet = Packet::new_query(1);
378        use simple_dns::{Name, Question};
379        let name = Name::new_unchecked("example.com");
380        let question = Question::new(name, QTYPE::TYPE(TYPE::AAAA), CLASS::IN.into(), false);
381        packet.questions.push(question);
382
383        assert!(extract_fips_name(&packet).is_none());
384    }
385
386    #[test]
387    fn test_build_aaaa_response() {
388        let mut query = Packet::new_query(42);
389        use simple_dns::{Name, Question};
390        let name = Name::new_unchecked("test.fips");
391        let question = Question::new(name, QTYPE::TYPE(TYPE::AAAA), CLASS::IN.into(), false);
392        query.questions.push(question);
393
394        let vip: Ipv6Addr = "fd01::1".parse().unwrap();
395        let response_bytes = build_aaaa_response(&query, vip, 60).unwrap();
396        let response = Packet::parse(&response_bytes).unwrap();
397
398        assert_eq!(response.id(), 42);
399        assert_eq!(response.answers.len(), 1);
400        if let rdata::RData::AAAA(aaaa) = &response.answers[0].rdata {
401            assert_eq!(Ipv6Addr::from(aaaa.address), vip);
402        } else {
403            panic!("Expected AAAA record");
404        }
405    }
406}