Skip to main content

microsandbox_network/
conn.rs

1//! Connection tracker: manages smoltcp TCP sockets for the poll loop.
2//!
3//! Creates sockets on SYN detection, tracks connection lifecycle, relays data
4//! between smoltcp sockets and proxy task channels, and cleans up closed
5//! connections.
6
7use std::collections::{HashMap, HashSet};
8use std::net::SocketAddr;
9
10use bytes::Bytes;
11use smoltcp::iface::{SocketHandle, SocketSet};
12use smoltcp::socket::tcp;
13use smoltcp::wire::IpListenEndpoint;
14use tokio::sync::mpsc;
15
16//--------------------------------------------------------------------------------------------------
17// Constants
18//--------------------------------------------------------------------------------------------------
19
20/// TCP socket receive buffer size (64 KiB).
21const TCP_RX_BUF_SIZE: usize = 65536;
22
23/// TCP socket transmit buffer size (64 KiB).
24const TCP_TX_BUF_SIZE: usize = 65536;
25
26/// Default max concurrent connections.
27const DEFAULT_MAX_CONNECTIONS: usize = 256;
28
29/// Capacity of the mpsc channels between the poll loop and proxy tasks.
30const CHANNEL_CAPACITY: usize = 32;
31
32/// Buffer size for reading from smoltcp sockets.
33const RELAY_BUF_SIZE: usize = 16384;
34
35//--------------------------------------------------------------------------------------------------
36// Types
37//--------------------------------------------------------------------------------------------------
38
39/// Tracks TCP connections between guest and proxy tasks.
40///
41/// Each guest TCP connection maps to a smoltcp socket and a pair of channels
42/// connecting it to a tokio proxy task. The tracker handles:
43///
44/// - **Socket creation** — on SYN detection, before smoltcp processes the frame.
45/// - **Data relay** — shuttles bytes between smoltcp sockets and channels.
46/// - **Lifecycle detection** — identifies newly-established connections for
47///   proxy spawning.
48/// - **Cleanup** — removes closed sockets from the socket set.
49pub struct ConnectionTracker {
50    /// Active connections keyed by smoltcp socket handle.
51    connections: HashMap<SocketHandle, Connection>,
52    /// Secondary index for O(1) duplicate-SYN detection by (src, dst) 4-tuple.
53    connection_keys: HashSet<(SocketAddr, SocketAddr)>,
54    /// Max concurrent connections (from NetworkConfig).
55    max_connections: usize,
56}
57
58/// Maximum number of poll iterations to attempt flushing remaining data
59/// after the proxy task has exited before force-aborting the socket.
60const DEFERRED_CLOSE_LIMIT: u16 = 64;
61
62/// Internal state for a single tracked TCP connection.
63struct Connection {
64    /// Guest source address (from the guest's SYN).
65    src: SocketAddr,
66    /// Original destination (from the guest's SYN).
67    dst: SocketAddr,
68    /// Sends data from smoltcp socket to proxy task (guest → server).
69    to_proxy: mpsc::Sender<Bytes>,
70    /// Receives data from proxy task to write to smoltcp socket (server → guest).
71    from_proxy: mpsc::Receiver<Bytes>,
72    /// Proxy-side channel ends, held until the connection is ESTABLISHED.
73    /// Taken by [`ConnectionTracker::take_new_connections()`].
74    proxy_channels: Option<ProxyChannels>,
75    /// Whether a proxy task has been spawned for this connection.
76    proxy_spawned: bool,
77    /// Partial data from proxy that couldn't be fully written to smoltcp socket.
78    write_buf: Option<(Bytes, usize)>,
79    /// Data read from smoltcp socket that couldn't be sent to proxy (channel full).
80    /// Must be sent before reading more from the socket to preserve stream order.
81    read_buf: Option<Bytes>,
82    /// Counter for deferred close attempts (prevents stalling forever).
83    close_attempts: u16,
84}
85
86/// Proxy-side channel ends, created at socket creation time and taken when
87/// the connection becomes ESTABLISHED.
88struct ProxyChannels {
89    /// Receive data from smoltcp socket (guest → proxy task).
90    from_smoltcp: mpsc::Receiver<Bytes>,
91    /// Send data to smoltcp socket (proxy task → guest).
92    to_smoltcp: mpsc::Sender<Bytes>,
93}
94
95/// Information for spawning a proxy task for a newly established connection.
96///
97/// Returned by [`ConnectionTracker::take_new_connections()`]. The poll loop
98/// passes this to the proxy task spawner.
99pub struct NewConnection {
100    /// Original destination the guest was connecting to.
101    pub dst: SocketAddr,
102    /// Receive data from smoltcp socket (guest → proxy task).
103    pub from_smoltcp: mpsc::Receiver<Bytes>,
104    /// Send data to smoltcp socket (proxy task → guest).
105    pub to_smoltcp: mpsc::Sender<Bytes>,
106}
107
108//--------------------------------------------------------------------------------------------------
109// Methods
110//--------------------------------------------------------------------------------------------------
111
112impl ConnectionTracker {
113    /// Create a new tracker with the given connection limit.
114    pub fn new(max_connections: Option<usize>) -> Self {
115        Self {
116            connections: HashMap::new(),
117            connection_keys: HashSet::new(),
118            max_connections: max_connections.unwrap_or(DEFAULT_MAX_CONNECTIONS),
119        }
120    }
121
122    /// Returns `true` if a tracked socket already exists for this exact
123    /// connection (same source AND destination). O(1) via HashSet lookup.
124    pub fn has_socket_for(&self, src: &SocketAddr, dst: &SocketAddr) -> bool {
125        self.connection_keys.contains(&(*src, *dst))
126    }
127
128    /// Create a smoltcp TCP socket for an incoming SYN and register it.
129    ///
130    /// The socket is put into LISTEN state on the destination IP + port so
131    /// smoltcp will complete the three-way handshake when it processes the
132    /// SYN frame. Binding to the specific destination IP (not just port)
133    /// prevents socket dispatch ambiguity when multiple connections target
134    /// different IPs on the same port.
135    ///
136    /// Returns `false` if at `max_connections` limit.
137    pub fn create_tcp_socket(
138        &mut self,
139        src: SocketAddr,
140        dst: SocketAddr,
141        sockets: &mut SocketSet<'_>,
142    ) -> bool {
143        if self.connections.len() >= self.max_connections {
144            return false;
145        }
146
147        // Create smoltcp TCP socket with buffers.
148        let rx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUF_SIZE]);
149        let tx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUF_SIZE]);
150        let mut socket = tcp::Socket::new(rx_buf, tx_buf);
151
152        // Listen on the specific destination IP + port. With any_ip mode,
153        // binding to the IP ensures the correct socket accepts each SYN
154        // when multiple connections target the same port on different IPs.
155        let listen_endpoint = IpListenEndpoint {
156            addr: Some(dst.ip().into()),
157            port: dst.port(),
158        };
159        if socket.listen(listen_endpoint).is_err() {
160            return false;
161        }
162
163        let handle = sockets.add(socket);
164
165        // Create channel pairs for proxy task communication.
166        //
167        // smoltcp → proxy (guest sends data, proxy relays to server):
168        let (to_proxy_tx, to_proxy_rx) = mpsc::channel(CHANNEL_CAPACITY);
169        // proxy → smoltcp (server sends data, proxy relays to guest):
170        let (from_proxy_tx, from_proxy_rx) = mpsc::channel(CHANNEL_CAPACITY);
171
172        self.connection_keys.insert((src, dst));
173        self.connections.insert(
174            handle,
175            Connection {
176                src,
177                dst,
178                to_proxy: to_proxy_tx,
179                from_proxy: from_proxy_rx,
180                proxy_channels: Some(ProxyChannels {
181                    from_smoltcp: to_proxy_rx,
182                    to_smoltcp: from_proxy_tx,
183                }),
184                proxy_spawned: false,
185                write_buf: None,
186                read_buf: None,
187                close_attempts: 0,
188            },
189        );
190
191        true
192    }
193
194    /// Relay data between smoltcp sockets and proxy task channels.
195    ///
196    /// For each connection with a spawned proxy:
197    /// - Reads data from the smoltcp socket and sends it to the proxy channel.
198    /// - Receives data from the proxy channel and writes it to the smoltcp socket.
199    pub fn relay_data(&mut self, sockets: &mut SocketSet<'_>) {
200        let mut relay_buf = [0u8; RELAY_BUF_SIZE];
201
202        for (&handle, conn) in &mut self.connections {
203            if !conn.proxy_spawned {
204                continue;
205            }
206
207            let socket = sockets.get_mut::<tcp::Socket>(handle);
208
209            // Detect proxy task exit: when the proxy drops its channel
210            // ends, close the smoltcp socket so the guest gets a FIN.
211            if conn.to_proxy.is_closed() {
212                write_proxy_data(socket, conn);
213                if conn.write_buf.is_none() {
214                    socket.close();
215                } else {
216                    // Abort if we've been trying to flush for too long
217                    // (guest stopped reading, socket send buffer full).
218                    conn.close_attempts += 1;
219                    if conn.close_attempts >= DEFERRED_CLOSE_LIMIT {
220                        socket.abort();
221                    }
222                }
223                continue;
224            }
225
226            // smoltcp → proxy: flush read_buf first, then read from socket.
227            if let Some(pending) = conn.read_buf.take()
228                && let Err(e) = conn.to_proxy.try_send(pending)
229            {
230                conn.read_buf = Some(e.into_inner());
231            }
232
233            if conn.read_buf.is_none() {
234                while socket.can_recv() {
235                    match socket.recv_slice(&mut relay_buf) {
236                        Ok(n) if n > 0 => {
237                            let data = Bytes::copy_from_slice(&relay_buf[..n]);
238                            if let Err(e) = conn.to_proxy.try_send(data) {
239                                conn.read_buf = Some(e.into_inner());
240                                break;
241                            }
242                        }
243                        _ => break,
244                    }
245                }
246            }
247
248            // proxy → smoltcp: write pending data, then drain channel.
249            write_proxy_data(socket, conn);
250        }
251    }
252
253    /// Collect newly-established connections that need proxy tasks.
254    ///
255    /// Returns a list of [`NewConnection`] structs containing the channel ends
256    /// for the proxy task. The poll loop is responsible for spawning the task.
257    pub fn take_new_connections(&mut self, sockets: &mut SocketSet<'_>) -> Vec<NewConnection> {
258        let mut new = Vec::new();
259
260        for (&handle, conn) in &mut self.connections {
261            if conn.proxy_spawned {
262                continue;
263            }
264
265            let socket = sockets.get::<tcp::Socket>(handle);
266            if socket.state() == tcp::State::Established {
267                conn.proxy_spawned = true;
268
269                if let Some(channels) = conn.proxy_channels.take() {
270                    new.push(NewConnection {
271                        dst: conn.dst,
272                        from_smoltcp: channels.from_smoltcp,
273                        to_smoltcp: channels.to_smoltcp,
274                    });
275                }
276            }
277        }
278
279        new
280    }
281
282    /// Remove closed connections and their sockets.
283    ///
284    /// Only removes sockets in the `Closed` state. Sockets in `TimeWait`
285    /// are left for smoltcp to handle naturally (2*MSL timer), preventing
286    /// delayed duplicate segments from being accepted by a reused port.
287    pub fn cleanup_closed(&mut self, sockets: &mut SocketSet<'_>) {
288        let keys = &mut self.connection_keys;
289        self.connections.retain(|&handle, conn| {
290            let socket = sockets.get::<tcp::Socket>(handle);
291            if matches!(socket.state(), tcp::State::Closed) {
292                keys.remove(&(conn.src, conn.dst));
293                sockets.remove(handle);
294                false
295            } else {
296                true
297            }
298        });
299    }
300}
301
302//--------------------------------------------------------------------------------------------------
303// Functions
304//--------------------------------------------------------------------------------------------------
305
306/// Try to write proxy data to the smoltcp socket.
307fn write_proxy_data(socket: &mut tcp::Socket<'_>, conn: &mut Connection) {
308    // First, try to finish writing any pending partial data.
309    if let Some((data, offset)) = &mut conn.write_buf {
310        if socket.can_send() {
311            match socket.send_slice(&data[*offset..]) {
312                Ok(written) => {
313                    *offset += written;
314                    if *offset >= data.len() {
315                        conn.write_buf = None;
316                    }
317                }
318                Err(_) => return,
319            }
320        } else {
321            return;
322        }
323    }
324
325    // Then drain the channel.
326    while conn.write_buf.is_none() {
327        match conn.from_proxy.try_recv() {
328            Ok(data) => {
329                if socket.can_send() {
330                    match socket.send_slice(&data) {
331                        Ok(written) if written < data.len() => {
332                            conn.write_buf = Some((data, written));
333                        }
334                        Err(_) => {
335                            conn.write_buf = Some((data, 0));
336                        }
337                        _ => {}
338                    }
339                } else {
340                    conn.write_buf = Some((data, 0));
341                }
342            }
343            Err(_) => break,
344        }
345    }
346}