1use std::time::Duration;
7
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpStream;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13use microsandbox_protocol::codec;
14use microsandbox_protocol::message::{Message, MessageType};
15use microsandbox_protocol::tcp::{TcpClosed, TcpConnect, TcpConnected, TcpData, TcpEof, TcpFailed};
16
17use crate::session::SessionOutput;
18
19const TCP_CHUNK_SIZE: usize = 64 * 1024;
25
26const TCP_COMMAND_CAPACITY: usize = 32;
31
32const TCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37pub struct TcpSession {
43 owner_id: u32,
44 commands: mpsc::Sender<TcpCommand>,
45 task: JoinHandle<()>,
46}
47
48enum TcpCommand {
49 Data(Vec<u8>),
50 Eof,
51}
52
53impl TcpSession {
58 pub fn owner_id(&self) -> u32 {
60 self.owner_id
61 }
62
63 pub async fn write_data(&self, data: Vec<u8>) -> Result<(), String> {
68 self.commands
69 .send(TcpCommand::Data(data))
70 .await
71 .map_err(|_| "TCP session is closed".to_string())
72 }
73
74 pub async fn close_write(&self) -> Result<(), String> {
79 self.commands
80 .send(TcpCommand::Eof)
81 .await
82 .map_err(|_| "TCP session is closed".to_string())
83 }
84
85 pub fn close(&self) {
92 self.task.abort();
93 }
94
95 pub fn is_finished(&self) -> bool {
97 self.task.is_finished()
98 }
99
100 pub fn open(
109 id: u32,
110 req: TcpConnect,
111 session_tx: &mpsc::UnboundedSender<(u32, SessionOutput)>,
112 ) -> Self {
113 let (commands_tx, commands_rx) = mpsc::channel(TCP_COMMAND_CAPACITY);
114 let output_tx = session_tx.clone();
115 let task = tokio::spawn(async move {
116 connect_and_relay(id, req, commands_rx, output_tx).await;
117 });
118
119 Self {
120 owner_id: id,
121 commands: commands_tx,
122 task,
123 }
124 }
125}
126
127async fn connect_and_relay(
138 id: u32,
139 req: TcpConnect,
140 commands: mpsc::Receiver<TcpCommand>,
141 tx: mpsc::UnboundedSender<(u32, SessionOutput)>,
142) {
143 let connect = TcpStream::connect((req.host.as_str(), req.port));
144 let stream = match tokio::time::timeout(TCP_CONNECT_TIMEOUT, connect).await {
145 Ok(Ok(stream)) => stream,
146 Ok(Err(e)) => {
147 send_raw_tcp_message(
148 id,
149 MessageType::TcpFailed,
150 &TcpFailed {
151 error: format!("connect {}:{}: {e}", req.host, req.port),
152 },
153 &tx,
154 );
155 return;
156 }
157 Err(_elapsed) => {
158 send_raw_tcp_message(
159 id,
160 MessageType::TcpFailed,
161 &TcpFailed {
162 error: format!("connect {}:{} timed out", req.host, req.port),
163 },
164 &tx,
165 );
166 return;
167 }
168 };
169
170 if !send_raw_tcp_message(id, MessageType::TcpConnected, &TcpConnected {}, &tx) {
171 return;
172 }
173
174 relay_tcp_session(id, stream, commands, tx).await;
175}
176
177async fn relay_tcp_session(
178 id: u32,
179 mut stream: TcpStream,
180 mut commands: mpsc::Receiver<TcpCommand>,
181 tx: mpsc::UnboundedSender<(u32, SessionOutput)>,
182) {
183 let mut read_buf = vec![0u8; TCP_CHUNK_SIZE];
184 let mut terminal_sent = false;
185 let mut read_eof = false;
188
189 loop {
190 tokio::select! {
191 read = stream.read(&mut read_buf), if !read_eof => {
192 match read {
193 Ok(0) => {
194 send_raw_tcp_message(id, MessageType::TcpEof, &TcpEof {}, &tx);
195 read_eof = true;
196 }
197 Ok(n) => {
198 if !send_raw_tcp_message(
199 id,
200 MessageType::TcpData,
201 &TcpData {
202 data: read_buf[..n].to_vec(),
203 },
204 &tx,
205 ) {
206 break;
207 }
208 }
209 Err(e) => {
210 terminal_sent = send_raw_tcp_message(
211 id,
212 MessageType::TcpFailed,
213 &TcpFailed {
214 error: format!("read TCP stream: {e}"),
215 },
216 &tx,
217 );
218 break;
219 }
220 }
221 }
222 command = commands.recv() => {
223 match command {
224 Some(TcpCommand::Data(data)) => {
225 if let Err(e) = stream.write_all(&data).await {
226 terminal_sent = send_raw_tcp_message(
227 id,
228 MessageType::TcpFailed,
229 &TcpFailed {
230 error: format!("write TCP stream: {e}"),
231 },
232 &tx,
233 );
234 break;
235 }
236 }
237 Some(TcpCommand::Eof) => {
238 if let Err(e) = stream.shutdown().await {
239 terminal_sent = send_raw_tcp_message(
240 id,
241 MessageType::TcpFailed,
242 &TcpFailed {
243 error: format!("shutdown TCP stream: {e}"),
244 },
245 &tx,
246 );
247 break;
248 }
249 }
250 None => {
251 break;
252 }
253 }
254 }
255 }
256 }
257
258 if !terminal_sent {
259 send_raw_tcp_message(id, MessageType::TcpClosed, &TcpClosed {}, &tx);
260 }
261}
262
263fn encode_tcp_message<T: serde::Serialize>(
264 id: u32,
265 t: MessageType,
266 payload: &T,
267 out_buf: &mut Vec<u8>,
268) -> Result<(), String> {
269 let msg = Message::with_payload(t, id, payload).map_err(|e| format!("encode tcp: {e}"))?;
270 codec::encode_to_buf(&msg, out_buf).map_err(|e| format!("encode tcp frame: {e}"))?;
271 Ok(())
272}
273
274fn send_raw_tcp_message<T: serde::Serialize>(
275 id: u32,
276 t: MessageType,
277 payload: &T,
278 tx: &mpsc::UnboundedSender<(u32, SessionOutput)>,
279) -> bool {
280 let mut buf = Vec::new();
281 match encode_tcp_message(id, t, payload, &mut buf) {
282 Ok(()) => tx.send((id, SessionOutput::Raw(buf))).is_ok(),
283 Err(e) => {
284 eprintln!("failed to encode tcp message for {id}: {e}");
285 false
286 }
287 }
288}
289
290#[cfg(test)]
295mod tests {
296 use std::time::Duration;
297
298 use microsandbox_protocol::message::FLAG_TERMINAL;
299 use tokio::net::TcpListener;
300
301 use super::*;
302
303 #[tokio::test]
304 async fn connect_failure_sends_terminal_failed() {
305 let (session_tx, mut session_rx) = mpsc::unbounded_channel();
306
307 let session = TcpSession::open(
308 7,
309 TcpConnect {
310 host: "127.0.0.1".to_string(),
311 port: 0,
312 },
313 &session_tx,
314 );
315
316 let msg = recv_message(&mut session_rx).await;
318 assert_eq!(msg.t, MessageType::TcpFailed);
319 assert_eq!(msg.flags, FLAG_TERMINAL);
320 let failed: TcpFailed = msg.payload().unwrap();
321 assert!(failed.error.contains("connect 127.0.0.1:0"));
322
323 wait_finished(&session).await;
324 }
325
326 #[tokio::test]
327 async fn close_request_finishes_session_task() {
328 let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
329 let port = listener.local_addr().unwrap().port();
330 let (session_tx, mut session_rx) = mpsc::unbounded_channel();
331 let accept_task = tokio::spawn(async move {
332 let (_socket, _) = listener.accept().await.unwrap();
333 tokio::time::sleep(Duration::from_secs(5)).await;
334 });
335
336 let session = TcpSession::open(
337 9,
338 TcpConnect {
339 host: "127.0.0.1".to_string(),
340 port,
341 },
342 &session_tx,
343 );
344
345 let connected = recv_message(&mut session_rx).await;
346 assert_eq!(connected.t, MessageType::TcpConnected);
347
348 session.close();
349 wait_finished(&session).await;
350
351 accept_task.abort();
352 }
353
354 #[tokio::test]
355 async fn destination_eof_keeps_session_open_for_host_writes() {
356 let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
357 let port = listener.local_addr().unwrap().port();
358 let (session_tx, mut session_rx) = mpsc::unbounded_channel();
359
360 let (got_tx, got_rx) = tokio::sync::oneshot::channel();
363 let accept_task = tokio::spawn(async move {
364 let (mut socket, _) = listener.accept().await.unwrap();
365 socket.shutdown().await.unwrap();
366 let mut buf = Vec::new();
367 socket.read_to_end(&mut buf).await.unwrap();
368 let _ = got_tx.send(buf);
369 });
370
371 let session = TcpSession::open(
372 11,
373 TcpConnect {
374 host: "127.0.0.1".to_string(),
375 port,
376 },
377 &session_tx,
378 );
379
380 let connected = recv_message(&mut session_rx).await;
381 assert_eq!(connected.t, MessageType::TcpConnected);
382
383 let eof = recv_message(&mut session_rx).await;
386 assert_eq!(eof.t, MessageType::TcpEof);
387 assert_ne!(eof.flags, FLAG_TERMINAL);
388 assert!(!session.is_finished());
389
390 session.write_data(b"after-eof".to_vec()).await.unwrap();
392 session.close_write().await.unwrap();
393 let received = tokio::time::timeout(Duration::from_secs(1), got_rx)
394 .await
395 .unwrap()
396 .unwrap();
397 assert_eq!(received, b"after-eof");
398
399 session.close();
401 wait_finished(&session).await;
402
403 accept_task.await.unwrap();
404 }
405
406 async fn wait_finished(session: &TcpSession) {
407 tokio::time::timeout(Duration::from_secs(1), async {
408 while !session.is_finished() {
409 tokio::time::sleep(Duration::from_millis(10)).await;
410 }
411 })
412 .await
413 .unwrap();
414 }
415
416 fn decode_one_message(buf: &mut Vec<u8>) -> Message {
417 codec::try_decode_from_buf(buf).unwrap().unwrap()
418 }
419
420 async fn recv_message(rx: &mut mpsc::UnboundedReceiver<(u32, SessionOutput)>) -> Message {
421 let (_id, output) = rx.recv().await.unwrap();
422 let SessionOutput::Raw(mut bytes) = output else {
423 panic!("expected SessionOutput::Raw frame");
424 };
425 decode_one_message(&mut bytes)
426 }
427}