Skip to main content

microsandbox_agentd/
tcp.rs

1//! Guest-side TCP stream session handling.
2//!
3//! Handles `core.tcp.*` protocol messages by opening TCP sockets from
4//! inside the guest and relaying bytes between those sockets and the host.
5
6use std::time::Duration;
7
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpStream;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13use microsandbox_protocol::codec;
14use microsandbox_protocol::message::{Message, MessageType};
15use microsandbox_protocol::tcp::{TcpClosed, TcpConnect, TcpConnected, TcpData, TcpEof, TcpFailed};
16
17use crate::session::SessionOutput;
18
19//--------------------------------------------------------------------------------------------------
20// Constants
21//--------------------------------------------------------------------------------------------------
22
23/// TCP stream read chunk size.
24const TCP_CHUNK_SIZE: usize = 64 * 1024;
25
26/// How many host->guest command frames may queue before the agent loop has to
27/// wait. Bounding this turns a slow or stalled destination into backpressure
28/// (the serial reader pauses, which throttles the SSH window) instead of
29/// unbounded guest memory growth.
30const TCP_COMMAND_CAPACITY: usize = 32;
31
32/// Upper bound on a single guest-side connect attempt. The connect runs in the
33/// per-session task, so this only bounds that task's lifetime; it never blocks
34/// the agent's serial loop.
35const TCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37//--------------------------------------------------------------------------------------------------
38// Types
39//--------------------------------------------------------------------------------------------------
40
41/// Tracks an active guest-originated TCP stream.
42pub struct TcpSession {
43    owner_id: u32,
44    commands: mpsc::Sender<TcpCommand>,
45    task: JoinHandle<()>,
46}
47
48enum TcpCommand {
49    Data(Vec<u8>),
50    Eof,
51}
52
53//--------------------------------------------------------------------------------------------------
54// Methods
55//--------------------------------------------------------------------------------------------------
56
57impl TcpSession {
58    /// Correlation ID whose relay client owns this TCP stream.
59    pub fn owner_id(&self) -> u32 {
60        self.owner_id
61    }
62
63    /// Queue stream data to write to the guest socket.
64    ///
65    /// Awaits queue space when the per-session relay is behind, so a stalled
66    /// destination backpressures the caller instead of growing memory.
67    pub async fn write_data(&self, data: Vec<u8>) -> Result<(), String> {
68        self.commands
69            .send(TcpCommand::Data(data))
70            .await
71            .map_err(|_| "TCP session is closed".to_string())
72    }
73
74    /// Close the guest socket write half.
75    ///
76    /// Ordered after any queued data, so the destination sees the write shutdown
77    /// only once it has received everything sent before it.
78    pub async fn close_write(&self) -> Result<(), String> {
79        self.commands
80            .send(TcpCommand::Eof)
81            .await
82            .map_err(|_| "TCP session is closed".to_string())
83    }
84
85    /// Tear down the TCP session.
86    ///
87    /// Aborts the relay task directly rather than queuing a command, so teardown
88    /// never waits behind a full command queue. Dropping the task closes the
89    /// guest socket. The host has already closed its side before asking for this,
90    /// so no terminal frame is owed back to it.
91    pub fn close(&self) {
92        self.task.abort();
93    }
94
95    /// Returns whether the background relay task has finished.
96    pub fn is_finished(&self) -> bool {
97        self.task.is_finished()
98    }
99
100    /// Open a TCP stream from inside the guest and start relaying it.
101    ///
102    /// The OS connect runs inside the spawned task, not on the caller's serial
103    /// loop, so a hanging or slow destination can never wedge the agent. The
104    /// task reports `core.tcp.connected` on success or a terminal
105    /// `core.tcp.failed` on error/timeout over `session_tx`; the host correlates
106    /// either reply by id. The returned session is live immediately, with
107    /// commands queued until the connect completes.
108    pub fn open(
109        id: u32,
110        req: TcpConnect,
111        session_tx: &mpsc::UnboundedSender<(u32, SessionOutput)>,
112    ) -> Self {
113        let (commands_tx, commands_rx) = mpsc::channel(TCP_COMMAND_CAPACITY);
114        let output_tx = session_tx.clone();
115        let task = tokio::spawn(async move {
116            connect_and_relay(id, req, commands_rx, output_tx).await;
117        });
118
119        Self {
120            owner_id: id,
121            commands: commands_tx,
122            task,
123        }
124    }
125}
126
127//--------------------------------------------------------------------------------------------------
128// Functions: Helpers
129//--------------------------------------------------------------------------------------------------
130
131/// Connects to the destination, reports the outcome, then relays the stream.
132///
133/// Runs entirely inside the per-session task. On a connect error or timeout it
134/// emits a terminal `core.tcp.failed`; the agent loop removes the session when
135/// that frame flows past. On success it emits `core.tcp.connected` and hands off
136/// to the relay loop.
137async fn connect_and_relay(
138    id: u32,
139    req: TcpConnect,
140    commands: mpsc::Receiver<TcpCommand>,
141    tx: mpsc::UnboundedSender<(u32, SessionOutput)>,
142) {
143    let connect = TcpStream::connect((req.host.as_str(), req.port));
144    let stream = match tokio::time::timeout(TCP_CONNECT_TIMEOUT, connect).await {
145        Ok(Ok(stream)) => stream,
146        Ok(Err(e)) => {
147            send_raw_tcp_message(
148                id,
149                MessageType::TcpFailed,
150                &TcpFailed {
151                    error: format!("connect {}:{}: {e}", req.host, req.port),
152                },
153                &tx,
154            );
155            return;
156        }
157        Err(_elapsed) => {
158            send_raw_tcp_message(
159                id,
160                MessageType::TcpFailed,
161                &TcpFailed {
162                    error: format!("connect {}:{} timed out", req.host, req.port),
163                },
164                &tx,
165            );
166            return;
167        }
168    };
169
170    if !send_raw_tcp_message(id, MessageType::TcpConnected, &TcpConnected {}, &tx) {
171        return;
172    }
173
174    relay_tcp_session(id, stream, commands, tx).await;
175}
176
177async fn relay_tcp_session(
178    id: u32,
179    mut stream: TcpStream,
180    mut commands: mpsc::Receiver<TcpCommand>,
181    tx: mpsc::UnboundedSender<(u32, SessionOutput)>,
182) {
183    let mut read_buf = vec![0u8; TCP_CHUNK_SIZE];
184    let mut terminal_sent = false;
185    // The destination half-closed its write side. We stop reading but keep the
186    // loop alive so host->destination data still flows until the host closes.
187    let mut read_eof = false;
188
189    loop {
190        tokio::select! {
191            read = stream.read(&mut read_buf), if !read_eof => {
192                match read {
193                    Ok(0) => {
194                        send_raw_tcp_message(id, MessageType::TcpEof, &TcpEof {}, &tx);
195                        read_eof = true;
196                    }
197                    Ok(n) => {
198                        if !send_raw_tcp_message(
199                            id,
200                            MessageType::TcpData,
201                            &TcpData {
202                                data: read_buf[..n].to_vec(),
203                            },
204                            &tx,
205                        ) {
206                            break;
207                        }
208                    }
209                    Err(e) => {
210                        terminal_sent = send_raw_tcp_message(
211                            id,
212                            MessageType::TcpFailed,
213                            &TcpFailed {
214                                error: format!("read TCP stream: {e}"),
215                            },
216                            &tx,
217                        );
218                        break;
219                    }
220                }
221            }
222            command = commands.recv() => {
223                match command {
224                    Some(TcpCommand::Data(data)) => {
225                        if let Err(e) = stream.write_all(&data).await {
226                            terminal_sent = send_raw_tcp_message(
227                                id,
228                                MessageType::TcpFailed,
229                                &TcpFailed {
230                                    error: format!("write TCP stream: {e}"),
231                                },
232                                &tx,
233                            );
234                            break;
235                        }
236                    }
237                    Some(TcpCommand::Eof) => {
238                        if let Err(e) = stream.shutdown().await {
239                            terminal_sent = send_raw_tcp_message(
240                                id,
241                                MessageType::TcpFailed,
242                                &TcpFailed {
243                                    error: format!("shutdown TCP stream: {e}"),
244                                },
245                                &tx,
246                            );
247                            break;
248                        }
249                    }
250                    None => {
251                        break;
252                    }
253                }
254            }
255        }
256    }
257
258    if !terminal_sent {
259        send_raw_tcp_message(id, MessageType::TcpClosed, &TcpClosed {}, &tx);
260    }
261}
262
263fn encode_tcp_message<T: serde::Serialize>(
264    id: u32,
265    t: MessageType,
266    payload: &T,
267    out_buf: &mut Vec<u8>,
268) -> Result<(), String> {
269    let msg = Message::with_payload(t, id, payload).map_err(|e| format!("encode tcp: {e}"))?;
270    codec::encode_to_buf(&msg, out_buf).map_err(|e| format!("encode tcp frame: {e}"))?;
271    Ok(())
272}
273
274fn send_raw_tcp_message<T: serde::Serialize>(
275    id: u32,
276    t: MessageType,
277    payload: &T,
278    tx: &mpsc::UnboundedSender<(u32, SessionOutput)>,
279) -> bool {
280    let mut buf = Vec::new();
281    match encode_tcp_message(id, t, payload, &mut buf) {
282        Ok(()) => tx.send((id, SessionOutput::Raw(buf))).is_ok(),
283        Err(e) => {
284            eprintln!("failed to encode tcp message for {id}: {e}");
285            false
286        }
287    }
288}
289
290//--------------------------------------------------------------------------------------------------
291// Tests
292//--------------------------------------------------------------------------------------------------
293
294#[cfg(test)]
295mod tests {
296    use std::time::Duration;
297
298    use microsandbox_protocol::message::FLAG_TERMINAL;
299    use tokio::net::TcpListener;
300
301    use super::*;
302
303    #[tokio::test]
304    async fn connect_failure_sends_terminal_failed() {
305        let (session_tx, mut session_rx) = mpsc::unbounded_channel();
306
307        let session = TcpSession::open(
308            7,
309            TcpConnect {
310                host: "127.0.0.1".to_string(),
311                port: 0,
312            },
313            &session_tx,
314        );
315
316        // The connect runs in the task and reports failure over session_tx.
317        let msg = recv_message(&mut session_rx).await;
318        assert_eq!(msg.t, MessageType::TcpFailed);
319        assert_eq!(msg.flags, FLAG_TERMINAL);
320        let failed: TcpFailed = msg.payload().unwrap();
321        assert!(failed.error.contains("connect 127.0.0.1:0"));
322
323        wait_finished(&session).await;
324    }
325
326    #[tokio::test]
327    async fn close_request_finishes_session_task() {
328        let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
329        let port = listener.local_addr().unwrap().port();
330        let (session_tx, mut session_rx) = mpsc::unbounded_channel();
331        let accept_task = tokio::spawn(async move {
332            let (_socket, _) = listener.accept().await.unwrap();
333            tokio::time::sleep(Duration::from_secs(5)).await;
334        });
335
336        let session = TcpSession::open(
337            9,
338            TcpConnect {
339                host: "127.0.0.1".to_string(),
340                port,
341            },
342            &session_tx,
343        );
344
345        let connected = recv_message(&mut session_rx).await;
346        assert_eq!(connected.t, MessageType::TcpConnected);
347
348        session.close();
349        wait_finished(&session).await;
350
351        accept_task.abort();
352    }
353
354    #[tokio::test]
355    async fn destination_eof_keeps_session_open_for_host_writes() {
356        let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
357        let port = listener.local_addr().unwrap().port();
358        let (session_tx, mut session_rx) = mpsc::unbounded_channel();
359
360        // The destination half-closes its write side, then keeps reading so it
361        // still receives whatever the host sends after the EOF.
362        let (got_tx, got_rx) = tokio::sync::oneshot::channel();
363        let accept_task = tokio::spawn(async move {
364            let (mut socket, _) = listener.accept().await.unwrap();
365            socket.shutdown().await.unwrap();
366            let mut buf = Vec::new();
367            socket.read_to_end(&mut buf).await.unwrap();
368            let _ = got_tx.send(buf);
369        });
370
371        let session = TcpSession::open(
372            11,
373            TcpConnect {
374                host: "127.0.0.1".to_string(),
375                port,
376            },
377            &session_tx,
378        );
379
380        let connected = recv_message(&mut session_rx).await;
381        assert_eq!(connected.t, MessageType::TcpConnected);
382
383        // The destination's FIN surfaces as a non-terminal TcpEof, and the
384        // session stays alive.
385        let eof = recv_message(&mut session_rx).await;
386        assert_eq!(eof.t, MessageType::TcpEof);
387        assert_ne!(eof.flags, FLAG_TERMINAL);
388        assert!(!session.is_finished());
389
390        // The host can still reach the destination after that EOF.
391        session.write_data(b"after-eof".to_vec()).await.unwrap();
392        session.close_write().await.unwrap();
393        let received = tokio::time::timeout(Duration::from_secs(1), got_rx)
394            .await
395            .unwrap()
396            .unwrap();
397        assert_eq!(received, b"after-eof");
398
399        // An explicit close tears the session down.
400        session.close();
401        wait_finished(&session).await;
402
403        accept_task.await.unwrap();
404    }
405
406    async fn wait_finished(session: &TcpSession) {
407        tokio::time::timeout(Duration::from_secs(1), async {
408            while !session.is_finished() {
409                tokio::time::sleep(Duration::from_millis(10)).await;
410            }
411        })
412        .await
413        .unwrap();
414    }
415
416    fn decode_one_message(buf: &mut Vec<u8>) -> Message {
417        codec::try_decode_from_buf(buf).unwrap().unwrap()
418    }
419
420    async fn recv_message(rx: &mut mpsc::UnboundedReceiver<(u32, SessionOutput)>) -> Message {
421        let (_id, output) = rx.recv().await.unwrap();
422        let SessionOutput::Raw(mut bytes) = output else {
423            panic!("expected SessionOutput::Raw frame");
424        };
425        decode_one_message(&mut bytes)
426    }
427}