1use std::{future::Future, io, net::SocketAddr, time::Duration};
2
3use agnostic_lite::RuntimeLite;
4
5use super::{
6 Fd, ToSocketAddrs,
7 io::{AsyncRead, AsyncReadWrite, AsyncWrite},
8};
9
10#[cfg(any(feature = "smol", feature = "tokio"))]
11macro_rules! resolve_address_error {
12 () => {{
13 ::std::io::Error::new(
14 ::std::io::ErrorKind::InvalidInput,
15 "could not resolve to any address",
16 )
17 }};
18}
19
20#[cfg(any(feature = "smol", feature = "tokio"))]
21macro_rules! tcp_listener_common_methods {
22 ($ty:ident.$field:ident) => {
23 async fn bind<A: $crate::ToSocketAddrs<Self::Runtime>>(addr: A) -> std::io::Result<Self>
24 where
25 Self: Sized,
26 {
27 let addrs = addr.to_socket_addrs().await?;
28
29 let mut last_err = core::option::Option::None;
30 for addr in addrs {
31 match $ty::bind(addr).await {
32 ::core::result::Result::Ok(ln) => return ::core::result::Result::Ok(Self { ln }),
33 ::core::result::Result::Err(e) => last_err = core::option::Option::Some(e),
34 }
35 }
36
37 ::core::result::Result::Err(last_err.unwrap_or_else(|| resolve_address_error!()))
38 }
39
40 async fn accept(&self) -> ::std::io::Result<(Self::Stream, ::std::net::SocketAddr)> {
41 self
42 .$field
43 .accept()
44 .await
45 .map(|(stream, addr)| (Self::Stream::from(stream), addr))
46 }
47
48 fn local_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
49 self.$field.local_addr()
50 }
51 };
52}
53
54#[cfg(any(feature = "smol", feature = "tokio"))]
55macro_rules! tcp_stream_common_methods {
56 ($runtime:literal::$field:ident) => {
57 async fn connect<A: $crate::ToSocketAddrs<Self::Runtime>>(addr: A) -> ::std::io::Result<Self>
58 where
59 Self: Sized,
60 {
61 let addrs = addr.to_socket_addrs().await?;
62
63 let mut last_err = ::core::option::Option::None;
64
65 for addr in addrs {
66 paste::paste! {
67 match ::[< $runtime:snake >]::net::TcpStream::connect(addr).await {
68 ::core::result::Result::Ok(stream) => return ::core::result::Result::Ok(Self::from(stream)),
69 ::core::result::Result::Err(e) => last_err = ::core::option::Option::Some(e),
70 }
71 }
72 }
73
74 ::core::result::Result::Err(last_err.unwrap_or_else(|| resolve_address_error!()))
75 }
76
77 async fn connect_timeout(
78 addr: &::std::net::SocketAddr,
79 timeout: ::std::time::Duration,
80 ) -> ::std::io::Result<Self>
81 where
82 Self: Sized
83 {
84 let res = <Self::Runtime as ::agnostic_lite::RuntimeLite>::timeout(timeout, Self::connect(addr)).await;
85
86 match res {
87 ::core::result::Result::Ok(stream) => stream,
88 ::core::result::Result::Err(err) => Err(err.into()),
89 }
90 }
91
92 async fn peek(&self, buf: &mut [u8]) -> ::std::io::Result<usize> {
93 self.$field.peek(buf).await
94 }
95
96 fn local_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
97 self.$field.local_addr()
98 }
99
100 fn peer_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
101 self.$field.peer_addr()
102 }
103
104 fn set_ttl(&self, ttl: u32) -> ::std::io::Result<()> {
105 self.$field.set_ttl(ttl)
106 }
107
108 fn ttl(&self) -> ::std::io::Result<u32> {
109 self.$field.ttl()
110 }
111
112 fn set_nodelay(&self, nodelay: bool) -> ::std::io::Result<()> {
113 self.$field.set_nodelay(nodelay)
114 }
115
116 fn nodelay(&self) -> ::std::io::Result<bool> {
117 self.$field.nodelay()
118 }
119 };
120}
121
122#[cfg(any(feature = "smol", feature = "tokio"))]
123macro_rules! tcp_stream_owned_read_half_common_methods {
124 ($field:ident) => {
125 fn local_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
126 self.$field.local_addr()
127 }
128
129 fn peer_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
130 self.$field.peer_addr()
131 }
132
133 async fn peek(&mut self, buf: &mut [u8]) -> ::std::io::Result<usize> {
134 self.$field.peek(buf).await
135 }
136 };
137}
138
139#[cfg(any(feature = "smol", feature = "tokio"))]
140macro_rules! tcp_stream_owned_write_half_common_methods {
141 ($field:ident) => {
142 fn local_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
143 self.$field.local_addr()
144 }
145
146 fn peer_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
147 self.$field.peer_addr()
148 }
149 };
150}
151
152#[cfg(feature = "smol")]
153macro_rules! tcp_listener_incoming {
154 ($ty:ty => $stream:ty) => {
155 pin_project_lite::pin_project! {
156 pub struct Incoming<'a> {
161 #[pin]
162 inner: $ty,
163 }
164 }
165
166 impl core::fmt::Debug for Incoming<'_> {
167 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
168 write!(f, "Incoming {{ ... }}")
169 }
170 }
171
172 impl<'a> From<$ty> for Incoming<'a> {
173 fn from(inner: $ty) -> Self {
174 Self { inner }
175 }
176 }
177
178 impl<'a> From<Incoming<'a>> for $ty {
179 fn from(incoming: Incoming<'a>) -> Self {
180 incoming.inner
181 }
182 }
183
184 impl<'a> ::futures_util::stream::Stream for Incoming<'a> {
185 type Item = ::std::io::Result<$stream>;
186
187 fn poll_next(
188 self: ::std::pin::Pin<&mut Self>,
189 cx: &mut ::std::task::Context<'_>,
190 ) -> ::std::task::Poll<::core::option::Option<Self::Item>> {
191 self
192 .project()
193 .inner
194 .poll_next(cx)
195 .map(|stream| stream.map(|stream| stream.map(<$stream>::from)))
196 }
197 }
198 };
199}
200
201pub trait OwnedReadHalf: AsyncRead + Unpin + Send + Sync + 'static {
203 type Runtime: RuntimeLite;
205
206 fn local_addr(&self) -> io::Result<SocketAddr>;
208
209 fn peer_addr(&self) -> io::Result<SocketAddr>;
211
212 fn peek(&mut self, buf: &mut [u8]) -> impl Future<Output = io::Result<usize>> + Send;
220}
221
222pub trait OwnedWriteHalf: AsyncWrite + Unpin + Send + Sync + 'static {
224 type Runtime: RuntimeLite;
226
227 fn forget(self);
229
230 fn local_addr(&self) -> io::Result<SocketAddr>;
232
233 fn peer_addr(&self) -> io::Result<SocketAddr>;
235}
236
237pub trait ReuniteError<T>: core::error::Error + Unpin + Send + Sync + 'static
239where
240 T: TcpStream,
241{
242 fn into_components(self) -> (T::OwnedReadHalf, T::OwnedWriteHalf);
244}
245
246pub trait TcpStream:
248 TryFrom<std::net::TcpStream, Error = io::Error>
249 + Fd
250 + AsyncReadWrite
251 + Unpin
252 + Send
253 + Sync
254 + 'static
255{
256 type Runtime: RuntimeLite;
258 type OwnedReadHalf: OwnedReadHalf;
260 type OwnedWriteHalf: OwnedWriteHalf;
262 type ReuniteError: ReuniteError<Self>;
264
265 fn connect<A: ToSocketAddrs<Self::Runtime>>(
267 addr: A,
268 ) -> impl Future<Output = io::Result<Self>> + Send
269 where
270 Self: Sized;
271
272 fn connect_timeout(
284 addr: &SocketAddr,
285 timeout: Duration,
286 ) -> impl Future<Output = io::Result<Self>> + Send
287 where
288 Self: Sized;
289
290 fn peek(&self, buf: &mut [u8]) -> impl Future<Output = io::Result<usize>> + Send;
298
299 fn local_addr(&self) -> io::Result<SocketAddr>;
301
302 fn peer_addr(&self) -> io::Result<SocketAddr>;
304
305 fn set_ttl(&self, ttl: u32) -> io::Result<()>;
307
308 fn ttl(&self) -> io::Result<u32>;
310
311 fn set_nodelay(&self, nodelay: bool) -> io::Result<()>;
313
314 fn nodelay(&self) -> io::Result<bool>;
316
317 fn into_split(self) -> (Self::OwnedReadHalf, Self::OwnedWriteHalf);
319
320 fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
322 super::os::shutdown(self, how)
323 }
324
325 fn try_clone(&self) -> io::Result<Self> {
331 super::os::duplicate::<_, std::net::TcpStream>(self).and_then(Self::try_from)
332 }
333
334 fn only_v6(&self) -> io::Result<bool> {
336 super::os::only_v6(self)
337 }
338
339 fn linger(&self) -> io::Result<Option<std::time::Duration>> {
343 super::os::linger(self)
344 }
345
346 fn set_linger(&self, duration: Option<std::time::Duration>) -> io::Result<()> {
352 super::os::set_linger(self, duration)
353 }
354
355 fn reunite(
357 read: Self::OwnedReadHalf,
358 write: Self::OwnedWriteHalf,
359 ) -> Result<Self, Self::ReuniteError>
360 where
361 Self: Sized;
362}
363
364pub trait TcpListener:
366 TryFrom<std::net::TcpListener, Error = io::Error> + Fd + Unpin + Send + Sync + 'static
367{
368 type Runtime: RuntimeLite;
370 type Stream: TcpStream<Runtime = Self::Runtime>;
372
373 type Incoming<'a>: futures_util::stream::Stream<Item = io::Result<Self::Stream>>
378 + Send
379 + Sync
380 + Unpin
381 + 'a;
382
383 fn bind<A: ToSocketAddrs<Self::Runtime>>(
399 addr: A,
400 ) -> impl Future<Output = io::Result<Self>> + Send
401 where
402 Self: Sized;
403
404 fn accept(&self) -> impl Future<Output = io::Result<(Self::Stream, SocketAddr)>> + Send;
409
410 fn incoming(&self) -> Self::Incoming<'_>;
418
419 fn into_incoming(
428 self,
429 ) -> impl futures_util::stream::Stream<Item = io::Result<Self::Stream>> + Send;
430
431 fn local_addr(&self) -> io::Result<SocketAddr>;
435
436 fn set_ttl(&self, ttl: u32) -> io::Result<()>;
438
439 fn ttl(&self) -> io::Result<u32>;
441
442 fn try_clone(&self) -> io::Result<Self> {
448 super::os::duplicate::<_, std::net::TcpListener>(self).and_then(Self::try_from)
449 }
450}