Skip to main content

microsandbox_network/
udp_relay.rs

1//! Non-DNS UDP relay: handles UDP traffic outside smoltcp.
2//!
3//! smoltcp has no wildcard port binding, so non-DNS UDP is intercepted at
4//! the device level, relayed through host UDP sockets via tokio, and
5//! responses are injected back into `rx_ring` as constructed ethernet frames.
6
7use std::collections::HashMap;
8use std::net::{IpAddr, Ipv4Addr, SocketAddr};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12use bytes::Bytes;
13use smoltcp::wire::{
14    EthernetAddress, EthernetFrame, EthernetProtocol, EthernetRepr, IpProtocol, Ipv4Packet,
15    Ipv6Packet, UdpPacket,
16};
17use tokio::net::UdpSocket;
18use tokio::sync::mpsc;
19
20use crate::shared::SharedState;
21
22//--------------------------------------------------------------------------------------------------
23// Constants
24//--------------------------------------------------------------------------------------------------
25
26/// Session idle timeout.
27const SESSION_TIMEOUT: Duration = Duration::from_secs(60);
28
29/// Channel capacity for outbound datagrams to the relay task.
30const OUTBOUND_CHANNEL_CAPACITY: usize = 64;
31
32/// Buffer size for receiving responses from the real server.
33/// Sized to match the MTU (1500) plus generous headroom for
34/// reassembled datagrams while avoiding 64 KiB per session.
35const RECV_BUF_SIZE: usize = 4096;
36
37/// Ethernet header length.
38const ETH_HDR_LEN: usize = 14;
39
40/// IPv4 header length (no options).
41const IPV4_HDR_LEN: usize = 20;
42
43/// UDP header length.
44const UDP_HDR_LEN: usize = 8;
45
46//--------------------------------------------------------------------------------------------------
47// Types
48//--------------------------------------------------------------------------------------------------
49
50/// Relays non-DNS UDP traffic between the guest and the real network.
51///
52/// Each unique `(guest_src, guest_dst)` pair gets a host-side UDP socket
53/// and a tokio relay task. The poll loop calls [`relay_outbound()`] to
54/// send guest datagrams; response frames are injected directly into
55/// `rx_ring`.
56///
57/// [`relay_outbound()`]: UdpRelay::relay_outbound
58pub struct UdpRelay {
59    shared: Arc<SharedState>,
60    sessions: HashMap<(SocketAddr, SocketAddr), UdpSession>,
61    gateway_mac: EthernetAddress,
62    guest_mac: EthernetAddress,
63    tokio_handle: tokio::runtime::Handle,
64}
65
66/// A single UDP relay session.
67struct UdpSession {
68    /// Channel to send outbound datagrams to the relay task.
69    outbound_tx: mpsc::Sender<Bytes>,
70    /// Last time this session was used.
71    last_active: Instant,
72}
73
74//--------------------------------------------------------------------------------------------------
75// Methods
76//--------------------------------------------------------------------------------------------------
77
78impl UdpRelay {
79    /// Create a new UDP relay.
80    pub fn new(
81        shared: Arc<SharedState>,
82        gateway_mac: [u8; 6],
83        guest_mac: [u8; 6],
84        tokio_handle: tokio::runtime::Handle,
85    ) -> Self {
86        Self {
87            shared,
88            sessions: HashMap::new(),
89            gateway_mac: EthernetAddress(gateway_mac),
90            guest_mac: EthernetAddress(guest_mac),
91            tokio_handle,
92        }
93    }
94
95    /// Relay an outbound UDP datagram from the guest.
96    ///
97    /// Extracts the UDP payload from the raw ethernet frame, looks up or
98    /// creates a session, and sends the payload to the relay task.
99    pub fn relay_outbound(&mut self, frame: &[u8], src: SocketAddr, dst: SocketAddr) {
100        // Extract UDP payload from the ethernet frame.
101        let Some(payload) = extract_udp_payload(frame) else {
102            return;
103        };
104
105        let key = (src, dst);
106
107        // Create session if it doesn't exist or has expired.
108        if self
109            .sessions
110            .get(&key)
111            .is_none_or(|s| s.last_active.elapsed() > SESSION_TIMEOUT)
112        {
113            self.sessions.remove(&key);
114            if let Some(session) = self.create_session(src, dst) {
115                self.sessions.insert(key, session);
116            } else {
117                return;
118            }
119        }
120
121        if let Some(session) = self.sessions.get_mut(&key) {
122            session.last_active = Instant::now();
123            let _ = session
124                .outbound_tx
125                .try_send(Bytes::copy_from_slice(payload));
126        }
127    }
128
129    /// Remove expired sessions.
130    pub fn cleanup_expired(&mut self) {
131        self.sessions
132            .retain(|_, session| session.last_active.elapsed() <= SESSION_TIMEOUT);
133    }
134}
135
136impl UdpRelay {
137    /// Create a new relay session: bind a host UDP socket and spawn a task.
138    fn create_session(&self, guest_src: SocketAddr, guest_dst: SocketAddr) -> Option<UdpSession> {
139        let (outbound_tx, outbound_rx) = mpsc::channel(OUTBOUND_CHANNEL_CAPACITY);
140
141        let shared = self.shared.clone();
142        let gateway_mac = self.gateway_mac;
143        let guest_mac = self.guest_mac;
144
145        self.tokio_handle.spawn(async move {
146            if let Err(e) = udp_relay_task(
147                outbound_rx,
148                guest_src,
149                guest_dst,
150                shared,
151                gateway_mac,
152                guest_mac,
153            )
154            .await
155            {
156                tracing::debug!(
157                    guest_src = %guest_src,
158                    guest_dst = %guest_dst,
159                    error = %e,
160                    "UDP relay task ended",
161                );
162            }
163        });
164
165        Some(UdpSession {
166            outbound_tx,
167            last_active: Instant::now(),
168        })
169    }
170}
171
172//--------------------------------------------------------------------------------------------------
173// Functions
174//--------------------------------------------------------------------------------------------------
175
176/// Async task that relays UDP between a host socket and the guest.
177async fn udp_relay_task(
178    mut outbound_rx: mpsc::Receiver<Bytes>,
179    guest_src: SocketAddr,
180    guest_dst: SocketAddr,
181    shared: Arc<SharedState>,
182    gateway_mac: EthernetAddress,
183    guest_mac: EthernetAddress,
184) -> std::io::Result<()> {
185    // Bind a host UDP socket. Use the same address family as the destination.
186    let bind_addr: SocketAddr = match guest_dst {
187        SocketAddr::V4(_) => (Ipv4Addr::UNSPECIFIED, 0u16).into(),
188        SocketAddr::V6(_) => (std::net::Ipv6Addr::UNSPECIFIED, 0u16).into(),
189    };
190    let socket = UdpSocket::bind(bind_addr).await?;
191    // Connect to the destination to restrict accepted source addresses,
192    // preventing host-network entities from injecting spoofed datagrams.
193    socket.connect(guest_dst).await?;
194
195    let mut recv_buf = vec![0u8; RECV_BUF_SIZE];
196    let timeout = SESSION_TIMEOUT;
197
198    loop {
199        tokio::select! {
200            // Outbound: guest → server.
201            data = outbound_rx.recv() => {
202                match data {
203                    Some(payload) => {
204                        let _ = socket.send(&payload).await;
205                    }
206                    // Channel closed — session dropped by poll loop.
207                    None => break,
208                }
209            }
210
211            // Inbound: server → guest (only from the connected destination).
212            result = socket.recv(&mut recv_buf) => {
213                match result {
214                    Ok(n) => {
215                        if let Some(frame) = construct_udp_response(
216                            guest_dst,
217                            guest_src,
218                            &recv_buf[..n],
219                            gateway_mac,
220                            guest_mac,
221                        ) {
222                            let _ = shared.rx_ring.push(frame);
223                            shared.rx_wake.wake();
224                        }
225                    }
226                    Err(e) => {
227                        tracing::debug!(error = %e, "UDP relay recv failed");
228                        break;
229                    }
230                }
231            }
232
233            // Idle timeout.
234            () = tokio::time::sleep(timeout) => {
235                break;
236            }
237        }
238    }
239
240    Ok(())
241}
242
243/// Construct an ethernet frame containing a UDP response for the guest.
244///
245/// Builds Ethernet + IPv4/IPv6 + UDP headers using smoltcp's wire module.
246fn construct_udp_response(
247    src: SocketAddr,
248    dst: SocketAddr,
249    payload: &[u8],
250    gateway_mac: EthernetAddress,
251    guest_mac: EthernetAddress,
252) -> Option<Vec<u8>> {
253    match (src.ip(), dst.ip()) {
254        (IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => Some(construct_udp_response_v4(
255            src_ip,
256            src.port(),
257            dst_ip,
258            dst.port(),
259            payload,
260            gateway_mac,
261            guest_mac,
262        )),
263        (IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => Some(construct_udp_response_v6(
264            src_ip,
265            src.port(),
266            dst_ip,
267            dst.port(),
268            payload,
269            gateway_mac,
270            guest_mac,
271        )),
272        _ => None, // Mismatched address families — shouldn't happen.
273    }
274}
275
276/// Construct an Ethernet + IPv4 + UDP frame.
277fn construct_udp_response_v4(
278    src_ip: Ipv4Addr,
279    src_port: u16,
280    dst_ip: Ipv4Addr,
281    dst_port: u16,
282    payload: &[u8],
283    gateway_mac: EthernetAddress,
284    guest_mac: EthernetAddress,
285) -> Vec<u8> {
286    let udp_len = UDP_HDR_LEN + payload.len();
287    let ip_total_len = IPV4_HDR_LEN + udp_len;
288    let frame_len = ETH_HDR_LEN + ip_total_len;
289    let mut buf = vec![0u8; frame_len];
290
291    // Ethernet header.
292    let eth_repr = EthernetRepr {
293        src_addr: gateway_mac,
294        dst_addr: guest_mac,
295        ethertype: EthernetProtocol::Ipv4,
296    };
297    let mut eth_frame = EthernetFrame::new_unchecked(&mut buf);
298    eth_repr.emit(&mut eth_frame);
299
300    // IPv4 header.
301    let ip_buf = &mut buf[ETH_HDR_LEN..];
302    let mut ip_pkt = Ipv4Packet::new_unchecked(ip_buf);
303    ip_pkt.set_version(4);
304    ip_pkt.set_header_len(20);
305    ip_pkt.set_total_len(ip_total_len as u16);
306    ip_pkt.clear_flags();
307    ip_pkt.set_dont_frag(true);
308    ip_pkt.set_hop_limit(64);
309    ip_pkt.set_next_header(IpProtocol::Udp);
310    ip_pkt.set_src_addr(src_ip);
311    ip_pkt.set_dst_addr(dst_ip);
312    ip_pkt.fill_checksum();
313
314    // UDP header + payload.
315    let udp_buf = &mut buf[ETH_HDR_LEN + IPV4_HDR_LEN..];
316    let mut udp_pkt = UdpPacket::new_unchecked(udp_buf);
317    udp_pkt.set_src_port(src_port);
318    udp_pkt.set_dst_port(dst_port);
319    udp_pkt.set_len(udp_len as u16);
320    udp_pkt.set_checksum(0); // Optional for UDP over IPv4.
321    udp_pkt.payload_mut()[..payload.len()].copy_from_slice(payload);
322
323    buf
324}
325
326/// Construct an Ethernet + IPv6 + UDP frame.
327fn construct_udp_response_v6(
328    src_ip: std::net::Ipv6Addr,
329    src_port: u16,
330    dst_ip: std::net::Ipv6Addr,
331    dst_port: u16,
332    payload: &[u8],
333    gateway_mac: EthernetAddress,
334    guest_mac: EthernetAddress,
335) -> Vec<u8> {
336    let udp_len = UDP_HDR_LEN + payload.len();
337    let ipv6_hdr_len = 40;
338    let frame_len = ETH_HDR_LEN + ipv6_hdr_len + udp_len;
339    let mut buf = vec![0u8; frame_len];
340
341    // Ethernet header.
342    let eth_repr = EthernetRepr {
343        src_addr: gateway_mac,
344        dst_addr: guest_mac,
345        ethertype: EthernetProtocol::Ipv6,
346    };
347    let mut eth_frame = EthernetFrame::new_unchecked(&mut buf);
348    eth_repr.emit(&mut eth_frame);
349
350    // IPv6 header.
351    let ip_buf = &mut buf[ETH_HDR_LEN..];
352    let mut ip_pkt = Ipv6Packet::new_unchecked(ip_buf);
353    ip_pkt.set_version(6);
354    ip_pkt.set_payload_len(udp_len as u16);
355    ip_pkt.set_next_header(IpProtocol::Udp);
356    ip_pkt.set_hop_limit(64);
357    ip_pkt.set_src_addr(src_ip);
358    ip_pkt.set_dst_addr(dst_ip);
359
360    // UDP header + payload.
361    let udp_buf = &mut buf[ETH_HDR_LEN + ipv6_hdr_len..];
362    let mut udp_pkt = UdpPacket::new_unchecked(udp_buf);
363    udp_pkt.set_src_port(src_port);
364    udp_pkt.set_dst_port(dst_port);
365    udp_pkt.set_len(udp_len as u16);
366    // Copy payload BEFORE computing checksum — fill_checksum reads the
367    // payload bytes, so they must be in place first.
368    udp_pkt.payload_mut()[..payload.len()].copy_from_slice(payload);
369    // IPv6 UDP checksum is mandatory per RFC 8200 section 8.1.
370    // A zero checksum causes the receiver to discard the packet.
371    udp_pkt.fill_checksum(
372        &smoltcp::wire::IpAddress::from(src_ip),
373        &smoltcp::wire::IpAddress::from(dst_ip),
374    );
375
376    buf
377}
378
379/// Extract the UDP payload from a raw ethernet frame.
380fn extract_udp_payload(frame: &[u8]) -> Option<&[u8]> {
381    let eth = EthernetFrame::new_checked(frame).ok()?;
382    match eth.ethertype() {
383        EthernetProtocol::Ipv4 => {
384            let ipv4 = Ipv4Packet::new_checked(eth.payload()).ok()?;
385            let udp = UdpPacket::new_checked(ipv4.payload()).ok()?;
386            Some(udp.payload())
387        }
388        EthernetProtocol::Ipv6 => {
389            let ipv6 = Ipv6Packet::new_checked(eth.payload()).ok()?;
390            let udp = UdpPacket::new_checked(ipv6.payload()).ok()?;
391            Some(udp.payload())
392        }
393        _ => None,
394    }
395}
396
397//--------------------------------------------------------------------------------------------------
398// Tests
399//--------------------------------------------------------------------------------------------------
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn construct_v4_response_has_correct_structure() {
407        let payload = b"hello";
408        let frame = construct_udp_response_v4(
409            Ipv4Addr::new(8, 8, 8, 8),
410            53,
411            Ipv4Addr::new(100, 96, 0, 2),
412            12345,
413            payload,
414            EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]),
415            EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]),
416        );
417
418        assert_eq!(frame.len(), ETH_HDR_LEN + IPV4_HDR_LEN + UDP_HDR_LEN + 5);
419
420        // Parse back.
421        let eth = EthernetFrame::new_checked(&frame).unwrap();
422        assert_eq!(eth.ethertype(), EthernetProtocol::Ipv4);
423        assert_eq!(
424            eth.dst_addr(),
425            EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02])
426        );
427
428        let ipv4 = Ipv4Packet::new_checked(eth.payload()).unwrap();
429        assert_eq!(Ipv4Addr::from(ipv4.src_addr()), Ipv4Addr::new(8, 8, 8, 8));
430        assert_eq!(
431            Ipv4Addr::from(ipv4.dst_addr()),
432            Ipv4Addr::new(100, 96, 0, 2)
433        );
434        assert_eq!(ipv4.next_header(), IpProtocol::Udp);
435
436        let udp = UdpPacket::new_checked(ipv4.payload()).unwrap();
437        assert_eq!(udp.src_port(), 53);
438        assert_eq!(udp.dst_port(), 12345);
439        assert_eq!(udp.payload(), b"hello");
440    }
441
442    #[test]
443    fn construct_v6_response_has_correct_structure() {
444        let payload = b"hello ipv6";
445        let src = "2001:db8::1".parse::<std::net::Ipv6Addr>().unwrap();
446        let dst = "fd42:6d73:62::2".parse::<std::net::Ipv6Addr>().unwrap();
447        let frame = construct_udp_response_v6(
448            src,
449            53,
450            dst,
451            12345,
452            payload,
453            EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]),
454            EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]),
455        );
456
457        let ipv6_hdr_len = 40;
458        assert_eq!(
459            frame.len(),
460            ETH_HDR_LEN + ipv6_hdr_len + UDP_HDR_LEN + payload.len()
461        );
462
463        // Parse back.
464        let eth = EthernetFrame::new_checked(&frame).unwrap();
465        assert_eq!(eth.ethertype(), EthernetProtocol::Ipv6);
466
467        let ipv6 = Ipv6Packet::new_checked(eth.payload()).unwrap();
468        assert_eq!(ipv6.next_header(), IpProtocol::Udp);
469
470        let udp = UdpPacket::new_checked(ipv6.payload()).unwrap();
471        assert_eq!(udp.src_port(), 53);
472        assert_eq!(udp.dst_port(), 12345);
473        assert_eq!(udp.payload(), b"hello ipv6");
474        // Verify checksum is non-zero (mandatory for IPv6 UDP per RFC 8200).
475        assert_ne!(udp.checksum(), 0, "IPv6 UDP checksum must not be zero");
476        // Verify checksum is correct.
477        assert!(
478            udp.verify_checksum(
479                &smoltcp::wire::IpAddress::from(src),
480                &smoltcp::wire::IpAddress::from(dst),
481            ),
482            "IPv6 UDP checksum must be valid"
483        );
484    }
485
486    #[test]
487    fn extract_payload_from_v6_udp_frame() {
488        let src = "2001:db8::1".parse::<std::net::Ipv6Addr>().unwrap();
489        let dst = "fd42:6d73:62::2".parse::<std::net::Ipv6Addr>().unwrap();
490        let frame = construct_udp_response_v6(
491            src,
492            80,
493            dst,
494            54321,
495            b"v6 data",
496            EthernetAddress([0; 6]),
497            EthernetAddress([0; 6]),
498        );
499        let payload = extract_udp_payload(&frame).unwrap();
500        assert_eq!(payload, b"v6 data");
501    }
502
503    #[test]
504    fn extract_payload_from_v4_udp_frame() {
505        // Build a frame then extract the payload.
506        let frame = construct_udp_response_v4(
507            Ipv4Addr::new(1, 2, 3, 4),
508            80,
509            Ipv4Addr::new(10, 0, 0, 2),
510            54321,
511            b"test data",
512            EthernetAddress([0; 6]),
513            EthernetAddress([0; 6]),
514        );
515        let payload = extract_udp_payload(&frame).unwrap();
516        assert_eq!(payload, b"test data");
517    }
518}