1use std::time::Duration;
7
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpStream;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13use microsandbox_protocol::codec;
14use microsandbox_protocol::message::{Message, MessageType};
15use microsandbox_protocol::tcp::{TcpClosed, TcpConnect, TcpConnected, TcpData, TcpEof, TcpFailed};
16
17use crate::session::{RawActivity, RawSessionCompletion, RawSessionOutput, SessionOutput};
18
19const TCP_CHUNK_SIZE: usize = 64 * 1024;
25
26const TCP_COMMAND_CAPACITY: usize = 32;
31
32const TCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37pub struct TcpSession {
43 owner_id: u32,
44 commands: mpsc::Sender<TcpCommand>,
45 task: JoinHandle<()>,
46}
47
48enum TcpCommand {
49 Data(Vec<u8>),
50 Eof,
51}
52
53impl TcpSession {
58 pub fn owner_id(&self) -> u32 {
60 self.owner_id
61 }
62
63 pub async fn write_data(&self, data: Vec<u8>) -> Result<(), String> {
68 self.commands
69 .send(TcpCommand::Data(data))
70 .await
71 .map_err(|_| "TCP session is closed".to_string())
72 }
73
74 pub async fn close_write(&self) -> Result<(), String> {
79 self.commands
80 .send(TcpCommand::Eof)
81 .await
82 .map_err(|_| "TCP session is closed".to_string())
83 }
84
85 pub fn close(&self) {
92 self.task.abort();
93 }
94
95 pub fn is_finished(&self) -> bool {
97 self.task.is_finished()
98 }
99
100 pub fn open(
109 id: u32,
110 req: TcpConnect,
111 session_tx: &mpsc::UnboundedSender<(u32, SessionOutput)>,
112 ) -> Self {
113 let (commands_tx, commands_rx) = mpsc::channel(TCP_COMMAND_CAPACITY);
114 let output_tx = session_tx.clone();
115 let task = tokio::spawn(async move {
116 connect_and_relay(id, req, commands_rx, output_tx).await;
117 });
118
119 Self {
120 owner_id: id,
121 commands: commands_tx,
122 task,
123 }
124 }
125}
126
127async fn connect_and_relay(
138 id: u32,
139 req: TcpConnect,
140 commands: mpsc::Receiver<TcpCommand>,
141 tx: mpsc::UnboundedSender<(u32, SessionOutput)>,
142) {
143 let connect = TcpStream::connect((req.host.as_str(), req.port));
144 let stream = match tokio::time::timeout(TCP_CONNECT_TIMEOUT, connect).await {
145 Ok(Ok(stream)) => stream,
146 Ok(Err(e)) => {
147 send_raw_tcp_message(
148 id,
149 MessageType::TcpFailed,
150 &TcpFailed {
151 error: format!("connect {}:{}: {e}", req.host, req.port),
152 },
153 RawActivity::guest_message(),
154 Some(RawSessionCompletion::Tcp),
155 &tx,
156 );
157 return;
158 }
159 Err(_elapsed) => {
160 send_raw_tcp_message(
161 id,
162 MessageType::TcpFailed,
163 &TcpFailed {
164 error: format!("connect {}:{} timed out", req.host, req.port),
165 },
166 RawActivity::guest_message(),
167 Some(RawSessionCompletion::Tcp),
168 &tx,
169 );
170 return;
171 }
172 };
173
174 if !send_raw_tcp_message(
175 id,
176 MessageType::TcpConnected,
177 &TcpConnected {},
178 RawActivity::guest_message(),
179 None,
180 &tx,
181 ) {
182 return;
183 }
184
185 relay_tcp_session(id, stream, commands, tx).await;
186}
187
188async fn relay_tcp_session(
189 id: u32,
190 mut stream: TcpStream,
191 mut commands: mpsc::Receiver<TcpCommand>,
192 tx: mpsc::UnboundedSender<(u32, SessionOutput)>,
193) {
194 let mut read_buf = vec![0u8; TCP_CHUNK_SIZE];
195 let mut terminal_sent = false;
196 let mut read_eof = false;
199
200 loop {
201 tokio::select! {
202 read = stream.read(&mut read_buf), if !read_eof => {
203 match read {
204 Ok(0) => {
205 send_raw_tcp_message(
206 id,
207 MessageType::TcpEof,
208 &TcpEof {},
209 RawActivity::guest_message(),
210 None,
211 &tx,
212 );
213 read_eof = true;
214 }
215 Ok(n) => {
216 let data = read_buf[..n].to_vec();
217 if !send_raw_tcp_message(
218 id,
219 MessageType::TcpData,
220 &TcpData { data },
221 RawActivity::tcp_bytes(n),
222 None,
223 &tx,
224 ) {
225 break;
226 }
227 }
228 Err(e) => {
229 terminal_sent = send_raw_tcp_message(
230 id,
231 MessageType::TcpFailed,
232 &TcpFailed {
233 error: format!("read TCP stream: {e}"),
234 },
235 RawActivity::guest_message(),
236 Some(RawSessionCompletion::Tcp),
237 &tx,
238 );
239 break;
240 }
241 }
242 }
243 command = commands.recv() => {
244 match command {
245 Some(TcpCommand::Data(data)) => {
246 if let Err(e) = stream.write_all(&data).await {
247 terminal_sent = send_raw_tcp_message(
248 id,
249 MessageType::TcpFailed,
250 &TcpFailed {
251 error: format!("write TCP stream: {e}"),
252 },
253 RawActivity::guest_message(),
254 Some(RawSessionCompletion::Tcp),
255 &tx,
256 );
257 break;
258 }
259 }
260 Some(TcpCommand::Eof) => {
261 if let Err(e) = stream.shutdown().await {
262 terminal_sent = send_raw_tcp_message(
263 id,
264 MessageType::TcpFailed,
265 &TcpFailed {
266 error: format!("shutdown TCP stream: {e}"),
267 },
268 RawActivity::guest_message(),
269 Some(RawSessionCompletion::Tcp),
270 &tx,
271 );
272 break;
273 }
274 }
275 None => {
276 break;
277 }
278 }
279 }
280 }
281 }
282
283 if !terminal_sent {
284 send_raw_tcp_message(
285 id,
286 MessageType::TcpClosed,
287 &TcpClosed {},
288 RawActivity::guest_message(),
289 Some(RawSessionCompletion::Tcp),
290 &tx,
291 );
292 }
293}
294
295fn encode_tcp_message<T: serde::Serialize>(
296 id: u32,
297 t: MessageType,
298 payload: &T,
299 out_buf: &mut Vec<u8>,
300) -> Result<(), String> {
301 let msg = Message::with_payload(t, id, payload).map_err(|e| format!("encode tcp: {e}"))?;
302 codec::encode_to_buf(&msg, out_buf).map_err(|e| format!("encode tcp frame: {e}"))?;
303 Ok(())
304}
305
306fn send_raw_tcp_message<T: serde::Serialize>(
307 id: u32,
308 t: MessageType,
309 payload: &T,
310 activity: RawActivity,
311 completion: Option<RawSessionCompletion>,
312 tx: &mpsc::UnboundedSender<(u32, SessionOutput)>,
313) -> bool {
314 let mut buf = Vec::new();
315 match encode_tcp_message(id, t, payload, &mut buf) {
316 Ok(()) => tx
317 .send((
318 id,
319 SessionOutput::Raw(RawSessionOutput::new(buf, activity, completion)),
320 ))
321 .is_ok(),
322 Err(e) => {
323 eprintln!("failed to encode tcp message for {id}: {e}");
324 false
325 }
326 }
327}
328
329#[cfg(test)]
334mod tests {
335 use std::time::Duration;
336
337 use microsandbox_protocol::message::FLAG_TERMINAL;
338 use tokio::net::TcpListener;
339
340 use super::*;
341
342 #[tokio::test]
343 async fn connect_failure_sends_terminal_failed() {
344 let (session_tx, mut session_rx) = mpsc::unbounded_channel();
345
346 let session = TcpSession::open(
347 7,
348 TcpConnect {
349 host: "127.0.0.1".to_string(),
350 port: 0,
351 },
352 &session_tx,
353 );
354
355 let msg = recv_message(&mut session_rx).await;
357 assert_eq!(msg.t, MessageType::TcpFailed);
358 assert_eq!(msg.flags, FLAG_TERMINAL);
359 let failed: TcpFailed = msg.payload().unwrap();
360 assert!(failed.error.contains("connect 127.0.0.1:0"));
361
362 wait_finished(&session).await;
363 }
364
365 #[tokio::test]
366 async fn close_request_finishes_session_task() {
367 let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
368 let port = listener.local_addr().unwrap().port();
369 let (session_tx, mut session_rx) = mpsc::unbounded_channel();
370 let accept_task = tokio::spawn(async move {
371 let (_socket, _) = listener.accept().await.unwrap();
372 tokio::time::sleep(Duration::from_secs(5)).await;
373 });
374
375 let session = TcpSession::open(
376 9,
377 TcpConnect {
378 host: "127.0.0.1".to_string(),
379 port,
380 },
381 &session_tx,
382 );
383
384 let connected = recv_message(&mut session_rx).await;
385 assert_eq!(connected.t, MessageType::TcpConnected);
386
387 session.close();
388 wait_finished(&session).await;
389
390 accept_task.abort();
391 }
392
393 #[tokio::test]
394 async fn destination_eof_keeps_session_open_for_host_writes() {
395 let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
396 let port = listener.local_addr().unwrap().port();
397 let (session_tx, mut session_rx) = mpsc::unbounded_channel();
398
399 let (got_tx, got_rx) = tokio::sync::oneshot::channel();
402 let accept_task = tokio::spawn(async move {
403 let (mut socket, _) = listener.accept().await.unwrap();
404 socket.shutdown().await.unwrap();
405 let mut buf = Vec::new();
406 socket.read_to_end(&mut buf).await.unwrap();
407 let _ = got_tx.send(buf);
408 });
409
410 let session = TcpSession::open(
411 11,
412 TcpConnect {
413 host: "127.0.0.1".to_string(),
414 port,
415 },
416 &session_tx,
417 );
418
419 let connected = recv_message(&mut session_rx).await;
420 assert_eq!(connected.t, MessageType::TcpConnected);
421
422 let eof = recv_message(&mut session_rx).await;
425 assert_eq!(eof.t, MessageType::TcpEof);
426 assert_ne!(eof.flags, FLAG_TERMINAL);
427 assert!(!session.is_finished());
428
429 session.write_data(b"after-eof".to_vec()).await.unwrap();
431 session.close_write().await.unwrap();
432 let received = tokio::time::timeout(Duration::from_secs(1), got_rx)
433 .await
434 .unwrap()
435 .unwrap();
436 assert_eq!(received, b"after-eof");
437
438 session.close();
440 wait_finished(&session).await;
441
442 accept_task.await.unwrap();
443 }
444
445 async fn wait_finished(session: &TcpSession) {
446 tokio::time::timeout(Duration::from_secs(1), async {
447 while !session.is_finished() {
448 tokio::time::sleep(Duration::from_millis(10)).await;
449 }
450 })
451 .await
452 .unwrap();
453 }
454
455 fn decode_one_message(buf: &mut Vec<u8>) -> Message {
456 codec::try_decode_from_buf(buf).unwrap().unwrap()
457 }
458
459 async fn recv_message(rx: &mut mpsc::UnboundedReceiver<(u32, SessionOutput)>) -> Message {
460 let (_id, output) = rx.recv().await.unwrap();
461 let SessionOutput::Raw(mut output) = output else {
462 panic!("expected SessionOutput::Raw frame");
463 };
464 decode_one_message(&mut output.frame)
465 }
466}