ts_netstack_smoltcp_socket/tcp/
stream.rs1use core::{
2 fmt::{Debug, Formatter},
3 net::SocketAddr,
4};
5
6use bytes::Bytes;
7use netcore::{DisplayExt, HasChannel, Response, smoltcp::iface::SocketHandle, tcp};
8
9#[cfg(any(feature = "tokio", feature = "futures-io"))]
10type PinBoxFut<T> = core::pin::Pin<alloc::boxed::Box<dyn Future<Output = T> + Send + Sync>>;
11
12pub struct TcpStream {
14 sender: netcore::Channel,
15 handle: SocketHandle,
16
17 local: SocketAddr,
18 remote: SocketAddr,
19
20 #[cfg(any(feature = "tokio", feature = "futures-io"))]
21 read_fut: Option<PinBoxFut<Result<Bytes, netcore::Error>>>,
22 #[cfg(any(feature = "tokio", feature = "futures-io"))]
23 write_fut: Option<PinBoxFut<Result<usize, netcore::Error>>>,
24}
25
26impl TcpStream {
27 pub(crate) const fn new(
28 sender: netcore::Channel,
29 handle: SocketHandle,
30 remote: SocketAddr,
31 local: SocketAddr,
32 ) -> Self {
33 Self {
34 sender,
35 handle,
36 remote,
37 local,
38
39 #[cfg(any(feature = "tokio", feature = "futures-io"))]
40 read_fut: None,
41
42 #[cfg(any(feature = "tokio", feature = "futures-io"))]
43 write_fut: None,
44 }
45 }
46}
47
48impl Debug for TcpStream {
49 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
50 f.debug_struct("TcpStream")
51 .field("handle", &self.handle.as_display_debug())
52 .field("local_endpoint", &self.local)
53 .field("remote_endpoint", &self.remote)
54 .finish()
55 }
56}
57
58impl TcpStream {
59 pub const fn local_addr(&self) -> SocketAddr {
61 self.local
62 }
63
64 pub const fn remote_addr(&self) -> SocketAddr {
66 self.remote
67 }
68
69 pub fn send_blocking(&self, b: &[u8]) -> Result<usize, netcore::Error> {
74 let resp = self.request_blocking(tcp::stream::Command::Send {
75 buf: Bytes::copy_from_slice(b),
76 })?;
77
78 self._send(resp)
79 }
80
81 pub async fn send(&self, b: &[u8]) -> Result<usize, netcore::Error> {
86 let resp = self
87 .request(tcp::stream::Command::Send {
88 buf: Bytes::copy_from_slice(b),
89 })
90 .await?;
91
92 self._send(resp)
93 }
94
95 fn _send(&self, resp: Response) -> Result<usize, netcore::Error> {
96 netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
97 Ok(n)
98 }
99
100 pub fn recv_blocking(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
104 let resp = self.request_blocking(tcp::stream::Command::Recv {
105 max_len: Some(b.len()),
106 })?;
107
108 self._recv(resp, b)
109 }
110
111 pub async fn recv(&self, b: &mut [u8]) -> Result<usize, netcore::Error> {
115 let resp = self
116 .request(tcp::stream::Command::Recv {
117 max_len: Some(b.len()),
118 })
119 .await?;
120
121 self._recv(resp, b)
122 }
123
124 pub fn recv_bytes_blocking(&self) -> Result<Bytes, netcore::Error> {
128 let resp = self.request_blocking(tcp::stream::Command::Recv { max_len: None })?;
129
130 self._recv_bytes(resp)
131 }
132
133 pub async fn recv_bytes(&self) -> Result<Bytes, netcore::Error> {
135 let resp = self
136 .request(tcp::stream::Command::Recv { max_len: None })
137 .await?;
138
139 self._recv_bytes(resp)
140 }
141
142 fn _recv(&self, resp: Response, b: &mut [u8]) -> Result<usize, netcore::Error> {
143 let buf = self._recv_bytes(resp)?;
144
145 let n = buf.len().min(b.len());
146 b[..n].copy_from_slice(&buf[..n]);
147
148 Ok(n)
149 }
150
151 fn _recv_bytes(&self, resp: Response) -> Result<Bytes, netcore::Error> {
152 if matches!(resp, Response::TcpStream(tcp::stream::Response::Finished)) {
153 return Ok(Bytes::new());
154 }
155
156 netcore::try_response_as!(resp, tcp::stream::Response::Recv { buf });
157 Ok(buf)
158 }
159
160 #[cfg(any(feature = "tokio", feature = "futures-io"))]
161 fn poll_read(
162 mut self: core::pin::Pin<&mut Self>,
163 cx: &mut core::task::Context,
164 buf: &mut [u8],
165 ) -> core::task::Poll<std::io::Result<usize>> {
166 use netcore::HasChannel;
167
168 let handle = self.handle;
169 let cap = buf.len();
170
171 loop {
172 match self.read_fut.as_mut() {
173 None => {
174 let sender = self.sender.clone();
175
176 let _ret = self.read_fut.insert(alloc::boxed::Box::pin(async move {
177 let resp = sender
178 .request(
179 Some(handle),
180 tcp::stream::Command::Recv { max_len: Some(cap) },
181 )
182 .await?;
183
184 match resp.try_into()? {
185 tcp::stream::Response::Recv { buf } => Ok(buf),
186 tcp::stream::Response::Finished => Ok(Bytes::new()),
187 _ => Err(netcore::Error::wrong_type()),
188 }
189 }));
190 }
191
192 Some(x) => {
193 let poll_result = x.as_mut().poll(cx);
194 let ret = core::task::ready!(poll_result)?;
195
196 buf[..ret.len()].copy_from_slice(&ret);
197
198 self.read_fut.take();
199
200 break core::task::Poll::Ready(Ok(ret.len()));
201 }
202 }
203 }
204 }
205
206 #[cfg(any(feature = "tokio", feature = "futures-io"))]
207 fn poll_write(
208 mut self: core::pin::Pin<&mut Self>,
209 cx: &mut core::task::Context<'_>,
210 buf: &[u8],
211 ) -> core::task::Poll<std::io::Result<usize>> {
212 use netcore::HasChannel;
213
214 let handle = self.handle;
215
216 loop {
217 match &mut self.write_fut {
218 None => {
219 let b = Bytes::copy_from_slice(buf);
220 let sender = self.sender.clone();
221
222 let _ret = self.write_fut.insert(alloc::boxed::Box::pin(async move {
223 let resp = sender
224 .request(Some(handle), tcp::stream::Command::Send { buf: b })
225 .await?;
226
227 netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
228 Ok(n)
229 }));
230 }
231
232 Some(x) => {
233 let poll_result = x.as_mut().poll(cx);
234 let ret = core::task::ready!(poll_result)?;
235
236 self.write_fut.take();
237
238 break core::task::Poll::Ready(Ok(ret));
239 }
240 }
241 }
242 }
243
244 socket_requestor_impl!();
245}
246
247impl Drop for TcpStream {
248 fn drop(&mut self) {
249 if let Err(e) = self
250 .sender
251 .request_nonblocking(Some(self.handle), tcp::stream::Command::Close)
252 {
253 tracing::warn!(err = %e, "possible socket leak");
254 }
255 }
256}
257
258#[cfg(feature = "std")]
259impl std::io::Read for TcpStream {
260 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
261 self.recv_blocking(buf).map_err(netcore::Error::into)
262 }
263}
264
265#[cfg(feature = "std")]
266impl std::io::Write for TcpStream {
267 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
268 self.send_blocking(buf).map_err(netcore::Error::into)
269 }
270
271 fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
272 let mut buf = Bytes::copy_from_slice(buf);
273
274 while !buf.is_empty() {
275 let resp = self.request_blocking(tcp::stream::Command::Send { buf: buf.clone() })?;
276 netcore::try_response_as!(resp, tcp::stream::Response::Sent { n });
277
278 let _consumed = buf.split_to(n);
279 }
280
281 Ok(())
282 }
283
284 fn flush(&mut self) -> std::io::Result<()> {
285 Ok(())
286 }
287}
288
289#[cfg(feature = "tokio")]
290impl tokio::io::AsyncRead for TcpStream {
291 fn poll_read(
292 self: core::pin::Pin<&mut Self>,
293 cx: &mut core::task::Context<'_>,
294 buf: &mut tokio::io::ReadBuf<'_>,
295 ) -> core::task::Poll<tokio::io::Result<()>> {
296 let n = core::task::ready!(self.poll_read(cx, buf.initialize_unfilled()))?;
297 buf.advance(n);
298
299 core::task::Poll::Ready(Ok(()))
300 }
301}
302
303#[cfg(feature = "tokio")]
304impl tokio::io::AsyncWrite for TcpStream {
305 fn poll_write(
306 self: core::pin::Pin<&mut Self>,
307 cx: &mut core::task::Context<'_>,
308 buf: &[u8],
309 ) -> core::task::Poll<std::io::Result<usize>> {
310 self.poll_write(cx, buf)
311 }
312
313 fn poll_flush(
314 self: core::pin::Pin<&mut Self>,
315 _cx: &mut core::task::Context<'_>,
316 ) -> core::task::Poll<std::io::Result<()>> {
317 core::task::Poll::Ready(Ok(()))
318 }
319
320 fn poll_shutdown(
321 self: core::pin::Pin<&mut Self>,
322 _cx: &mut core::task::Context<'_>,
323 ) -> core::task::Poll<std::io::Result<()>> {
324 core::task::Poll::Ready(Ok(()))
330 }
331}
332
333#[cfg(feature = "futures-io")]
334impl futures_io::AsyncRead for TcpStream {
335 fn poll_read(
336 self: core::pin::Pin<&mut Self>,
337 cx: &mut core::task::Context<'_>,
338 buf: &mut [u8],
339 ) -> core::task::Poll<std::io::Result<usize>> {
340 self.poll_read(cx, buf)
341 }
342}
343
344#[cfg(feature = "futures-io")]
345impl futures_io::AsyncWrite for TcpStream {
346 fn poll_write(
347 self: core::pin::Pin<&mut Self>,
348 cx: &mut core::task::Context<'_>,
349 buf: &[u8],
350 ) -> core::task::Poll<std::io::Result<usize>> {
351 self.poll_write(cx, buf)
352 }
353
354 fn poll_flush(
355 self: core::pin::Pin<&mut Self>,
356 _cx: &mut core::task::Context<'_>,
357 ) -> core::task::Poll<std::io::Result<()>> {
358 core::task::Poll::Ready(Ok(()))
359 }
360
361 fn poll_close(
362 self: core::pin::Pin<&mut Self>,
363 _cx: &mut core::task::Context<'_>,
364 ) -> core::task::Poll<std::io::Result<()>> {
365 core::task::Poll::Ready(Ok(()))
367 }
368}