ts_netstack_smoltcp_socket/tcp/
stream.rs1use core::{
2 fmt::{Debug, Formatter},
3 net::SocketAddr,
4};
5
6use bytes::Bytes;
7use netcore::{DisplayExt, HasChannel, Response, smoltcp::iface::SocketHandle, tcp};
8
9#[cfg(any(feature = "tokio", feature = "futures-io"))]
10type PinBoxFut<T> = core::pin::Pin<alloc::boxed::Box<dyn Future<Output = T> + Send + Sync>>;
11
12pub struct TcpStream {
14 sender: netcore::Channel,
15 handle: SocketHandle,
16
17 local: SocketAddr,
18 remote: SocketAddr,
19
20 #[cfg(any(feature = "tokio", feature = "futures-io"))]
21 read_fut: Option<PinBoxFut<Result<Bytes, netcore::Error>>>,
22 #[cfg(any(feature = "tokio", feature = "futures-io"))]
29 read_remainder: Option<Bytes>,
30 #[cfg(any(feature = "tokio", feature = "futures-io"))]
31 write_fut: Option<PinBoxFut<Result<usize, netcore::Error>>>,
32}
33
34impl TcpStream {
35 pub(crate) const fn new(
36 sender: netcore::Channel,
37 handle: SocketHandle,
38 remote: SocketAddr,
39 local: SocketAddr,
40 ) -> Self {
41 Self {
42 sender,
43 handle,
44 remote,
45 local,
46
47 #[cfg(any(feature = "tokio", feature = "futures-io"))]
48 read_fut: None,
49
50 #[cfg(any(feature = "tokio", feature = "futures-io"))]
51 read_remainder: None,
52
53 #[cfg(any(feature = "tokio", feature = "futures-io"))]
54 write_fut: None,
55 }
56 }
57}
58
59impl Debug for TcpStream {
60 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
61 f.debug_struct("TcpStream")
62 .field("handle", &self.handle.as_display_debug())
63 .field("local_endpoint", &self.local)
64 .field("remote_endpoint", &self.remote)
65 .finish()
66 }
67}
68
69impl TcpStream {
70 pub const fn local_addr(&self) -> SocketAddr {
72 self.local
73 }
74
75 pub const fn remote_addr(&self) -> SocketAddr {
77 self.remote
78 }
79
80 pub fn send_blocking(&self, b: &[u8]) -> Result<usize, netcore::Error> {
85 let resp = self.request_blocking(tcp::stream::Command::Send {
86 buf: Bytes::copy_from_slice(b),
87 })?;
88
89 self._send(resp)
90 }
91
92 pub async fn send(&self, b: &[u8]) -> Result<usize, netcore::Error> {
97 let resp = self
98 .request(tcp::stream::Command::Send {
99 buf: Bytes::copy_from_slice(b),
100 })
101 .await?;
102
103 self._send(resp)
104 }
105
106 fn _send(&self, resp: Response) -> Result<usize, netcore::Error> {
107 netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
108 Ok(n)
109 }
110
111 pub fn recv_blocking(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
115 let resp = self.request_blocking(tcp::stream::Command::Recv {
116 max_len: Some(b.len()),
117 })?;
118
119 self._recv(resp, b)
120 }
121
122 pub async fn recv(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
126 let resp = self
127 .request(tcp::stream::Command::Recv {
128 max_len: Some(b.len()),
129 })
130 .await?;
131
132 self._recv(resp, b)
133 }
134
135 pub fn recv_bytes_blocking(&self) -> Result<Bytes, netcore::Error> {
139 let resp = self.request_blocking(tcp::stream::Command::Recv { max_len: None })?;
140
141 self._recv_bytes(resp)
142 }
143
144 pub async fn recv_bytes(&self) -> Result<Bytes, netcore::Error> {
146 let resp = self
147 .request(tcp::stream::Command::Recv { max_len: None })
148 .await?;
149
150 self._recv_bytes(resp)
151 }
152
153 fn _recv(&self, resp: Response, b: &mut [u8]) -> Result<usize, netcore::Error> {
154 let buf = self._recv_bytes(resp)?;
155
156 let n = buf.len().min(b.len());
157 b[..n].copy_from_slice(&buf[..n]);
158
159 Ok(n)
160 }
161
162 fn _recv_bytes(&self, resp: Response) -> Result<Bytes, netcore::Error> {
163 if matches!(resp, Response::TcpStream(tcp::stream::Response::Finished)) {
164 return Ok(Bytes::new());
165 }
166
167 netcore::try_response_as!(resp, tcp::stream::Response::Recv { buf });
168 Ok(buf)
169 }
170
171 #[cfg(any(feature = "tokio", feature = "futures-io"))]
172 fn poll_read(
173 mut self: core::pin::Pin<&mut Self>,
174 cx: &mut core::task::Context,
175 buf: &mut [u8],
176 ) -> core::task::Poll<std::io::Result<usize>> {
177 use netcore::HasChannel;
178
179 debug_assert!(
185 !buf.is_empty() || self.read_remainder.is_none(),
186 "poll_read called with an empty buffer while bytes are buffered — Ok(0) would look like EOF"
187 );
188
189 fn copy_into_buf(mut data: Bytes, buf: &mut [u8]) -> (usize, Bytes) {
194 let n = data.len().min(buf.len());
195 buf[..n].copy_from_slice(&data.split_to(n));
196 (n, data)
197 }
198
199 if let Some(rem) = self.read_remainder.take() {
201 let (n, tail) = copy_into_buf(rem, buf);
202 if !tail.is_empty() {
203 self.read_remainder = Some(tail);
204 }
205 return core::task::Poll::Ready(Ok(n));
206 }
207
208 let handle = self.handle;
209 let cap = buf.len();
210
211 loop {
212 match self.read_fut.as_mut() {
213 None => {
214 let sender = self.sender.clone();
215
216 let _ret = self.read_fut.insert(alloc::boxed::Box::pin(async move {
217 let resp = sender
218 .request(
219 Some(handle),
220 tcp::stream::Command::Recv { max_len: Some(cap) },
221 )
222 .await?;
223
224 match resp.try_into()? {
225 tcp::stream::Response::Recv { buf } => Ok(buf),
226 tcp::stream::Response::Finished => Ok(Bytes::new()),
227 _ => Err(netcore::Error::wrong_type()),
228 }
229 }));
230 }
231
232 Some(x) => {
233 let poll_result = x.as_mut().poll(cx);
234 let ret = core::task::ready!(poll_result)?;
235
236 self.read_fut.take();
237
238 let (n, tail) = copy_into_buf(ret, buf);
242 if !tail.is_empty() {
243 self.read_remainder = Some(tail);
244 }
245
246 break core::task::Poll::Ready(Ok(n));
247 }
248 }
249 }
250 }
251
252 #[cfg(any(feature = "tokio", feature = "futures-io"))]
253 fn poll_write(
254 mut self: core::pin::Pin<&mut Self>,
255 cx: &mut core::task::Context<'_>,
256 buf: &[u8],
257 ) -> core::task::Poll<std::io::Result<usize>> {
258 use netcore::HasChannel;
259
260 let handle = self.handle;
261
262 loop {
263 match &mut self.write_fut {
264 None => {
265 let b = Bytes::copy_from_slice(buf);
266 let sender = self.sender.clone();
267
268 let _ret = self.write_fut.insert(alloc::boxed::Box::pin(async move {
269 let resp = sender
270 .request(Some(handle), tcp::stream::Command::Send { buf: b })
271 .await?;
272
273 netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
274 Ok(n)
275 }));
276 }
277
278 Some(x) => {
279 let poll_result = x.as_mut().poll(cx);
280 let ret = core::task::ready!(poll_result)?;
281
282 self.write_fut.take();
283
284 break core::task::Poll::Ready(Ok(ret));
285 }
286 }
287 }
288 }
289
290 socket_requestor_impl!();
291}
292
293impl Drop for TcpStream {
294 fn drop(&mut self) {
295 if let Err(e) = self
296 .sender
297 .request_nonblocking(Some(self.handle), tcp::stream::Command::Close)
298 {
299 tracing::warn!(err = %e, "possible socket leak");
300 }
301 }
302}
303
304#[cfg(feature = "std")]
305impl std::io::Read for TcpStream {
306 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
307 self.recv_blocking(buf).map_err(netcore::Error::into)
308 }
309}
310
311#[cfg(feature = "std")]
312impl std::io::Write for TcpStream {
313 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
314 self.send_blocking(buf).map_err(netcore::Error::into)
315 }
316
317 fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
318 let mut buf = Bytes::copy_from_slice(buf);
319
320 while !buf.is_empty() {
321 let resp = self.request_blocking(tcp::stream::Command::Send { buf: buf.clone() })?;
322 netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
323
324 let _consumed = buf.split_to(n);
325 }
326
327 Ok(())
328 }
329
330 fn flush(&mut self) -> std::io::Result<()> {
331 Ok(())
332 }
333}
334
335#[cfg(feature = "tokio")]
336impl tokio::io::AsyncRead for TcpStream {
337 fn poll_read(
338 self: core::pin::Pin<&mut Self>,
339 cx: &mut core::task::Context<'_>,
340 buf: &mut tokio::io::ReadBuf<'_>,
341 ) -> core::task::Poll<tokio::io::Result<()>> {
342 let n = core::task::ready!(self.poll_read(cx, buf.initialize_unfilled()))?;
343 buf.advance(n);
344
345 core::task::Poll::Ready(Ok(()))
346 }
347}
348
349#[cfg(feature = "tokio")]
350impl tokio::io::AsyncWrite for TcpStream {
351 fn poll_write(
352 self: core::pin::Pin<&mut Self>,
353 cx: &mut core::task::Context<'_>,
354 buf: &[u8],
355 ) -> core::task::Poll<std::io::Result<usize>> {
356 self.poll_write(cx, buf)
357 }
358
359 fn poll_flush(
360 self: core::pin::Pin<&mut Self>,
361 _cx: &mut core::task::Context<'_>,
362 ) -> core::task::Poll<std::io::Result<()>> {
363 core::task::Poll::Ready(Ok(()))
364 }
365
366 fn poll_shutdown(
367 self: core::pin::Pin<&mut Self>,
368 _cx: &mut core::task::Context<'_>,
369 ) -> core::task::Poll<std::io::Result<()>> {
370 core::task::Poll::Ready(Ok(()))
376 }
377}
378
379#[cfg(feature = "futures-io")]
380impl futures_io::AsyncRead for TcpStream {
381 fn poll_read(
382 self: core::pin::Pin<&mut Self>,
383 cx: &mut core::task::Context<'_>,
384 buf: &mut [u8],
385 ) -> core::task::Poll<std::io::Result<usize>> {
386 self.poll_read(cx, buf)
387 }
388}
389
390#[cfg(feature = "futures-io")]
391impl futures_io::AsyncWrite for TcpStream {
392 fn poll_write(
393 self: core::pin::Pin<&mut Self>,
394 cx: &mut core::task::Context<'_>,
395 buf: &[u8],
396 ) -> core::task::Poll<std::io::Result<usize>> {
397 self.poll_write(cx, buf)
398 }
399
400 fn poll_flush(
401 self: core::pin::Pin<&mut Self>,
402 _cx: &mut core::task::Context<'_>,
403 ) -> core::task::Poll<std::io::Result<()>> {
404 core::task::Poll::Ready(Ok(()))
405 }
406
407 fn poll_close(
408 self: core::pin::Pin<&mut Self>,
409 _cx: &mut core::task::Context<'_>,
410 ) -> core::task::Poll<std::io::Result<()>> {
411 core::task::Poll::Ready(Ok(()))
413 }
414}