ts_netstack_smoltcp_socket/tcp/
stream.rs1use core::{
2 fmt::{Debug, Formatter},
3 net::SocketAddr,
4};
5
6use bytes::Bytes;
7use netcore::{DisplayExt, HasChannel, Response, smoltcp::iface::SocketHandle, tcp};
8
9#[cfg(any(feature = "tokio", feature = "futures-io"))]
10type PinBoxFut<T> = core::pin::Pin<alloc::boxed::Box<dyn Future<Output = T> + Send + Sync>>;
11
12pub struct TcpStream {
14 sender: netcore::Channel,
15 handle: SocketHandle,
16
17 local: SocketAddr,
18 remote: SocketAddr,
19
20 #[cfg(any(feature = "tokio", feature = "futures-io"))]
21 read_fut: Option<PinBoxFut<Result<Bytes, netcore::Error>>>,
22 #[cfg(any(feature = "tokio", feature = "futures-io"))]
29 read_remainder: Option<Bytes>,
30 #[cfg(any(feature = "tokio", feature = "futures-io"))]
31 write_fut: Option<PinBoxFut<Result<usize, netcore::Error>>>,
32}
33
34impl TcpStream {
35 pub(crate) const fn new(
36 sender: netcore::Channel,
37 handle: SocketHandle,
38 remote: SocketAddr,
39 local: SocketAddr,
40 ) -> Self {
41 Self {
42 sender,
43 handle,
44 remote,
45 local,
46
47 #[cfg(any(feature = "tokio", feature = "futures-io"))]
48 read_fut: None,
49
50 #[cfg(any(feature = "tokio", feature = "futures-io"))]
51 read_remainder: None,
52
53 #[cfg(any(feature = "tokio", feature = "futures-io"))]
54 write_fut: None,
55 }
56 }
57}
58
59impl Debug for TcpStream {
60 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
61 f.debug_struct("TcpStream")
62 .field("handle", &self.handle.as_display_debug())
63 .field("local_endpoint", &self.local)
64 .field("remote_endpoint", &self.remote)
65 .finish()
66 }
67}
68
69impl TcpStream {
70 pub const fn local_addr(&self) -> SocketAddr {
72 self.local
73 }
74
75 pub const fn remote_addr(&self) -> SocketAddr {
77 self.remote
78 }
79
80 pub fn shutdown_write(&self) {
96 if let Err(e) = self
97 .sender
98 .request_nonblocking(Some(self.handle), tcp::stream::Command::ShutdownWrite)
99 {
100 tracing::debug!(err = %e, "shutdown_write: netstack channel closed");
101 }
102 }
103
104 pub fn send_blocking(&self, b: &[u8]) -> Result<usize, netcore::Error> {
109 let resp = self.request_blocking(tcp::stream::Command::Send {
110 buf: Bytes::copy_from_slice(b),
111 })?;
112
113 self._send(resp)
114 }
115
116 pub async fn send(&self, b: &[u8]) -> Result<usize, netcore::Error> {
121 let resp = self
122 .request(tcp::stream::Command::Send {
123 buf: Bytes::copy_from_slice(b),
124 })
125 .await?;
126
127 self._send(resp)
128 }
129
130 fn _send(&self, resp: Response) -> Result<usize, netcore::Error> {
131 netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
132 Ok(n)
133 }
134
135 pub fn recv_blocking(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
139 let resp = self.request_blocking(tcp::stream::Command::Recv {
140 max_len: Some(b.len()),
141 })?;
142
143 self._recv(resp, b)
144 }
145
146 pub async fn recv(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
150 let resp = self
151 .request(tcp::stream::Command::Recv {
152 max_len: Some(b.len()),
153 })
154 .await?;
155
156 self._recv(resp, b)
157 }
158
159 pub fn recv_bytes_blocking(&self) -> Result<Bytes, netcore::Error> {
163 let resp = self.request_blocking(tcp::stream::Command::Recv { max_len: None })?;
164
165 self._recv_bytes(resp)
166 }
167
168 pub async fn recv_bytes(&self) -> Result<Bytes, netcore::Error> {
170 let resp = self
171 .request(tcp::stream::Command::Recv { max_len: None })
172 .await?;
173
174 self._recv_bytes(resp)
175 }
176
177 fn _recv(&self, resp: Response, b: &mut [u8]) -> Result<usize, netcore::Error> {
178 let buf = self._recv_bytes(resp)?;
179
180 let n = buf.len().min(b.len());
181 b[..n].copy_from_slice(&buf[..n]);
182
183 Ok(n)
184 }
185
186 fn _recv_bytes(&self, resp: Response) -> Result<Bytes, netcore::Error> {
187 if matches!(resp, Response::TcpStream(tcp::stream::Response::Finished)) {
188 return Ok(Bytes::new());
189 }
190
191 netcore::try_response_as!(resp, tcp::stream::Response::Recv { buf });
192 Ok(buf)
193 }
194
195 #[cfg(any(feature = "tokio", feature = "futures-io"))]
196 fn poll_read(
197 mut self: core::pin::Pin<&mut Self>,
198 cx: &mut core::task::Context,
199 buf: &mut [u8],
200 ) -> core::task::Poll<std::io::Result<usize>> {
201 use netcore::HasChannel;
202
203 debug_assert!(
209 !buf.is_empty() || self.read_remainder.is_none(),
210 "poll_read called with an empty buffer while bytes are buffered — Ok(0) would look like EOF"
211 );
212
213 fn copy_into_buf(mut data: Bytes, buf: &mut [u8]) -> (usize, Bytes) {
218 let n = data.len().min(buf.len());
219 buf[..n].copy_from_slice(&data.split_to(n));
220 (n, data)
221 }
222
223 if let Some(rem) = self.read_remainder.take() {
225 let (n, tail) = copy_into_buf(rem, buf);
226 if !tail.is_empty() {
227 self.read_remainder = Some(tail);
228 }
229 return core::task::Poll::Ready(Ok(n));
230 }
231
232 let handle = self.handle;
233 let cap = buf.len();
234
235 loop {
236 match self.read_fut.as_mut() {
237 None => {
238 let sender = self.sender.clone();
239
240 let _ret = self.read_fut.insert(alloc::boxed::Box::pin(async move {
241 let resp = sender
242 .request(
243 Some(handle),
244 tcp::stream::Command::Recv { max_len: Some(cap) },
245 )
246 .await?;
247
248 if matches!(
254 resp,
255 netcore::Response::Error(netcore::Error::Internal(
256 netcore::InternalErrorKind::BadSocketHandle
257 ))
258 ) {
259 return Ok(Bytes::new());
260 }
261
262 match resp.try_into()? {
263 tcp::stream::Response::Recv { buf } => Ok(buf),
264 tcp::stream::Response::Finished => Ok(Bytes::new()),
265 _ => Err(netcore::Error::wrong_type()),
266 }
267 }));
268 }
269
270 Some(x) => {
271 let poll_result = x.as_mut().poll(cx);
272 let ret = core::task::ready!(poll_result)?;
273
274 self.read_fut.take();
275
276 let (n, tail) = copy_into_buf(ret, buf);
280 if !tail.is_empty() {
281 self.read_remainder = Some(tail);
282 }
283
284 break core::task::Poll::Ready(Ok(n));
285 }
286 }
287 }
288 }
289
290 #[cfg(any(feature = "tokio", feature = "futures-io"))]
291 fn poll_write(
292 mut self: core::pin::Pin<&mut Self>,
293 cx: &mut core::task::Context<'_>,
294 buf: &[u8],
295 ) -> core::task::Poll<std::io::Result<usize>> {
296 use netcore::HasChannel;
297
298 let handle = self.handle;
299
300 loop {
301 match &mut self.write_fut {
302 None => {
303 let b = Bytes::copy_from_slice(buf);
304 let sender = self.sender.clone();
305
306 let _ret = self.write_fut.insert(alloc::boxed::Box::pin(async move {
307 let resp = sender
308 .request(Some(handle), tcp::stream::Command::Send { buf: b })
309 .await?;
310
311 if matches!(
317 resp,
318 netcore::Response::Error(netcore::Error::Internal(
319 netcore::InternalErrorKind::BadSocketHandle
320 ))
321 ) {
322 return Err(netcore::Error::ConnectionReset);
323 }
324
325 netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
326 Ok(n)
327 }));
328 }
329
330 Some(x) => {
331 let poll_result = x.as_mut().poll(cx);
332 let ret = core::task::ready!(poll_result)?;
333
334 self.write_fut.take();
335
336 break core::task::Poll::Ready(Ok(ret));
337 }
338 }
339 }
340 }
341
342 socket_requestor_impl!();
343}
344
345impl Drop for TcpStream {
346 fn drop(&mut self) {
347 if let Err(e) = self
348 .sender
349 .request_nonblocking(Some(self.handle), tcp::stream::Command::Close)
350 {
351 tracing::warn!(err = %e, "possible socket leak");
352 }
353 }
354}
355
356#[cfg(feature = "std")]
357impl std::io::Read for TcpStream {
358 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
359 self.recv_blocking(buf).map_err(netcore::Error::into)
360 }
361}
362
363#[cfg(feature = "std")]
364impl std::io::Write for TcpStream {
365 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
366 self.send_blocking(buf).map_err(netcore::Error::into)
367 }
368
369 fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
370 let mut buf = Bytes::copy_from_slice(buf);
371
372 while !buf.is_empty() {
373 let resp = self.request_blocking(tcp::stream::Command::Send { buf: buf.clone() })?;
374 netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
375
376 let _consumed = buf.split_to(n);
377 }
378
379 Ok(())
380 }
381
382 fn flush(&mut self) -> std::io::Result<()> {
383 Ok(())
384 }
385}
386
387#[cfg(feature = "tokio")]
388impl tokio::io::AsyncRead for TcpStream {
389 fn poll_read(
390 self: core::pin::Pin<&mut Self>,
391 cx: &mut core::task::Context<'_>,
392 buf: &mut tokio::io::ReadBuf<'_>,
393 ) -> core::task::Poll<tokio::io::Result<()>> {
394 let n = core::task::ready!(self.poll_read(cx, buf.initialize_unfilled()))?;
395 buf.advance(n);
396
397 core::task::Poll::Ready(Ok(()))
398 }
399}
400
401#[cfg(feature = "tokio")]
402impl tokio::io::AsyncWrite for TcpStream {
403 fn poll_write(
404 self: core::pin::Pin<&mut Self>,
405 cx: &mut core::task::Context<'_>,
406 buf: &[u8],
407 ) -> core::task::Poll<std::io::Result<usize>> {
408 self.poll_write(cx, buf)
409 }
410
411 fn poll_flush(
412 self: core::pin::Pin<&mut Self>,
413 _cx: &mut core::task::Context<'_>,
414 ) -> core::task::Poll<std::io::Result<()>> {
415 core::task::Poll::Ready(Ok(()))
416 }
417
418 fn poll_shutdown(
419 self: core::pin::Pin<&mut Self>,
420 _cx: &mut core::task::Context<'_>,
421 ) -> core::task::Poll<std::io::Result<()>> {
422 self.shutdown_write();
423 core::task::Poll::Ready(Ok(()))
424 }
425}
426
427#[cfg(feature = "tokio")]
428#[cfg(test)]
429mod reaped_socket_mapping_tests {
430 use core::net::SocketAddr;
431
432 use netcore::{HasChannel, Netstack, smoltcp::iface::SocketHandle, udp};
433 use tokio::io::{AsyncReadExt, AsyncWriteExt};
434
435 use super::TcpStream;
436
437 fn stream_over_reaped_handle() -> TcpStream {
444 let mut stack = Netstack::new(
445 netcore::Config::default(),
446 netcore::smoltcp::time::Instant::ZERO,
447 );
448 let chan = stack.command_channel();
449
450 std::thread::spawn(move || {
454 while let Ok(cmd) = stack.wait_for_cmd_blocking(None) {
455 stack.process_one_cmd(cmd);
456 }
457 });
458
459 let handle: SocketHandle = match chan
461 .request_blocking(
462 None,
463 udp::Command::Bind {
464 endpoint: SocketAddr::from(([127, 0, 0, 1], 9200)),
465 },
466 )
467 .expect("channel open")
468 {
469 netcore::Response::Udp(udp::Response::Bound { handle, .. }) => handle,
470 other => panic!("expected Bound, got {other:?}"),
471 };
472 assert!(matches!(
475 chan.request_blocking(Some(handle), udp::Command::Close)
476 .expect("channel open"),
477 netcore::Response::Ok
478 ));
479
480 let local = SocketAddr::from(([127, 0, 0, 1], 50100));
481 let remote = SocketAddr::from(([127, 0, 0, 1], 9200));
482 TcpStream::new(chan, handle, remote, local)
483 }
484
485 #[tokio::test]
488 async fn poll_read_on_reaped_socket_is_eof() {
489 let mut stream = stream_over_reaped_handle();
490 let mut buf = [0u8; 64];
491 let n = stream
492 .read(&mut buf)
493 .await
494 .expect("read on a reaped socket must be Ok(0), not an error");
495 assert_eq!(n, 0, "a reaped socket must read as EOF (Ok(0))");
496 }
497
498 #[tokio::test]
501 async fn poll_write_on_reaped_socket_is_connection_reset() {
502 let mut stream = stream_over_reaped_handle();
503 let err = stream
504 .write(b"payload")
505 .await
506 .expect_err("write to a reaped socket must error");
507 assert_eq!(
508 err.kind(),
509 std::io::ErrorKind::ConnectionReset,
510 "writing to a reaped socket must surface as ConnectionReset"
511 );
512 }
513}
514
515#[cfg(feature = "futures-io")]
516impl futures_io::AsyncRead for TcpStream {
517 fn poll_read(
518 self: core::pin::Pin<&mut Self>,
519 cx: &mut core::task::Context<'_>,
520 buf: &mut [u8],
521 ) -> core::task::Poll<std::io::Result<usize>> {
522 self.poll_read(cx, buf)
523 }
524}
525
526#[cfg(feature = "futures-io")]
527impl futures_io::AsyncWrite for TcpStream {
528 fn poll_write(
529 self: core::pin::Pin<&mut Self>,
530 cx: &mut core::task::Context<'_>,
531 buf: &[u8],
532 ) -> core::task::Poll<std::io::Result<usize>> {
533 self.poll_write(cx, buf)
534 }
535
536 fn poll_flush(
537 self: core::pin::Pin<&mut Self>,
538 _cx: &mut core::task::Context<'_>,
539 ) -> core::task::Poll<std::io::Result<()>> {
540 core::task::Poll::Ready(Ok(()))
541 }
542
543 fn poll_close(
544 self: core::pin::Pin<&mut Self>,
545 _cx: &mut core::task::Context<'_>,
546 ) -> core::task::Poll<std::io::Result<()>> {
547 self.shutdown_write();
548 core::task::Poll::Ready(Ok(()))
549 }
550}