Skip to main content

microsandbox_network/
udp_relay.rs

1//! Non-DNS UDP relay: handles UDP traffic outside smoltcp.
2//!
3//! smoltcp has no wildcard port binding, so non-DNS UDP is intercepted at
4//! the device level, relayed through host UDP sockets via tokio, and
5//! responses are injected back into `rx_ring` as constructed ethernet frames.
6
7use std::collections::HashMap;
8use std::net::{IpAddr, Ipv4Addr, SocketAddr};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12use bytes::Bytes;
13use smoltcp::wire::{
14    EthernetAddress, EthernetFrame, EthernetProtocol, EthernetRepr, IpProtocol, Ipv4Packet,
15    Ipv6Packet, UdpPacket,
16};
17use tokio::net::UdpSocket;
18use tokio::sync::mpsc;
19
20use crate::shared::SharedState;
21
22//--------------------------------------------------------------------------------------------------
23// Constants
24//--------------------------------------------------------------------------------------------------
25
26/// Session idle timeout.
27const SESSION_TIMEOUT: Duration = Duration::from_secs(60);
28
29/// Channel capacity for outbound datagrams to the relay task.
30const OUTBOUND_CHANNEL_CAPACITY: usize = 64;
31
32/// Buffer size for receiving responses from the real server.
33/// Sized to match the MTU (1500) plus generous headroom for
34/// reassembled datagrams while avoiding 64 KiB per session.
35const RECV_BUF_SIZE: usize = 4096;
36
37/// Ethernet header length.
38const ETH_HDR_LEN: usize = 14;
39
40/// IPv4 header length (no options).
41const IPV4_HDR_LEN: usize = 20;
42
43/// UDP header length.
44const UDP_HDR_LEN: usize = 8;
45
46//--------------------------------------------------------------------------------------------------
47// Types
48//--------------------------------------------------------------------------------------------------
49
50/// Relays non-DNS UDP traffic between the guest and the real network.
51///
52/// Each unique `(guest_src, guest_dst)` pair gets a host-side UDP socket
53/// and a tokio relay task. The poll loop calls [`relay_outbound()`] to
54/// send guest datagrams; response frames are injected directly into
55/// `rx_ring`.
56///
57/// [`relay_outbound()`]: UdpRelay::relay_outbound
58pub struct UdpRelay {
59    shared: Arc<SharedState>,
60    sessions: HashMap<(SocketAddr, SocketAddr), UdpSession>,
61    gateway_mac: EthernetAddress,
62    guest_mac: EthernetAddress,
63    tokio_handle: tokio::runtime::Handle,
64}
65
66/// A single UDP relay session.
67struct UdpSession {
68    /// Channel to send outbound datagrams to the relay task.
69    outbound_tx: mpsc::Sender<Bytes>,
70    /// Last time this session was used.
71    last_active: Instant,
72}
73
74//--------------------------------------------------------------------------------------------------
75// Methods
76//--------------------------------------------------------------------------------------------------
77
78impl UdpRelay {
79    /// Build a new UDP relay.
80    ///
81    /// # Arguments
82    ///
83    /// * `shared` - Stack-wide shared state used to inject response frames into `rx_ring`
84    ///   and wake the poll thread.
85    /// * `gateway_mac` - MAC address stamped as the source on synthesized response frames.
86    /// * `guest_mac` - MAC address stamped as the destination on synthesized response frames.
87    /// * `tokio_handle` - Runtime the per-session relay tasks are spawned on.
88    pub fn new(
89        shared: Arc<SharedState>,
90        gateway_mac: [u8; 6],
91        guest_mac: [u8; 6],
92        tokio_handle: tokio::runtime::Handle,
93    ) -> Self {
94        Self {
95            shared,
96            sessions: HashMap::new(),
97            gateway_mac: EthernetAddress(gateway_mac),
98            guest_mac: EthernetAddress(guest_mac),
99            tokio_handle,
100        }
101    }
102
103    /// Relay an outbound UDP datagram from the guest.
104    ///
105    /// # Arguments
106    ///
107    /// * `frame` - Raw ethernet frame captured from the guest.
108    /// * `src` - Guest source address; keys the session and becomes the destination on
109    ///   response frames.
110    /// * `guest_dst` - Destination the guest wrote on the datagram. Retained as the session
111    ///   key and the source IP on replies.
112    /// * `host_dst` - Address the host socket actually connects to. Usually equal to
113    ///   `guest_dst`; the caller substitutes loopback when `guest_dst` matches the gateway IP.
114    pub fn relay_outbound(
115        &mut self,
116        frame: &[u8],
117        src: SocketAddr,
118        guest_dst: SocketAddr,
119        host_dst: SocketAddr,
120    ) {
121        // Extract UDP payload from the ethernet frame.
122        let Some(payload) = extract_udp_payload(frame) else {
123            return;
124        };
125
126        let key = (src, guest_dst);
127
128        // Create session if it doesn't exist or has expired.
129        if self
130            .sessions
131            .get(&key)
132            .is_none_or(|s| s.last_active.elapsed() > SESSION_TIMEOUT)
133        {
134            self.sessions.remove(&key);
135            if let Some(session) = self.create_session(src, guest_dst, host_dst) {
136                self.sessions.insert(key, session);
137            } else {
138                return;
139            }
140        }
141
142        if let Some(session) = self.sessions.get_mut(&key) {
143            session.last_active = Instant::now();
144            let _ = session
145                .outbound_tx
146                .try_send(Bytes::copy_from_slice(payload));
147        }
148    }
149
150    /// Remove expired sessions.
151    pub fn cleanup_expired(&mut self) {
152        self.sessions
153            .retain(|_, session| session.last_active.elapsed() <= SESSION_TIMEOUT);
154    }
155}
156
157impl UdpRelay {
158    /// Create a new relay session: bind a host UDP socket and spawn a task.
159    fn create_session(
160        &self,
161        guest_src: SocketAddr,
162        guest_dst: SocketAddr,
163        host_dst: SocketAddr,
164    ) -> Option<UdpSession> {
165        let (outbound_tx, outbound_rx) = mpsc::channel(OUTBOUND_CHANNEL_CAPACITY);
166
167        let shared = self.shared.clone();
168        let gateway_mac = self.gateway_mac;
169        let guest_mac = self.guest_mac;
170
171        self.tokio_handle.spawn(async move {
172            if let Err(e) = udp_relay_task(
173                outbound_rx,
174                guest_src,
175                guest_dst,
176                host_dst,
177                shared,
178                gateway_mac,
179                guest_mac,
180            )
181            .await
182            {
183                tracing::debug!(
184                    guest_src = %guest_src,
185                    guest_dst = %guest_dst,
186                    error = %e,
187                    "UDP relay task ended",
188                );
189            }
190        });
191
192        Some(UdpSession {
193            outbound_tx,
194            last_active: Instant::now(),
195        })
196    }
197}
198
199//--------------------------------------------------------------------------------------------------
200// Functions
201//--------------------------------------------------------------------------------------------------
202
203/// Per-session UDP relay loop: forwards guest datagrams to a host socket, stamps the replies
204/// back into frames the guest accepts, and exits on idle timeout or channel close.
205///
206/// Binds an ephemeral host UDP socket in the address family of `host_dst` and `connect()`s it
207/// to that peer. The `connect` restricts the socket to that peer's datagrams, which both sets
208/// the default send target and filters spoofed inbound traffic. Responses are wrapped in a
209/// synthesised ethernet frame (src IP = `guest_dst`, dst = `guest_src`) and pushed into
210/// `rx_ring`.
211///
212/// # Arguments
213///
214/// * `outbound_rx` - Receives UDP payloads from the poll-loop side. Channel close signals
215///   session drop.
216/// * `guest_src` - Guest source address; stamped as the destination on reply frames.
217/// * `guest_dst` - Destination the guest wrote on the datagram. Stamped as the source IP on
218///   reply frames so the guest sees replies from the same address it dialed.
219/// * `host_dst` - Address the host socket connects to. Equal to `guest_dst` for external
220///   destinations; rewritten to loopback by [`crate::stack::resolve_host_dst`] when the guest
221///   addressed the gateway.
222/// * `shared` - Shared state; reply frames go into `rx_ring` and wake the poll thread.
223/// * `gateway_mac` - Source MAC on reply frames (guest sees replies from the gateway's MAC).
224/// * `guest_mac` - Destination MAC on reply frames.
225///
226/// # Errors
227///
228/// Returns [`std::io::Error`] when the initial `bind` or `connect` on
229/// the host UDP socket fails, or when the host-side `recv` fails after
230/// the socket was established.
231#[allow(clippy::too_many_arguments)]
232async fn udp_relay_task(
233    mut outbound_rx: mpsc::Receiver<Bytes>,
234    guest_src: SocketAddr,
235    guest_dst: SocketAddr,
236    host_dst: SocketAddr,
237    shared: Arc<SharedState>,
238    gateway_mac: EthernetAddress,
239    guest_mac: EthernetAddress,
240) -> std::io::Result<()> {
241    // Bind a host UDP socket. Use the same address family as the host destination.
242    let bind_addr: SocketAddr = match host_dst {
243        SocketAddr::V4(_) => (Ipv4Addr::UNSPECIFIED, 0u16).into(),
244        SocketAddr::V6(_) => (std::net::Ipv6Addr::UNSPECIFIED, 0u16).into(),
245    };
246    let socket = UdpSocket::bind(bind_addr).await?;
247    // Connect to the destination to restrict accepted source addresses,
248    // preventing host-network entities from injecting spoofed datagrams.
249    socket.connect(host_dst).await?;
250
251    let mut recv_buf = vec![0u8; RECV_BUF_SIZE];
252    let timeout = SESSION_TIMEOUT;
253
254    loop {
255        tokio::select! {
256            // Outbound: guest → server.
257            data = outbound_rx.recv() => {
258                match data {
259                    Some(payload) => {
260                        let _ = socket.send(&payload).await;
261                    }
262                    // Channel closed — session dropped by poll loop.
263                    None => break,
264                }
265            }
266
267            // Inbound: server → guest (only from the connected destination).
268            result = socket.recv(&mut recv_buf) => {
269                match result {
270                    Ok(n) => {
271                        if let Some(frame) = construct_udp_response(
272                            guest_dst,
273                            guest_src,
274                            &recv_buf[..n],
275                            gateway_mac,
276                            guest_mac,
277                        ) && !shared.push_rx_frame_and_wake(frame) {
278                            tracing::debug!("UDP relay response dropped because rx_ring is full");
279                        }
280                    }
281                    Err(e) => {
282                        tracing::debug!(error = %e, "UDP relay recv failed");
283                        break;
284                    }
285                }
286            }
287
288            // Idle timeout.
289            () = tokio::time::sleep(timeout) => {
290                break;
291            }
292        }
293    }
294
295    Ok(())
296}
297
298/// Construct an ethernet frame containing a UDP response for the guest.
299///
300/// Builds Ethernet + IPv4/IPv6 + UDP headers using smoltcp's wire module.
301pub(crate) fn construct_udp_response(
302    src: SocketAddr,
303    dst: SocketAddr,
304    payload: &[u8],
305    gateway_mac: EthernetAddress,
306    guest_mac: EthernetAddress,
307) -> Option<Vec<u8>> {
308    match (src.ip(), dst.ip()) {
309        (IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => Some(construct_udp_response_v4(
310            src_ip,
311            src.port(),
312            dst_ip,
313            dst.port(),
314            payload,
315            gateway_mac,
316            guest_mac,
317        )),
318        (IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => Some(construct_udp_response_v6(
319            src_ip,
320            src.port(),
321            dst_ip,
322            dst.port(),
323            payload,
324            gateway_mac,
325            guest_mac,
326        )),
327        _ => None, // Mismatched address families — shouldn't happen.
328    }
329}
330
331/// Construct an Ethernet + IPv4 + UDP frame.
332fn construct_udp_response_v4(
333    src_ip: Ipv4Addr,
334    src_port: u16,
335    dst_ip: Ipv4Addr,
336    dst_port: u16,
337    payload: &[u8],
338    gateway_mac: EthernetAddress,
339    guest_mac: EthernetAddress,
340) -> Vec<u8> {
341    let udp_len = UDP_HDR_LEN + payload.len();
342    let ip_total_len = IPV4_HDR_LEN + udp_len;
343    let frame_len = ETH_HDR_LEN + ip_total_len;
344    let mut buf = vec![0u8; frame_len];
345
346    // Ethernet header.
347    let eth_repr = EthernetRepr {
348        src_addr: gateway_mac,
349        dst_addr: guest_mac,
350        ethertype: EthernetProtocol::Ipv4,
351    };
352    let mut eth_frame = EthernetFrame::new_unchecked(&mut buf);
353    eth_repr.emit(&mut eth_frame);
354
355    // IPv4 header.
356    let ip_buf = &mut buf[ETH_HDR_LEN..];
357    let mut ip_pkt = Ipv4Packet::new_unchecked(ip_buf);
358    ip_pkt.set_version(4);
359    ip_pkt.set_header_len(20);
360    ip_pkt.set_total_len(ip_total_len as u16);
361    ip_pkt.clear_flags();
362    ip_pkt.set_dont_frag(true);
363    ip_pkt.set_hop_limit(64);
364    ip_pkt.set_next_header(IpProtocol::Udp);
365    ip_pkt.set_src_addr(src_ip);
366    ip_pkt.set_dst_addr(dst_ip);
367    ip_pkt.fill_checksum();
368
369    // UDP header + payload.
370    let udp_buf = &mut buf[ETH_HDR_LEN + IPV4_HDR_LEN..];
371    let mut udp_pkt = UdpPacket::new_unchecked(udp_buf);
372    udp_pkt.set_src_port(src_port);
373    udp_pkt.set_dst_port(dst_port);
374    udp_pkt.set_len(udp_len as u16);
375    udp_pkt.set_checksum(0); // Optional for UDP over IPv4.
376    udp_pkt.payload_mut()[..payload.len()].copy_from_slice(payload);
377
378    buf
379}
380
381/// Construct an Ethernet + IPv6 + UDP frame.
382fn construct_udp_response_v6(
383    src_ip: std::net::Ipv6Addr,
384    src_port: u16,
385    dst_ip: std::net::Ipv6Addr,
386    dst_port: u16,
387    payload: &[u8],
388    gateway_mac: EthernetAddress,
389    guest_mac: EthernetAddress,
390) -> Vec<u8> {
391    let udp_len = UDP_HDR_LEN + payload.len();
392    let ipv6_hdr_len = 40;
393    let frame_len = ETH_HDR_LEN + ipv6_hdr_len + udp_len;
394    let mut buf = vec![0u8; frame_len];
395
396    // Ethernet header.
397    let eth_repr = EthernetRepr {
398        src_addr: gateway_mac,
399        dst_addr: guest_mac,
400        ethertype: EthernetProtocol::Ipv6,
401    };
402    let mut eth_frame = EthernetFrame::new_unchecked(&mut buf);
403    eth_repr.emit(&mut eth_frame);
404
405    // IPv6 header.
406    let ip_buf = &mut buf[ETH_HDR_LEN..];
407    let mut ip_pkt = Ipv6Packet::new_unchecked(ip_buf);
408    ip_pkt.set_version(6);
409    ip_pkt.set_payload_len(udp_len as u16);
410    ip_pkt.set_next_header(IpProtocol::Udp);
411    ip_pkt.set_hop_limit(64);
412    ip_pkt.set_src_addr(src_ip);
413    ip_pkt.set_dst_addr(dst_ip);
414
415    // UDP header + payload.
416    let udp_buf = &mut buf[ETH_HDR_LEN + ipv6_hdr_len..];
417    let mut udp_pkt = UdpPacket::new_unchecked(udp_buf);
418    udp_pkt.set_src_port(src_port);
419    udp_pkt.set_dst_port(dst_port);
420    udp_pkt.set_len(udp_len as u16);
421    // Copy payload BEFORE computing checksum — fill_checksum reads the
422    // payload bytes, so they must be in place first.
423    udp_pkt.payload_mut()[..payload.len()].copy_from_slice(payload);
424    // IPv6 UDP checksum is mandatory per RFC 8200 section 8.1.
425    // A zero checksum causes the receiver to discard the packet.
426    udp_pkt.fill_checksum(
427        &smoltcp::wire::IpAddress::from(src_ip),
428        &smoltcp::wire::IpAddress::from(dst_ip),
429    );
430
431    buf
432}
433
434/// Extract the UDP payload from a raw ethernet frame.
435pub(crate) fn extract_udp_payload(frame: &[u8]) -> Option<&[u8]> {
436    let eth = EthernetFrame::new_checked(frame).ok()?;
437    match eth.ethertype() {
438        EthernetProtocol::Ipv4 => {
439            let ipv4 = Ipv4Packet::new_checked(eth.payload()).ok()?;
440            let udp = UdpPacket::new_checked(ipv4.payload()).ok()?;
441            Some(udp.payload())
442        }
443        EthernetProtocol::Ipv6 => {
444            let ipv6 = Ipv6Packet::new_checked(eth.payload()).ok()?;
445            let udp = UdpPacket::new_checked(ipv6.payload()).ok()?;
446            Some(udp.payload())
447        }
448        _ => None,
449    }
450}
451
452//--------------------------------------------------------------------------------------------------
453// Tests
454//--------------------------------------------------------------------------------------------------
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    #[test]
461    fn construct_v4_response_has_correct_structure() {
462        let payload = b"hello";
463        let frame = construct_udp_response_v4(
464            Ipv4Addr::new(8, 8, 8, 8),
465            53,
466            Ipv4Addr::new(100, 96, 0, 2),
467            12345,
468            payload,
469            EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]),
470            EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]),
471        );
472
473        assert_eq!(frame.len(), ETH_HDR_LEN + IPV4_HDR_LEN + UDP_HDR_LEN + 5);
474
475        // Parse back.
476        let eth = EthernetFrame::new_checked(&frame).unwrap();
477        assert_eq!(eth.ethertype(), EthernetProtocol::Ipv4);
478        assert_eq!(
479            eth.dst_addr(),
480            EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02])
481        );
482
483        let ipv4 = Ipv4Packet::new_checked(eth.payload()).unwrap();
484        assert_eq!(ipv4.src_addr(), Ipv4Addr::new(8, 8, 8, 8));
485        assert_eq!(ipv4.dst_addr(), Ipv4Addr::new(100, 96, 0, 2));
486        assert_eq!(ipv4.next_header(), IpProtocol::Udp);
487
488        let udp = UdpPacket::new_checked(ipv4.payload()).unwrap();
489        assert_eq!(udp.src_port(), 53);
490        assert_eq!(udp.dst_port(), 12345);
491        assert_eq!(udp.payload(), b"hello");
492    }
493
494    #[test]
495    fn construct_v6_response_has_correct_structure() {
496        let payload = b"hello ipv6";
497        let src = "2001:db8::1".parse::<std::net::Ipv6Addr>().unwrap();
498        let dst = "fd42:6d73:62::2".parse::<std::net::Ipv6Addr>().unwrap();
499        let frame = construct_udp_response_v6(
500            src,
501            53,
502            dst,
503            12345,
504            payload,
505            EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]),
506            EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]),
507        );
508
509        let ipv6_hdr_len = 40;
510        assert_eq!(
511            frame.len(),
512            ETH_HDR_LEN + ipv6_hdr_len + UDP_HDR_LEN + payload.len()
513        );
514
515        // Parse back.
516        let eth = EthernetFrame::new_checked(&frame).unwrap();
517        assert_eq!(eth.ethertype(), EthernetProtocol::Ipv6);
518
519        let ipv6 = Ipv6Packet::new_checked(eth.payload()).unwrap();
520        assert_eq!(ipv6.next_header(), IpProtocol::Udp);
521
522        let udp = UdpPacket::new_checked(ipv6.payload()).unwrap();
523        assert_eq!(udp.src_port(), 53);
524        assert_eq!(udp.dst_port(), 12345);
525        assert_eq!(udp.payload(), b"hello ipv6");
526        // Verify checksum is non-zero (mandatory for IPv6 UDP per RFC 8200).
527        assert_ne!(udp.checksum(), 0, "IPv6 UDP checksum must not be zero");
528        // Verify checksum is correct.
529        assert!(
530            udp.verify_checksum(
531                &smoltcp::wire::IpAddress::from(src),
532                &smoltcp::wire::IpAddress::from(dst),
533            ),
534            "IPv6 UDP checksum must be valid"
535        );
536    }
537
538    #[test]
539    fn extract_payload_from_v6_udp_frame() {
540        let src = "2001:db8::1".parse::<std::net::Ipv6Addr>().unwrap();
541        let dst = "fd42:6d73:62::2".parse::<std::net::Ipv6Addr>().unwrap();
542        let frame = construct_udp_response_v6(
543            src,
544            80,
545            dst,
546            54321,
547            b"v6 data",
548            EthernetAddress([0; 6]),
549            EthernetAddress([0; 6]),
550        );
551        let payload = extract_udp_payload(&frame).unwrap();
552        assert_eq!(payload, b"v6 data");
553    }
554
555    #[test]
556    fn extract_payload_from_v4_udp_frame() {
557        // Build a frame then extract the payload.
558        let frame = construct_udp_response_v4(
559            Ipv4Addr::new(1, 2, 3, 4),
560            80,
561            Ipv4Addr::new(10, 0, 0, 2),
562            54321,
563            b"test data",
564            EthernetAddress([0; 6]),
565            EthernetAddress([0; 6]),
566        );
567        let payload = extract_udp_payload(&frame).unwrap();
568        assert_eq!(payload, b"test data");
569    }
570}