Skip to main content

ts_netstack_smoltcp_socket/tcp/
stream.rs

1use core::{
2    fmt::{Debug, Formatter},
3    net::SocketAddr,
4};
5
6use bytes::Bytes;
7use netcore::{DisplayExt, HasChannel, Response, smoltcp::iface::SocketHandle, tcp};
8
9#[cfg(any(feature = "tokio", feature = "futures-io"))]
10type PinBoxFut<T> = core::pin::Pin<alloc::boxed::Box<dyn Future<Output = T> + Send + Sync>>;
11
12/// A TCP stream.
13pub struct TcpStream {
14    sender: netcore::Channel,
15    handle: SocketHandle,
16
17    local: SocketAddr,
18    remote: SocketAddr,
19
20    #[cfg(any(feature = "tokio", feature = "futures-io"))]
21    read_fut: Option<PinBoxFut<Result<Bytes, netcore::Error>>>,
22    /// Bytes received from a completed `Recv` that did not fit the caller's buffer on the poll that
23    /// produced them, carried to the next `poll_read`. A `Recv` is sized by the buffer length at
24    /// future-creation, but the `AsyncRead` contract permits the caller to re-poll with a *smaller*
25    /// buffer, so the response can exceed the live buffer — copying it whole would panic
26    /// (`copy_from_slice` length mismatch). We copy what fits and stash the tail here (lossless),
27    /// draining it before issuing the next `Recv`.
28    #[cfg(any(feature = "tokio", feature = "futures-io"))]
29    read_remainder: Option<Bytes>,
30    #[cfg(any(feature = "tokio", feature = "futures-io"))]
31    write_fut: Option<PinBoxFut<Result<usize, netcore::Error>>>,
32}
33
34impl TcpStream {
35    pub(crate) const fn new(
36        sender: netcore::Channel,
37        handle: SocketHandle,
38        remote: SocketAddr,
39        local: SocketAddr,
40    ) -> Self {
41        Self {
42            sender,
43            handle,
44            remote,
45            local,
46
47            #[cfg(any(feature = "tokio", feature = "futures-io"))]
48            read_fut: None,
49
50            #[cfg(any(feature = "tokio", feature = "futures-io"))]
51            read_remainder: None,
52
53            #[cfg(any(feature = "tokio", feature = "futures-io"))]
54            write_fut: None,
55        }
56    }
57}
58
59impl Debug for TcpStream {
60    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
61        f.debug_struct("TcpStream")
62            .field("handle", &self.handle.as_display_debug())
63            .field("local_endpoint", &self.local)
64            .field("remote_endpoint", &self.remote)
65            .finish()
66    }
67}
68
69impl TcpStream {
70    /// Report the local endpoint to which this stream is connected.
71    pub const fn local_addr(&self) -> SocketAddr {
72        self.local
73    }
74
75    /// Report the remote endpoint to which this stream is connected.
76    pub const fn remote_addr(&self) -> SocketAddr {
77        self.remote
78    }
79
80    /// Half-close the write side: send a FIN to the peer (`shutdown(SHUT_WR)` / `CloseWrite`) while
81    /// keeping the read side open so the peer's remaining data can still be received. Fire-and-forget
82    /// (non-blocking, like the `Drop`-time `Close`): the FIN is emitted by the netstack in the
83    /// background, so this returns immediately and a caller using shutdown for signaling (e.g. a
84    /// bidirectional splice half-closing one direction) no longer hangs waiting for a FIN that was
85    /// never sent.
86    ///
87    /// After this, **writes fail** (`InvalidState`): the socket has left the sendable state — this is
88    /// the intended `shutdown(SHUT_WR)` POSIX behavior (previously, when this was a no-op, a write
89    /// after shutdown still succeeded). Reads continue until the peer's FIN.
90    ///
91    /// Best-effort delivery: `request_nonblocking` treats a *full* command channel as success and
92    /// drops the command, so under channel saturation the FIN may not be sent — the socket then
93    /// teardown-degrades to the idle/keep-alive timeout reaper instead of a prompt FIN (never a hard
94    /// leak). A channel-*closed* error means the netstack is gone; the socket is already moot.
95    pub fn shutdown_write(&self) {
96        if let Err(e) = self
97            .sender
98            .request_nonblocking(Some(self.handle), tcp::stream::Command::ShutdownWrite)
99        {
100            tracing::debug!(err = %e, "shutdown_write: netstack channel closed");
101        }
102    }
103
104    /// Send bytes to the remote.
105    ///
106    /// Blocks until at least one byte can be queued. The return value is the number of
107    /// bytes actually sent.
108    pub fn send_blocking(&self, b: &[u8]) -> Result<usize, netcore::Error> {
109        let resp = self.request_blocking(tcp::stream::Command::Send {
110            buf: Bytes::copy_from_slice(b),
111        })?;
112
113        self._send(resp)
114    }
115
116    /// Send bytes to the remote.
117    ///
118    /// Blocks until at least one byte can be queued. The return value is the number of
119    /// bytes actually sent.
120    pub async fn send(&self, b: &[u8]) -> Result<usize, netcore::Error> {
121        let resp = self
122            .request(tcp::stream::Command::Send {
123                buf: Bytes::copy_from_slice(b),
124            })
125            .await?;
126
127        self._send(resp)
128    }
129
130    fn _send(&self, resp: Response) -> Result<usize, netcore::Error> {
131        netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
132        Ok(n)
133    }
134
135    /// Receive bytes from the remote.
136    ///
137    /// Returns the number of bytes actually received (blocks until there is at least one).
138    pub fn recv_blocking(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
139        let resp = self.request_blocking(tcp::stream::Command::Recv {
140            max_len: Some(b.len()),
141        })?;
142
143        self._recv(resp, b)
144    }
145
146    /// Receive bytes from the remote into the supplied buffer.
147    ///
148    /// Returns the number of bytes actually received (blocks until there is at least one).
149    pub async fn recv(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
150        let resp = self
151            .request(tcp::stream::Command::Recv {
152                max_len: Some(b.len()),
153            })
154            .await?;
155
156        self._recv(resp, b)
157    }
158
159    /// Receive bytes from the remote.
160    ///
161    /// Returns the number of bytes actually received (blocks until there is at least one).
162    pub fn recv_bytes_blocking(&self) -> Result<Bytes, netcore::Error> {
163        let resp = self.request_blocking(tcp::stream::Command::Recv { max_len: None })?;
164
165        self._recv_bytes(resp)
166    }
167
168    /// Receive bytes from the remote.
169    pub async fn recv_bytes(&self) -> Result<Bytes, netcore::Error> {
170        let resp = self
171            .request(tcp::stream::Command::Recv { max_len: None })
172            .await?;
173
174        self._recv_bytes(resp)
175    }
176
177    fn _recv(&self, resp: Response, b: &mut [u8]) -> Result<usize, netcore::Error> {
178        let buf = self._recv_bytes(resp)?;
179
180        let n = buf.len().min(b.len());
181        b[..n].copy_from_slice(&buf[..n]);
182
183        Ok(n)
184    }
185
186    fn _recv_bytes(&self, resp: Response) -> Result<Bytes, netcore::Error> {
187        if matches!(resp, Response::TcpStream(tcp::stream::Response::Finished)) {
188            return Ok(Bytes::new());
189        }
190
191        netcore::try_response_as!(resp, tcp::stream::Response::Recv { buf });
192        Ok(buf)
193    }
194
195    #[cfg(any(feature = "tokio", feature = "futures-io"))]
196    fn poll_read(
197        mut self: core::pin::Pin<&mut Self>,
198        cx: &mut core::task::Context,
199        buf: &mut [u8],
200    ) -> core::task::Poll<std::io::Result<usize>> {
201        use netcore::HasChannel;
202
203        // Callers must pass a non-empty buffer: an `Ok(0)` return is `AsyncRead`'s EOF signal, so
204        // returning it while `read_remainder` still holds bytes (which a zero-length `buf` would
205        // force) would silently truncate the stream. Every in-tree caller passes a non-empty buffer;
206        // this guards the invariant for the public type so a zero-length read can't be mistaken for
207        // EOF-with-data-pending. `tokio`/`futures-io` themselves never poll a read with an empty buf.
208        debug_assert!(
209            !buf.is_empty() || self.read_remainder.is_none(),
210            "poll_read called with an empty buffer while bytes are buffered — Ok(0) would look like EOF"
211        );
212
213        // Copy up to `buf.len()` bytes out of `data` into `buf`, returning `(written, remainder)`
214        // where `remainder` is the unwritten tail (empty if it all fit). `Bytes::split_to` is a
215        // cheap refcount split, so carrying a remainder is allocation-free. Free fn (not a
216        // self-capturing closure) so it doesn't conflict with the `&mut self.read_fut` borrow below.
217        fn copy_into_buf(mut data: Bytes, buf: &mut [u8]) -> (usize, Bytes) {
218            let n = data.len().min(buf.len());
219            buf[..n].copy_from_slice(&data.split_to(n));
220            (n, data)
221        }
222
223        // Drain any stashed remainder first — never issue a fresh `Recv` while bytes are buffered.
224        if let Some(rem) = self.read_remainder.take() {
225            let (n, tail) = copy_into_buf(rem, buf);
226            if !tail.is_empty() {
227                self.read_remainder = Some(tail);
228            }
229            return core::task::Poll::Ready(Ok(n));
230        }
231
232        let handle = self.handle;
233        let cap = buf.len();
234
235        loop {
236            match self.read_fut.as_mut() {
237                None => {
238                    let sender = self.sender.clone();
239
240                    let _ret = self.read_fut.insert(alloc::boxed::Box::pin(async move {
241                        let resp = sender
242                            .request(
243                                Some(handle),
244                                tcp::stream::Command::Recv { max_len: Some(cap) },
245                            )
246                            .await?;
247
248                        match resp.try_into()? {
249                            tcp::stream::Response::Recv { buf } => Ok(buf),
250                            tcp::stream::Response::Finished => Ok(Bytes::new()),
251                            _ => Err(netcore::Error::wrong_type()),
252                        }
253                    }));
254                }
255
256                Some(x) => {
257                    let poll_result = x.as_mut().poll(cx);
258                    let ret = core::task::ready!(poll_result)?;
259
260                    self.read_fut.take();
261
262                    // Copy what fits into the CURRENT buffer (which the caller may have shrunk since
263                    // the `Recv` was issued at `cap`); stash any tail. A whole-`ret` copy would panic
264                    // when `ret.len() > buf.len()`.
265                    let (n, tail) = copy_into_buf(ret, buf);
266                    if !tail.is_empty() {
267                        self.read_remainder = Some(tail);
268                    }
269
270                    break core::task::Poll::Ready(Ok(n));
271                }
272            }
273        }
274    }
275
276    #[cfg(any(feature = "tokio", feature = "futures-io"))]
277    fn poll_write(
278        mut self: core::pin::Pin<&mut Self>,
279        cx: &mut core::task::Context<'_>,
280        buf: &[u8],
281    ) -> core::task::Poll<std::io::Result<usize>> {
282        use netcore::HasChannel;
283
284        let handle = self.handle;
285
286        loop {
287            match &mut self.write_fut {
288                None => {
289                    let b = Bytes::copy_from_slice(buf);
290                    let sender = self.sender.clone();
291
292                    let _ret = self.write_fut.insert(alloc::boxed::Box::pin(async move {
293                        let resp = sender
294                            .request(Some(handle), tcp::stream::Command::Send { buf: b })
295                            .await?;
296
297                        netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
298                        Ok(n)
299                    }));
300                }
301
302                Some(x) => {
303                    let poll_result = x.as_mut().poll(cx);
304                    let ret = core::task::ready!(poll_result)?;
305
306                    self.write_fut.take();
307
308                    break core::task::Poll::Ready(Ok(ret));
309                }
310            }
311        }
312    }
313
314    socket_requestor_impl!();
315}
316
317impl Drop for TcpStream {
318    fn drop(&mut self) {
319        if let Err(e) = self
320            .sender
321            .request_nonblocking(Some(self.handle), tcp::stream::Command::Close)
322        {
323            tracing::warn!(err = %e, "possible socket leak");
324        }
325    }
326}
327
328#[cfg(feature = "std")]
329impl std::io::Read for TcpStream {
330    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
331        self.recv_blocking(buf).map_err(netcore::Error::into)
332    }
333}
334
335#[cfg(feature = "std")]
336impl std::io::Write for TcpStream {
337    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
338        self.send_blocking(buf).map_err(netcore::Error::into)
339    }
340
341    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
342        let mut buf = Bytes::copy_from_slice(buf);
343
344        while !buf.is_empty() {
345            let resp = self.request_blocking(tcp::stream::Command::Send { buf: buf.clone() })?;
346            netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
347
348            let _consumed = buf.split_to(n);
349        }
350
351        Ok(())
352    }
353
354    fn flush(&mut self) -> std::io::Result<()> {
355        Ok(())
356    }
357}
358
359#[cfg(feature = "tokio")]
360impl tokio::io::AsyncRead for TcpStream {
361    fn poll_read(
362        self: core::pin::Pin<&mut Self>,
363        cx: &mut core::task::Context<'_>,
364        buf: &mut tokio::io::ReadBuf<'_>,
365    ) -> core::task::Poll<tokio::io::Result<()>> {
366        let n = core::task::ready!(self.poll_read(cx, buf.initialize_unfilled()))?;
367        buf.advance(n);
368
369        core::task::Poll::Ready(Ok(()))
370    }
371}
372
373#[cfg(feature = "tokio")]
374impl tokio::io::AsyncWrite for TcpStream {
375    fn poll_write(
376        self: core::pin::Pin<&mut Self>,
377        cx: &mut core::task::Context<'_>,
378        buf: &[u8],
379    ) -> core::task::Poll<std::io::Result<usize>> {
380        self.poll_write(cx, buf)
381    }
382
383    fn poll_flush(
384        self: core::pin::Pin<&mut Self>,
385        _cx: &mut core::task::Context<'_>,
386    ) -> core::task::Poll<std::io::Result<()>> {
387        core::task::Poll::Ready(Ok(()))
388    }
389
390    fn poll_shutdown(
391        self: core::pin::Pin<&mut Self>,
392        _cx: &mut core::task::Context<'_>,
393    ) -> core::task::Poll<std::io::Result<()>> {
394        self.shutdown_write();
395        core::task::Poll::Ready(Ok(()))
396    }
397}
398
399#[cfg(feature = "futures-io")]
400impl futures_io::AsyncRead for TcpStream {
401    fn poll_read(
402        self: core::pin::Pin<&mut Self>,
403        cx: &mut core::task::Context<'_>,
404        buf: &mut [u8],
405    ) -> core::task::Poll<std::io::Result<usize>> {
406        self.poll_read(cx, buf)
407    }
408}
409
410#[cfg(feature = "futures-io")]
411impl futures_io::AsyncWrite for TcpStream {
412    fn poll_write(
413        self: core::pin::Pin<&mut Self>,
414        cx: &mut core::task::Context<'_>,
415        buf: &[u8],
416    ) -> core::task::Poll<std::io::Result<usize>> {
417        self.poll_write(cx, buf)
418    }
419
420    fn poll_flush(
421        self: core::pin::Pin<&mut Self>,
422        _cx: &mut core::task::Context<'_>,
423    ) -> core::task::Poll<std::io::Result<()>> {
424        core::task::Poll::Ready(Ok(()))
425    }
426
427    fn poll_close(
428        self: core::pin::Pin<&mut Self>,
429        _cx: &mut core::task::Context<'_>,
430    ) -> core::task::Poll<std::io::Result<()>> {
431        self.shutdown_write();
432        core::task::Poll::Ready(Ok(()))
433    }
434}