Skip to main content

microsandbox_network/
proxy.rs

1//! Bidirectional TCP proxy: smoltcp socket ↔ channels ↔ tokio socket.
2//!
3//! Each outbound guest TCP connection gets a proxy task that opens a real
4//! TCP connection to the destination via tokio and relays data between the
5//! channel pair (connected to the smoltcp socket in the poll loop) and the
6//! real server.
7
8use std::io;
9use std::net::SocketAddr;
10use std::sync::Arc;
11
12use bytes::Bytes;
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tokio::net::TcpStream;
15use tokio::sync::mpsc;
16
17use crate::shared::SharedState;
18
19//--------------------------------------------------------------------------------------------------
20// Constants
21//--------------------------------------------------------------------------------------------------
22
23/// Buffer size for reading from the real server.
24const SERVER_READ_BUF_SIZE: usize = 16384;
25
26//--------------------------------------------------------------------------------------------------
27// Functions
28//--------------------------------------------------------------------------------------------------
29
30/// Spawn a TCP proxy task for a newly established connection.
31///
32/// Connects to `dst` via tokio, then bidirectionally relays data between
33/// the smoltcp socket (via channels) and the real server. Wakes the poll
34/// thread via `shared.proxy_wake` whenever data is sent toward the guest.
35pub fn spawn_tcp_proxy(
36    handle: &tokio::runtime::Handle,
37    dst: SocketAddr,
38    from_smoltcp: mpsc::Receiver<Bytes>,
39    to_smoltcp: mpsc::Sender<Bytes>,
40    shared: Arc<SharedState>,
41) {
42    handle.spawn(async move {
43        if let Err(e) = tcp_proxy_task(dst, from_smoltcp, to_smoltcp, shared).await {
44            tracing::debug!(dst = %dst, error = %e, "TCP proxy task ended");
45        }
46    });
47}
48
49/// Core TCP proxy: connect to real destination and relay bidirectionally.
50async fn tcp_proxy_task(
51    dst: SocketAddr,
52    mut from_smoltcp: mpsc::Receiver<Bytes>,
53    to_smoltcp: mpsc::Sender<Bytes>,
54    shared: Arc<SharedState>,
55) -> io::Result<()> {
56    let stream = TcpStream::connect(dst).await?;
57    let (mut server_rx, mut server_tx) = stream.into_split();
58
59    let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
60
61    // Bidirectional relay using tokio::select!.
62    //
63    // guest → server: receive from channel, write to server socket.
64    // server → guest: read from server socket, send via channel + wake poll.
65    loop {
66        tokio::select! {
67            // Guest → server.
68            data = from_smoltcp.recv() => {
69                match data {
70                    Some(bytes) => {
71                        if let Err(e) = server_tx.write_all(&bytes).await {
72                            tracing::debug!(dst = %dst, error = %e, "write to server failed");
73                            break;
74                        }
75                    }
76                    // Channel closed — smoltcp socket was closed by guest.
77                    None => break,
78                }
79            }
80
81            // Server → guest.
82            result = server_rx.read(&mut server_buf) => {
83                match result {
84                    Ok(0) => break, // Server closed connection.
85                    Ok(n) => {
86                        let data = Bytes::copy_from_slice(&server_buf[..n]);
87                        if to_smoltcp.send(data).await.is_err() {
88                            // Channel closed — poll loop dropped the receiver.
89                            break;
90                        }
91                        // Wake the poll thread so it writes data to the
92                        // smoltcp socket.
93                        shared.proxy_wake.wake();
94                    }
95                    Err(e) => {
96                        tracing::debug!(dst = %dst, error = %e, "read from server failed");
97                        break;
98                    }
99                }
100            }
101        }
102    }
103
104    Ok(())
105}