Skip to main content

microsandbox_agentd/
tcp.rs

1//! Guest-side TCP stream session handling.
2//!
3//! Handles `core.tcp.*` protocol messages by opening TCP sockets from
4//! inside the guest and relaying bytes between those sockets and the host.
5
6use std::time::Duration;
7
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpStream;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13use microsandbox_protocol::codec;
14use microsandbox_protocol::message::{Message, MessageType};
15use microsandbox_protocol::tcp::{TcpClosed, TcpConnect, TcpConnected, TcpData, TcpEof, TcpFailed};
16
17use crate::session::{RawActivity, RawSessionCompletion, RawSessionOutput, SessionOutput};
18
19//--------------------------------------------------------------------------------------------------
20// Constants
21//--------------------------------------------------------------------------------------------------
22
23/// TCP stream read chunk size.
24const TCP_CHUNK_SIZE: usize = 64 * 1024;
25
26/// How many host->guest command frames may queue before the agent loop has to
27/// wait. Bounding this turns a slow or stalled destination into backpressure
28/// (the serial reader pauses, which throttles the SSH window) instead of
29/// unbounded guest memory growth.
30const TCP_COMMAND_CAPACITY: usize = 32;
31
32/// Upper bound on a single guest-side connect attempt. The connect runs in the
33/// per-session task, so this only bounds that task's lifetime; it never blocks
34/// the agent's serial loop.
35const TCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37//--------------------------------------------------------------------------------------------------
38// Types
39//--------------------------------------------------------------------------------------------------
40
41/// Tracks an active guest-originated TCP stream.
42pub struct TcpSession {
43    owner_id: u32,
44    commands: mpsc::Sender<TcpCommand>,
45    task: JoinHandle<()>,
46}
47
48enum TcpCommand {
49    Data(Vec<u8>),
50    Eof,
51}
52
53//--------------------------------------------------------------------------------------------------
54// Methods
55//--------------------------------------------------------------------------------------------------
56
57impl TcpSession {
58    /// Correlation ID whose relay client owns this TCP stream.
59    pub fn owner_id(&self) -> u32 {
60        self.owner_id
61    }
62
63    /// Queue stream data to write to the guest socket.
64    ///
65    /// Awaits queue space when the per-session relay is behind, so a stalled
66    /// destination backpressures the caller instead of growing memory.
67    pub async fn write_data(&self, data: Vec<u8>) -> Result<(), String> {
68        self.commands
69            .send(TcpCommand::Data(data))
70            .await
71            .map_err(|_| "TCP session is closed".to_string())
72    }
73
74    /// Close the guest socket write half.
75    ///
76    /// Ordered after any queued data, so the destination sees the write shutdown
77    /// only once it has received everything sent before it.
78    pub async fn close_write(&self) -> Result<(), String> {
79        self.commands
80            .send(TcpCommand::Eof)
81            .await
82            .map_err(|_| "TCP session is closed".to_string())
83    }
84
85    /// Tear down the TCP session.
86    ///
87    /// Aborts the relay task directly rather than queuing a command, so teardown
88    /// never waits behind a full command queue. Dropping the task closes the
89    /// guest socket. The host has already closed its side before asking for this,
90    /// so no terminal frame is owed back to it.
91    pub fn close(&self) {
92        self.task.abort();
93    }
94
95    /// Returns whether the background relay task has finished.
96    pub fn is_finished(&self) -> bool {
97        self.task.is_finished()
98    }
99
100    /// Open a TCP stream from inside the guest and start relaying it.
101    ///
102    /// The OS connect runs inside the spawned task, not on the caller's serial
103    /// loop, so a hanging or slow destination can never wedge the agent. The
104    /// task reports `core.tcp.connected` on success or a terminal
105    /// `core.tcp.failed` on error/timeout over `session_tx`; the host correlates
106    /// either reply by id. The returned session is live immediately, with
107    /// commands queued until the connect completes.
108    pub fn open(
109        id: u32,
110        req: TcpConnect,
111        session_tx: &mpsc::UnboundedSender<(u32, SessionOutput)>,
112    ) -> Self {
113        let (commands_tx, commands_rx) = mpsc::channel(TCP_COMMAND_CAPACITY);
114        let output_tx = session_tx.clone();
115        let task = tokio::spawn(async move {
116            connect_and_relay(id, req, commands_rx, output_tx).await;
117        });
118
119        Self {
120            owner_id: id,
121            commands: commands_tx,
122            task,
123        }
124    }
125}
126
127//--------------------------------------------------------------------------------------------------
128// Functions: Helpers
129//--------------------------------------------------------------------------------------------------
130
131/// Connects to the destination, reports the outcome, then relays the stream.
132///
133/// Runs entirely inside the per-session task. On a connect error or timeout it
134/// emits a terminal `core.tcp.failed`; the agent loop removes the session when
135/// that frame flows past. On success it emits `core.tcp.connected` and hands off
136/// to the relay loop.
137async fn connect_and_relay(
138    id: u32,
139    req: TcpConnect,
140    commands: mpsc::Receiver<TcpCommand>,
141    tx: mpsc::UnboundedSender<(u32, SessionOutput)>,
142) {
143    let connect = TcpStream::connect((req.host.as_str(), req.port));
144    let stream = match tokio::time::timeout(TCP_CONNECT_TIMEOUT, connect).await {
145        Ok(Ok(stream)) => stream,
146        Ok(Err(e)) => {
147            send_raw_tcp_message(
148                id,
149                MessageType::TcpFailed,
150                &TcpFailed {
151                    error: format!("connect {}:{}: {e}", req.host, req.port),
152                },
153                RawActivity::guest_message(),
154                Some(RawSessionCompletion::Tcp),
155                &tx,
156            );
157            return;
158        }
159        Err(_elapsed) => {
160            send_raw_tcp_message(
161                id,
162                MessageType::TcpFailed,
163                &TcpFailed {
164                    error: format!("connect {}:{} timed out", req.host, req.port),
165                },
166                RawActivity::guest_message(),
167                Some(RawSessionCompletion::Tcp),
168                &tx,
169            );
170            return;
171        }
172    };
173
174    if !send_raw_tcp_message(
175        id,
176        MessageType::TcpConnected,
177        &TcpConnected {},
178        RawActivity::guest_message(),
179        None,
180        &tx,
181    ) {
182        return;
183    }
184
185    relay_tcp_session(id, stream, commands, tx).await;
186}
187
188async fn relay_tcp_session(
189    id: u32,
190    mut stream: TcpStream,
191    mut commands: mpsc::Receiver<TcpCommand>,
192    tx: mpsc::UnboundedSender<(u32, SessionOutput)>,
193) {
194    let mut read_buf = vec![0u8; TCP_CHUNK_SIZE];
195    let mut terminal_sent = false;
196    // The destination half-closed its write side. We stop reading but keep the
197    // loop alive so host->destination data still flows until the host closes.
198    let mut read_eof = false;
199
200    loop {
201        tokio::select! {
202            read = stream.read(&mut read_buf), if !read_eof => {
203                match read {
204                    Ok(0) => {
205                        send_raw_tcp_message(
206                            id,
207                            MessageType::TcpEof,
208                            &TcpEof {},
209                            RawActivity::guest_message(),
210                            None,
211                            &tx,
212                        );
213                        read_eof = true;
214                    }
215                    Ok(n) => {
216                        let data = read_buf[..n].to_vec();
217                        if !send_raw_tcp_message(
218                            id,
219                            MessageType::TcpData,
220                            &TcpData { data },
221                            RawActivity::tcp_bytes(n),
222                            None,
223                            &tx,
224                        ) {
225                            break;
226                        }
227                    }
228                    Err(e) => {
229                        terminal_sent = send_raw_tcp_message(
230                            id,
231                            MessageType::TcpFailed,
232                            &TcpFailed {
233                                error: format!("read TCP stream: {e}"),
234                            },
235                            RawActivity::guest_message(),
236                            Some(RawSessionCompletion::Tcp),
237                            &tx,
238                        );
239                        break;
240                    }
241                }
242            }
243            command = commands.recv() => {
244                match command {
245                    Some(TcpCommand::Data(data)) => {
246                        if let Err(e) = stream.write_all(&data).await {
247                            terminal_sent = send_raw_tcp_message(
248                                id,
249                                MessageType::TcpFailed,
250                                &TcpFailed {
251                                    error: format!("write TCP stream: {e}"),
252                                },
253                                RawActivity::guest_message(),
254                                Some(RawSessionCompletion::Tcp),
255                                &tx,
256                            );
257                            break;
258                        }
259                    }
260                    Some(TcpCommand::Eof) => {
261                        if let Err(e) = stream.shutdown().await {
262                            terminal_sent = send_raw_tcp_message(
263                                id,
264                                MessageType::TcpFailed,
265                                &TcpFailed {
266                                    error: format!("shutdown TCP stream: {e}"),
267                                },
268                                RawActivity::guest_message(),
269                                Some(RawSessionCompletion::Tcp),
270                                &tx,
271                            );
272                            break;
273                        }
274                    }
275                    None => {
276                        break;
277                    }
278                }
279            }
280        }
281    }
282
283    if !terminal_sent {
284        send_raw_tcp_message(
285            id,
286            MessageType::TcpClosed,
287            &TcpClosed {},
288            RawActivity::guest_message(),
289            Some(RawSessionCompletion::Tcp),
290            &tx,
291        );
292    }
293}
294
295fn encode_tcp_message<T: serde::Serialize>(
296    id: u32,
297    t: MessageType,
298    payload: &T,
299    out_buf: &mut Vec<u8>,
300) -> Result<(), String> {
301    let msg = Message::with_payload(t, id, payload).map_err(|e| format!("encode tcp: {e}"))?;
302    codec::encode_to_buf(&msg, out_buf).map_err(|e| format!("encode tcp frame: {e}"))?;
303    Ok(())
304}
305
306fn send_raw_tcp_message<T: serde::Serialize>(
307    id: u32,
308    t: MessageType,
309    payload: &T,
310    activity: RawActivity,
311    completion: Option<RawSessionCompletion>,
312    tx: &mpsc::UnboundedSender<(u32, SessionOutput)>,
313) -> bool {
314    let mut buf = Vec::new();
315    match encode_tcp_message(id, t, payload, &mut buf) {
316        Ok(()) => tx
317            .send((
318                id,
319                SessionOutput::Raw(RawSessionOutput::new(buf, activity, completion)),
320            ))
321            .is_ok(),
322        Err(e) => {
323            eprintln!("failed to encode tcp message for {id}: {e}");
324            false
325        }
326    }
327}
328
329//--------------------------------------------------------------------------------------------------
330// Tests
331//--------------------------------------------------------------------------------------------------
332
333#[cfg(test)]
334mod tests {
335    use std::time::Duration;
336
337    use microsandbox_protocol::message::FLAG_TERMINAL;
338    use tokio::net::TcpListener;
339
340    use super::*;
341
342    #[tokio::test]
343    async fn connect_failure_sends_terminal_failed() {
344        let (session_tx, mut session_rx) = mpsc::unbounded_channel();
345
346        let session = TcpSession::open(
347            7,
348            TcpConnect {
349                host: "127.0.0.1".to_string(),
350                port: 0,
351            },
352            &session_tx,
353        );
354
355        // The connect runs in the task and reports failure over session_tx.
356        let msg = recv_message(&mut session_rx).await;
357        assert_eq!(msg.t, MessageType::TcpFailed);
358        assert_eq!(msg.flags, FLAG_TERMINAL);
359        let failed: TcpFailed = msg.payload().unwrap();
360        assert!(failed.error.contains("connect 127.0.0.1:0"));
361
362        wait_finished(&session).await;
363    }
364
365    #[tokio::test]
366    async fn close_request_finishes_session_task() {
367        let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
368        let port = listener.local_addr().unwrap().port();
369        let (session_tx, mut session_rx) = mpsc::unbounded_channel();
370        let accept_task = tokio::spawn(async move {
371            let (_socket, _) = listener.accept().await.unwrap();
372            tokio::time::sleep(Duration::from_secs(5)).await;
373        });
374
375        let session = TcpSession::open(
376            9,
377            TcpConnect {
378                host: "127.0.0.1".to_string(),
379                port,
380            },
381            &session_tx,
382        );
383
384        let connected = recv_message(&mut session_rx).await;
385        assert_eq!(connected.t, MessageType::TcpConnected);
386
387        session.close();
388        wait_finished(&session).await;
389
390        accept_task.abort();
391    }
392
393    #[tokio::test]
394    async fn destination_eof_keeps_session_open_for_host_writes() {
395        let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
396        let port = listener.local_addr().unwrap().port();
397        let (session_tx, mut session_rx) = mpsc::unbounded_channel();
398
399        // The destination half-closes its write side, then keeps reading so it
400        // still receives whatever the host sends after the EOF.
401        let (got_tx, got_rx) = tokio::sync::oneshot::channel();
402        let accept_task = tokio::spawn(async move {
403            let (mut socket, _) = listener.accept().await.unwrap();
404            socket.shutdown().await.unwrap();
405            let mut buf = Vec::new();
406            socket.read_to_end(&mut buf).await.unwrap();
407            let _ = got_tx.send(buf);
408        });
409
410        let session = TcpSession::open(
411            11,
412            TcpConnect {
413                host: "127.0.0.1".to_string(),
414                port,
415            },
416            &session_tx,
417        );
418
419        let connected = recv_message(&mut session_rx).await;
420        assert_eq!(connected.t, MessageType::TcpConnected);
421
422        // The destination's FIN surfaces as a non-terminal TcpEof, and the
423        // session stays alive.
424        let eof = recv_message(&mut session_rx).await;
425        assert_eq!(eof.t, MessageType::TcpEof);
426        assert_ne!(eof.flags, FLAG_TERMINAL);
427        assert!(!session.is_finished());
428
429        // The host can still reach the destination after that EOF.
430        session.write_data(b"after-eof".to_vec()).await.unwrap();
431        session.close_write().await.unwrap();
432        let received = tokio::time::timeout(Duration::from_secs(1), got_rx)
433            .await
434            .unwrap()
435            .unwrap();
436        assert_eq!(received, b"after-eof");
437
438        // An explicit close tears the session down.
439        session.close();
440        wait_finished(&session).await;
441
442        accept_task.await.unwrap();
443    }
444
445    async fn wait_finished(session: &TcpSession) {
446        tokio::time::timeout(Duration::from_secs(1), async {
447            while !session.is_finished() {
448                tokio::time::sleep(Duration::from_millis(10)).await;
449            }
450        })
451        .await
452        .unwrap();
453    }
454
455    fn decode_one_message(buf: &mut Vec<u8>) -> Message {
456        codec::try_decode_from_buf(buf).unwrap().unwrap()
457    }
458
459    async fn recv_message(rx: &mut mpsc::UnboundedReceiver<(u32, SessionOutput)>) -> Message {
460        let (_id, output) = rx.recv().await.unwrap();
461        let SessionOutput::Raw(mut output) = output else {
462            panic!("expected SessionOutput::Raw frame");
463        };
464        decode_one_message(&mut output.frame)
465    }
466}