ts_netstack_smoltcp_socket/tcp/
stream.rs1use core::{
2 fmt::{Debug, Formatter},
3 net::SocketAddr,
4};
5
6use bytes::Bytes;
7use netcore::{DisplayExt, HasChannel, Response, smoltcp::iface::SocketHandle, tcp};
8
9#[cfg(any(feature = "tokio", feature = "futures-io"))]
10type PinBoxFut<T> = core::pin::Pin<alloc::boxed::Box<dyn Future<Output = T> + Send + Sync>>;
11
12pub struct TcpStream {
14 sender: netcore::Channel,
15 handle: SocketHandle,
16
17 local: SocketAddr,
18 remote: SocketAddr,
19
20 #[cfg(any(feature = "tokio", feature = "futures-io"))]
21 read_fut: Option<PinBoxFut<Result<Bytes, netcore::Error>>>,
22 #[cfg(any(feature = "tokio", feature = "futures-io"))]
29 read_remainder: Option<Bytes>,
30 #[cfg(any(feature = "tokio", feature = "futures-io"))]
31 write_fut: Option<PinBoxFut<Result<usize, netcore::Error>>>,
32}
33
34impl TcpStream {
35 pub(crate) const fn new(
36 sender: netcore::Channel,
37 handle: SocketHandle,
38 remote: SocketAddr,
39 local: SocketAddr,
40 ) -> Self {
41 Self {
42 sender,
43 handle,
44 remote,
45 local,
46
47 #[cfg(any(feature = "tokio", feature = "futures-io"))]
48 read_fut: None,
49
50 #[cfg(any(feature = "tokio", feature = "futures-io"))]
51 read_remainder: None,
52
53 #[cfg(any(feature = "tokio", feature = "futures-io"))]
54 write_fut: None,
55 }
56 }
57}
58
59impl Debug for TcpStream {
60 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
61 f.debug_struct("TcpStream")
62 .field("handle", &self.handle.as_display_debug())
63 .field("local_endpoint", &self.local)
64 .field("remote_endpoint", &self.remote)
65 .finish()
66 }
67}
68
69impl TcpStream {
70 pub const fn local_addr(&self) -> SocketAddr {
72 self.local
73 }
74
75 pub const fn remote_addr(&self) -> SocketAddr {
77 self.remote
78 }
79
80 pub fn shutdown_write(&self) {
96 if let Err(e) = self
97 .sender
98 .request_nonblocking(Some(self.handle), tcp::stream::Command::ShutdownWrite)
99 {
100 tracing::debug!(err = %e, "shutdown_write: netstack channel closed");
101 }
102 }
103
104 pub fn send_blocking(&self, b: &[u8]) -> Result<usize, netcore::Error> {
109 let resp = self.request_blocking(tcp::stream::Command::Send {
110 buf: Bytes::copy_from_slice(b),
111 })?;
112
113 self._send(resp)
114 }
115
116 pub async fn send(&self, b: &[u8]) -> Result<usize, netcore::Error> {
121 let resp = self
122 .request(tcp::stream::Command::Send {
123 buf: Bytes::copy_from_slice(b),
124 })
125 .await?;
126
127 self._send(resp)
128 }
129
130 fn _send(&self, resp: Response) -> Result<usize, netcore::Error> {
131 netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
132 Ok(n)
133 }
134
135 pub fn recv_blocking(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
139 let resp = self.request_blocking(tcp::stream::Command::Recv {
140 max_len: Some(b.len()),
141 })?;
142
143 self._recv(resp, b)
144 }
145
146 pub async fn recv(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
150 let resp = self
151 .request(tcp::stream::Command::Recv {
152 max_len: Some(b.len()),
153 })
154 .await?;
155
156 self._recv(resp, b)
157 }
158
159 pub fn recv_bytes_blocking(&self) -> Result<Bytes, netcore::Error> {
163 let resp = self.request_blocking(tcp::stream::Command::Recv { max_len: None })?;
164
165 self._recv_bytes(resp)
166 }
167
168 pub async fn recv_bytes(&self) -> Result<Bytes, netcore::Error> {
170 let resp = self
171 .request(tcp::stream::Command::Recv { max_len: None })
172 .await?;
173
174 self._recv_bytes(resp)
175 }
176
177 fn _recv(&self, resp: Response, b: &mut [u8]) -> Result<usize, netcore::Error> {
178 let buf = self._recv_bytes(resp)?;
179
180 let n = buf.len().min(b.len());
181 b[..n].copy_from_slice(&buf[..n]);
182
183 Ok(n)
184 }
185
186 fn _recv_bytes(&self, resp: Response) -> Result<Bytes, netcore::Error> {
187 if matches!(resp, Response::TcpStream(tcp::stream::Response::Finished)) {
188 return Ok(Bytes::new());
189 }
190
191 netcore::try_response_as!(resp, tcp::stream::Response::Recv { buf });
192 Ok(buf)
193 }
194
195 #[cfg(any(feature = "tokio", feature = "futures-io"))]
196 fn poll_read(
197 mut self: core::pin::Pin<&mut Self>,
198 cx: &mut core::task::Context,
199 buf: &mut [u8],
200 ) -> core::task::Poll<std::io::Result<usize>> {
201 use netcore::HasChannel;
202
203 debug_assert!(
209 !buf.is_empty() || self.read_remainder.is_none(),
210 "poll_read called with an empty buffer while bytes are buffered — Ok(0) would look like EOF"
211 );
212
213 fn copy_into_buf(mut data: Bytes, buf: &mut [u8]) -> (usize, Bytes) {
218 let n = data.len().min(buf.len());
219 buf[..n].copy_from_slice(&data.split_to(n));
220 (n, data)
221 }
222
223 if let Some(rem) = self.read_remainder.take() {
225 let (n, tail) = copy_into_buf(rem, buf);
226 if !tail.is_empty() {
227 self.read_remainder = Some(tail);
228 }
229 return core::task::Poll::Ready(Ok(n));
230 }
231
232 let handle = self.handle;
233 let cap = buf.len();
234
235 loop {
236 match self.read_fut.as_mut() {
237 None => {
238 let sender = self.sender.clone();
239
240 let _ret = self.read_fut.insert(alloc::boxed::Box::pin(async move {
241 let resp = sender
242 .request(
243 Some(handle),
244 tcp::stream::Command::Recv { max_len: Some(cap) },
245 )
246 .await?;
247
248 match resp.try_into()? {
249 tcp::stream::Response::Recv { buf } => Ok(buf),
250 tcp::stream::Response::Finished => Ok(Bytes::new()),
251 _ => Err(netcore::Error::wrong_type()),
252 }
253 }));
254 }
255
256 Some(x) => {
257 let poll_result = x.as_mut().poll(cx);
258 let ret = core::task::ready!(poll_result)?;
259
260 self.read_fut.take();
261
262 let (n, tail) = copy_into_buf(ret, buf);
266 if !tail.is_empty() {
267 self.read_remainder = Some(tail);
268 }
269
270 break core::task::Poll::Ready(Ok(n));
271 }
272 }
273 }
274 }
275
276 #[cfg(any(feature = "tokio", feature = "futures-io"))]
277 fn poll_write(
278 mut self: core::pin::Pin<&mut Self>,
279 cx: &mut core::task::Context<'_>,
280 buf: &[u8],
281 ) -> core::task::Poll<std::io::Result<usize>> {
282 use netcore::HasChannel;
283
284 let handle = self.handle;
285
286 loop {
287 match &mut self.write_fut {
288 None => {
289 let b = Bytes::copy_from_slice(buf);
290 let sender = self.sender.clone();
291
292 let _ret = self.write_fut.insert(alloc::boxed::Box::pin(async move {
293 let resp = sender
294 .request(Some(handle), tcp::stream::Command::Send { buf: b })
295 .await?;
296
297 netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
298 Ok(n)
299 }));
300 }
301
302 Some(x) => {
303 let poll_result = x.as_mut().poll(cx);
304 let ret = core::task::ready!(poll_result)?;
305
306 self.write_fut.take();
307
308 break core::task::Poll::Ready(Ok(ret));
309 }
310 }
311 }
312 }
313
314 socket_requestor_impl!();
315}
316
317impl Drop for TcpStream {
318 fn drop(&mut self) {
319 if let Err(e) = self
320 .sender
321 .request_nonblocking(Some(self.handle), tcp::stream::Command::Close)
322 {
323 tracing::warn!(err = %e, "possible socket leak");
324 }
325 }
326}
327
328#[cfg(feature = "std")]
329impl std::io::Read for TcpStream {
330 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
331 self.recv_blocking(buf).map_err(netcore::Error::into)
332 }
333}
334
335#[cfg(feature = "std")]
336impl std::io::Write for TcpStream {
337 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
338 self.send_blocking(buf).map_err(netcore::Error::into)
339 }
340
341 fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
342 let mut buf = Bytes::copy_from_slice(buf);
343
344 while !buf.is_empty() {
345 let resp = self.request_blocking(tcp::stream::Command::Send { buf: buf.clone() })?;
346 netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
347
348 let _consumed = buf.split_to(n);
349 }
350
351 Ok(())
352 }
353
354 fn flush(&mut self) -> std::io::Result<()> {
355 Ok(())
356 }
357}
358
359#[cfg(feature = "tokio")]
360impl tokio::io::AsyncRead for TcpStream {
361 fn poll_read(
362 self: core::pin::Pin<&mut Self>,
363 cx: &mut core::task::Context<'_>,
364 buf: &mut tokio::io::ReadBuf<'_>,
365 ) -> core::task::Poll<tokio::io::Result<()>> {
366 let n = core::task::ready!(self.poll_read(cx, buf.initialize_unfilled()))?;
367 buf.advance(n);
368
369 core::task::Poll::Ready(Ok(()))
370 }
371}
372
373#[cfg(feature = "tokio")]
374impl tokio::io::AsyncWrite for TcpStream {
375 fn poll_write(
376 self: core::pin::Pin<&mut Self>,
377 cx: &mut core::task::Context<'_>,
378 buf: &[u8],
379 ) -> core::task::Poll<std::io::Result<usize>> {
380 self.poll_write(cx, buf)
381 }
382
383 fn poll_flush(
384 self: core::pin::Pin<&mut Self>,
385 _cx: &mut core::task::Context<'_>,
386 ) -> core::task::Poll<std::io::Result<()>> {
387 core::task::Poll::Ready(Ok(()))
388 }
389
390 fn poll_shutdown(
391 self: core::pin::Pin<&mut Self>,
392 _cx: &mut core::task::Context<'_>,
393 ) -> core::task::Poll<std::io::Result<()>> {
394 self.shutdown_write();
395 core::task::Poll::Ready(Ok(()))
396 }
397}
398
399#[cfg(feature = "futures-io")]
400impl futures_io::AsyncRead for TcpStream {
401 fn poll_read(
402 self: core::pin::Pin<&mut Self>,
403 cx: &mut core::task::Context<'_>,
404 buf: &mut [u8],
405 ) -> core::task::Poll<std::io::Result<usize>> {
406 self.poll_read(cx, buf)
407 }
408}
409
410#[cfg(feature = "futures-io")]
411impl futures_io::AsyncWrite for TcpStream {
412 fn poll_write(
413 self: core::pin::Pin<&mut Self>,
414 cx: &mut core::task::Context<'_>,
415 buf: &[u8],
416 ) -> core::task::Poll<std::io::Result<usize>> {
417 self.poll_write(cx, buf)
418 }
419
420 fn poll_flush(
421 self: core::pin::Pin<&mut Self>,
422 _cx: &mut core::task::Context<'_>,
423 ) -> core::task::Poll<std::io::Result<()>> {
424 core::task::Poll::Ready(Ok(()))
425 }
426
427 fn poll_close(
428 self: core::pin::Pin<&mut Self>,
429 _cx: &mut core::task::Context<'_>,
430 ) -> core::task::Poll<std::io::Result<()>> {
431 self.shutdown_write();
432 core::task::Poll::Ready(Ok(()))
433 }
434}