Skip to main content

ts_netstack_smoltcp_socket/tcp/
stream.rs

1use core::{
2    fmt::{Debug, Formatter},
3    net::SocketAddr,
4};
5
6use bytes::Bytes;
7use netcore::{DisplayExt, HasChannel, Response, smoltcp::iface::SocketHandle, tcp};
8
9#[cfg(any(feature = "tokio", feature = "futures-io"))]
10type PinBoxFut<T> = core::pin::Pin<alloc::boxed::Box<dyn Future<Output = T> + Send + Sync>>;
11
12/// A TCP stream.
13pub struct TcpStream {
14    sender: netcore::Channel,
15    handle: SocketHandle,
16
17    local: SocketAddr,
18    remote: SocketAddr,
19
20    #[cfg(any(feature = "tokio", feature = "futures-io"))]
21    read_fut: Option<PinBoxFut<Result<Bytes, netcore::Error>>>,
22    /// Bytes received from a completed `Recv` that did not fit the caller's buffer on the poll that
23    /// produced them, carried to the next `poll_read`. A `Recv` is sized by the buffer length at
24    /// future-creation, but the `AsyncRead` contract permits the caller to re-poll with a *smaller*
25    /// buffer, so the response can exceed the live buffer — copying it whole would panic
26    /// (`copy_from_slice` length mismatch). We copy what fits and stash the tail here (lossless),
27    /// draining it before issuing the next `Recv`.
28    #[cfg(any(feature = "tokio", feature = "futures-io"))]
29    read_remainder: Option<Bytes>,
30    #[cfg(any(feature = "tokio", feature = "futures-io"))]
31    write_fut: Option<PinBoxFut<Result<usize, netcore::Error>>>,
32}
33
34impl TcpStream {
35    pub(crate) const fn new(
36        sender: netcore::Channel,
37        handle: SocketHandle,
38        remote: SocketAddr,
39        local: SocketAddr,
40    ) -> Self {
41        Self {
42            sender,
43            handle,
44            remote,
45            local,
46
47            #[cfg(any(feature = "tokio", feature = "futures-io"))]
48            read_fut: None,
49
50            #[cfg(any(feature = "tokio", feature = "futures-io"))]
51            read_remainder: None,
52
53            #[cfg(any(feature = "tokio", feature = "futures-io"))]
54            write_fut: None,
55        }
56    }
57}
58
59impl Debug for TcpStream {
60    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
61        f.debug_struct("TcpStream")
62            .field("handle", &self.handle.as_display_debug())
63            .field("local_endpoint", &self.local)
64            .field("remote_endpoint", &self.remote)
65            .finish()
66    }
67}
68
69impl TcpStream {
70    /// Report the local endpoint to which this stream is connected.
71    pub const fn local_addr(&self) -> SocketAddr {
72        self.local
73    }
74
75    /// Report the remote endpoint to which this stream is connected.
76    pub const fn remote_addr(&self) -> SocketAddr {
77        self.remote
78    }
79
80    /// Send bytes to the remote.
81    ///
82    /// Blocks until at least one byte can be queued. The return value is the number of
83    /// bytes actually sent.
84    pub fn send_blocking(&self, b: &[u8]) -> Result<usize, netcore::Error> {
85        let resp = self.request_blocking(tcp::stream::Command::Send {
86            buf: Bytes::copy_from_slice(b),
87        })?;
88
89        self._send(resp)
90    }
91
92    /// Send bytes to the remote.
93    ///
94    /// Blocks until at least one byte can be queued. The return value is the number of
95    /// bytes actually sent.
96    pub async fn send(&self, b: &[u8]) -> Result<usize, netcore::Error> {
97        let resp = self
98            .request(tcp::stream::Command::Send {
99                buf: Bytes::copy_from_slice(b),
100            })
101            .await?;
102
103        self._send(resp)
104    }
105
106    fn _send(&self, resp: Response) -> Result<usize, netcore::Error> {
107        netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
108        Ok(n)
109    }
110
111    /// Receive bytes from the remote.
112    ///
113    /// Returns the number of bytes actually received (blocks until there is at least one).
114    pub fn recv_blocking(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
115        let resp = self.request_blocking(tcp::stream::Command::Recv {
116            max_len: Some(b.len()),
117        })?;
118
119        self._recv(resp, b)
120    }
121
122    /// Receive bytes from the remote into the supplied buffer.
123    ///
124    /// Returns the number of bytes actually received (blocks until there is at least one).
125    pub async fn recv(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
126        let resp = self
127            .request(tcp::stream::Command::Recv {
128                max_len: Some(b.len()),
129            })
130            .await?;
131
132        self._recv(resp, b)
133    }
134
135    /// Receive bytes from the remote.
136    ///
137    /// Returns the number of bytes actually received (blocks until there is at least one).
138    pub fn recv_bytes_blocking(&self) -> Result<Bytes, netcore::Error> {
139        let resp = self.request_blocking(tcp::stream::Command::Recv { max_len: None })?;
140
141        self._recv_bytes(resp)
142    }
143
144    /// Receive bytes from the remote.
145    pub async fn recv_bytes(&self) -> Result<Bytes, netcore::Error> {
146        let resp = self
147            .request(tcp::stream::Command::Recv { max_len: None })
148            .await?;
149
150        self._recv_bytes(resp)
151    }
152
153    fn _recv(&self, resp: Response, b: &mut [u8]) -> Result<usize, netcore::Error> {
154        let buf = self._recv_bytes(resp)?;
155
156        let n = buf.len().min(b.len());
157        b[..n].copy_from_slice(&buf[..n]);
158
159        Ok(n)
160    }
161
162    fn _recv_bytes(&self, resp: Response) -> Result<Bytes, netcore::Error> {
163        if matches!(resp, Response::TcpStream(tcp::stream::Response::Finished)) {
164            return Ok(Bytes::new());
165        }
166
167        netcore::try_response_as!(resp, tcp::stream::Response::Recv { buf });
168        Ok(buf)
169    }
170
171    #[cfg(any(feature = "tokio", feature = "futures-io"))]
172    fn poll_read(
173        mut self: core::pin::Pin<&mut Self>,
174        cx: &mut core::task::Context,
175        buf: &mut [u8],
176    ) -> core::task::Poll<std::io::Result<usize>> {
177        use netcore::HasChannel;
178
179        // Callers must pass a non-empty buffer: an `Ok(0)` return is `AsyncRead`'s EOF signal, so
180        // returning it while `read_remainder` still holds bytes (which a zero-length `buf` would
181        // force) would silently truncate the stream. Every in-tree caller passes a non-empty buffer;
182        // this guards the invariant for the public type so a zero-length read can't be mistaken for
183        // EOF-with-data-pending. `tokio`/`futures-io` themselves never poll a read with an empty buf.
184        debug_assert!(
185            !buf.is_empty() || self.read_remainder.is_none(),
186            "poll_read called with an empty buffer while bytes are buffered — Ok(0) would look like EOF"
187        );
188
189        // Copy up to `buf.len()` bytes out of `data` into `buf`, returning `(written, remainder)`
190        // where `remainder` is the unwritten tail (empty if it all fit). `Bytes::split_to` is a
191        // cheap refcount split, so carrying a remainder is allocation-free. Free fn (not a
192        // self-capturing closure) so it doesn't conflict with the `&mut self.read_fut` borrow below.
193        fn copy_into_buf(mut data: Bytes, buf: &mut [u8]) -> (usize, Bytes) {
194            let n = data.len().min(buf.len());
195            buf[..n].copy_from_slice(&data.split_to(n));
196            (n, data)
197        }
198
199        // Drain any stashed remainder first — never issue a fresh `Recv` while bytes are buffered.
200        if let Some(rem) = self.read_remainder.take() {
201            let (n, tail) = copy_into_buf(rem, buf);
202            if !tail.is_empty() {
203                self.read_remainder = Some(tail);
204            }
205            return core::task::Poll::Ready(Ok(n));
206        }
207
208        let handle = self.handle;
209        let cap = buf.len();
210
211        loop {
212            match self.read_fut.as_mut() {
213                None => {
214                    let sender = self.sender.clone();
215
216                    let _ret = self.read_fut.insert(alloc::boxed::Box::pin(async move {
217                        let resp = sender
218                            .request(
219                                Some(handle),
220                                tcp::stream::Command::Recv { max_len: Some(cap) },
221                            )
222                            .await?;
223
224                        match resp.try_into()? {
225                            tcp::stream::Response::Recv { buf } => Ok(buf),
226                            tcp::stream::Response::Finished => Ok(Bytes::new()),
227                            _ => Err(netcore::Error::wrong_type()),
228                        }
229                    }));
230                }
231
232                Some(x) => {
233                    let poll_result = x.as_mut().poll(cx);
234                    let ret = core::task::ready!(poll_result)?;
235
236                    self.read_fut.take();
237
238                    // Copy what fits into the CURRENT buffer (which the caller may have shrunk since
239                    // the `Recv` was issued at `cap`); stash any tail. A whole-`ret` copy would panic
240                    // when `ret.len() > buf.len()`.
241                    let (n, tail) = copy_into_buf(ret, buf);
242                    if !tail.is_empty() {
243                        self.read_remainder = Some(tail);
244                    }
245
246                    break core::task::Poll::Ready(Ok(n));
247                }
248            }
249        }
250    }
251
252    #[cfg(any(feature = "tokio", feature = "futures-io"))]
253    fn poll_write(
254        mut self: core::pin::Pin<&mut Self>,
255        cx: &mut core::task::Context<'_>,
256        buf: &[u8],
257    ) -> core::task::Poll<std::io::Result<usize>> {
258        use netcore::HasChannel;
259
260        let handle = self.handle;
261
262        loop {
263            match &mut self.write_fut {
264                None => {
265                    let b = Bytes::copy_from_slice(buf);
266                    let sender = self.sender.clone();
267
268                    let _ret = self.write_fut.insert(alloc::boxed::Box::pin(async move {
269                        let resp = sender
270                            .request(Some(handle), tcp::stream::Command::Send { buf: b })
271                            .await?;
272
273                        netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
274                        Ok(n)
275                    }));
276                }
277
278                Some(x) => {
279                    let poll_result = x.as_mut().poll(cx);
280                    let ret = core::task::ready!(poll_result)?;
281
282                    self.write_fut.take();
283
284                    break core::task::Poll::Ready(Ok(ret));
285                }
286            }
287        }
288    }
289
290    socket_requestor_impl!();
291}
292
293impl Drop for TcpStream {
294    fn drop(&mut self) {
295        if let Err(e) = self
296            .sender
297            .request_nonblocking(Some(self.handle), tcp::stream::Command::Close)
298        {
299            tracing::warn!(err = %e, "possible socket leak");
300        }
301    }
302}
303
304#[cfg(feature = "std")]
305impl std::io::Read for TcpStream {
306    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
307        self.recv_blocking(buf).map_err(netcore::Error::into)
308    }
309}
310
311#[cfg(feature = "std")]
312impl std::io::Write for TcpStream {
313    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
314        self.send_blocking(buf).map_err(netcore::Error::into)
315    }
316
317    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
318        let mut buf = Bytes::copy_from_slice(buf);
319
320        while !buf.is_empty() {
321            let resp = self.request_blocking(tcp::stream::Command::Send { buf: buf.clone() })?;
322            netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
323
324            let _consumed = buf.split_to(n);
325        }
326
327        Ok(())
328    }
329
330    fn flush(&mut self) -> std::io::Result<()> {
331        Ok(())
332    }
333}
334
335#[cfg(feature = "tokio")]
336impl tokio::io::AsyncRead for TcpStream {
337    fn poll_read(
338        self: core::pin::Pin<&mut Self>,
339        cx: &mut core::task::Context<'_>,
340        buf: &mut tokio::io::ReadBuf<'_>,
341    ) -> core::task::Poll<tokio::io::Result<()>> {
342        let n = core::task::ready!(self.poll_read(cx, buf.initialize_unfilled()))?;
343        buf.advance(n);
344
345        core::task::Poll::Ready(Ok(()))
346    }
347}
348
349#[cfg(feature = "tokio")]
350impl tokio::io::AsyncWrite for TcpStream {
351    fn poll_write(
352        self: core::pin::Pin<&mut Self>,
353        cx: &mut core::task::Context<'_>,
354        buf: &[u8],
355    ) -> core::task::Poll<std::io::Result<usize>> {
356        self.poll_write(cx, buf)
357    }
358
359    fn poll_flush(
360        self: core::pin::Pin<&mut Self>,
361        _cx: &mut core::task::Context<'_>,
362    ) -> core::task::Poll<std::io::Result<()>> {
363        core::task::Poll::Ready(Ok(()))
364    }
365
366    fn poll_shutdown(
367        self: core::pin::Pin<&mut Self>,
368        _cx: &mut core::task::Context<'_>,
369    ) -> core::task::Poll<std::io::Result<()>> {
370        // NOTE(npry): explicit shutdown semantics don't make sense for us because we have to
371        // support closing the socket out-of-band anyway, since we can't rely on an async runtime
372        // driving us. This creates this unfortunate situation where calling shutdown doesn't
373        // actually confirm that we're closed, so any dependents using close for signaling (before
374        // dropping the socket) could hang here.
375        core::task::Poll::Ready(Ok(()))
376    }
377}
378
379#[cfg(feature = "futures-io")]
380impl futures_io::AsyncRead for TcpStream {
381    fn poll_read(
382        self: core::pin::Pin<&mut Self>,
383        cx: &mut core::task::Context<'_>,
384        buf: &mut [u8],
385    ) -> core::task::Poll<std::io::Result<usize>> {
386        self.poll_read(cx, buf)
387    }
388}
389
390#[cfg(feature = "futures-io")]
391impl futures_io::AsyncWrite for TcpStream {
392    fn poll_write(
393        self: core::pin::Pin<&mut Self>,
394        cx: &mut core::task::Context<'_>,
395        buf: &[u8],
396    ) -> core::task::Poll<std::io::Result<usize>> {
397        self.poll_write(cx, buf)
398    }
399
400    fn poll_flush(
401        self: core::pin::Pin<&mut Self>,
402        _cx: &mut core::task::Context<'_>,
403    ) -> core::task::Poll<std::io::Result<()>> {
404        core::task::Poll::Ready(Ok(()))
405    }
406
407    fn poll_close(
408        self: core::pin::Pin<&mut Self>,
409        _cx: &mut core::task::Context<'_>,
410    ) -> core::task::Poll<std::io::Result<()>> {
411        // See note above in poll_shutdown.
412        core::task::Poll::Ready(Ok(()))
413    }
414}