Skip to main content

microsandbox_network/
publisher.rs

1//! Published port handling: host-side listeners that forward connections
2//! into the guest VM via smoltcp.
3//!
4//! For each configured [`PublishedPort`], a tokio TCP listener or UDP socket
5//! binds on the host. TCP connections are queued for the poll loop to create
6//! smoltcp sockets into the guest. UDP datagrams are injected as guest-visible
7//! packets, and guest replies to active peers are sent back through the same
8//! host socket.
9
10use std::collections::HashMap;
11use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
12use std::sync::Arc;
13use std::sync::atomic::{AtomicU16, Ordering};
14use std::time::{Duration, Instant};
15
16use bytes::Bytes;
17use parking_lot::Mutex;
18use smoltcp::iface::{Interface, SocketHandle, SocketSet};
19use smoltcp::socket::tcp;
20use smoltcp::wire::{EthernetAddress, IpEndpoint};
21use tokio::io::{AsyncReadExt, AsyncWriteExt};
22use tokio::net::{TcpListener, TcpStream, UdpSocket};
23use tokio::sync::mpsc;
24
25use crate::config::{PortProtocol, PublishedPort};
26use crate::policy::{NetworkPolicy, Protocol};
27use crate::shared::SharedState;
28use crate::udp_relay::{construct_udp_response, extract_udp_payload};
29
30//--------------------------------------------------------------------------------------------------
31// Helpers
32//--------------------------------------------------------------------------------------------------
33
34/// Set zero-linger on a stream so the kernel sends a TCP RST instead of
35/// the default FIN close when the stream drops. Used for deliberate
36/// rejection paths (policy deny, max-inbound exhaustion,
37/// smoltcp-connect failure) so the peer sees `ECONNRESET` rather than
38/// a graceful close that looks like the server simply went away.
39///
40/// Goes through `socket2` rather than tokio's deprecated
41/// `TcpStream::set_linger` so the call site doesn't trip
42/// `#[deny(deprecated)]` in clippy. The cast to `SockRef` is
43/// zero-cost — it borrows the underlying fd.
44fn reject_with_rst(stream: &TcpStream) {
45    let _ = socket2::SockRef::from(stream).set_linger(Some(Duration::ZERO));
46}
47
48//--------------------------------------------------------------------------------------------------
49// Constants
50//--------------------------------------------------------------------------------------------------
51
52/// TCP socket buffer sizes for inbound connections.
53const TCP_RX_BUF_SIZE: usize = 65536;
54const TCP_TX_BUF_SIZE: usize = 65536;
55
56/// Channel capacity for relay tasks.
57const CHANNEL_CAPACITY: usize = 32;
58
59/// Buffer size for reading from host sockets.
60const RELAY_BUF_SIZE: usize = 16384;
61
62/// Buffer size for host-side UDP published-port sockets.
63const UDP_RELAY_BUF_SIZE: usize = 65535;
64
65/// Idle timeout for UDP peers that have contacted a published port.
66const UDP_PEER_TIMEOUT: Duration = Duration::from_secs(60);
67
68/// First ephemeral source port used to represent host UDP peers inside the guest.
69const UDP_EPHEMERAL_PORT_START: u16 = 49152;
70
71/// Number of usable ephemeral ports from [`UDP_EPHEMERAL_PORT_START`] through `u16::MAX`.
72const UDP_EPHEMERAL_PORT_COUNT: usize =
73    (u16::MAX as usize) - (UDP_EPHEMERAL_PORT_START as usize) + 1;
74
75//--------------------------------------------------------------------------------------------------
76// Types
77//--------------------------------------------------------------------------------------------------
78
79/// Manages published port listeners and inbound connections.
80///
81/// Spawns tokio listeners for each published port. When connections arrive,
82/// they are queued for the poll loop to create smoltcp sockets and initiate
83/// connections to the guest.
84pub struct PortPublisher {
85    /// Receives accepted connections from listener tasks.
86    inbound_rx: mpsc::Receiver<InboundConnection>,
87    /// Held to keep the channel open (listener tasks hold clones).
88    _inbound_tx: mpsc::Sender<InboundConnection>,
89    /// Tracked inbound connections (smoltcp socket → relay state).
90    connections: Vec<InboundRelay>,
91    /// Guest IP that inbound connections are dialed to. Prefers IPv4 (the
92    /// common case — most services bind `0.0.0.0` or dual-stack `::`, both
93    /// of which accept v4) and falls back to IPv6 for v6-only sandboxes.
94    /// `None` when neither family is active; listeners are not spawned.
95    guest_ip: Option<IpAddr>,
96    /// Guest IPv4, when active.
97    guest_ipv4: Option<Ipv4Addr>,
98    /// Guest IPv6, when active.
99    guest_ipv6: Option<Ipv6Addr>,
100    /// Ephemeral port counter.
101    ephemeral_port: Arc<AtomicU16>,
102    /// Maximum inbound connections (prevents resource exhaustion from host-side floods).
103    max_inbound: usize,
104    /// UDP published-port routes, keyed by guest-side port.
105    udp_routes: PublishedUdpRoutes,
106}
107
108/// An accepted host-side connection waiting to be wired to the guest.
109struct InboundConnection {
110    /// The accepted host-side TCP stream.
111    stream: TcpStream,
112    /// Guest port to connect to.
113    guest_port: u16,
114}
115
116/// Shared UDP published-port route table.
117type PublishedUdpRoutes = Arc<Mutex<HashMap<u16, Vec<PublishedUdpRoute>>>>;
118
119/// A host UDP socket that can send replies for active peers.
120struct PublishedUdpRoute {
121    /// Host bind address for diagnostics.
122    bind_addr: SocketAddr,
123    /// Send guest reply payloads to the UDP listener task.
124    outbound_tx: mpsc::Sender<PublishedUdpOutbound>,
125    /// NAT mappings for peers that recently sent datagrams to this published port.
126    peers: Arc<Mutex<PublishedUdpPeers>>,
127}
128
129/// Guest response payload for a host peer.
130struct PublishedUdpOutbound {
131    peer: SocketAddr,
132    payload: Bytes,
133}
134
135/// Active UDP peer NAT mappings for one published route.
136#[derive(Default)]
137struct PublishedUdpPeers {
138    host_to_guest: HashMap<SocketAddr, PublishedUdpPeer>,
139    guest_to_host: HashMap<SocketAddr, SocketAddr>,
140}
141
142/// One host peer as represented on the guest-side virtual network.
143struct PublishedUdpPeer {
144    guest_addr: SocketAddr,
145    last_seen: Instant,
146}
147
148/// Maximum number of poll iterations to attempt flushing remaining data
149/// after the relay task has exited before force-aborting the socket.
150const DEFERRED_CLOSE_LIMIT: u16 = 64;
151
152/// A single inbound connection relay (host socket ↔ smoltcp socket).
153struct InboundRelay {
154    handle: SocketHandle,
155    /// Send data from smoltcp socket to host relay task.
156    to_host: mpsc::Sender<Bytes>,
157    /// Receive data from host relay task to write to smoltcp socket.
158    from_host: mpsc::Receiver<Bytes>,
159    /// Partial data that couldn't be fully written to smoltcp socket.
160    write_buf: Option<(Bytes, usize)>,
161    /// Counter for deferred close attempts (prevents stalling forever).
162    close_attempts: u16,
163}
164
165#[derive(Debug, Clone, Copy, Eq, PartialEq)]
166enum BindExposure {
167    /// Listener is reachable only through host loopback.
168    Loopback,
169    /// Listener is reachable through every host interface in that address family.
170    Wildcard,
171    /// Listener is reachable through one non-loopback host interface address.
172    Interface,
173}
174
175//--------------------------------------------------------------------------------------------------
176// Methods
177//--------------------------------------------------------------------------------------------------
178
179impl PortPublisher {
180    /// Create a new publisher and spawn listeners for all published ports.
181    ///
182    /// Listeners are only spawned when at least one of `guest_ipv4` /
183    /// `guest_ipv6` is `Some`; published ports need a smoltcp dial target.
184    /// Each TCP listener task gates accepted connections through the
185    /// supplied [`NetworkPolicy`]'s `evaluate_ingress` before queuing
186    /// them; rejected connections drop with TCP RST (zero-linger) so
187    /// the peer observes `ECONNRESET`.
188    #[allow(clippy::too_many_arguments)]
189    pub fn new(
190        ports: &[PublishedPort],
191        guest_ipv4: Option<Ipv4Addr>,
192        guest_ipv6: Option<Ipv6Addr>,
193        gateway_ipv4: Option<Ipv4Addr>,
194        gateway_ipv6: Option<Ipv6Addr>,
195        gateway_mac: [u8; 6],
196        guest_mac: [u8; 6],
197        policy: Arc<NetworkPolicy>,
198        shared: Arc<SharedState>,
199        tokio_handle: &tokio::runtime::Handle,
200    ) -> Self {
201        let (inbound_tx, inbound_rx) = mpsc::channel(64);
202        let udp_routes = Arc::new(Mutex::new(HashMap::new()));
203        let ephemeral_port = Arc::new(AtomicU16::new(49152));
204
205        let guest_ip = guest_ipv4
206            .map(IpAddr::V4)
207            .or_else(|| guest_ipv6.map(IpAddr::V6));
208
209        if guest_ip.is_some() {
210            Self::spawn_listeners(
211                ports,
212                &inbound_tx,
213                udp_routes.clone(),
214                guest_ipv4,
215                guest_ipv6,
216                gateway_ipv4,
217                gateway_ipv6,
218                ephemeral_port.clone(),
219                gateway_mac,
220                guest_mac,
221                policy,
222                shared,
223                tokio_handle,
224            );
225        } else if !ports.is_empty() {
226            tracing::warn!(
227                count = ports.len(),
228                "skipping published port listeners: guest has no IPv4 or IPv6 address",
229            );
230        }
231
232        Self {
233            inbound_rx,
234            _inbound_tx: inbound_tx,
235            connections: Vec::new(),
236            guest_ip,
237            guest_ipv4,
238            guest_ipv6,
239            ephemeral_port,
240            max_inbound: 256,
241            udp_routes,
242        }
243    }
244
245    /// Accept queued inbound connections: create smoltcp sockets and
246    /// initiate connections to the guest.
247    ///
248    /// Must be called each poll iteration.
249    pub fn accept_inbound(
250        &mut self,
251        iface: &mut Interface,
252        sockets: &mut SocketSet<'_>,
253        shared: &Arc<SharedState>,
254        tokio_handle: &tokio::runtime::Handle,
255    ) {
256        // No guest IP means listeners weren't spawned; the channel is empty
257        // and there's nothing to do.
258        let Some(guest_ip) = self.guest_ip else {
259            return;
260        };
261
262        while let Ok(conn) = self.inbound_rx.try_recv() {
263            if self.connections.len() >= self.max_inbound {
264                tracing::debug!("published port: max inbound connections reached, rejecting");
265                reject_with_rst(&conn.stream);
266                continue;
267            }
268            // Create smoltcp TCP socket.
269            let rx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUF_SIZE]);
270            let tx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUF_SIZE]);
271            let mut socket = tcp::Socket::new(rx_buf, tx_buf);
272
273            // Connect to the guest.
274            let remote = IpEndpoint::new(guest_ip.into(), conn.guest_port);
275            let local_port = self.alloc_ephemeral_port();
276
277            if socket.connect(iface.context(), remote, local_port).is_err() {
278                tracing::debug!(
279                    guest_port = conn.guest_port,
280                    "failed to connect smoltcp socket to guest",
281                );
282                reject_with_rst(&conn.stream);
283                continue;
284            }
285
286            let handle = sockets.add(socket);
287
288            // Create channel pair for relay.
289            let (to_host_tx, to_host_rx) = mpsc::channel(CHANNEL_CAPACITY);
290            let (from_host_tx, from_host_rx) = mpsc::channel(CHANNEL_CAPACITY);
291
292            // Spawn relay task: host TcpStream ↔ channels.
293            let shared_clone = shared.clone();
294            tokio_handle.spawn(async move {
295                let _ =
296                    inbound_relay_task(conn.stream, to_host_rx, from_host_tx, shared_clone).await;
297            });
298
299            self.connections.push(InboundRelay {
300                handle,
301                to_host: to_host_tx,
302                from_host: from_host_rx,
303                write_buf: None,
304                close_attempts: 0,
305            });
306        }
307    }
308
309    /// Relay data between smoltcp sockets and host relay tasks.
310    pub fn relay_data(&mut self, sockets: &mut SocketSet<'_>) {
311        let mut relay_buf = [0u8; RELAY_BUF_SIZE];
312
313        for relay in &mut self.connections {
314            let socket = sockets.get_mut::<tcp::Socket>(relay.handle);
315
316            // Detect relay task exit — close the smoltcp socket.
317            if relay.to_host.is_closed() {
318                write_host_data(socket, relay);
319                if relay.write_buf.is_none() {
320                    socket.close();
321                } else {
322                    // Abort if we've been trying to flush for too long
323                    // (guest stopped reading, socket send buffer full).
324                    relay.close_attempts += 1;
325                    if relay.close_attempts >= DEFERRED_CLOSE_LIMIT {
326                        socket.abort();
327                    }
328                }
329                continue;
330            }
331
332            // smoltcp → host: read from socket, send via channel.
333            while socket.can_recv() {
334                match socket.recv_slice(&mut relay_buf) {
335                    Ok(n) if n > 0 => {
336                        let data = Bytes::copy_from_slice(&relay_buf[..n]);
337                        if relay.to_host.try_send(data).is_err() {
338                            break;
339                        }
340                    }
341                    _ => break,
342                }
343            }
344
345            // host → smoltcp: write pending data, then drain channel.
346            write_host_data(socket, relay);
347        }
348    }
349
350    /// Relay a guest UDP datagram to a host peer that recently sent traffic
351    /// to a UDP published port.
352    ///
353    /// Returns `true` when the frame belongs to a published-port flow and
354    /// should be consumed by the caller.
355    pub fn relay_udp_outbound(&self, frame: &[u8], src: SocketAddr, dst: SocketAddr) -> bool {
356        if !self.is_guest_ip(src.ip()) {
357            return false;
358        }
359
360        let Some(payload) = extract_udp_payload(frame) else {
361            return false;
362        };
363
364        let routes = self.udp_routes.lock();
365        let Some(routes) = routes.get(&src.port()) else {
366            return false;
367        };
368
369        let now = Instant::now();
370        for route in routes {
371            let mut peers = route.peers.lock();
372            cleanup_udp_peer_mappings(&mut peers, now);
373            let Some(peer) = peers.guest_to_host.get(&dst).copied() else {
374                continue;
375            };
376            drop(peers);
377
378            let outbound = PublishedUdpOutbound {
379                peer,
380                payload: Bytes::copy_from_slice(payload),
381            };
382            if route.outbound_tx.try_send(outbound).is_err() {
383                tracing::debug!(
384                    bind = %route.bind_addr,
385                    peer = %peer,
386                    "published UDP reply dropped because outbound queue is unavailable",
387                );
388            }
389            return true;
390        }
391
392        false
393    }
394
395    /// Remove closed inbound connections.
396    ///
397    /// Only removes sockets in `Closed` state. Sockets in `TimeWait` are
398    /// left for smoltcp's 2*MSL timer to handle naturally.
399    pub fn cleanup_closed(&mut self, sockets: &mut SocketSet<'_>) {
400        self.connections.retain(|relay| {
401            let socket = sockets.get::<tcp::Socket>(relay.handle);
402            let closed = matches!(socket.state(), tcp::State::Closed);
403            if closed {
404                sockets.remove(relay.handle);
405            }
406            !closed
407        });
408        self.cleanup_udp_peers();
409    }
410
411    /// Spawn one tokio listener task per TCP published port.
412    #[allow(clippy::too_many_arguments)]
413    fn spawn_listeners(
414        ports: &[PublishedPort],
415        inbound_tx: &mpsc::Sender<InboundConnection>,
416        udp_routes: PublishedUdpRoutes,
417        guest_ipv4: Option<Ipv4Addr>,
418        guest_ipv6: Option<Ipv6Addr>,
419        gateway_ipv4: Option<Ipv4Addr>,
420        gateway_ipv6: Option<Ipv6Addr>,
421        ephemeral_port: Arc<AtomicU16>,
422        gateway_mac: [u8; 6],
423        guest_mac: [u8; 6],
424        policy: Arc<NetworkPolicy>,
425        shared: Arc<SharedState>,
426        tokio_handle: &tokio::runtime::Handle,
427    ) {
428        for port in ports {
429            let bind_addr = SocketAddr::new(port.host_bind, port.host_port);
430            let guest_port = port.guest_port;
431
432            match port.protocol {
433                PortProtocol::Tcp => {
434                    let tx = inbound_tx.clone();
435                    let policy = policy.clone();
436                    let shared = shared.clone();
437                    tokio_handle.spawn(async move {
438                        if let Err(e) =
439                            tcp_listener_task(bind_addr, guest_port, tx, policy, shared).await
440                        {
441                            tracing::error!(
442                                bind = %bind_addr,
443                                error = %e,
444                                "published TCP port listener failed",
445                            );
446                        }
447                    });
448                }
449                PortProtocol::Udp => {
450                    let Some((guest_ip, gateway_ip)) = udp_ips_for_bind(
451                        port.host_bind,
452                        guest_ipv4,
453                        guest_ipv6,
454                        gateway_ipv4,
455                        gateway_ipv6,
456                    ) else {
457                        tracing::warn!(
458                            bind = %bind_addr,
459                            guest_port,
460                            "skipping UDP published port: guest has no matching gateway/guest IP family",
461                        );
462                        continue;
463                    };
464
465                    let (outbound_tx, outbound_rx) = mpsc::channel(CHANNEL_CAPACITY);
466                    let peers = Arc::new(Mutex::new(PublishedUdpPeers::default()));
467                    udp_routes
468                        .lock()
469                        .entry(guest_port)
470                        .or_default()
471                        .push(PublishedUdpRoute {
472                            bind_addr,
473                            outbound_tx,
474                            peers: peers.clone(),
475                        });
476
477                    let policy = policy.clone();
478                    let shared = shared.clone();
479                    let ephemeral_port = ephemeral_port.clone();
480                    tokio_handle.spawn(async move {
481                        if let Err(e) = udp_listener_task(
482                            bind_addr,
483                            guest_ip,
484                            gateway_ip,
485                            guest_port,
486                            outbound_rx,
487                            peers,
488                            ephemeral_port.clone(),
489                            policy,
490                            shared,
491                            EthernetAddress(gateway_mac),
492                            EthernetAddress(guest_mac),
493                        )
494                        .await
495                        {
496                            tracing::error!(
497                                bind = %bind_addr,
498                                error = %e,
499                                "published UDP port listener failed",
500                            );
501                        }
502                    });
503                }
504            }
505        }
506    }
507
508    fn alloc_ephemeral_port(&self) -> u16 {
509        loop {
510            let port = self.ephemeral_port.fetch_add(1, Ordering::Relaxed);
511            // Wrap around in the ephemeral range.
512            if port == 0 || port < UDP_EPHEMERAL_PORT_START {
513                self.ephemeral_port
514                    .store(UDP_EPHEMERAL_PORT_START, Ordering::Relaxed);
515                continue;
516            }
517            return port;
518        }
519    }
520
521    fn cleanup_udp_peers(&self) {
522        let now = Instant::now();
523        for routes in self.udp_routes.lock().values() {
524            for route in routes {
525                cleanup_udp_peer_mappings(&mut route.peers.lock(), now);
526            }
527        }
528    }
529
530    fn is_guest_ip(&self, ip: IpAddr) -> bool {
531        match ip {
532            IpAddr::V4(ip) => self.guest_ipv4 == Some(ip),
533            IpAddr::V6(ip) => self.guest_ipv6 == Some(ip),
534        }
535    }
536}
537
538//--------------------------------------------------------------------------------------------------
539// Functions
540//--------------------------------------------------------------------------------------------------
541
542/// Listener task: accepts TCP connections on the host, runs each
543/// through the network policy's ingress evaluator, and queues
544/// allowed connections for the publisher's accept loop. Denied
545/// connections are dropped with TCP RST (zero-linger) so the peer
546/// sees `ECONNRESET` rather than a graceful close.
547async fn tcp_listener_task(
548    bind_addr: SocketAddr,
549    guest_port: u16,
550    inbound_tx: mpsc::Sender<InboundConnection>,
551    policy: Arc<NetworkPolicy>,
552    shared: Arc<SharedState>,
553) -> std::io::Result<()> {
554    let listener = TcpListener::bind(bind_addr).await?;
555    log_published_port_listener("TCP", bind_addr, guest_port);
556
557    loop {
558        let (stream, peer) = listener.accept().await?;
559
560        // Policy gate: peer source IP and the guest's listening port.
561        let action = policy.evaluate_ingress(peer, guest_port, Protocol::Tcp, &shared);
562        if action.is_deny() {
563            tracing::debug!(
564                peer = %peer,
565                guest_port,
566                "ingress denied by policy; sending RST",
567            );
568            reject_with_rst(&stream);
569            drop(stream);
570            continue;
571        }
572
573        let conn = InboundConnection { stream, guest_port };
574        if !queue_inbound_connection(&inbound_tx, conn, &shared).await {
575            break; // Publisher dropped.
576        }
577    }
578
579    Ok(())
580}
581
582/// UDP listener task: receives host datagrams, injects them into the guest,
583/// and sends guest replies back to active peers through the same socket.
584#[allow(clippy::too_many_arguments)]
585async fn udp_listener_task(
586    bind_addr: SocketAddr,
587    guest_ip: IpAddr,
588    gateway_ip: IpAddr,
589    guest_port: u16,
590    mut outbound_rx: mpsc::Receiver<PublishedUdpOutbound>,
591    peers: Arc<Mutex<PublishedUdpPeers>>,
592    ephemeral_port: Arc<AtomicU16>,
593    policy: Arc<NetworkPolicy>,
594    shared: Arc<SharedState>,
595    gateway_mac: EthernetAddress,
596    guest_mac: EthernetAddress,
597) -> std::io::Result<()> {
598    let socket = UdpSocket::bind(bind_addr).await?;
599    log_published_port_listener("UDP", bind_addr, guest_port);
600
601    let mut buf = vec![0u8; UDP_RELAY_BUF_SIZE];
602    loop {
603        tokio::select! {
604            inbound = socket.recv_from(&mut buf) => {
605                let (n, peer) = inbound?;
606                let action = policy.evaluate_ingress(peer, guest_port, Protocol::Udp, &shared);
607                if action.is_deny() {
608                    tracing::debug!(
609                        peer = %peer,
610                        guest_port,
611                        "UDP ingress denied by policy",
612                    );
613                    continue;
614                }
615
616                let Some(guest_peer) =
617                    resolve_udp_guest_peer(peer, gateway_ip, &peers, &ephemeral_port)
618                else {
619                    tracing::debug!(
620                        peer = %peer,
621                        guest_port,
622                        "UDP ingress dropped because published-port peer table is full",
623                    );
624                    continue;
625                };
626                inject_udp_datagram_to_guest(
627                    guest_peer,
628                    SocketAddr::new(guest_ip, guest_port),
629                    &buf[..n],
630                    &shared,
631                    gateway_mac,
632                    guest_mac,
633                );
634            }
635            outbound = outbound_rx.recv() => {
636                let Some(outbound) = outbound else {
637                    break;
638                };
639                if let Err(e) = socket.send_to(&outbound.payload, outbound.peer).await {
640                    tracing::debug!(
641                        peer = %outbound.peer,
642                        error = %e,
643                        "published UDP send to host peer failed",
644                    );
645                }
646            }
647        }
648    }
649
650    Ok(())
651}
652
653fn log_published_port_listener(protocol: &'static str, bind_addr: SocketAddr, guest_port: u16) {
654    match bind_exposure(bind_addr.ip()) {
655        BindExposure::Loopback => {
656            tracing::debug!(
657                protocol,
658                bind = %bind_addr,
659                guest_port,
660                "published port listener started on host loopback",
661            );
662        }
663        BindExposure::Wildcard => {
664            tracing::warn!(
665                protocol,
666                bind = %bind_addr,
667                guest_port,
668                windows_firewall_prompt = cfg!(windows),
669                "published port is listening on all host interfaces",
670            );
671        }
672        BindExposure::Interface => {
673            tracing::warn!(
674                protocol,
675                bind = %bind_addr,
676                guest_port,
677                windows_firewall_prompt = cfg!(windows),
678                "published port is listening on a non-loopback host interface",
679            );
680        }
681    }
682}
683
684fn bind_exposure(ip: IpAddr) -> BindExposure {
685    if ip.is_loopback() {
686        BindExposure::Loopback
687    } else if ip.is_unspecified() {
688        BindExposure::Wildcard
689    } else {
690        BindExposure::Interface
691    }
692}
693
694async fn queue_inbound_connection<T>(
695    inbound_tx: &mpsc::Sender<T>,
696    conn: T,
697    shared: &SharedState,
698) -> bool {
699    if inbound_tx.send(conn).await.is_err() {
700        return false;
701    }
702
703    shared.proxy_wake.wake();
704    true
705}
706
707fn udp_ips_for_bind(
708    host_bind: IpAddr,
709    guest_ipv4: Option<Ipv4Addr>,
710    guest_ipv6: Option<Ipv6Addr>,
711    gateway_ipv4: Option<Ipv4Addr>,
712    gateway_ipv6: Option<Ipv6Addr>,
713) -> Option<(IpAddr, IpAddr)> {
714    match host_bind {
715        IpAddr::V4(_) => Some((IpAddr::V4(guest_ipv4?), IpAddr::V4(gateway_ipv4?))),
716        IpAddr::V6(_) => Some((IpAddr::V6(guest_ipv6?), IpAddr::V6(gateway_ipv6?))),
717    }
718}
719
720fn resolve_udp_guest_peer(
721    host_peer: SocketAddr,
722    gateway_ip: IpAddr,
723    peers: &Arc<Mutex<PublishedUdpPeers>>,
724    ephemeral_port: &AtomicU16,
725) -> Option<SocketAddr> {
726    let now = Instant::now();
727    let mut peers = peers.lock();
728    cleanup_udp_peer_mappings(&mut peers, now);
729
730    if let Some(peer) = peers.host_to_guest.get_mut(&host_peer) {
731        peer.last_seen = now;
732        return Some(peer.guest_addr);
733    }
734
735    let guest_addr = (0..UDP_EPHEMERAL_PORT_COUNT).find_map(|_| {
736        let candidate = SocketAddr::new(gateway_ip, next_ephemeral_port(ephemeral_port));
737        if !peers.guest_to_host.contains_key(&candidate) {
738            Some(candidate)
739        } else {
740            None
741        }
742    })?;
743
744    peers.host_to_guest.insert(
745        host_peer,
746        PublishedUdpPeer {
747            guest_addr,
748            last_seen: now,
749        },
750    );
751    peers.guest_to_host.insert(guest_addr, host_peer);
752    Some(guest_addr)
753}
754
755fn cleanup_udp_peer_mappings(peers: &mut PublishedUdpPeers, now: Instant) {
756    peers
757        .host_to_guest
758        .retain(|_, peer| now.duration_since(peer.last_seen) <= UDP_PEER_TIMEOUT);
759    let host_to_guest = &peers.host_to_guest;
760    peers
761        .guest_to_host
762        .retain(|_, host_peer| host_to_guest.contains_key(host_peer));
763}
764
765fn next_ephemeral_port(ephemeral_port: &AtomicU16) -> u16 {
766    loop {
767        let port = ephemeral_port.fetch_add(1, Ordering::Relaxed);
768        if port == 0 || port < UDP_EPHEMERAL_PORT_START {
769            ephemeral_port.store(UDP_EPHEMERAL_PORT_START, Ordering::Relaxed);
770            continue;
771        }
772        return port;
773    }
774}
775
776fn inject_udp_datagram_to_guest(
777    peer: SocketAddr,
778    guest_dst: SocketAddr,
779    payload: &[u8],
780    shared: &SharedState,
781    gateway_mac: EthernetAddress,
782    guest_mac: EthernetAddress,
783) {
784    let Some(frame) = construct_udp_response(peer, guest_dst, payload, gateway_mac, guest_mac)
785    else {
786        tracing::debug!(
787            peer = %peer,
788            guest = %guest_dst,
789            "published UDP datagram dropped because address families differ",
790        );
791        return;
792    };
793
794    if !shared.push_rx_frame_and_wake(frame) {
795        tracing::debug!("published UDP datagram dropped because rx_ring is full");
796    }
797}
798
799/// Relay task: bridges a host TcpStream to channels connected to smoltcp.
800async fn inbound_relay_task(
801    stream: TcpStream,
802    mut to_host_rx: mpsc::Receiver<Bytes>,
803    from_host_tx: mpsc::Sender<Bytes>,
804    shared: Arc<SharedState>,
805) -> std::io::Result<()> {
806    let (mut rx, mut tx) = stream.into_split();
807    let mut buf = vec![0u8; RELAY_BUF_SIZE];
808
809    loop {
810        tokio::select! {
811            // smoltcp → host: data from guest arrives via channel.
812            data = to_host_rx.recv() => {
813                match data {
814                    Some(bytes) => {
815                        // Wake as soon as recv frees channel capacity. Waiting
816                        // for write_all can stall the poll loop behind a slow
817                        // host client.
818                        shared.proxy_wake.wake();
819                        if let Err(e) = tx.write_all(&bytes).await {
820                            tracing::debug!(error = %e, "write to host client failed");
821                            break;
822                        }
823                    }
824                    None => break,
825                }
826            }
827
828            // host → smoltcp: data from host client to write to guest.
829            result = rx.read(&mut buf) => {
830                match result {
831                    Ok(0) => break,
832                    Ok(n) => {
833                        let data = Bytes::copy_from_slice(&buf[..n]);
834                        if from_host_tx.send(data).await.is_err() {
835                            break;
836                        }
837                        shared.proxy_wake.wake();
838                    }
839                    Err(e) => {
840                        tracing::debug!(error = %e, "read from host client failed");
841                        break;
842                    }
843                }
844            }
845        }
846    }
847
848    Ok(())
849}
850
851/// Write data from the host relay channel to the smoltcp socket.
852fn write_host_data(socket: &mut tcp::Socket<'_>, relay: &mut InboundRelay) {
853    // First, try to finish writing any pending partial data.
854    if let Some((data, offset)) = &mut relay.write_buf {
855        if socket.can_send() {
856            match socket.send_slice(&data[*offset..]) {
857                Ok(written) => {
858                    *offset += written;
859                    if *offset >= data.len() {
860                        relay.write_buf = None;
861                    }
862                }
863                Err(_) => return,
864            }
865        } else {
866            return;
867        }
868    }
869
870    // Then drain the channel.
871    while relay.write_buf.is_none() {
872        match relay.from_host.try_recv() {
873            Ok(data) => {
874                if socket.can_send() {
875                    match socket.send_slice(&data) {
876                        Ok(written) if written < data.len() => {
877                            relay.write_buf = Some((data, written));
878                        }
879                        Err(_) => {
880                            relay.write_buf = Some((data, 0));
881                        }
882                        _ => {}
883                    }
884                } else {
885                    relay.write_buf = Some((data, 0));
886                }
887            }
888            Err(_) => break,
889        }
890    }
891}
892
893#[cfg(test)]
894mod tests {
895    use super::*;
896
897    #[tokio::test]
898    async fn queue_inbound_connection_wakes_poll_loop() {
899        let shared = SharedState::new(4);
900        shared.proxy_wake.drain();
901
902        let (tx, mut rx) = mpsc::channel(1);
903
904        assert!(queue_inbound_connection(&tx, (), &shared).await);
905        assert!(rx.try_recv().is_ok());
906        assert!(shared.proxy_wake.wait_timeout(Duration::ZERO));
907    }
908
909    #[tokio::test]
910    async fn inbound_relay_wakes_when_to_host_channel_slot_is_freed() {
911        let shared = Arc::new(SharedState::new(4));
912        shared.proxy_wake.drain();
913
914        let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
915            .await
916            .unwrap();
917        let addr = listener.local_addr().unwrap();
918        let client = tokio::spawn(TcpStream::connect(addr));
919        let (server_stream, _) = listener.accept().await.unwrap();
920        let client = client.await.unwrap().unwrap();
921
922        socket2::SockRef::from(&server_stream)
923            .set_send_buffer_size(4096)
924            .unwrap();
925
926        let (to_host_tx, to_host_rx) = mpsc::channel(1);
927        let (from_host_tx, _from_host_rx) = mpsc::channel(1);
928        let task = tokio::spawn(inbound_relay_task(
929            server_stream,
930            to_host_rx,
931            from_host_tx,
932            shared.clone(),
933        ));
934
935        to_host_tx
936            .send(Bytes::from(vec![b'a'; 64 * 1024 * 1024]))
937            .await
938            .unwrap();
939
940        tokio::time::timeout(
941            Duration::from_secs(1),
942            to_host_tx.send(Bytes::from_static(b"next")),
943        )
944        .await
945        .unwrap()
946        .unwrap();
947
948        assert!(shared.proxy_wake.wait_timeout(Duration::ZERO));
949
950        drop(client);
951        drop(to_host_tx);
952        task.abort();
953        let _ = task.await;
954    }
955
956    #[test]
957    fn inject_udp_datagram_to_guest_counts_rx_bytes() {
958        let shared = SharedState::new(4);
959        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)), 50000);
960        let guest = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 2)), 5353);
961
962        inject_udp_datagram_to_guest(
963            peer,
964            guest,
965            b"hello",
966            &shared,
967            EthernetAddress([0x02, 0, 0, 0, 0, 1]),
968            EthernetAddress([0x02, 0, 0, 0, 0, 2]),
969        );
970
971        let frame = shared.rx_ring.pop().expect("published UDP frame");
972        assert_eq!(shared.rx_bytes(), frame.len() as u64);
973    }
974
975    #[test]
976    fn relay_udp_outbound_queues_reply_for_active_peer() {
977        let (inbound_tx, inbound_rx) = mpsc::channel(1);
978        let (outbound_tx, mut outbound_rx) = mpsc::channel(1);
979        let routes = Arc::new(Mutex::new(HashMap::new()));
980        let peers = Arc::new(Mutex::new(PublishedUdpPeers::default()));
981        let guest_ip = Ipv4Addr::new(172, 16, 0, 2);
982        let host_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 50000);
983        let guest_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)), 49152);
984
985        {
986            let mut peers = peers.lock();
987            peers.host_to_guest.insert(
988                host_peer,
989                PublishedUdpPeer {
990                    guest_addr: guest_peer,
991                    last_seen: Instant::now(),
992                },
993            );
994            peers.guest_to_host.insert(guest_peer, host_peer);
995        }
996        routes.lock().insert(
997            5353,
998            vec![PublishedUdpRoute {
999                bind_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5353),
1000                outbound_tx,
1001                peers,
1002            }],
1003        );
1004
1005        let publisher = PortPublisher {
1006            inbound_rx,
1007            _inbound_tx: inbound_tx,
1008            connections: Vec::new(),
1009            guest_ip: Some(IpAddr::V4(guest_ip)),
1010            guest_ipv4: Some(guest_ip),
1011            guest_ipv6: None,
1012            ephemeral_port: Arc::new(AtomicU16::new(49152)),
1013            max_inbound: 256,
1014            udp_routes: routes,
1015        };
1016        let src = SocketAddr::new(IpAddr::V4(guest_ip), 5353);
1017        let frame = construct_udp_response(
1018            src,
1019            guest_peer,
1020            b"pong",
1021            EthernetAddress([0x02, 0, 0, 0, 0, 1]),
1022            EthernetAddress([0x02, 0, 0, 0, 0, 2]),
1023        )
1024        .unwrap();
1025
1026        assert!(publisher.relay_udp_outbound(&frame, src, guest_peer));
1027        let outbound = outbound_rx.try_recv().unwrap();
1028        assert_eq!(outbound.peer, host_peer);
1029        assert_eq!(outbound.payload.as_ref(), b"pong");
1030    }
1031
1032    #[test]
1033    fn relay_udp_outbound_ignores_inactive_peer() {
1034        let (inbound_tx, inbound_rx) = mpsc::channel(1);
1035        let (outbound_tx, _outbound_rx) = mpsc::channel(1);
1036        let routes = Arc::new(Mutex::new(HashMap::new()));
1037        let guest_ip = Ipv4Addr::new(172, 16, 0, 2);
1038        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 50000);
1039
1040        routes.lock().insert(
1041            5353,
1042            vec![PublishedUdpRoute {
1043                bind_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5353),
1044                outbound_tx,
1045                peers: Arc::new(Mutex::new(PublishedUdpPeers::default())),
1046            }],
1047        );
1048
1049        let publisher = PortPublisher {
1050            inbound_rx,
1051            _inbound_tx: inbound_tx,
1052            connections: Vec::new(),
1053            guest_ip: Some(IpAddr::V4(guest_ip)),
1054            guest_ipv4: Some(guest_ip),
1055            guest_ipv6: None,
1056            ephemeral_port: Arc::new(AtomicU16::new(49152)),
1057            max_inbound: 256,
1058            udp_routes: routes,
1059        };
1060        let src = SocketAddr::new(IpAddr::V4(guest_ip), 5353);
1061        let frame = construct_udp_response(
1062            src,
1063            peer,
1064            b"pong",
1065            EthernetAddress([0x02, 0, 0, 0, 0, 1]),
1066            EthernetAddress([0x02, 0, 0, 0, 0, 2]),
1067        )
1068        .unwrap();
1069
1070        assert!(!publisher.relay_udp_outbound(&frame, src, peer));
1071    }
1072
1073    #[test]
1074    fn resolve_udp_guest_peer_returns_none_when_ephemeral_ports_exhausted() {
1075        let peers = Arc::new(Mutex::new(PublishedUdpPeers::default()));
1076        let gateway_ip = IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1));
1077        let now = Instant::now();
1078
1079        {
1080            let mut peers = peers.lock();
1081            for port in UDP_EPHEMERAL_PORT_START..=u16::MAX {
1082                let host_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
1083                let guest_addr = SocketAddr::new(gateway_ip, port);
1084                peers.host_to_guest.insert(
1085                    host_peer,
1086                    PublishedUdpPeer {
1087                        guest_addr,
1088                        last_seen: now,
1089                    },
1090                );
1091                peers.guest_to_host.insert(guest_addr, host_peer);
1092            }
1093        }
1094
1095        let next = resolve_udp_guest_peer(
1096            SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 40000),
1097            gateway_ip,
1098            &peers,
1099            &AtomicU16::new(UDP_EPHEMERAL_PORT_START),
1100        );
1101
1102        assert!(next.is_none());
1103    }
1104
1105    #[test]
1106    fn bind_exposure_keeps_loopback_distinct_from_lan_binds() {
1107        assert_eq!(
1108            bind_exposure(IpAddr::V4(Ipv4Addr::LOCALHOST)),
1109            BindExposure::Loopback
1110        );
1111        assert_eq!(
1112            bind_exposure(IpAddr::V6(Ipv6Addr::LOCALHOST)),
1113            BindExposure::Loopback
1114        );
1115        assert_eq!(
1116            bind_exposure(IpAddr::V4(Ipv4Addr::UNSPECIFIED)),
1117            BindExposure::Wildcard
1118        );
1119        assert_eq!(
1120            bind_exposure(IpAddr::V6(Ipv6Addr::UNSPECIFIED)),
1121            BindExposure::Wildcard
1122        );
1123        assert_eq!(
1124            bind_exposure(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 10))),
1125            BindExposure::Interface
1126        );
1127    }
1128}