Skip to main content

ts_netstack_smoltcp_socket/tcp/
stream.rs

1use core::{
2    fmt::{Debug, Formatter},
3    net::SocketAddr,
4};
5
6use bytes::Bytes;
7use netcore::{DisplayExt, HasChannel, Response, smoltcp::iface::SocketHandle, tcp};
8
9#[cfg(any(feature = "tokio", feature = "futures-io"))]
10type PinBoxFut<T> = core::pin::Pin<alloc::boxed::Box<dyn Future<Output = T> + Send + Sync>>;
11
12/// A TCP stream.
13pub struct TcpStream {
14    sender: netcore::Channel,
15    handle: SocketHandle,
16
17    local: SocketAddr,
18    remote: SocketAddr,
19
20    #[cfg(any(feature = "tokio", feature = "futures-io"))]
21    read_fut: Option<PinBoxFut<Result<Bytes, netcore::Error>>>,
22    #[cfg(any(feature = "tokio", feature = "futures-io"))]
23    write_fut: Option<PinBoxFut<Result<usize, netcore::Error>>>,
24}
25
26impl TcpStream {
27    pub(crate) const fn new(
28        sender: netcore::Channel,
29        handle: SocketHandle,
30        remote: SocketAddr,
31        local: SocketAddr,
32    ) -> Self {
33        Self {
34            sender,
35            handle,
36            remote,
37            local,
38
39            #[cfg(any(feature = "tokio", feature = "futures-io"))]
40            read_fut: None,
41
42            #[cfg(any(feature = "tokio", feature = "futures-io"))]
43            write_fut: None,
44        }
45    }
46}
47
48impl Debug for TcpStream {
49    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
50        f.debug_struct("TcpStream")
51            .field("handle", &self.handle.as_display_debug())
52            .field("local_endpoint", &self.local)
53            .field("remote_endpoint", &self.remote)
54            .finish()
55    }
56}
57
58impl TcpStream {
59    /// Report the local endpoint to which this stream is connected.
60    pub const fn local_addr(&self) -> SocketAddr {
61        self.local
62    }
63
64    /// Report the remote endpoint to which this stream is connected.
65    pub const fn remote_addr(&self) -> SocketAddr {
66        self.remote
67    }
68
69    /// Send bytes to the remote.
70    ///
71    /// Blocks until at least one byte can be queued. The return value is the number of
72    /// bytes actually sent.
73    pub fn send_blocking(&self, b: &[u8]) -> Result<usize, netcore::Error> {
74        let resp = self.request_blocking(tcp::stream::Command::Send {
75            buf: Bytes::copy_from_slice(b),
76        })?;
77
78        self._send(resp)
79    }
80
81    /// Send bytes to the remote.
82    ///
83    /// Blocks until at least one byte can be queued. The return value is the number of
84    /// bytes actually sent.
85    pub async fn send(&self, b: &[u8]) -> Result<usize, netcore::Error> {
86        let resp = self
87            .request(tcp::stream::Command::Send {
88                buf: Bytes::copy_from_slice(b),
89            })
90            .await?;
91
92        self._send(resp)
93    }
94
95    fn _send(&self, resp: Response) -> Result<usize, netcore::Error> {
96        netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
97        Ok(n)
98    }
99
100    /// Receive bytes from the remote.
101    ///
102    /// Returns the number of bytes actually received (blocks until there is at least one).
103    pub fn recv_blocking(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
104        let resp = self.request_blocking(tcp::stream::Command::Recv {
105            max_len: Some(b.len()),
106        })?;
107
108        self._recv(resp, b)
109    }
110
111    /// Receive bytes from the remote into the supplied buffer.
112    ///
113    /// Returns the number of bytes actually received (blocks until there is at least one).
114    pub async fn recv(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
115        let resp = self
116            .request(tcp::stream::Command::Recv {
117                max_len: Some(b.len()),
118            })
119            .await?;
120
121        self._recv(resp, b)
122    }
123
124    /// Receive bytes from the remote.
125    ///
126    /// Returns the number of bytes actually received (blocks until there is at least one).
127    pub fn recv_bytes_blocking(&self) -> Result<Bytes, netcore::Error> {
128        let resp = self.request_blocking(tcp::stream::Command::Recv { max_len: None })?;
129
130        self._recv_bytes(resp)
131    }
132
133    /// Receive bytes from the remote.
134    pub async fn recv_bytes(&self) -> Result<Bytes, netcore::Error> {
135        let resp = self
136            .request(tcp::stream::Command::Recv { max_len: None })
137            .await?;
138
139        self._recv_bytes(resp)
140    }
141
142    fn _recv(&self, resp: Response, b: &mut [u8]) -> Result<usize, netcore::Error> {
143        let buf = self._recv_bytes(resp)?;
144
145        let n = buf.len().min(b.len());
146        b[..n].copy_from_slice(&buf[..n]);
147
148        Ok(n)
149    }
150
151    fn _recv_bytes(&self, resp: Response) -> Result<Bytes, netcore::Error> {
152        if matches!(resp, Response::TcpStream(tcp::stream::Response::Finished)) {
153            return Ok(Bytes::new());
154        }
155
156        netcore::try_response_as!(resp, tcp::stream::Response::Recv { buf });
157        Ok(buf)
158    }
159
160    #[cfg(any(feature = "tokio", feature = "futures-io"))]
161    fn poll_read(
162        mut self: core::pin::Pin<&mut Self>,
163        cx: &mut core::task::Context,
164        buf: &mut [u8],
165    ) -> core::task::Poll<std::io::Result<usize>> {
166        use netcore::HasChannel;
167
168        let handle = self.handle;
169        let cap = buf.len();
170
171        loop {
172            match self.read_fut.as_mut() {
173                None => {
174                    let sender = self.sender.clone();
175
176                    let _ret = self.read_fut.insert(alloc::boxed::Box::pin(async move {
177                        let resp = sender
178                            .request(
179                                Some(handle),
180                                tcp::stream::Command::Recv { max_len: Some(cap) },
181                            )
182                            .await?;
183
184                        match resp.try_into()? {
185                            tcp::stream::Response::Recv { buf } => Ok(buf),
186                            tcp::stream::Response::Finished => Ok(Bytes::new()),
187                            _ => Err(netcore::Error::wrong_type()),
188                        }
189                    }));
190                }
191
192                Some(x) => {
193                    let poll_result = x.as_mut().poll(cx);
194                    let ret = core::task::ready!(poll_result)?;
195
196                    buf[..ret.len()].copy_from_slice(&ret);
197
198                    self.read_fut.take();
199
200                    break core::task::Poll::Ready(Ok(ret.len()));
201                }
202            }
203        }
204    }
205
206    #[cfg(any(feature = "tokio", feature = "futures-io"))]
207    fn poll_write(
208        mut self: core::pin::Pin<&mut Self>,
209        cx: &mut core::task::Context<'_>,
210        buf: &[u8],
211    ) -> core::task::Poll<std::io::Result<usize>> {
212        use netcore::HasChannel;
213
214        let handle = self.handle;
215
216        loop {
217            match &mut self.write_fut {
218                None => {
219                    let b = Bytes::copy_from_slice(buf);
220                    let sender = self.sender.clone();
221
222                    let _ret = self.write_fut.insert(alloc::boxed::Box::pin(async move {
223                        let resp = sender
224                            .request(Some(handle), tcp::stream::Command::Send { buf: b })
225                            .await?;
226
227                        netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
228                        Ok(n)
229                    }));
230                }
231
232                Some(x) => {
233                    let poll_result = x.as_mut().poll(cx);
234                    let ret = core::task::ready!(poll_result)?;
235
236                    self.write_fut.take();
237
238                    break core::task::Poll::Ready(Ok(ret));
239                }
240            }
241        }
242    }
243
244    socket_requestor_impl!();
245}
246
247impl Drop for TcpStream {
248    fn drop(&mut self) {
249        if let Err(e) = self
250            .sender
251            .request_nonblocking(Some(self.handle), tcp::stream::Command::Close)
252        {
253            tracing::warn!(err = %e, "possible socket leak");
254        }
255    }
256}
257
258#[cfg(feature = "std")]
259impl std::io::Read for TcpStream {
260    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
261        self.recv_blocking(buf).map_err(netcore::Error::into)
262    }
263}
264
265#[cfg(feature = "std")]
266impl std::io::Write for TcpStream {
267    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
268        self.send_blocking(buf).map_err(netcore::Error::into)
269    }
270
271    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
272        let mut buf = Bytes::copy_from_slice(buf);
273
274        while !buf.is_empty() {
275            let resp = self.request_blocking(tcp::stream::Command::Send { buf: buf.clone() })?;
276            netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
277
278            let _consumed = buf.split_to(n);
279        }
280
281        Ok(())
282    }
283
284    fn flush(&mut self) -> std::io::Result<()> {
285        Ok(())
286    }
287}
288
289#[cfg(feature = "tokio")]
290impl tokio::io::AsyncRead for TcpStream {
291    fn poll_read(
292        self: core::pin::Pin<&mut Self>,
293        cx: &mut core::task::Context<'_>,
294        buf: &mut tokio::io::ReadBuf<'_>,
295    ) -> core::task::Poll<tokio::io::Result<()>> {
296        let n = core::task::ready!(self.poll_read(cx, buf.initialize_unfilled()))?;
297        buf.advance(n);
298
299        core::task::Poll::Ready(Ok(()))
300    }
301}
302
303#[cfg(feature = "tokio")]
304impl tokio::io::AsyncWrite for TcpStream {
305    fn poll_write(
306        self: core::pin::Pin<&mut Self>,
307        cx: &mut core::task::Context<'_>,
308        buf: &[u8],
309    ) -> core::task::Poll<std::io::Result<usize>> {
310        self.poll_write(cx, buf)
311    }
312
313    fn poll_flush(
314        self: core::pin::Pin<&mut Self>,
315        _cx: &mut core::task::Context<'_>,
316    ) -> core::task::Poll<std::io::Result<()>> {
317        core::task::Poll::Ready(Ok(()))
318    }
319
320    fn poll_shutdown(
321        self: core::pin::Pin<&mut Self>,
322        _cx: &mut core::task::Context<'_>,
323    ) -> core::task::Poll<std::io::Result<()>> {
324        // NOTE(npry): explicit shutdown semantics don't make sense for us because we have to
325        // support closing the socket out-of-band anyway, since we can't rely on an async runtime
326        // driving us. This creates this unfortunate situation where calling shutdown doesn't
327        // actually confirm that we're closed, so any dependents using close for signaling (before
328        // dropping the socket) could hang here.
329        core::task::Poll::Ready(Ok(()))
330    }
331}
332
333#[cfg(feature = "futures-io")]
334impl futures_io::AsyncRead for TcpStream {
335    fn poll_read(
336        self: core::pin::Pin<&mut Self>,
337        cx: &mut core::task::Context<'_>,
338        buf: &mut [u8],
339    ) -> core::task::Poll<std::io::Result<usize>> {
340        self.poll_read(cx, buf)
341    }
342}
343
344#[cfg(feature = "futures-io")]
345impl futures_io::AsyncWrite for TcpStream {
346    fn poll_write(
347        self: core::pin::Pin<&mut Self>,
348        cx: &mut core::task::Context<'_>,
349        buf: &[u8],
350    ) -> core::task::Poll<std::io::Result<usize>> {
351        self.poll_write(cx, buf)
352    }
353
354    fn poll_flush(
355        self: core::pin::Pin<&mut Self>,
356        _cx: &mut core::task::Context<'_>,
357    ) -> core::task::Poll<std::io::Result<()>> {
358        core::task::Poll::Ready(Ok(()))
359    }
360
361    fn poll_close(
362        self: core::pin::Pin<&mut Self>,
363        _cx: &mut core::task::Context<'_>,
364    ) -> core::task::Poll<std::io::Result<()>> {
365        // See note above in poll_shutdown.
366        core::task::Poll::Ready(Ok(()))
367    }
368}