Skip to main content

microsandbox_network/
conn.rs

1//! Connection tracker: manages smoltcp TCP sockets for the poll loop.
2//!
3//! Creates sockets on SYN detection, tracks connection lifecycle, relays data
4//! between smoltcp sockets and proxy task channels, and cleans up closed
5//! connections.
6
7use std::collections::{HashMap, HashSet};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU8, Ordering};
11
12use bytes::Bytes;
13use smoltcp::iface::{SocketHandle, SocketSet};
14use smoltcp::socket::tcp;
15use smoltcp::wire::IpListenEndpoint;
16use tokio::sync::mpsc;
17
18//--------------------------------------------------------------------------------------------------
19// Constants
20//--------------------------------------------------------------------------------------------------
21
22/// TCP socket receive buffer size (64 KiB).
23const TCP_RX_BUF_SIZE: usize = 65536;
24
25/// TCP socket transmit buffer size (64 KiB).
26const TCP_TX_BUF_SIZE: usize = 65536;
27
28/// Default max concurrent connections.
29const DEFAULT_MAX_CONNECTIONS: usize = 256;
30
31/// Capacity of the mpsc channels between the poll loop and proxy tasks.
32const CHANNEL_CAPACITY: usize = 32;
33
34/// Buffer size for reading from smoltcp sockets.
35const RELAY_BUF_SIZE: usize = 16384;
36
37//--------------------------------------------------------------------------------------------------
38// Types
39//--------------------------------------------------------------------------------------------------
40
41/// Terminal connection status reported by an outbound proxy task.
42#[repr(u8)]
43#[derive(Clone, Copy, Debug, Eq, PartialEq)]
44pub enum ProxyConnectStatus {
45    /// No final proxy connection status has been reported yet.
46    Pending = 0,
47    /// The proxy connected to the upstream.
48    Connected = 1,
49    /// The proxy denied the connection before dialing upstream.
50    PolicyDenied = 2,
51    /// The proxy attempted to dial upstream and the connect failed.
52    UpstreamConnectFailed = 3,
53}
54
55/// Shared status for an outbound proxy task.
56///
57/// The smoltcp poll loop reads this when the proxy task exits to decide
58/// whether the guest should see a clean close or a TCP reset.
59pub struct ProxyConnectState {
60    status: AtomicU8,
61}
62
63/// Tracks TCP connections between guest and proxy tasks.
64///
65/// Each guest TCP connection maps to a smoltcp socket and a pair of channels
66/// connecting it to a tokio proxy task. The tracker handles:
67///
68/// - **Socket creation** — on SYN detection, before smoltcp processes the frame.
69/// - **Data relay** — shuttles bytes between smoltcp sockets and channels.
70/// - **Lifecycle detection** — identifies newly-established connections for
71///   proxy spawning.
72/// - **Cleanup** — removes closed sockets from the socket set.
73pub struct ConnectionTracker {
74    /// Active connections keyed by smoltcp socket handle.
75    connections: HashMap<SocketHandle, Connection>,
76    /// Secondary index for O(1) duplicate-SYN detection by (src, dst) 4-tuple.
77    connection_keys: HashSet<(SocketAddr, SocketAddr)>,
78    /// Max concurrent connections (from NetworkConfig).
79    max_connections: usize,
80}
81
82/// Maximum number of poll iterations to attempt flushing remaining data
83/// after the proxy task has exited before force-aborting the socket.
84const DEFERRED_CLOSE_LIMIT: u16 = 64;
85
86/// Internal state for a single tracked TCP connection.
87struct Connection {
88    /// Guest source address (from the guest's SYN).
89    src: SocketAddr,
90    /// Original destination (from the guest's SYN).
91    dst: SocketAddr,
92    /// Sends data from smoltcp socket to proxy task (guest → server).
93    to_proxy: mpsc::Sender<Bytes>,
94    /// Receives data from proxy task to write to smoltcp socket (server → guest).
95    from_proxy: mpsc::Receiver<Bytes>,
96    /// Proxy-side channel ends, held until the connection is ESTABLISHED.
97    /// Taken by [`ConnectionTracker::take_new_connections()`].
98    proxy_channels: Option<ProxyChannels>,
99    /// Whether a proxy task has been spawned for this connection.
100    proxy_spawned: bool,
101    /// Status reported by the proxy task before it exits.
102    proxy_connect: Arc<ProxyConnectState>,
103    /// Partial data from proxy that couldn't be fully written to smoltcp socket.
104    write_buf: Option<(Bytes, usize)>,
105    /// Data read from smoltcp socket that couldn't be sent to proxy (channel full).
106    /// Must be sent before reading more from the socket to preserve stream order.
107    read_buf: Option<Bytes>,
108    /// Counter for deferred close attempts (prevents stalling forever).
109    close_attempts: u16,
110}
111
112/// Proxy-side channel ends, created at socket creation time and taken when
113/// the connection becomes ESTABLISHED.
114struct ProxyChannels {
115    /// Receive data from smoltcp socket (guest → proxy task).
116    from_smoltcp: mpsc::Receiver<Bytes>,
117    /// Send data to smoltcp socket (proxy task → guest).
118    to_smoltcp: mpsc::Sender<Bytes>,
119}
120
121/// Information for spawning a proxy task for a newly established connection.
122///
123/// Returned by [`ConnectionTracker::take_new_connections()`]. The poll loop
124/// passes this to the proxy task spawner.
125pub struct NewConnection {
126    /// Original destination the guest was connecting to.
127    pub dst: SocketAddr,
128    /// Receive data from smoltcp socket (guest → proxy task).
129    pub from_smoltcp: mpsc::Receiver<Bytes>,
130    /// Send data to smoltcp socket (proxy task → guest).
131    pub to_smoltcp: mpsc::Sender<Bytes>,
132    /// Status the proxy task updates before it exits.
133    pub proxy_connect: Arc<ProxyConnectState>,
134}
135
136//--------------------------------------------------------------------------------------------------
137// Methods
138//--------------------------------------------------------------------------------------------------
139
140impl ProxyConnectStatus {
141    fn as_u8(self) -> u8 {
142        self as u8
143    }
144
145    fn from_u8(value: u8) -> Self {
146        match value {
147            value if value == Self::Connected as u8 => Self::Connected,
148            value if value == Self::PolicyDenied as u8 => Self::PolicyDenied,
149            value if value == Self::UpstreamConnectFailed as u8 => Self::UpstreamConnectFailed,
150            _ => Self::Pending,
151        }
152    }
153}
154
155impl ProxyConnectState {
156    /// Create a new pending proxy connection status.
157    pub fn new() -> Self {
158        Self {
159            status: AtomicU8::new(ProxyConnectStatus::Pending.as_u8()),
160        }
161    }
162
163    /// Mark the proxy as successfully connected to upstream.
164    pub fn mark_connected(&self) {
165        self.store(ProxyConnectStatus::Connected);
166    }
167
168    /// Mark the proxy as denied by egress policy before dialing upstream.
169    pub fn mark_policy_denied(&self) {
170        self.store(ProxyConnectStatus::PolicyDenied);
171    }
172
173    /// Mark the proxy as failed while dialing upstream.
174    pub fn mark_upstream_connect_failed(&self) {
175        self.store(ProxyConnectStatus::UpstreamConnectFailed);
176    }
177
178    /// Load the latest proxy connection status.
179    pub fn status(&self) -> ProxyConnectStatus {
180        ProxyConnectStatus::from_u8(self.status.load(Ordering::Acquire))
181    }
182
183    fn store(&self, status: ProxyConnectStatus) {
184        self.status.store(status.as_u8(), Ordering::Release);
185    }
186}
187
188impl Default for ProxyConnectState {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194impl ConnectionTracker {
195    /// Create a new tracker with the given connection limit.
196    pub fn new(max_connections: Option<usize>) -> Self {
197        Self {
198            connections: HashMap::new(),
199            connection_keys: HashSet::new(),
200            max_connections: max_connections.unwrap_or(DEFAULT_MAX_CONNECTIONS),
201        }
202    }
203
204    /// Returns `true` if a tracked socket already exists for this exact
205    /// connection (same source AND destination). O(1) via HashSet lookup.
206    pub fn has_socket_for(&self, src: &SocketAddr, dst: &SocketAddr) -> bool {
207        self.connection_keys.contains(&(*src, *dst))
208    }
209
210    /// Create a smoltcp TCP socket for an incoming SYN and register it.
211    ///
212    /// The socket is put into LISTEN state on the destination IP + port so
213    /// smoltcp will complete the three-way handshake when it processes the
214    /// SYN frame. Binding to the specific destination IP (not just port)
215    /// prevents socket dispatch ambiguity when multiple connections target
216    /// different IPs on the same port.
217    ///
218    /// Returns `false` if at `max_connections` limit.
219    pub fn create_tcp_socket(
220        &mut self,
221        src: SocketAddr,
222        dst: SocketAddr,
223        sockets: &mut SocketSet<'_>,
224    ) -> bool {
225        if self.connections.len() >= self.max_connections {
226            return false;
227        }
228
229        // Create smoltcp TCP socket with buffers.
230        let rx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUF_SIZE]);
231        let tx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUF_SIZE]);
232        let mut socket = tcp::Socket::new(rx_buf, tx_buf);
233
234        // Listen on the specific destination IP + port. With any_ip mode,
235        // binding to the IP ensures the correct socket accepts each SYN
236        // when multiple connections target the same port on different IPs.
237        let listen_endpoint = IpListenEndpoint {
238            addr: Some(dst.ip().into()),
239            port: dst.port(),
240        };
241        if socket.listen(listen_endpoint).is_err() {
242            return false;
243        }
244
245        let handle = sockets.add(socket);
246
247        // Create channel pairs for proxy task communication.
248        //
249        // smoltcp → proxy (guest sends data, proxy relays to server):
250        let (to_proxy_tx, to_proxy_rx) = mpsc::channel(CHANNEL_CAPACITY);
251        // proxy → smoltcp (server sends data, proxy relays to guest):
252        let (from_proxy_tx, from_proxy_rx) = mpsc::channel(CHANNEL_CAPACITY);
253
254        self.connection_keys.insert((src, dst));
255        self.connections.insert(
256            handle,
257            Connection {
258                src,
259                dst,
260                to_proxy: to_proxy_tx,
261                from_proxy: from_proxy_rx,
262                proxy_channels: Some(ProxyChannels {
263                    from_smoltcp: to_proxy_rx,
264                    to_smoltcp: from_proxy_tx,
265                }),
266                proxy_spawned: false,
267                proxy_connect: Arc::new(ProxyConnectState::new()),
268                write_buf: None,
269                read_buf: None,
270                close_attempts: 0,
271            },
272        );
273
274        true
275    }
276
277    /// Relay data between smoltcp sockets and proxy task channels.
278    ///
279    /// For each connection with a spawned proxy:
280    /// - Reads data from the smoltcp socket and sends it to the proxy channel.
281    /// - Receives data from the proxy channel and writes it to the smoltcp socket.
282    pub fn relay_data(&mut self, sockets: &mut SocketSet<'_>) {
283        let mut relay_buf = [0u8; RELAY_BUF_SIZE];
284
285        for (&handle, conn) in &mut self.connections {
286            if !conn.proxy_spawned {
287                continue;
288            }
289
290            let socket = sockets.get_mut::<tcp::Socket>(handle);
291
292            // Already torn down (e.g. abort fired on a previous pass).
293            // Leave it for `cleanup_closed` to evict.
294            if matches!(socket.state(), tcp::State::Closed) {
295                continue;
296            }
297
298            // Detect proxy task exit: when the proxy drops its channel
299            // ends, close the smoltcp socket so the guest gets a FIN.
300            //
301            // If the proxy attempted and failed to reach upstream,
302            // an RST via `abort()` is instead sent so happy-eyeballs
303            // clients fall back to another family instead of committing
304            // to this half-open connection.
305            if conn.to_proxy.is_closed() {
306                if matches!(
307                    conn.proxy_connect.status(),
308                    ProxyConnectStatus::UpstreamConnectFailed
309                ) {
310                    tracing::debug!(
311                        src = %conn.src,
312                        dst = %conn.dst,
313                        "upstream connect failed; aborting smoltcp socket (RST to guest)"
314                    );
315                    socket.abort();
316                    continue;
317                }
318                write_proxy_data(socket, conn);
319                if conn.write_buf.is_none() {
320                    socket.close();
321                } else {
322                    // Abort if we've been trying to flush for too long
323                    // (guest stopped reading, socket send buffer full).
324                    conn.close_attempts += 1;
325                    if conn.close_attempts >= DEFERRED_CLOSE_LIMIT {
326                        socket.abort();
327                    }
328                }
329                continue;
330            }
331
332            // smoltcp → proxy: flush read_buf first, then read from socket.
333            if let Some(pending) = conn.read_buf.take()
334                && let Err(e) = conn.to_proxy.try_send(pending)
335            {
336                conn.read_buf = Some(e.into_inner());
337            }
338
339            if conn.read_buf.is_none() {
340                while socket.can_recv() {
341                    match socket.recv_slice(&mut relay_buf) {
342                        Ok(n) if n > 0 => {
343                            let data = Bytes::copy_from_slice(&relay_buf[..n]);
344                            if let Err(e) = conn.to_proxy.try_send(data) {
345                                conn.read_buf = Some(e.into_inner());
346                                break;
347                            }
348                        }
349                        _ => break,
350                    }
351                }
352            }
353
354            // proxy → smoltcp: write pending data, then drain channel.
355            write_proxy_data(socket, conn);
356        }
357    }
358
359    /// Collect newly-established connections that need proxy tasks.
360    ///
361    /// Returns a list of [`NewConnection`] structs containing the channel ends
362    /// for the proxy task. The poll loop is responsible for spawning the task.
363    pub fn take_new_connections(&mut self, sockets: &mut SocketSet<'_>) -> Vec<NewConnection> {
364        let mut new = Vec::new();
365
366        for (&handle, conn) in &mut self.connections {
367            if conn.proxy_spawned {
368                continue;
369            }
370
371            let socket = sockets.get::<tcp::Socket>(handle);
372            if socket.state() == tcp::State::Established {
373                conn.proxy_spawned = true;
374
375                if let Some(channels) = conn.proxy_channels.take() {
376                    new.push(NewConnection {
377                        dst: conn.dst,
378                        from_smoltcp: channels.from_smoltcp,
379                        to_smoltcp: channels.to_smoltcp,
380                        proxy_connect: conn.proxy_connect.clone(),
381                    });
382                }
383            }
384        }
385
386        new
387    }
388
389    /// Remove closed connections and their sockets.
390    ///
391    /// Only removes sockets in the `Closed` state. Sockets in `TimeWait`
392    /// are left for smoltcp to handle naturally (2*MSL timer), preventing
393    /// delayed duplicate segments from being accepted by a reused port.
394    pub fn cleanup_closed(&mut self, sockets: &mut SocketSet<'_>) {
395        let keys = &mut self.connection_keys;
396        self.connections.retain(|&handle, conn| {
397            let socket = sockets.get::<tcp::Socket>(handle);
398            if matches!(socket.state(), tcp::State::Closed) {
399                keys.remove(&(conn.src, conn.dst));
400                sockets.remove(handle);
401                false
402            } else {
403                true
404            }
405        });
406    }
407}
408
409//--------------------------------------------------------------------------------------------------
410// Functions
411//--------------------------------------------------------------------------------------------------
412
413/// Try to write proxy data to the smoltcp socket.
414fn write_proxy_data(socket: &mut tcp::Socket<'_>, conn: &mut Connection) {
415    // First, try to finish writing any pending partial data.
416    if let Some((data, offset)) = &mut conn.write_buf {
417        if socket.can_send() {
418            match socket.send_slice(&data[*offset..]) {
419                Ok(written) => {
420                    *offset += written;
421                    if *offset >= data.len() {
422                        conn.write_buf = None;
423                    }
424                }
425                Err(_) => return,
426            }
427        } else {
428            return;
429        }
430    }
431
432    // Then drain the channel.
433    while conn.write_buf.is_none() {
434        match conn.from_proxy.try_recv() {
435            Ok(data) => {
436                if socket.can_send() {
437                    match socket.send_slice(&data) {
438                        Ok(written) if written < data.len() => {
439                            conn.write_buf = Some((data, written));
440                        }
441                        Err(_) => {
442                            conn.write_buf = Some((data, 0));
443                        }
444                        _ => {}
445                    }
446                } else {
447                    conn.write_buf = Some((data, 0));
448                }
449            }
450            Err(_) => break,
451        }
452    }
453}