Skip to main content

microsandbox_network/
publisher.rs

1//! Published port handling: host-side listeners that forward connections
2//! into the guest VM via smoltcp.
3//!
4//! For each configured [`PublishedPort`], a tokio TCP or UDP listener binds
5//! on the host. When a connection arrives, the poll loop creates a smoltcp
6//! socket that connects to the guest, and a relay task bridges the host
7//! socket to the smoltcp socket via channels.
8
9use std::net::{IpAddr, Ipv4Addr, SocketAddr};
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU16, Ordering};
12
13use bytes::Bytes;
14use smoltcp::iface::{Interface, SocketHandle, SocketSet};
15use smoltcp::socket::tcp;
16use smoltcp::wire::IpEndpoint;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{TcpListener, TcpStream};
19use tokio::sync::mpsc;
20
21use crate::config::{PortProtocol, PublishedPort};
22use crate::shared::SharedState;
23
24//--------------------------------------------------------------------------------------------------
25// Constants
26//--------------------------------------------------------------------------------------------------
27
28/// TCP socket buffer sizes for inbound connections.
29const TCP_RX_BUF_SIZE: usize = 65536;
30const TCP_TX_BUF_SIZE: usize = 65536;
31
32/// Channel capacity for relay tasks.
33const CHANNEL_CAPACITY: usize = 32;
34
35/// Buffer size for reading from host sockets.
36const RELAY_BUF_SIZE: usize = 16384;
37
38//--------------------------------------------------------------------------------------------------
39// Types
40//--------------------------------------------------------------------------------------------------
41
42/// Manages published port listeners and inbound connections.
43///
44/// Spawns tokio listeners for each published port. When connections arrive,
45/// they are queued for the poll loop to create smoltcp sockets and initiate
46/// connections to the guest.
47pub struct PortPublisher {
48    /// Receives accepted connections from listener tasks.
49    inbound_rx: mpsc::Receiver<InboundConnection>,
50    /// Held to keep the channel open (listener tasks hold clones).
51    _inbound_tx: mpsc::Sender<InboundConnection>,
52    /// Tracked inbound connections (smoltcp socket → relay state).
53    connections: Vec<InboundRelay>,
54    /// Guest IPv4 address (for smoltcp connect target).
55    guest_ipv4: Ipv4Addr,
56    /// Ephemeral port counter.
57    ephemeral_port: Arc<AtomicU16>,
58    /// Maximum inbound connections (prevents resource exhaustion from host-side floods).
59    max_inbound: usize,
60}
61
62/// An accepted host-side connection waiting to be wired to the guest.
63struct InboundConnection {
64    /// The accepted host-side TCP stream.
65    stream: TcpStream,
66    /// Guest port to connect to.
67    guest_port: u16,
68}
69
70/// Maximum number of poll iterations to attempt flushing remaining data
71/// after the relay task has exited before force-aborting the socket.
72const DEFERRED_CLOSE_LIMIT: u16 = 64;
73
74/// A single inbound connection relay (host socket ↔ smoltcp socket).
75struct InboundRelay {
76    handle: SocketHandle,
77    /// Send data from smoltcp socket to host relay task.
78    to_host: mpsc::Sender<Bytes>,
79    /// Receive data from host relay task to write to smoltcp socket.
80    from_host: mpsc::Receiver<Bytes>,
81    /// Partial data that couldn't be fully written to smoltcp socket.
82    write_buf: Option<(Bytes, usize)>,
83    /// Counter for deferred close attempts (prevents stalling forever).
84    close_attempts: u16,
85}
86
87//--------------------------------------------------------------------------------------------------
88// Methods
89//--------------------------------------------------------------------------------------------------
90
91impl PortPublisher {
92    /// Create a new publisher and spawn listeners for all published ports.
93    pub fn new(
94        ports: &[PublishedPort],
95        guest_ipv4: Ipv4Addr,
96        tokio_handle: &tokio::runtime::Handle,
97    ) -> Self {
98        let (inbound_tx, inbound_rx) = mpsc::channel(64);
99
100        // Spawn a listener for each published TCP port.
101        for port in ports {
102            if port.protocol == PortProtocol::Tcp {
103                let tx = inbound_tx.clone();
104                let bind_addr = SocketAddr::new(port.host_bind, port.host_port);
105                let guest_port = port.guest_port;
106                tokio_handle.spawn(async move {
107                    if let Err(e) = tcp_listener_task(bind_addr, guest_port, tx).await {
108                        tracing::error!(
109                            bind = %bind_addr,
110                            error = %e,
111                            "published port listener failed",
112                        );
113                    }
114                });
115            }
116            // TODO: UDP published ports.
117        }
118
119        Self {
120            inbound_rx,
121            _inbound_tx: inbound_tx,
122            connections: Vec::new(),
123            guest_ipv4,
124            ephemeral_port: Arc::new(AtomicU16::new(49152)),
125            max_inbound: 256,
126        }
127    }
128
129    /// Accept queued inbound connections: create smoltcp sockets and
130    /// initiate connections to the guest.
131    ///
132    /// Must be called each poll iteration.
133    pub fn accept_inbound(
134        &mut self,
135        iface: &mut Interface,
136        sockets: &mut SocketSet<'_>,
137        shared: &Arc<SharedState>,
138        tokio_handle: &tokio::runtime::Handle,
139    ) {
140        while let Ok(conn) = self.inbound_rx.try_recv() {
141            if self.connections.len() >= self.max_inbound {
142                tracing::debug!("published port: max inbound connections reached, rejecting");
143                continue;
144            }
145            // Create smoltcp TCP socket.
146            let rx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUF_SIZE]);
147            let tx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUF_SIZE]);
148            let mut socket = tcp::Socket::new(rx_buf, tx_buf);
149
150            // Connect to the guest.
151            let remote = IpEndpoint::new(IpAddr::V4(self.guest_ipv4).into(), conn.guest_port);
152            let local_port = self.alloc_ephemeral_port();
153
154            if socket.connect(iface.context(), remote, local_port).is_err() {
155                tracing::debug!(
156                    guest_port = conn.guest_port,
157                    "failed to connect smoltcp socket to guest",
158                );
159                continue;
160            }
161
162            let handle = sockets.add(socket);
163
164            // Create channel pair for relay.
165            let (to_host_tx, to_host_rx) = mpsc::channel(CHANNEL_CAPACITY);
166            let (from_host_tx, from_host_rx) = mpsc::channel(CHANNEL_CAPACITY);
167
168            // Spawn relay task: host TcpStream ↔ channels.
169            let shared_clone = shared.clone();
170            tokio_handle.spawn(async move {
171                let _ =
172                    inbound_relay_task(conn.stream, to_host_rx, from_host_tx, shared_clone).await;
173            });
174
175            self.connections.push(InboundRelay {
176                handle,
177                to_host: to_host_tx,
178                from_host: from_host_rx,
179                write_buf: None,
180                close_attempts: 0,
181            });
182        }
183    }
184
185    /// Relay data between smoltcp sockets and host relay tasks.
186    pub fn relay_data(&mut self, sockets: &mut SocketSet<'_>) {
187        let mut relay_buf = [0u8; RELAY_BUF_SIZE];
188
189        for relay in &mut self.connections {
190            let socket = sockets.get_mut::<tcp::Socket>(relay.handle);
191
192            // Detect relay task exit — close the smoltcp socket.
193            if relay.to_host.is_closed() {
194                write_host_data(socket, relay);
195                if relay.write_buf.is_none() {
196                    socket.close();
197                } else {
198                    // Abort if we've been trying to flush for too long
199                    // (guest stopped reading, socket send buffer full).
200                    relay.close_attempts += 1;
201                    if relay.close_attempts >= DEFERRED_CLOSE_LIMIT {
202                        socket.abort();
203                    }
204                }
205                continue;
206            }
207
208            // smoltcp → host: read from socket, send via channel.
209            while socket.can_recv() {
210                match socket.recv_slice(&mut relay_buf) {
211                    Ok(n) if n > 0 => {
212                        let data = Bytes::copy_from_slice(&relay_buf[..n]);
213                        if relay.to_host.try_send(data).is_err() {
214                            break;
215                        }
216                    }
217                    _ => break,
218                }
219            }
220
221            // host → smoltcp: write pending data, then drain channel.
222            write_host_data(socket, relay);
223        }
224    }
225
226    /// Remove closed inbound connections.
227    ///
228    /// Only removes sockets in `Closed` state. Sockets in `TimeWait` are
229    /// left for smoltcp's 2*MSL timer to handle naturally.
230    pub fn cleanup_closed(&mut self, sockets: &mut SocketSet<'_>) {
231        self.connections.retain(|relay| {
232            let socket = sockets.get::<tcp::Socket>(relay.handle);
233            let closed = matches!(socket.state(), tcp::State::Closed);
234            if closed {
235                sockets.remove(relay.handle);
236            }
237            !closed
238        });
239    }
240}
241
242impl PortPublisher {
243    fn alloc_ephemeral_port(&self) -> u16 {
244        loop {
245            let port = self.ephemeral_port.fetch_add(1, Ordering::Relaxed);
246            // Wrap around in the ephemeral range.
247            if port == 0 || port < 49152 {
248                self.ephemeral_port.store(49152, Ordering::Relaxed);
249                continue;
250            }
251            return port;
252        }
253    }
254}
255
256//--------------------------------------------------------------------------------------------------
257// Functions
258//--------------------------------------------------------------------------------------------------
259
260/// Listener task: accepts TCP connections on the host and queues them.
261async fn tcp_listener_task(
262    bind_addr: SocketAddr,
263    guest_port: u16,
264    inbound_tx: mpsc::Sender<InboundConnection>,
265) -> std::io::Result<()> {
266    let listener = TcpListener::bind(bind_addr).await?;
267    tracing::debug!(bind = %bind_addr, guest_port, "published port listener started");
268
269    loop {
270        let (stream, _peer) = listener.accept().await?;
271        let conn = InboundConnection { stream, guest_port };
272        if inbound_tx.send(conn).await.is_err() {
273            break; // Publisher dropped.
274        }
275    }
276
277    Ok(())
278}
279
280/// Relay task: bridges a host TcpStream to channels connected to smoltcp.
281async fn inbound_relay_task(
282    stream: TcpStream,
283    mut to_host_rx: mpsc::Receiver<Bytes>,
284    from_host_tx: mpsc::Sender<Bytes>,
285    shared: Arc<SharedState>,
286) -> std::io::Result<()> {
287    let (mut rx, mut tx) = stream.into_split();
288    let mut buf = vec![0u8; RELAY_BUF_SIZE];
289
290    loop {
291        tokio::select! {
292            // smoltcp → host: data from guest arrives via channel.
293            data = to_host_rx.recv() => {
294                match data {
295                    Some(bytes) => {
296                        if let Err(e) = tx.write_all(&bytes).await {
297                            tracing::debug!(error = %e, "write to host client failed");
298                            break;
299                        }
300                    }
301                    None => break,
302                }
303            }
304
305            // host → smoltcp: data from host client to write to guest.
306            result = rx.read(&mut buf) => {
307                match result {
308                    Ok(0) => break,
309                    Ok(n) => {
310                        let data = Bytes::copy_from_slice(&buf[..n]);
311                        if from_host_tx.send(data).await.is_err() {
312                            break;
313                        }
314                        shared.proxy_wake.wake();
315                    }
316                    Err(e) => {
317                        tracing::debug!(error = %e, "read from host client failed");
318                        break;
319                    }
320                }
321            }
322        }
323    }
324
325    Ok(())
326}
327
328/// Write data from the host relay channel to the smoltcp socket.
329fn write_host_data(socket: &mut tcp::Socket<'_>, relay: &mut InboundRelay) {
330    // First, try to finish writing any pending partial data.
331    if let Some((data, offset)) = &mut relay.write_buf {
332        if socket.can_send() {
333            match socket.send_slice(&data[*offset..]) {
334                Ok(written) => {
335                    *offset += written;
336                    if *offset >= data.len() {
337                        relay.write_buf = None;
338                    }
339                }
340                Err(_) => return,
341            }
342        } else {
343            return;
344        }
345    }
346
347    // Then drain the channel.
348    while relay.write_buf.is_none() {
349        match relay.from_host.try_recv() {
350            Ok(data) => {
351                if socket.can_send() {
352                    match socket.send_slice(&data) {
353                        Ok(written) if written < data.len() => {
354                            relay.write_buf = Some((data, written));
355                        }
356                        Err(_) => {
357                            relay.write_buf = Some((data, 0));
358                        }
359                        _ => {}
360                    }
361                } else {
362                    relay.write_buf = Some((data, 0));
363                }
364            }
365            Err(_) => break,
366        }
367    }
368}