Skip to main content

microsandbox_network/
stack.rs

1//! smoltcp interface setup, frame classification, and poll loop.
2//!
3//! This module contains the core networking event loop that runs on a
4//! dedicated OS thread. It bridges guest ethernet frames (via
5//! [`SmoltcpDevice`]) to smoltcp's TCP/IP stack and services connections
6//! through tokio proxy tasks.
7
8use std::collections::HashSet;
9use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
10use std::sync::Arc;
11use std::sync::atomic::Ordering;
12
13use smoltcp::iface::{Config, Interface, SocketSet};
14use smoltcp::time::Instant;
15
16use smoltcp::wire::{
17    EthernetAddress, EthernetFrame, EthernetProtocol, HardwareAddress, Icmpv4Packet, Icmpv4Repr,
18    Icmpv6Packet, Icmpv6Repr, IpAddress, IpCidr, IpProtocol, Ipv4Packet, Ipv4Repr, Ipv6Packet,
19    Ipv6Repr, TcpPacket, UdpPacket,
20};
21
22use crate::config::{DnsConfig, PublishedPort};
23use crate::conn::ConnectionTracker;
24use crate::device::SmoltcpDevice;
25use crate::dns::common::ports::DnsPortType;
26use crate::dns::{
27    interceptor::DnsInterceptor,
28    proxies::{dot::DotProxy, tcp::DnsTcpProxy},
29};
30use crate::icmp_relay::IcmpRelay;
31use crate::policy::{NetworkPolicy, Protocol};
32use crate::proxy;
33use crate::publisher::PortPublisher;
34use crate::shared::SharedState;
35use crate::tls::{proxy as tls_proxy, state::TlsState};
36use crate::udp_relay::UdpRelay;
37
38//--------------------------------------------------------------------------------------------------
39// Types
40//--------------------------------------------------------------------------------------------------
41
42/// Result of classifying a guest ethernet frame before smoltcp processes it.
43///
44/// Pre-inspection allows the poll loop to:
45/// - Create TCP sockets before smoltcp sees a SYN (preventing auto-RST).
46/// - Handle non-DNS UDP outside smoltcp (smoltcp lacks wildcard port binding).
47/// - Route DNS queries to the interception handler.
48pub enum FrameAction {
49    /// TCP SYN to a new destination — create a smoltcp socket before
50    /// letting smoltcp process the frame.
51    TcpSyn { src: SocketAddr, dst: SocketAddr },
52
53    /// Non-DNS UDP datagram — handle entirely outside smoltcp via the UDP
54    /// relay.
55    UdpRelay { src: SocketAddr, dst: SocketAddr },
56
57    /// DNS query (UDP to port 53) — let smoltcp's bound UDP socket handle it.
58    Dns,
59
60    /// Everything else (ARP, NDP, ICMP, TCP data/ACK/FIN, etc.) — let
61    /// smoltcp process normally.
62    Passthrough,
63}
64
65/// Resolved network parameters for the poll loop. Created by
66/// `SmoltcpNetwork::new()` from `NetworkConfig` + sandbox slot.
67pub struct PollLoopConfig {
68    /// Gateway MAC address (smoltcp's identity on the virtual LAN).
69    pub gateway_mac: [u8; 6],
70    /// Guest MAC address.
71    pub guest_mac: [u8; 6],
72    /// Gateway addresses (IPv4 + IPv6) owned by the smoltcp virtual
73    /// stack.
74    pub gateway: GatewayIps,
75    /// Guest IPv4 address.
76    pub guest_ipv4: Ipv4Addr,
77    /// IP-level MTU (e.g. 1500).
78    pub mtu: usize,
79}
80
81/// Per-sandbox gateway addresses (v4 + v6) owned by the smoltcp virtual stack.
82/// Both families are always assigned. The proxy's `resolve_host_dst` helper uses
83/// these to rewrite gateway-bound connections to loopback at dial time.
84#[derive(Debug, Clone, Copy)]
85pub struct GatewayIps {
86    /// Gateway IPv4.
87    pub ipv4: Ipv4Addr,
88    /// Gateway IPv6.
89    pub ipv6: Ipv6Addr,
90}
91
92//--------------------------------------------------------------------------------------------------
93// Functions
94//--------------------------------------------------------------------------------------------------
95
96/// Classify a raw ethernet frame for pre-inspection.
97///
98/// Uses smoltcp's wire module for zero-copy parsing. Returns
99/// [`FrameAction::Passthrough`] for any frame that cannot be parsed or
100/// doesn't match a special case.
101pub fn classify_frame(frame: &[u8]) -> FrameAction {
102    let Ok(eth) = EthernetFrame::new_checked(frame) else {
103        return FrameAction::Passthrough;
104    };
105
106    match eth.ethertype() {
107        EthernetProtocol::Ipv4 => classify_ipv4(eth.payload()),
108        EthernetProtocol::Ipv6 => classify_ipv6(eth.payload()),
109        _ => FrameAction::Passthrough, // ARP, etc.
110    }
111}
112
113/// Create and configure the smoltcp [`Interface`].
114///
115/// The interface is configured as the **gateway**: it owns the gateway IP
116/// addresses and responds to ARP/NDP for them. `any_ip` mode is enabled so
117/// smoltcp accepts traffic destined for arbitrary remote IPs (not just the
118/// gateway), combined with default routes.
119pub fn create_interface(device: &mut SmoltcpDevice, config: &PollLoopConfig) -> Interface {
120    let hw_addr = HardwareAddress::Ethernet(EthernetAddress(config.gateway_mac));
121    let iface_config = Config::new(hw_addr);
122    let mut iface = Interface::new(iface_config, device, smoltcp_now());
123
124    // Configure gateway IP addresses.
125    iface.update_ip_addrs(|addrs| {
126        addrs
127            .push(IpCidr::new(
128                IpAddress::Ipv4(config.gateway.ipv4),
129                // /30 subnet: gateway + guest.
130                30,
131            ))
132            .expect("failed to add gateway IPv4 address");
133        addrs
134            .push(IpCidr::new(IpAddress::Ipv6(config.gateway.ipv6), 64))
135            .expect("failed to add gateway IPv6 address");
136    });
137
138    // Default routes so smoltcp accepts traffic for all destinations.
139    iface
140        .routes_mut()
141        .add_default_ipv4_route(config.gateway.ipv4)
142        .expect("failed to add default IPv4 route");
143    iface
144        .routes_mut()
145        .add_default_ipv6_route(config.gateway.ipv6)
146        .expect("failed to add default IPv6 route");
147
148    // Accept traffic destined for any IP, not just gateway addresses.
149    iface.set_any_ip(true);
150
151    iface
152}
153
154/// Main smoltcp poll loop. Runs on a dedicated OS thread.
155///
156/// Processes guest frames with pre-inspection, drives smoltcp's TCP/IP stack,
157/// and sleeps via `poll(2)` between events.
158///
159/// # Phases per iteration
160///
161/// 1. **Drain guest frames** — pop from `tx_ring`, classify, pre-inspect.
162/// 2. **smoltcp egress + maintenance** — transmit queued packets, run timers.
163/// 3. **Service connections** — relay data between smoltcp sockets and proxy
164///    tasks (added by later tasks).
165/// 4. **Sleep** — `poll(2)` on `tx_wake` + `proxy_wake` pipes with smoltcp's
166///    requested timeout.
167///
168/// # Arguments
169///
170/// * `shared` - Stack-wide shared state: `tx_ring` / `rx_ring` for the virtio-net boundary
171///   and the wake eventfds.
172/// * `config` - Resolved per-sandbox parameters (gateway / guest MAC + IPv4 + IPv6, MTU).
173/// * `network_policy` - User-provided egress policy. Evaluated against the sandbox's
174///   gateway IPs (stored on [`SharedState`]) so `DestinationGroup::Host` rules match.
175/// * `dns_config` - DNS interception settings (block lists, upstreams, timeout).
176/// * `tls_state` - Optional TLS MITM state; drives interception of intercepted ports and DoT
177///   when present.
178/// * `published_ports` - Host → guest port publishes; the publisher accepts inbound
179///   connections on the host-bind address and forwards into the guest.
180/// * `max_connections` - Optional cap on concurrent guest connections tracked by
181///   [`ConnectionTracker`]; `None` uses the default.
182/// * `tokio_handle` - Runtime handle used for proxy tasks, DNS forwarding, port publishing,
183///   and ICMP relays.
184#[allow(clippy::too_many_arguments)]
185pub fn smoltcp_poll_loop(
186    shared: Arc<SharedState>,
187    config: PollLoopConfig,
188    network_policy: NetworkPolicy,
189    dns_config: DnsConfig,
190    tls_state: Option<Arc<TlsState>>,
191    published_ports: Vec<PublishedPort>,
192    max_connections: Option<usize>,
193    tokio_handle: tokio::runtime::Handle,
194) {
195    let mut device = SmoltcpDevice::new(shared.clone(), config.mtu);
196    let mut iface = create_interface(&mut device, &config);
197    let mut sockets = SocketSet::new(vec![]);
198    let mut conn_tracker = ConnectionTracker::new(max_connections);
199
200    // The DNS forwarder needs to know which IPs count as "the gateway"
201    // (so it routes guest queries to those addresses through the
202    // configured upstream) and a policy evaluator (so guest-chosen
203    // `@target` resolvers are gated by egress rules just like any
204    // other outbound).
205    let gateway_ips: Arc<HashSet<IpAddr>> = Arc::new(HashSet::from([
206        IpAddr::V4(config.gateway.ipv4),
207        IpAddr::V6(config.gateway.ipv6),
208    ]));
209    // Gateway IPs must be on SharedState before any egress evaluation runs,
210    // so `DestinationGroup::Host` rules can resolve to the right address.
211    shared.set_gateway_ips(config.gateway.ipv4, config.gateway.ipv6);
212    let network_policy = Arc::new(network_policy);
213
214    let (mut dns_interceptor, dns_forwarder_handle) = DnsInterceptor::new(
215        &mut sockets,
216        dns_config,
217        shared.clone(),
218        &tokio_handle,
219        gateway_ips,
220        network_policy.clone(),
221        config.gateway,
222    );
223    let mut port_publisher = PortPublisher::new(&published_ports, config.guest_ipv4, &tokio_handle);
224    let mut udp_relay = UdpRelay::new(
225        shared.clone(),
226        config.gateway_mac,
227        config.guest_mac,
228        tokio_handle.clone(),
229    );
230    let icmp_relay = IcmpRelay::new(
231        shared.clone(),
232        config.gateway_mac,
233        config.guest_mac,
234        tokio_handle.clone(),
235    );
236
237    // Rate-limit cleanup operations: run at most once per second.
238    let mut last_cleanup = std::time::Instant::now();
239
240    // poll(2) file descriptors for sleeping.
241    let mut poll_fds = [
242        libc::pollfd {
243            fd: shared.tx_wake.as_raw_fd(),
244            events: libc::POLLIN,
245            revents: 0,
246        },
247        libc::pollfd {
248            fd: shared.proxy_wake.as_raw_fd(),
249            events: libc::POLLIN,
250            revents: 0,
251        },
252    ];
253
254    loop {
255        let now = smoltcp_now();
256
257        // ── Phase 1: Drain all guest frames with pre-inspection ──────────
258        while let Some(frame) = device.stage_next_frame() {
259            if handle_gateway_icmp_echo(frame, &config, &shared) {
260                device.drop_staged_frame();
261                continue;
262            }
263
264            if icmp_relay.relay_outbound_if_echo(frame, &config, &network_policy) {
265                device.drop_staged_frame();
266                continue;
267            }
268
269            match classify_frame(frame) {
270                FrameAction::TcpSyn { src, dst } => {
271                    let allow = match DnsPortType::from_tcp(dst.port()) {
272                        // Plain DNS: the interceptor enforces policy at
273                        // the application layer (block list + rebind
274                        // protection); bypass the network egress check.
275                        DnsPortType::Dns => true,
276                        // DoT: intercept only when TLS MITM is
277                        // configured. Without it, the block list can't
278                        // apply (traffic is encrypted end-to-end), so
279                        // we refuse to force a fall-back to plain
280                        // TCP/53. When TLS MITM is configured, bypass
281                        // egress policy the same way plain DNS does —
282                        // policy for the upstream resolver is applied
283                        // per query by the forwarder.
284                        DnsPortType::EncryptedDns => {
285                            if tls_state.is_some() {
286                                true
287                            } else {
288                                tracing::debug!(%dst, "DoT port refused (TLS interception not configured); stub should fall back to TCP/53");
289                                false
290                            }
291                        }
292                        // Alternative DNS protocol we can't proxy:
293                        // refuse outright — no socket means smoltcp
294                        // emits RST, which the guest's stub treats as
295                        // "upstream unavailable" and falls back to
296                        // plain TCP/53.
297                        DnsPortType::AlternativeDns => {
298                            tracing::debug!(%dst, "alternative-DNS TCP port refused; stub should fall back to TCP/53");
299                            false
300                        }
301                        // Other: regular outbound — apply egress policy.
302                        DnsPortType::Other => network_policy
303                            .evaluate_egress(dst, Protocol::Tcp, &shared)
304                            .is_allow(),
305                    };
306                    if allow && !conn_tracker.has_socket_for(&src, &dst) {
307                        conn_tracker.create_tcp_socket(src, dst, &mut sockets);
308                    }
309                    // Let smoltcp process — matching socket completes
310                    // handshake, no socket means auto-RST.
311                    iface.poll_ingress_single(now, &mut device, &mut sockets);
312                }
313
314                FrameAction::UdpRelay { src, dst } => {
315                    // QUIC blocking: drop UDP to intercepted ports when
316                    // TLS interception is active.
317                    if let Some(ref tls) = tls_state
318                        && tls.config.intercepted_ports.contains(&dst.port())
319                        && tls.config.block_quic_on_intercept
320                    {
321                        device.drop_staged_frame();
322                        continue;
323                    }
324
325                    match DnsPortType::from_udp(dst.port()) {
326                        // Dns: unreachable here — classify_transport
327                        // routes UDP/53 to FrameAction::Dns, not
328                        // UdpRelay. Defensive drop covers regressions.
329                        DnsPortType::Dns => {
330                            device.drop_staged_frame();
331                            continue;
332                        }
333                        // EncryptedDns: unreachable here —
334                        // `DnsPortType::from_udp` never returns it
335                        // today (DoT is TCP-only; UDP/853 is DoQ and
336                        // returns AlternativeDns). Defensive drop.
337                        DnsPortType::EncryptedDns => {
338                            device.drop_staged_frame();
339                            continue;
340                        }
341                        // Alternative DNS protocols on well-known UDP
342                        // ports are dropped — forces fall-back to UDP/53.
343                        DnsPortType::AlternativeDns => {
344                            tracing::debug!(%dst, "alternative-DNS UDP port dropped; stub should fall back to UDP/53");
345                            device.drop_staged_frame();
346                            continue;
347                        }
348                        DnsPortType::Other => {}
349                    }
350
351                    // Policy check.
352                    if network_policy
353                        .evaluate_egress(dst, Protocol::Udp, &shared)
354                        .is_deny()
355                    {
356                        device.drop_staged_frame();
357                        continue;
358                    }
359
360                    // Resolve the host-side destination for the dial.
361                    // `dst` stays unchanged so reply frames are stamped
362                    // with the IP the guest expects.
363                    let host_dst = resolve_host_dst(dst, config.gateway);
364                    udp_relay.relay_outbound(frame, src, dst, host_dst);
365                    device.drop_staged_frame();
366                }
367
368                FrameAction::Dns | FrameAction::Passthrough => {
369                    // ARP, ICMP, DNS (port 53), TCP data — smoltcp handles.
370                    iface.poll_ingress_single(now, &mut device, &mut sockets);
371                }
372            }
373        }
374
375        // ── Phase 2: Ingress egress + maintenance ─────────────────────────
376        // Flush frames generated by Phase 1 ingress (ACKs, SYN-ACKs, etc.)
377        // before relaying data so smoltcp has up-to-date state.
378        loop {
379            let result = iface.poll_egress(now, &mut device, &mut sockets);
380            if matches!(result, smoltcp::iface::PollResult::None) {
381                break;
382            }
383        }
384        iface.poll_maintenance(now);
385
386        // Coalesced wake: if Phase 1/2 emitted any frames, wake the
387        // NetWorker once instead of per-frame.
388        if device.frames_emitted.swap(false, Ordering::Relaxed) {
389            shared.rx_wake.wake();
390        }
391
392        // ── Phase 3: Service connections + relay data ────────────────────
393        // Relay proxy data INTO smoltcp sockets first, then a single egress
394        // pass flushes everything. This eliminates the former "Phase 2b"
395        // double-egress pattern.
396        conn_tracker.relay_data(&mut sockets);
397        dns_interceptor.process(&mut sockets);
398
399        // Accept queued inbound connections from published port listeners.
400        port_publisher.accept_inbound(&mut iface, &mut sockets, &shared, &tokio_handle);
401        port_publisher.relay_data(&mut sockets);
402
403        // Detect newly-established connections and spawn proxy tasks.
404        let new_conns = conn_tracker.take_new_connections(&mut sockets);
405        for conn in new_conns {
406            if let Some(ref tls_state) = tls_state
407                && tls_state
408                    .config
409                    .intercepted_ports
410                    .contains(&conn.dst.port())
411            {
412                // TLS-intercepted port — spawn TLS MITM proxy.
413                let conn_dst = resolve_host_dst(conn.dst, config.gateway);
414                tls_proxy::spawn_tls_proxy(
415                    &tokio_handle,
416                    conn_dst,
417                    conn.from_smoltcp,
418                    conn.to_smoltcp,
419                    shared.clone(),
420                    tls_state.clone(),
421                );
422                continue;
423            }
424            if conn.dst.port() == 53 {
425                // DNS over TCP: route through the same forwarder the UDP
426                // path uses. The forwarder applies the domain block list
427                // and rebind protection to every query and routes
428                // upstream based on `conn.dst.ip()` — the configured
429                // upstream for queries to the gateway, direct forward
430                // to the chosen `@target` (subject to egress policy)
431                // otherwise. No gateway→loopback rewrite here: the
432                // forwarder dials the configured upstream, not the
433                // gateway.
434                DnsTcpProxy::spawn(
435                    &tokio_handle,
436                    conn.dst,
437                    conn.from_smoltcp,
438                    conn.to_smoltcp,
439                    dns_forwarder_handle.clone(),
440                    shared.clone(),
441                );
442                continue;
443            }
444            if conn.dst.port() == 853
445                && let Some(ref tls_state) = tls_state
446            {
447                // DNS over TLS: terminate TLS at the gateway with a
448                // per-domain cert, hand the inner DNS frames to the
449                // same forwarder plain DNS uses. Policy for the
450                // chosen `@target` resolver is applied per-query by
451                // the forwarder (block list + rebind + egress).
452                DotProxy::spawn(
453                    &tokio_handle,
454                    conn.dst,
455                    conn.from_smoltcp,
456                    conn.to_smoltcp,
457                    dns_forwarder_handle.clone(),
458                    tls_state.clone(),
459                    shared.clone(),
460                );
461                continue;
462            }
463            // Plain TCP proxy.
464            let dst = resolve_host_dst(conn.dst, config.gateway);
465            proxy::spawn_tcp_proxy(
466                &tokio_handle,
467                dst,
468                conn.from_smoltcp,
469                conn.to_smoltcp,
470                shared.clone(),
471            );
472        }
473
474        // Rate-limited cleanup: TIME_WAIT is 60s, session timeout is 60s,
475        // so checking once per second is more than sufficient.
476        if last_cleanup.elapsed() >= std::time::Duration::from_secs(1) {
477            conn_tracker.cleanup_closed(&mut sockets);
478            port_publisher.cleanup_closed(&mut sockets);
479            udp_relay.cleanup_expired();
480            shared.cleanup_resolved_hostnames();
481            last_cleanup = std::time::Instant::now();
482        }
483
484        // ── Phase 4: Flush relay data + sleep ────────────────────────────
485        // Single egress pass flushes all data written by Phase 3.
486        loop {
487            let result = iface.poll_egress(now, &mut device, &mut sockets);
488            if matches!(result, smoltcp::iface::PollResult::None) {
489                break;
490            }
491        }
492
493        // Coalesced wake: if Phase 3/4 emitted any frames, wake once.
494        if device.frames_emitted.swap(false, Ordering::Relaxed) {
495            shared.rx_wake.wake();
496        }
497
498        let timeout_ms = iface
499            .poll_delay(now, &sockets)
500            .map(|d| d.total_millis().min(i32::MAX as u64) as i32)
501            .unwrap_or(100); // 100ms fallback when no timers pending.
502
503        // SAFETY: poll_fds is a valid array of pollfd structs with valid fds.
504        unsafe {
505            libc::poll(
506                poll_fds.as_mut_ptr(),
507                poll_fds.len() as libc::nfds_t,
508                timeout_ms,
509            );
510        }
511
512        // Conditional drain: only drain pipes that actually have data.
513        if poll_fds[0].revents & libc::POLLIN != 0 {
514            shared.tx_wake.drain();
515        }
516        if poll_fds[1].revents & libc::POLLIN != 0 {
517            shared.proxy_wake.drain();
518        }
519    }
520}
521
522//--------------------------------------------------------------------------------------------------
523// Functions: Helpers
524//--------------------------------------------------------------------------------------------------
525
526/// Map a guest-wire destination to its host-socket equivalent.
527///
528/// Gateway IPs rewrite to loopback (`127.0.0.1` / `::1`); everything else
529/// passes through. Shared by the TCP proxy dispatch and the UDP relay.
530///
531/// # Arguments
532///
533/// * `dst` - Destination from the guest's packet.
534/// * `gateway` - Per-sandbox gateway IPs that trigger the loopback rewrite.
535pub(crate) fn resolve_host_dst(dst: SocketAddr, gateway: GatewayIps) -> SocketAddr {
536    match dst.ip() {
537        IpAddr::V4(v4) if v4 == gateway.ipv4 => {
538            SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), dst.port())
539        }
540        IpAddr::V6(v6) if v6 == gateway.ipv6 => {
541            SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), dst.port())
542        }
543        _ => dst,
544    }
545}
546
547/// Get the current time as a smoltcp [`Instant`] using a monotonic clock.
548///
549/// Uses `std::time::Instant` (monotonic) instead of `SystemTime` (wall
550/// clock) to avoid issues with NTP clock step corrections that could
551/// cause smoltcp timers to misbehave.
552fn smoltcp_now() -> Instant {
553    static EPOCH: std::sync::OnceLock<std::time::Instant> = std::sync::OnceLock::new();
554    let epoch = EPOCH.get_or_init(std::time::Instant::now);
555    let elapsed = epoch.elapsed();
556    Instant::from_millis(elapsed.as_millis() as i64)
557}
558
559/// Reply locally to ICMP echo requests aimed at the sandbox gateway.
560///
561/// `any_ip` is required so smoltcp accepts guest traffic for arbitrary remote
562/// destinations, but that would make smoltcp's automatic ICMP echo replies
563/// spoof remote hosts. Handle only the real gateway IPs here and leave all
564/// other ICMP traffic untouched.
565fn handle_gateway_icmp_echo(frame: &[u8], config: &PollLoopConfig, shared: &SharedState) -> bool {
566    let Ok(eth) = EthernetFrame::new_checked(frame) else {
567        return false;
568    };
569
570    let reply = match eth.ethertype() {
571        EthernetProtocol::Ipv4 => gateway_icmpv4_echo_reply(&eth, config),
572        EthernetProtocol::Ipv6 => gateway_icmpv6_echo_reply(&eth, config),
573        _ => None,
574    };
575    let Some(reply) = reply else {
576        return false;
577    };
578
579    let reply_len = reply.len();
580    if shared.rx_ring.push(reply).is_ok() {
581        shared.add_rx_bytes(reply_len);
582        shared.rx_wake.wake();
583    }
584
585    true
586}
587
588/// Build an IPv4 ICMP echo reply when the guest pings the gateway IPv4.
589fn gateway_icmpv4_echo_reply(
590    eth: &EthernetFrame<&[u8]>,
591    config: &PollLoopConfig,
592) -> Option<Vec<u8>> {
593    let ipv4 = Ipv4Packet::new_checked(eth.payload()).ok()?;
594    if ipv4.dst_addr() != config.gateway.ipv4 || ipv4.next_header() != IpProtocol::Icmp {
595        return None;
596    }
597
598    let icmp = Icmpv4Packet::new_checked(ipv4.payload()).ok()?;
599    let Icmpv4Repr::EchoRequest {
600        ident,
601        seq_no,
602        data,
603    } = Icmpv4Repr::parse(&icmp, &smoltcp::phy::ChecksumCapabilities::default()).ok()?
604    else {
605        return None;
606    };
607
608    let ipv4_repr = Ipv4Repr {
609        src_addr: config.gateway.ipv4,
610        dst_addr: ipv4.src_addr(),
611        next_header: IpProtocol::Icmp,
612        payload_len: 8 + data.len(),
613        hop_limit: 64,
614    };
615    let icmp_repr = Icmpv4Repr::EchoReply {
616        ident,
617        seq_no,
618        data,
619    };
620    let mut reply = vec![0u8; 14 + ipv4_repr.buffer_len() + icmp_repr.buffer_len()];
621
622    let mut reply_eth = EthernetFrame::new_unchecked(&mut reply);
623    reply_eth.set_src_addr(EthernetAddress(config.gateway_mac));
624    reply_eth.set_dst_addr(eth.src_addr());
625    reply_eth.set_ethertype(EthernetProtocol::Ipv4);
626
627    ipv4_repr.emit(
628        &mut Ipv4Packet::new_unchecked(&mut reply[14..34]),
629        &smoltcp::phy::ChecksumCapabilities::default(),
630    );
631    icmp_repr.emit(
632        &mut Icmpv4Packet::new_unchecked(&mut reply[34..]),
633        &smoltcp::phy::ChecksumCapabilities::default(),
634    );
635
636    Some(reply)
637}
638
639/// Build an IPv6 ICMP echo reply when the guest pings the gateway IPv6.
640fn gateway_icmpv6_echo_reply(
641    eth: &EthernetFrame<&[u8]>,
642    config: &PollLoopConfig,
643) -> Option<Vec<u8>> {
644    let ipv6 = Ipv6Packet::new_checked(eth.payload()).ok()?;
645    if ipv6.dst_addr() != config.gateway.ipv6 || ipv6.next_header() != IpProtocol::Icmpv6 {
646        return None;
647    }
648
649    let icmp = Icmpv6Packet::new_checked(ipv6.payload()).ok()?;
650    let Icmpv6Repr::EchoRequest {
651        ident,
652        seq_no,
653        data,
654    } = Icmpv6Repr::parse(
655        &ipv6.src_addr(),
656        &ipv6.dst_addr(),
657        &icmp,
658        &smoltcp::phy::ChecksumCapabilities::default(),
659    )
660    .ok()?
661    else {
662        return None;
663    };
664
665    let ipv6_repr = Ipv6Repr {
666        src_addr: config.gateway.ipv6,
667        dst_addr: ipv6.src_addr(),
668        next_header: IpProtocol::Icmpv6,
669        payload_len: icmp_repr_buffer_len_v6(data),
670        hop_limit: 64,
671    };
672    let icmp_repr = Icmpv6Repr::EchoReply {
673        ident,
674        seq_no,
675        data,
676    };
677    let ipv6_hdr_len = 40;
678    let mut reply = vec![0u8; 14 + ipv6_hdr_len + icmp_repr.buffer_len()];
679
680    let mut reply_eth = EthernetFrame::new_unchecked(&mut reply);
681    reply_eth.set_src_addr(EthernetAddress(config.gateway_mac));
682    reply_eth.set_dst_addr(eth.src_addr());
683    reply_eth.set_ethertype(EthernetProtocol::Ipv6);
684
685    ipv6_repr.emit(&mut Ipv6Packet::new_unchecked(&mut reply[14..54]));
686    icmp_repr.emit(
687        &config.gateway.ipv6,
688        &ipv6.src_addr(),
689        &mut Icmpv6Packet::new_unchecked(&mut reply[54..]),
690        &smoltcp::phy::ChecksumCapabilities::default(),
691    );
692
693    Some(reply)
694}
695
696fn icmp_repr_buffer_len_v6(data: &[u8]) -> usize {
697    Icmpv6Repr::EchoReply {
698        ident: 0,
699        seq_no: 0,
700        data,
701    }
702    .buffer_len()
703}
704
705/// Classify an IPv4 packet payload (after stripping the Ethernet header).
706fn classify_ipv4(payload: &[u8]) -> FrameAction {
707    let Ok(ipv4) = Ipv4Packet::new_checked(payload) else {
708        return FrameAction::Passthrough;
709    };
710    classify_transport(
711        ipv4.next_header(),
712        ipv4.src_addr().into(),
713        ipv4.dst_addr().into(),
714        ipv4.payload(),
715    )
716}
717
718/// Classify an IPv6 packet payload (after stripping the Ethernet header).
719fn classify_ipv6(payload: &[u8]) -> FrameAction {
720    let Ok(ipv6) = Ipv6Packet::new_checked(payload) else {
721        return FrameAction::Passthrough;
722    };
723    classify_transport(
724        ipv6.next_header(),
725        ipv6.src_addr().into(),
726        ipv6.dst_addr().into(),
727        ipv6.payload(),
728    )
729}
730
731/// Classify the transport-layer protocol (shared by IPv4 and IPv6).
732fn classify_transport(
733    protocol: IpProtocol,
734    src_ip: std::net::IpAddr,
735    dst_ip: std::net::IpAddr,
736    transport_payload: &[u8],
737) -> FrameAction {
738    match protocol {
739        IpProtocol::Tcp => {
740            let Ok(tcp) = TcpPacket::new_checked(transport_payload) else {
741                return FrameAction::Passthrough;
742            };
743            if tcp.syn() && !tcp.ack() {
744                FrameAction::TcpSyn {
745                    src: SocketAddr::new(src_ip, tcp.src_port()),
746                    dst: SocketAddr::new(dst_ip, tcp.dst_port()),
747                }
748            } else {
749                FrameAction::Passthrough
750            }
751        }
752        IpProtocol::Udp => {
753            let Ok(udp) = UdpPacket::new_checked(transport_payload) else {
754                return FrameAction::Passthrough;
755            };
756            // The plain-DNS port (UDP/53) lives in dns::common::ports so
757            // the alternative-DNS refusal logic and this dispatcher
758            // share one source of truth for "which UDP ports are DNS".
759            if DnsPortType::from_udp(udp.dst_port()) == DnsPortType::Dns {
760                FrameAction::Dns
761            } else {
762                FrameAction::UdpRelay {
763                    src: SocketAddr::new(src_ip, udp.src_port()),
764                    dst: SocketAddr::new(dst_ip, udp.dst_port()),
765                }
766            }
767        }
768        _ => FrameAction::Passthrough, // ICMP, etc.
769    }
770}
771
772//--------------------------------------------------------------------------------------------------
773// Tests
774//--------------------------------------------------------------------------------------------------
775
776#[cfg(test)]
777mod tests {
778    use super::*;
779    use std::sync::Arc;
780
781    use smoltcp::phy::ChecksumCapabilities;
782    use smoltcp::wire::{
783        ArpOperation, ArpPacket, ArpRepr, EthernetRepr, Icmpv4Packet, Icmpv4Repr, Ipv4Repr,
784    };
785
786    use crate::device::SmoltcpDevice;
787    use crate::shared::SharedState;
788
789    /// Build a minimal Ethernet + IPv4 + TCP SYN frame.
790    fn build_tcp_syn_frame(
791        src_ip: [u8; 4],
792        dst_ip: [u8; 4],
793        src_port: u16,
794        dst_port: u16,
795    ) -> Vec<u8> {
796        let mut frame = vec![0u8; 14 + 20 + 20]; // eth + ipv4 + tcp
797
798        // Ethernet header.
799        frame[12] = 0x08; // EtherType: IPv4
800        frame[13] = 0x00;
801
802        // IPv4 header.
803        let ip = &mut frame[14..34];
804        ip[0] = 0x45; // Version + IHL
805        let total_len = 40u16; // 20 (IP) + 20 (TCP)
806        ip[2..4].copy_from_slice(&total_len.to_be_bytes());
807        ip[6] = 0x40; // Don't Fragment
808        ip[8] = 64; // TTL
809        ip[9] = 6; // Protocol: TCP
810        ip[12..16].copy_from_slice(&src_ip);
811        ip[16..20].copy_from_slice(&dst_ip);
812
813        // TCP header.
814        let tcp = &mut frame[34..54];
815        tcp[0..2].copy_from_slice(&src_port.to_be_bytes());
816        tcp[2..4].copy_from_slice(&dst_port.to_be_bytes());
817        tcp[12] = 0x50; // Data offset: 5 words
818        tcp[13] = 0x02; // SYN flag
819
820        frame
821    }
822
823    /// Build a minimal Ethernet + IPv4 + UDP frame.
824    fn build_udp_frame(src_ip: [u8; 4], dst_ip: [u8; 4], src_port: u16, dst_port: u16) -> Vec<u8> {
825        let mut frame = vec![0u8; 14 + 20 + 8]; // eth + ipv4 + udp
826
827        // Ethernet header.
828        frame[12] = 0x08;
829        frame[13] = 0x00;
830
831        // IPv4 header.
832        let ip = &mut frame[14..34];
833        ip[0] = 0x45;
834        let total_len = 28u16; // 20 (IP) + 8 (UDP)
835        ip[2..4].copy_from_slice(&total_len.to_be_bytes());
836        ip[8] = 64;
837        ip[9] = 17; // Protocol: UDP
838        ip[12..16].copy_from_slice(&src_ip);
839        ip[16..20].copy_from_slice(&dst_ip);
840
841        // UDP header.
842        let udp = &mut frame[34..42];
843        udp[0..2].copy_from_slice(&src_port.to_be_bytes());
844        udp[2..4].copy_from_slice(&dst_port.to_be_bytes());
845        let udp_len = 8u16;
846        udp[4..6].copy_from_slice(&udp_len.to_be_bytes());
847
848        frame
849    }
850
851    /// Build a minimal Ethernet + IPv4 + ICMP echo request frame.
852    fn build_icmpv4_echo_frame(
853        src_mac: [u8; 6],
854        dst_mac: [u8; 6],
855        src_ip: [u8; 4],
856        dst_ip: [u8; 4],
857        ident: u16,
858        seq_no: u16,
859        data: &[u8],
860    ) -> Vec<u8> {
861        let ipv4_repr = Ipv4Repr {
862            src_addr: Ipv4Addr::from(src_ip).into(),
863            dst_addr: Ipv4Addr::from(dst_ip).into(),
864            next_header: IpProtocol::Icmp,
865            payload_len: 8 + data.len(),
866            hop_limit: 64,
867        };
868        let icmp_repr = Icmpv4Repr::EchoRequest {
869            ident,
870            seq_no,
871            data,
872        };
873        let frame_len = 14 + ipv4_repr.buffer_len() + icmp_repr.buffer_len();
874        let mut frame = vec![0u8; frame_len];
875
876        let mut eth_frame = EthernetFrame::new_unchecked(&mut frame);
877        EthernetRepr {
878            src_addr: EthernetAddress(src_mac),
879            dst_addr: EthernetAddress(dst_mac),
880            ethertype: EthernetProtocol::Ipv4,
881        }
882        .emit(&mut eth_frame);
883
884        ipv4_repr.emit(
885            &mut Ipv4Packet::new_unchecked(&mut frame[14..34]),
886            &ChecksumCapabilities::default(),
887        );
888        icmp_repr.emit(
889            &mut Icmpv4Packet::new_unchecked(&mut frame[34..]),
890            &ChecksumCapabilities::default(),
891        );
892
893        frame
894    }
895
896    /// Build a minimal Ethernet + ARP request frame.
897    fn build_arp_request_frame(src_mac: [u8; 6], src_ip: [u8; 4], target_ip: [u8; 4]) -> Vec<u8> {
898        let mut frame = vec![0u8; 14 + 28];
899
900        let mut eth_frame = EthernetFrame::new_unchecked(&mut frame);
901        EthernetRepr {
902            src_addr: EthernetAddress(src_mac),
903            dst_addr: EthernetAddress([0xff; 6]),
904            ethertype: EthernetProtocol::Arp,
905        }
906        .emit(&mut eth_frame);
907
908        ArpRepr::EthernetIpv4 {
909            operation: ArpOperation::Request,
910            source_hardware_addr: EthernetAddress(src_mac),
911            source_protocol_addr: Ipv4Addr::from(src_ip).into(),
912            target_hardware_addr: EthernetAddress([0x00; 6]),
913            target_protocol_addr: Ipv4Addr::from(target_ip).into(),
914        }
915        .emit(&mut ArpPacket::new_unchecked(&mut frame[14..]));
916
917        frame
918    }
919
920    #[test]
921    fn classify_tcp_syn() {
922        let frame = build_tcp_syn_frame([10, 0, 0, 2], [93, 184, 216, 34], 54321, 443);
923        match classify_frame(&frame) {
924            FrameAction::TcpSyn { src, dst } => {
925                assert_eq!(
926                    src,
927                    SocketAddr::new(Ipv4Addr::new(10, 0, 0, 2).into(), 54321)
928                );
929                assert_eq!(
930                    dst,
931                    SocketAddr::new(Ipv4Addr::new(93, 184, 216, 34).into(), 443)
932                );
933            }
934            _ => panic!("expected TcpSyn"),
935        }
936    }
937
938    #[test]
939    fn classify_tcp_ack_is_passthrough() {
940        let mut frame = build_tcp_syn_frame([10, 0, 0, 2], [93, 184, 216, 34], 54321, 443);
941        // Change flags to ACK only (not SYN).
942        frame[34 + 13] = 0x10; // ACK flag
943        assert!(matches!(classify_frame(&frame), FrameAction::Passthrough));
944    }
945
946    #[test]
947    fn classify_udp_dns() {
948        let frame = build_udp_frame([10, 0, 0, 2], [10, 0, 0, 1], 12345, 53);
949        assert!(matches!(classify_frame(&frame), FrameAction::Dns));
950    }
951
952    #[test]
953    fn classify_udp_non_dns() {
954        let frame = build_udp_frame([10, 0, 0, 2], [8, 8, 8, 8], 12345, 443);
955        match classify_frame(&frame) {
956            FrameAction::UdpRelay { src, dst } => {
957                assert_eq!(src.port(), 12345);
958                assert_eq!(dst.port(), 443);
959            }
960            _ => panic!("expected UdpRelay"),
961        }
962    }
963
964    #[test]
965    fn classify_arp_is_passthrough() {
966        let mut frame = vec![0u8; 42]; // ARP frame
967        frame[12] = 0x08;
968        frame[13] = 0x06; // EtherType: ARP
969        assert!(matches!(classify_frame(&frame), FrameAction::Passthrough));
970    }
971
972    #[test]
973    fn classify_garbage_is_passthrough() {
974        assert!(matches!(classify_frame(&[]), FrameAction::Passthrough));
975        assert!(matches!(classify_frame(&[0; 5]), FrameAction::Passthrough));
976    }
977
978    #[test]
979    fn gateway_replies_to_icmp_echo_requests() {
980        fn drive_one_frame(
981            device: &mut SmoltcpDevice,
982            iface: &mut Interface,
983            sockets: &mut SocketSet<'_>,
984            shared: &Arc<SharedState>,
985            poll_config: &PollLoopConfig,
986            now: Instant,
987        ) {
988            let frame = device.stage_next_frame().expect("expected staged frame");
989            if handle_gateway_icmp_echo(frame, poll_config, shared) {
990                device.drop_staged_frame();
991                return;
992            }
993            let _ = iface.poll_ingress_single(now, device, sockets);
994            let _ = iface.poll_egress(now, device, sockets);
995        }
996
997        let shared = Arc::new(SharedState::new(4));
998        let poll_config = PollLoopConfig {
999            gateway_mac: [0x02, 0x00, 0x00, 0x00, 0x00, 0x01],
1000            guest_mac: [0x02, 0x00, 0x00, 0x00, 0x00, 0x02],
1001            gateway: GatewayIps {
1002                ipv4: Ipv4Addr::new(100, 96, 0, 1),
1003                ipv6: Ipv6Addr::LOCALHOST,
1004            },
1005            guest_ipv4: Ipv4Addr::new(100, 96, 0, 2),
1006            mtu: 1500,
1007        };
1008        let mut device = SmoltcpDevice::new(shared.clone(), poll_config.mtu);
1009        let mut iface = create_interface(&mut device, &poll_config);
1010        let mut sockets = SocketSet::new(vec![]);
1011        let now = smoltcp_now();
1012
1013        // Mirror the real guest flow: resolve the gateway MAC before sending
1014        // the ICMP echo request.
1015        shared
1016            .tx_ring
1017            .push(build_arp_request_frame(
1018                poll_config.guest_mac,
1019                poll_config.guest_ipv4.octets(),
1020                poll_config.gateway.ipv4.octets(),
1021            ))
1022            .unwrap();
1023        shared
1024            .tx_ring
1025            .push(build_icmpv4_echo_frame(
1026                poll_config.guest_mac,
1027                poll_config.gateway_mac,
1028                poll_config.guest_ipv4.octets(),
1029                poll_config.gateway.ipv4.octets(),
1030                0x1234,
1031                0xABCD,
1032                b"ping",
1033            ))
1034            .unwrap();
1035
1036        drive_one_frame(
1037            &mut device,
1038            &mut iface,
1039            &mut sockets,
1040            &shared,
1041            &poll_config,
1042            now,
1043        );
1044        let _ = shared.rx_ring.pop().expect("expected ARP reply");
1045
1046        drive_one_frame(
1047            &mut device,
1048            &mut iface,
1049            &mut sockets,
1050            &shared,
1051            &poll_config,
1052            now,
1053        );
1054
1055        let reply = shared.rx_ring.pop().expect("expected ICMP echo reply");
1056        let eth = EthernetFrame::new_checked(&reply).expect("valid ethernet frame");
1057        assert_eq!(eth.src_addr(), EthernetAddress(poll_config.gateway_mac));
1058        assert_eq!(eth.dst_addr(), EthernetAddress(poll_config.guest_mac));
1059        assert_eq!(eth.ethertype(), EthernetProtocol::Ipv4);
1060
1061        let ipv4 = Ipv4Packet::new_checked(eth.payload()).expect("valid IPv4 packet");
1062        assert_eq!(Ipv4Addr::from(ipv4.src_addr()), poll_config.gateway.ipv4);
1063        assert_eq!(Ipv4Addr::from(ipv4.dst_addr()), poll_config.guest_ipv4);
1064        assert_eq!(ipv4.next_header(), IpProtocol::Icmp);
1065
1066        let icmp = Icmpv4Packet::new_checked(ipv4.payload()).expect("valid ICMP packet");
1067        let icmp_repr = Icmpv4Repr::parse(&icmp, &ChecksumCapabilities::default())
1068            .expect("valid ICMP echo reply");
1069        assert_eq!(
1070            icmp_repr,
1071            Icmpv4Repr::EchoReply {
1072                ident: 0x1234,
1073                seq_no: 0xABCD,
1074                data: b"ping",
1075            }
1076        );
1077    }
1078
1079    fn test_gateway() -> GatewayIps {
1080        GatewayIps {
1081            ipv4: Ipv4Addr::new(100, 96, 0, 1),
1082            ipv6: "fd42:6d73:62::1".parse().unwrap(),
1083        }
1084    }
1085
1086    #[test]
1087    fn resolve_host_dst_matches_ipv4() {
1088        let gw = test_gateway();
1089        let dst = SocketAddr::new(IpAddr::V4(gw.ipv4), 8080);
1090        assert_eq!(
1091            resolve_host_dst(dst, gw),
1092            SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080)
1093        );
1094    }
1095
1096    #[test]
1097    fn resolve_host_dst_matches_ipv6() {
1098        let gw = test_gateway();
1099        let dst = SocketAddr::new(IpAddr::V6(gw.ipv6), 8080);
1100        assert_eq!(
1101            resolve_host_dst(dst, gw),
1102            SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080)
1103        );
1104    }
1105
1106    #[test]
1107    fn resolve_host_dst_passes_through_non_gateway() {
1108        let gw = test_gateway();
1109        let dst = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 443);
1110        assert_eq!(resolve_host_dst(dst, gw), dst);
1111    }
1112
1113    #[test]
1114    fn external_icmp_echo_requests_are_not_answered_locally() {
1115        fn drive_one_frame(
1116            device: &mut SmoltcpDevice,
1117            iface: &mut Interface,
1118            sockets: &mut SocketSet<'_>,
1119            shared: &Arc<SharedState>,
1120            poll_config: &PollLoopConfig,
1121            now: Instant,
1122        ) {
1123            let frame = device.stage_next_frame().expect("expected staged frame");
1124            if handle_gateway_icmp_echo(frame, poll_config, shared) {
1125                device.drop_staged_frame();
1126                return;
1127            }
1128            let _ = iface.poll_ingress_single(now, device, sockets);
1129            let _ = iface.poll_egress(now, device, sockets);
1130        }
1131
1132        let shared = Arc::new(SharedState::new(4));
1133        let poll_config = PollLoopConfig {
1134            gateway_mac: [0x02, 0x00, 0x00, 0x00, 0x00, 0x01],
1135            guest_mac: [0x02, 0x00, 0x00, 0x00, 0x00, 0x02],
1136            gateway: GatewayIps {
1137                ipv4: Ipv4Addr::new(100, 96, 0, 1),
1138                ipv6: Ipv6Addr::LOCALHOST,
1139            },
1140            guest_ipv4: Ipv4Addr::new(100, 96, 0, 2),
1141            mtu: 1500,
1142        };
1143        let mut device = SmoltcpDevice::new(shared.clone(), poll_config.mtu);
1144        let mut iface = create_interface(&mut device, &poll_config);
1145        let mut sockets = SocketSet::new(vec![]);
1146        let now = smoltcp_now();
1147
1148        shared
1149            .tx_ring
1150            .push(build_arp_request_frame(
1151                poll_config.guest_mac,
1152                poll_config.guest_ipv4.octets(),
1153                poll_config.gateway.ipv4.octets(),
1154            ))
1155            .unwrap();
1156        shared
1157            .tx_ring
1158            .push(build_icmpv4_echo_frame(
1159                poll_config.guest_mac,
1160                poll_config.gateway_mac,
1161                poll_config.guest_ipv4.octets(),
1162                [142, 251, 216, 46],
1163                0x1234,
1164                0xABCD,
1165                b"ping",
1166            ))
1167            .unwrap();
1168
1169        drive_one_frame(
1170            &mut device,
1171            &mut iface,
1172            &mut sockets,
1173            &shared,
1174            &poll_config,
1175            now,
1176        );
1177        let _ = shared.rx_ring.pop().expect("expected ARP reply");
1178
1179        drive_one_frame(
1180            &mut device,
1181            &mut iface,
1182            &mut sockets,
1183            &shared,
1184            &poll_config,
1185            now,
1186        );
1187        assert!(
1188            shared.rx_ring.pop().is_none(),
1189            "external ICMP should not be answered locally"
1190        );
1191    }
1192}