Skip to main content

ts_netstack_smoltcp_socket/tcp/
stream.rs

1use core::{
2    fmt::{Debug, Formatter},
3    net::SocketAddr,
4};
5
6use bytes::Bytes;
7use netcore::{DisplayExt, HasChannel, Response, smoltcp::iface::SocketHandle, tcp};
8
9#[cfg(any(feature = "tokio", feature = "futures-io"))]
10type PinBoxFut<T> = core::pin::Pin<alloc::boxed::Box<dyn Future<Output = T> + Send + Sync>>;
11
12/// A TCP stream.
13pub struct TcpStream {
14    sender: netcore::Channel,
15    handle: SocketHandle,
16
17    local: SocketAddr,
18    remote: SocketAddr,
19
20    #[cfg(any(feature = "tokio", feature = "futures-io"))]
21    read_fut: Option<PinBoxFut<Result<Bytes, netcore::Error>>>,
22    /// Bytes received from a completed `Recv` that did not fit the caller's buffer on the poll that
23    /// produced them, carried to the next `poll_read`. A `Recv` is sized by the buffer length at
24    /// future-creation, but the `AsyncRead` contract permits the caller to re-poll with a *smaller*
25    /// buffer, so the response can exceed the live buffer — copying it whole would panic
26    /// (`copy_from_slice` length mismatch). We copy what fits and stash the tail here (lossless),
27    /// draining it before issuing the next `Recv`.
28    #[cfg(any(feature = "tokio", feature = "futures-io"))]
29    read_remainder: Option<Bytes>,
30    #[cfg(any(feature = "tokio", feature = "futures-io"))]
31    write_fut: Option<PinBoxFut<Result<usize, netcore::Error>>>,
32}
33
34impl TcpStream {
35    pub(crate) const fn new(
36        sender: netcore::Channel,
37        handle: SocketHandle,
38        remote: SocketAddr,
39        local: SocketAddr,
40    ) -> Self {
41        Self {
42            sender,
43            handle,
44            remote,
45            local,
46
47            #[cfg(any(feature = "tokio", feature = "futures-io"))]
48            read_fut: None,
49
50            #[cfg(any(feature = "tokio", feature = "futures-io"))]
51            read_remainder: None,
52
53            #[cfg(any(feature = "tokio", feature = "futures-io"))]
54            write_fut: None,
55        }
56    }
57}
58
59impl Debug for TcpStream {
60    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
61        f.debug_struct("TcpStream")
62            .field("handle", &self.handle.as_display_debug())
63            .field("local_endpoint", &self.local)
64            .field("remote_endpoint", &self.remote)
65            .finish()
66    }
67}
68
69impl TcpStream {
70    /// Report the local endpoint to which this stream is connected.
71    pub const fn local_addr(&self) -> SocketAddr {
72        self.local
73    }
74
75    /// Report the remote endpoint to which this stream is connected.
76    pub const fn remote_addr(&self) -> SocketAddr {
77        self.remote
78    }
79
80    /// Half-close the write side: send a FIN to the peer (`shutdown(SHUT_WR)` / `CloseWrite`) while
81    /// keeping the read side open so the peer's remaining data can still be received. Fire-and-forget
82    /// (non-blocking, like the `Drop`-time `Close`): the FIN is emitted by the netstack in the
83    /// background, so this returns immediately and a caller using shutdown for signaling (e.g. a
84    /// bidirectional splice half-closing one direction) no longer hangs waiting for a FIN that was
85    /// never sent.
86    ///
87    /// After this, **writes fail** (`InvalidState`): the socket has left the sendable state — this is
88    /// the intended `shutdown(SHUT_WR)` POSIX behavior (previously, when this was a no-op, a write
89    /// after shutdown still succeeded). Reads continue until the peer's FIN.
90    ///
91    /// Best-effort delivery: `request_nonblocking` treats a *full* command channel as success and
92    /// drops the command, so under channel saturation the FIN may not be sent — the socket then
93    /// teardown-degrades to the idle/keep-alive timeout reaper instead of a prompt FIN (never a hard
94    /// leak). A channel-*closed* error means the netstack is gone; the socket is already moot.
95    pub fn shutdown_write(&self) {
96        if let Err(e) = self
97            .sender
98            .request_nonblocking(Some(self.handle), tcp::stream::Command::ShutdownWrite)
99        {
100            tracing::debug!(err = %e, "shutdown_write: netstack channel closed");
101        }
102    }
103
104    /// Send bytes to the remote.
105    ///
106    /// Blocks until at least one byte can be queued. The return value is the number of
107    /// bytes actually sent.
108    pub fn send_blocking(&self, b: &[u8]) -> Result<usize, netcore::Error> {
109        let resp = self.request_blocking(tcp::stream::Command::Send {
110            buf: Bytes::copy_from_slice(b),
111        })?;
112
113        self._send(resp)
114    }
115
116    /// Send bytes to the remote.
117    ///
118    /// Blocks until at least one byte can be queued. The return value is the number of
119    /// bytes actually sent.
120    pub async fn send(&self, b: &[u8]) -> Result<usize, netcore::Error> {
121        let resp = self
122            .request(tcp::stream::Command::Send {
123                buf: Bytes::copy_from_slice(b),
124            })
125            .await?;
126
127        self._send(resp)
128    }
129
130    fn _send(&self, resp: Response) -> Result<usize, netcore::Error> {
131        netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
132        Ok(n)
133    }
134
135    /// Receive bytes from the remote.
136    ///
137    /// Returns the number of bytes actually received (blocks until there is at least one).
138    pub fn recv_blocking(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
139        let resp = self.request_blocking(tcp::stream::Command::Recv {
140            max_len: Some(b.len()),
141        })?;
142
143        self._recv(resp, b)
144    }
145
146    /// Receive bytes from the remote into the supplied buffer.
147    ///
148    /// Returns the number of bytes actually received (blocks until there is at least one).
149    pub async fn recv(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
150        let resp = self
151            .request(tcp::stream::Command::Recv {
152                max_len: Some(b.len()),
153            })
154            .await?;
155
156        self._recv(resp, b)
157    }
158
159    /// Receive bytes from the remote.
160    ///
161    /// Returns the number of bytes actually received (blocks until there is at least one).
162    pub fn recv_bytes_blocking(&self) -> Result<Bytes, netcore::Error> {
163        let resp = self.request_blocking(tcp::stream::Command::Recv { max_len: None })?;
164
165        self._recv_bytes(resp)
166    }
167
168    /// Receive bytes from the remote.
169    pub async fn recv_bytes(&self) -> Result<Bytes, netcore::Error> {
170        let resp = self
171            .request(tcp::stream::Command::Recv { max_len: None })
172            .await?;
173
174        self._recv_bytes(resp)
175    }
176
177    fn _recv(&self, resp: Response, b: &mut [u8]) -> Result<usize, netcore::Error> {
178        let buf = self._recv_bytes(resp)?;
179
180        let n = buf.len().min(b.len());
181        b[..n].copy_from_slice(&buf[..n]);
182
183        Ok(n)
184    }
185
186    fn _recv_bytes(&self, resp: Response) -> Result<Bytes, netcore::Error> {
187        if matches!(resp, Response::TcpStream(tcp::stream::Response::Finished)) {
188            return Ok(Bytes::new());
189        }
190
191        netcore::try_response_as!(resp, tcp::stream::Response::Recv { buf });
192        Ok(buf)
193    }
194
195    #[cfg(any(feature = "tokio", feature = "futures-io"))]
196    fn poll_read(
197        mut self: core::pin::Pin<&mut Self>,
198        cx: &mut core::task::Context,
199        buf: &mut [u8],
200    ) -> core::task::Poll<std::io::Result<usize>> {
201        use netcore::HasChannel;
202
203        // Callers must pass a non-empty buffer: an `Ok(0)` return is `AsyncRead`'s EOF signal, so
204        // returning it while `read_remainder` still holds bytes (which a zero-length `buf` would
205        // force) would silently truncate the stream. Every in-tree caller passes a non-empty buffer;
206        // this guards the invariant for the public type so a zero-length read can't be mistaken for
207        // EOF-with-data-pending. `tokio`/`futures-io` themselves never poll a read with an empty buf.
208        debug_assert!(
209            !buf.is_empty() || self.read_remainder.is_none(),
210            "poll_read called with an empty buffer while bytes are buffered — Ok(0) would look like EOF"
211        );
212
213        // Copy up to `buf.len()` bytes out of `data` into `buf`, returning `(written, remainder)`
214        // where `remainder` is the unwritten tail (empty if it all fit). `Bytes::split_to` is a
215        // cheap refcount split, so carrying a remainder is allocation-free. Free fn (not a
216        // self-capturing closure) so it doesn't conflict with the `&mut self.read_fut` borrow below.
217        fn copy_into_buf(mut data: Bytes, buf: &mut [u8]) -> (usize, Bytes) {
218            let n = data.len().min(buf.len());
219            buf[..n].copy_from_slice(&data.split_to(n));
220            (n, data)
221        }
222
223        // Drain any stashed remainder first — never issue a fresh `Recv` while bytes are buffered.
224        if let Some(rem) = self.read_remainder.take() {
225            let (n, tail) = copy_into_buf(rem, buf);
226            if !tail.is_empty() {
227                self.read_remainder = Some(tail);
228            }
229            return core::task::Poll::Ready(Ok(n));
230        }
231
232        let handle = self.handle;
233        let cap = buf.len();
234
235        loop {
236            match self.read_fut.as_mut() {
237                None => {
238                    let sender = self.sender.clone();
239
240                    let _ret = self.read_fut.insert(alloc::boxed::Box::pin(async move {
241                        let resp = sender
242                            .request(
243                                Some(handle),
244                                tcp::stream::Command::Recv { max_len: Some(cap) },
245                            )
246                            .await?;
247
248                        // A reaped socket (tsr-9ue: the netstack autonomously closed + freed an
249                        // idle/dead accepted stream) answers a first-touch `Recv` with
250                        // `missing_socket`. Surface it as a clean end-of-stream — an empty `Bytes`,
251                        // exactly like `Finished` — so it reads as a normal `Ok(0)` EOF rather than
252                        // a confusing generic internal `io::Error`.
253                        if matches!(
254                            resp,
255                            netcore::Response::Error(netcore::Error::Internal(
256                                netcore::InternalErrorKind::BadSocketHandle
257                            ))
258                        ) {
259                            return Ok(Bytes::new());
260                        }
261
262                        match resp.try_into()? {
263                            tcp::stream::Response::Recv { buf } => Ok(buf),
264                            tcp::stream::Response::Finished => Ok(Bytes::new()),
265                            _ => Err(netcore::Error::wrong_type()),
266                        }
267                    }));
268                }
269
270                Some(x) => {
271                    let poll_result = x.as_mut().poll(cx);
272                    let ret = core::task::ready!(poll_result)?;
273
274                    self.read_fut.take();
275
276                    // Copy what fits into the CURRENT buffer (which the caller may have shrunk since
277                    // the `Recv` was issued at `cap`); stash any tail. A whole-`ret` copy would panic
278                    // when `ret.len() > buf.len()`.
279                    let (n, tail) = copy_into_buf(ret, buf);
280                    if !tail.is_empty() {
281                        self.read_remainder = Some(tail);
282                    }
283
284                    break core::task::Poll::Ready(Ok(n));
285                }
286            }
287        }
288    }
289
290    #[cfg(any(feature = "tokio", feature = "futures-io"))]
291    fn poll_write(
292        mut self: core::pin::Pin<&mut Self>,
293        cx: &mut core::task::Context<'_>,
294        buf: &[u8],
295    ) -> core::task::Poll<std::io::Result<usize>> {
296        use netcore::HasChannel;
297
298        let handle = self.handle;
299
300        loop {
301            match &mut self.write_fut {
302                None => {
303                    let b = Bytes::copy_from_slice(buf);
304                    let sender = self.sender.clone();
305
306                    let _ret = self.write_fut.insert(alloc::boxed::Box::pin(async move {
307                        let resp = sender
308                            .request(Some(handle), tcp::stream::Command::Send { buf: b })
309                            .await?;
310
311                        // A reaped socket (tsr-9ue) answers a first-touch `Send` with
312                        // `missing_socket`. Writing to a torn-down connection is POSIX
313                        // `ECONNRESET`, so remap to `ConnectionReset` — `From<Error> for
314                        // io::Error` then yields `ErrorKind::ConnectionReset` — instead of letting
315                        // it fall through `try_response_as!` as a generic internal error.
316                        if matches!(
317                            resp,
318                            netcore::Response::Error(netcore::Error::Internal(
319                                netcore::InternalErrorKind::BadSocketHandle
320                            ))
321                        ) {
322                            return Err(netcore::Error::ConnectionReset);
323                        }
324
325                        netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
326                        Ok(n)
327                    }));
328                }
329
330                Some(x) => {
331                    let poll_result = x.as_mut().poll(cx);
332                    let ret = core::task::ready!(poll_result)?;
333
334                    self.write_fut.take();
335
336                    break core::task::Poll::Ready(Ok(ret));
337                }
338            }
339        }
340    }
341
342    socket_requestor_impl!();
343}
344
345impl Drop for TcpStream {
346    fn drop(&mut self) {
347        if let Err(e) = self
348            .sender
349            .request_nonblocking(Some(self.handle), tcp::stream::Command::Close)
350        {
351            tracing::warn!(err = %e, "possible socket leak");
352        }
353    }
354}
355
356#[cfg(feature = "std")]
357impl std::io::Read for TcpStream {
358    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
359        self.recv_blocking(buf).map_err(netcore::Error::into)
360    }
361}
362
363#[cfg(feature = "std")]
364impl std::io::Write for TcpStream {
365    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
366        self.send_blocking(buf).map_err(netcore::Error::into)
367    }
368
369    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
370        let mut buf = Bytes::copy_from_slice(buf);
371
372        while !buf.is_empty() {
373            let resp = self.request_blocking(tcp::stream::Command::Send { buf: buf.clone() })?;
374            netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
375
376            let _consumed = buf.split_to(n);
377        }
378
379        Ok(())
380    }
381
382    fn flush(&mut self) -> std::io::Result<()> {
383        Ok(())
384    }
385}
386
387#[cfg(feature = "tokio")]
388impl tokio::io::AsyncRead for TcpStream {
389    fn poll_read(
390        self: core::pin::Pin<&mut Self>,
391        cx: &mut core::task::Context<'_>,
392        buf: &mut tokio::io::ReadBuf<'_>,
393    ) -> core::task::Poll<tokio::io::Result<()>> {
394        let n = core::task::ready!(self.poll_read(cx, buf.initialize_unfilled()))?;
395        buf.advance(n);
396
397        core::task::Poll::Ready(Ok(()))
398    }
399}
400
401#[cfg(feature = "tokio")]
402impl tokio::io::AsyncWrite for TcpStream {
403    fn poll_write(
404        self: core::pin::Pin<&mut Self>,
405        cx: &mut core::task::Context<'_>,
406        buf: &[u8],
407    ) -> core::task::Poll<std::io::Result<usize>> {
408        self.poll_write(cx, buf)
409    }
410
411    fn poll_flush(
412        self: core::pin::Pin<&mut Self>,
413        _cx: &mut core::task::Context<'_>,
414    ) -> core::task::Poll<std::io::Result<()>> {
415        core::task::Poll::Ready(Ok(()))
416    }
417
418    fn poll_shutdown(
419        self: core::pin::Pin<&mut Self>,
420        _cx: &mut core::task::Context<'_>,
421    ) -> core::task::Poll<std::io::Result<()>> {
422        self.shutdown_write();
423        core::task::Poll::Ready(Ok(()))
424    }
425}
426
427#[cfg(feature = "tokio")]
428#[cfg(test)]
429mod reaped_socket_mapping_tests {
430    use core::net::SocketAddr;
431
432    use netcore::{HasChannel, Netstack, smoltcp::iface::SocketHandle, udp};
433    use tokio::io::{AsyncReadExt, AsyncWriteExt};
434
435    use super::TcpStream;
436
437    /// Spawn a real netstack on a background thread that continuously processes commands, and return
438    /// a `TcpStream` wired to a freed `SocketHandle`. The handle comes from a UDP bind that is then
439    /// closed, so its slot is empty: every TCP command the stream issues for it hits
440    /// `get_socket_mut!`'s existence check (which fails) and is answered `missing_socket` — exactly
441    /// the post-reap (tsr-9ue) state seen from the consumer's side. The driver thread answers the
442    /// stream's async `Recv`/`Send` so the awaited future actually resolves.
443    fn stream_over_reaped_handle() -> TcpStream {
444        let mut stack = Netstack::new(
445            netcore::Config::default(),
446            netcore::smoltcp::time::Instant::ZERO,
447        );
448        let chan = stack.command_channel();
449
450        // Drive the stack from a background thread so EVERY command (the setup bind/close AND the
451        // stream's later async `Recv`/`Send`) is answered. `request_blocking` below blocks on its
452        // response, so the driver must already be running or it would deadlock.
453        std::thread::spawn(move || {
454            while let Ok(cmd) = stack.wait_for_cmd_blocking(None) {
455                stack.process_one_cmd(cmd);
456            }
457        });
458
459        // A real handle value from a UDP bind (answered by the driver thread)...
460        let handle: SocketHandle = match chan
461            .request_blocking(
462                None,
463                udp::Command::Bind {
464                    endpoint: SocketAddr::from(([127, 0, 0, 1], 9200)),
465                },
466            )
467            .expect("channel open")
468        {
469            netcore::Response::Udp(udp::Response::Bound { handle, .. }) => handle,
470            other => panic!("expected Bound, got {other:?}"),
471        };
472        // ...then close it so the slot is freed: the handle now refers to nothing — the reaped state
473        // a first-touch TCP command then sees as `missing_socket`.
474        assert!(matches!(
475            chan.request_blocking(Some(handle), udp::Command::Close)
476                .expect("channel open"),
477            netcore::Response::Ok
478        ));
479
480        let local = SocketAddr::from(([127, 0, 0, 1], 50100));
481        let remote = SocketAddr::from(([127, 0, 0, 1], 9200));
482        TcpStream::new(chan, handle, remote, local)
483    }
484
485    /// Part 3: a reaped socket's `Recv` resolves to `missing_socket`, which `poll_read` maps to a
486    /// clean end-of-stream — `Ok(0)` — not a generic internal `io::Error`.
487    #[tokio::test]
488    async fn poll_read_on_reaped_socket_is_eof() {
489        let mut stream = stream_over_reaped_handle();
490        let mut buf = [0u8; 64];
491        let n = stream
492            .read(&mut buf)
493            .await
494            .expect("read on a reaped socket must be Ok(0), not an error");
495        assert_eq!(n, 0, "a reaped socket must read as EOF (Ok(0))");
496    }
497
498    /// Part 3: a reaped socket's `Send` resolves to `missing_socket`, which `poll_write` maps to
499    /// `ErrorKind::ConnectionReset` (POSIX `ECONNRESET` for writing to a torn-down connection).
500    #[tokio::test]
501    async fn poll_write_on_reaped_socket_is_connection_reset() {
502        let mut stream = stream_over_reaped_handle();
503        let err = stream
504            .write(b"payload")
505            .await
506            .expect_err("write to a reaped socket must error");
507        assert_eq!(
508            err.kind(),
509            std::io::ErrorKind::ConnectionReset,
510            "writing to a reaped socket must surface as ConnectionReset"
511        );
512    }
513}
514
515#[cfg(feature = "futures-io")]
516impl futures_io::AsyncRead for TcpStream {
517    fn poll_read(
518        self: core::pin::Pin<&mut Self>,
519        cx: &mut core::task::Context<'_>,
520        buf: &mut [u8],
521    ) -> core::task::Poll<std::io::Result<usize>> {
522        self.poll_read(cx, buf)
523    }
524}
525
526#[cfg(feature = "futures-io")]
527impl futures_io::AsyncWrite for TcpStream {
528    fn poll_write(
529        self: core::pin::Pin<&mut Self>,
530        cx: &mut core::task::Context<'_>,
531        buf: &[u8],
532    ) -> core::task::Poll<std::io::Result<usize>> {
533        self.poll_write(cx, buf)
534    }
535
536    fn poll_flush(
537        self: core::pin::Pin<&mut Self>,
538        _cx: &mut core::task::Context<'_>,
539    ) -> core::task::Poll<std::io::Result<()>> {
540        core::task::Poll::Ready(Ok(()))
541    }
542
543    fn poll_close(
544        self: core::pin::Pin<&mut Self>,
545        _cx: &mut core::task::Context<'_>,
546    ) -> core::task::Poll<std::io::Result<()>> {
547        self.shutdown_write();
548        core::task::Poll::Ready(Ok(()))
549    }
550}