agnostic_net/
tcp.rs

1use std::{future::Future, io, net::SocketAddr, time::Duration};
2
3use agnostic_lite::RuntimeLite;
4
5use super::{
6  Fd, ToSocketAddrs,
7  io::{AsyncRead, AsyncReadWrite, AsyncWrite},
8};
9
10#[cfg(any(feature = "smol", feature = "tokio"))]
11macro_rules! resolve_address_error {
12  () => {{
13    ::std::io::Error::new(
14      ::std::io::ErrorKind::InvalidInput,
15      "could not resolve to any address",
16    )
17  }};
18}
19
20#[cfg(any(feature = "smol", feature = "tokio"))]
21macro_rules! tcp_listener_common_methods {
22  ($ty:ident.$field:ident) => {
23    async fn bind<A: $crate::ToSocketAddrs<Self::Runtime>>(addr: A) -> std::io::Result<Self>
24    where
25      Self: Sized,
26    {
27      let addrs = addr.to_socket_addrs().await?;
28
29      let mut last_err = core::option::Option::None;
30      for addr in addrs {
31        match $ty::bind(addr).await {
32          ::core::result::Result::Ok(ln) => return ::core::result::Result::Ok(Self { ln }),
33          ::core::result::Result::Err(e) => last_err = core::option::Option::Some(e),
34        }
35      }
36
37      ::core::result::Result::Err(last_err.unwrap_or_else(|| resolve_address_error!()))
38    }
39
40    async fn accept(&self) -> ::std::io::Result<(Self::Stream, ::std::net::SocketAddr)> {
41      self
42        .$field
43        .accept()
44        .await
45        .map(|(stream, addr)| (Self::Stream::from(stream), addr))
46    }
47
48    fn local_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
49      self.$field.local_addr()
50    }
51  };
52}
53
54#[cfg(any(feature = "smol", feature = "tokio"))]
55macro_rules! tcp_stream_common_methods {
56  ($runtime:literal::$field:ident) => {
57    async fn connect<A: $crate::ToSocketAddrs<Self::Runtime>>(addr: A) -> ::std::io::Result<Self>
58    where
59      Self: Sized,
60    {
61      let addrs = addr.to_socket_addrs().await?;
62
63      let mut last_err = ::core::option::Option::None;
64
65      for addr in addrs {
66        paste::paste! {
67          match ::[< $runtime:snake >]::net::TcpStream::connect(addr).await {
68            ::core::result::Result::Ok(stream) => return ::core::result::Result::Ok(Self::from(stream)),
69            ::core::result::Result::Err(e) => last_err = ::core::option::Option::Some(e),
70          }
71        }
72      }
73
74      ::core::result::Result::Err(last_err.unwrap_or_else(|| resolve_address_error!()))
75    }
76
77    async fn connect_timeout(
78      addr: &::std::net::SocketAddr,
79      timeout: ::std::time::Duration,
80    ) -> ::std::io::Result<Self>
81    where
82      Self: Sized
83    {
84      let res = <Self::Runtime as ::agnostic_lite::RuntimeLite>::timeout(timeout, Self::connect(addr)).await;
85
86      match res {
87        ::core::result::Result::Ok(stream) => stream,
88        ::core::result::Result::Err(err) => Err(err.into()),
89      }
90    }
91
92    async fn peek(&self, buf: &mut [u8]) -> ::std::io::Result<usize> {
93      self.$field.peek(buf).await
94    }
95
96    fn local_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
97      self.$field.local_addr()
98    }
99
100    fn peer_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
101      self.$field.peer_addr()
102    }
103
104    fn set_ttl(&self, ttl: u32) -> ::std::io::Result<()> {
105      self.$field.set_ttl(ttl)
106    }
107
108    fn ttl(&self) -> ::std::io::Result<u32> {
109      self.$field.ttl()
110    }
111
112    fn set_nodelay(&self, nodelay: bool) -> ::std::io::Result<()> {
113      self.$field.set_nodelay(nodelay)
114    }
115
116    fn nodelay(&self) -> ::std::io::Result<bool> {
117      self.$field.nodelay()
118    }
119  };
120}
121
122#[cfg(any(feature = "smol", feature = "tokio"))]
123macro_rules! tcp_stream_owned_read_half_common_methods {
124  ($field:ident) => {
125    fn local_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
126      self.$field.local_addr()
127    }
128
129    fn peer_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
130      self.$field.peer_addr()
131    }
132
133    async fn peek(&mut self, buf: &mut [u8]) -> ::std::io::Result<usize> {
134      self.$field.peek(buf).await
135    }
136  };
137}
138
139#[cfg(any(feature = "smol", feature = "tokio"))]
140macro_rules! tcp_stream_owned_write_half_common_methods {
141  ($field:ident) => {
142    fn local_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
143      self.$field.local_addr()
144    }
145
146    fn peer_addr(&self) -> ::std::io::Result<::std::net::SocketAddr> {
147      self.$field.peer_addr()
148    }
149  };
150}
151
152#[cfg(feature = "smol")]
153macro_rules! tcp_listener_incoming {
154  ($ty:ty => $stream:ty) => {
155    pin_project_lite::pin_project! {
156      /// A stream of incoming TCP connections.
157      ///
158      /// This stream is infinite, i.e awaiting the next connection will never result in [`None`]. It is
159      /// created by the [`TcpListener::incoming()`](crate::TcpListener::incoming) method.
160      pub struct Incoming<'a> {
161        #[pin]
162        inner: $ty,
163      }
164    }
165
166    impl core::fmt::Debug for Incoming<'_> {
167      fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
168        write!(f, "Incoming {{ ... }}")
169      }
170    }
171
172    impl<'a> From<$ty> for Incoming<'a> {
173      fn from(inner: $ty) -> Self {
174        Self { inner }
175      }
176    }
177
178    impl<'a> From<Incoming<'a>> for $ty {
179      fn from(incoming: Incoming<'a>) -> Self {
180        incoming.inner
181      }
182    }
183
184    impl<'a> ::futures_util::stream::Stream for Incoming<'a> {
185      type Item = ::std::io::Result<$stream>;
186
187      fn poll_next(
188        self: ::std::pin::Pin<&mut Self>,
189        cx: &mut ::std::task::Context<'_>,
190      ) -> ::std::task::Poll<::core::option::Option<Self::Item>> {
191        self
192          .project()
193          .inner
194          .poll_next(cx)
195          .map(|stream| stream.map(|stream| stream.map(<$stream>::from)))
196      }
197    }
198  };
199}
200
201/// The abstraction of a owned read half of a TcpStream.
202pub trait OwnedReadHalf: AsyncRead + Unpin + Send + Sync + 'static {
203  /// The async runtime.
204  type Runtime: RuntimeLite;
205
206  /// Returns the local address that this stream is bound to.
207  fn local_addr(&self) -> io::Result<SocketAddr>;
208
209  /// Returns the remote address that this stream is connected to.
210  fn peer_addr(&self) -> io::Result<SocketAddr>;
211
212  /// Receives data on the socket from the remote address to which it is connected, without
213  /// removing that data from the queue.
214  ///
215  /// On success, returns the number of bytes peeked.
216  ///
217  /// Successive calls return the same data. This is accomplished by passing `MSG_PEEK` as a flag
218  /// to the underlying `recv` system call.
219  fn peek(&mut self, buf: &mut [u8]) -> impl Future<Output = io::Result<usize>> + Send;
220}
221
222/// The abstraction of a owned write half of a TcpStream.
223pub trait OwnedWriteHalf: AsyncWrite + Unpin + Send + Sync + 'static {
224  /// The async runtime.
225  type Runtime: RuntimeLite;
226
227  /// Shuts down the write half and without closing the read half.
228  fn forget(self);
229
230  /// Returns the local address that this stream is bound to.
231  fn local_addr(&self) -> io::Result<SocketAddr>;
232
233  /// Returns the remote address that this stream is connected to.
234  fn peer_addr(&self) -> io::Result<SocketAddr>;
235}
236
237/// Error indicating that two halves were not from the same socket, and thus could not be reunited.
238pub trait ReuniteError<T>: core::error::Error + Unpin + Send + Sync + 'static
239where
240  T: TcpStream,
241{
242  /// Consumes the error and returns the read half and write half of the socket.
243  fn into_components(self) -> (T::OwnedReadHalf, T::OwnedWriteHalf);
244}
245
246/// The abstraction of a TCP stream.
247pub trait TcpStream:
248  TryFrom<std::net::TcpStream, Error = io::Error>
249  + Fd
250  + AsyncReadWrite
251  + Unpin
252  + Send
253  + Sync
254  + 'static
255{
256  /// The async runtime.
257  type Runtime: RuntimeLite;
258  /// The owned read half of the stream.
259  type OwnedReadHalf: OwnedReadHalf;
260  /// The owned write half of the stream.
261  type OwnedWriteHalf: OwnedWriteHalf;
262  /// Error indicating that two halves were not from the same socket, and thus could not be reunited.
263  type ReuniteError: ReuniteError<Self>;
264
265  /// Connects to the specified address.
266  fn connect<A: ToSocketAddrs<Self::Runtime>>(
267    addr: A,
268  ) -> impl Future<Output = io::Result<Self>> + Send
269  where
270    Self: Sized;
271
272  /// Opens a TCP connection to a remote host with a timeout.
273  ///
274  /// Unlike `connect`, `connect_timeout` takes a single [`SocketAddr`] since
275  /// timeout must be applied to individual addresses.
276  ///
277  /// It is an error to pass a zero `Duration` to this function.
278  ///
279  /// Unlike other methods on `TcpStream`, this does not correspond to a
280  /// single system call. It instead calls `connect` in nonblocking mode and
281  /// then uses an OS-specific mechanism to await the completion of the
282  /// connection request.
283  fn connect_timeout(
284    addr: &SocketAddr,
285    timeout: Duration,
286  ) -> impl Future<Output = io::Result<Self>> + Send
287  where
288    Self: Sized;
289
290  /// Receives data on the socket from the remote address to which it is connected, without
291  /// removing that data from the queue.
292  ///
293  /// On success, returns the number of bytes peeked.
294  ///
295  /// Successive calls return the same data. This is accomplished by passing `MSG_PEEK` as a flag
296  /// to the underlying `recv` system call.
297  fn peek(&self, buf: &mut [u8]) -> impl Future<Output = io::Result<usize>> + Send;
298
299  /// Returns the local address that this stream is bound to.
300  fn local_addr(&self) -> io::Result<SocketAddr>;
301
302  /// Returns the remote address that this stream is connected to.
303  fn peer_addr(&self) -> io::Result<SocketAddr>;
304
305  /// Sets the time-to-live value for this socket.  
306  fn set_ttl(&self, ttl: u32) -> io::Result<()>;
307
308  /// Gets the time-to-live value of this socket.
309  fn ttl(&self) -> io::Result<u32>;
310
311  /// Sets the value of the `TCP_NODELAY` option on this socket.
312  fn set_nodelay(&self, nodelay: bool) -> io::Result<()>;
313
314  /// Gets the value of the `TCP_NODELAY` option on this socket.
315  fn nodelay(&self) -> io::Result<bool>;
316
317  /// Splits the stream to read and write halves.
318  fn into_split(self) -> (Self::OwnedReadHalf, Self::OwnedWriteHalf);
319
320  /// Shuts down the read, write, or both halves of this connection.
321  fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
322    super::os::shutdown(self, how)
323  }
324
325  /// Creates a new independently owned handle to the underlying socket.
326  ///
327  /// The returned `UdpSocket` is a reference to the same socket that this
328  /// object references. Both handles will read and write the same port, and
329  /// options set on one socket will be propagated to the other.
330  fn try_clone(&self) -> io::Result<Self> {
331    super::os::duplicate::<_, std::net::TcpStream>(self).and_then(Self::try_from)
332  }
333
334  /// Get the value of the `IPV6_V6ONLY` option for this socket.
335  fn only_v6(&self) -> io::Result<bool> {
336    super::os::only_v6(self)
337  }
338
339  /// Gets the value of the `SO_LINGER` option on this socket.
340  ///
341  /// For more information about this option, see [`TcpStream::set_linger`].
342  fn linger(&self) -> io::Result<Option<std::time::Duration>> {
343    super::os::linger(self)
344  }
345
346  /// Sets the value of the `SO_LINGER` option on this socket.
347  ///
348  /// This value controls how the socket is closed when data remains to be sent.
349  /// If `SO_LINGER` is set, the socket will remain open for the specified duration as the system attempts to send pending data.
350  /// Otherwise, the system may close the socket immediately, or wait for a default timeout.
351  fn set_linger(&self, duration: Option<std::time::Duration>) -> io::Result<()> {
352    super::os::set_linger(self, duration)
353  }
354
355  /// Attempts to put the two halves of a TcpStream back together and recover the original socket. Succeeds only if the two halves originated from the same call to [`into_split`][TcpStream::into_split].
356  fn reunite(
357    read: Self::OwnedReadHalf,
358    write: Self::OwnedWriteHalf,
359  ) -> Result<Self, Self::ReuniteError>
360  where
361    Self: Sized;
362}
363
364/// An abstraction layer for TCP listener.
365pub trait TcpListener:
366  TryFrom<std::net::TcpListener, Error = io::Error> + Fd + Unpin + Send + Sync + 'static
367{
368  /// The async runtime.
369  type Runtime: RuntimeLite;
370  /// Stream of incoming connections.
371  type Stream: TcpStream<Runtime = Self::Runtime>;
372
373  /// A stream of incoming TCP connections.
374  ///
375  /// This stream is infinite, i.e awaiting the next connection will never result in [`None`]. It is
376  /// created by the [`TcpListener::incoming()`] method.
377  type Incoming<'a>: futures_util::stream::Stream<Item = io::Result<Self::Stream>>
378    + Send
379    + Sync
380    + Unpin
381    + 'a;
382
383  /// Creates a new TcpListener, which will be bound to the specified address.
384  ///
385  /// The returned listener is ready for accepting connections.
386  ///
387  /// Binding with a port number of 0 will request that the OS assigns a port
388  /// to this listener. The port allocated can be queried via the `local_addr`
389  /// method.
390  ///
391  /// The address type can be any implementor of the [`ToSocketAddrs`] trait.
392  /// If `addr` yields multiple addresses, bind will be attempted with each of
393  /// the addresses until one succeeds and returns the listener. If none of
394  /// the addresses succeed in creating a listener, the error returned from
395  /// the last attempt (the last address) is returned.
396  ///
397  /// This function sets the `SO_REUSEADDR` option on the socket.
398  fn bind<A: ToSocketAddrs<Self::Runtime>>(
399    addr: A,
400  ) -> impl Future<Output = io::Result<Self>> + Send
401  where
402    Self: Sized;
403
404  /// Accepts a new incoming connection from this listener.
405  ///
406  /// This function will yield once a new TCP connection is established. When established,
407  /// the corresponding [`TcpStream`] and the remote peer's address will be returned.
408  fn accept(&self) -> impl Future<Output = io::Result<(Self::Stream, SocketAddr)>> + Send;
409
410  /// Returns a stream of incoming connections.
411  ///
412  /// Iterating over this stream is equivalent to calling [`accept()`][`TcpListener::accept()`]
413  /// in a loop. The stream of connections is infinite, i.e awaiting the next connection will
414  /// never result in [`None`].
415  ///
416  /// See also [`TcpListener::into_incoming`].
417  fn incoming(&self) -> Self::Incoming<'_>;
418
419  /// Turn this into a stream over the connections being received on this
420  /// listener.
421  ///
422  /// The returned stream is infinite and will also not yield
423  /// the peer's [`SocketAddr`] structure. Iterating over it is equivalent to
424  /// calling [`TcpListener::accept`] in a loop.
425  ///
426  /// See also [`TcpListener::incoming`].
427  fn into_incoming(
428    self,
429  ) -> impl futures_util::stream::Stream<Item = io::Result<Self::Stream>> + Send;
430
431  /// Returns the local address that this listener is bound to.
432  ///
433  /// This can be useful, for example, when binding to port 0 to figure out which port was actually bound.
434  fn local_addr(&self) -> io::Result<SocketAddr>;
435
436  /// Sets the time-to-live value for this socket.  
437  fn set_ttl(&self, ttl: u32) -> io::Result<()>;
438
439  /// Gets the time-to-live value of this socket.
440  fn ttl(&self) -> io::Result<u32>;
441
442  /// Creates a new independently owned handle to the underlying socket.
443  ///
444  /// The returned `UdpSocket` is a reference to the same socket that this
445  /// object references. Both handles will read and write the same port, and
446  /// options set on one socket will be propagated to the other.
447  fn try_clone(&self) -> io::Result<Self> {
448    super::os::duplicate::<_, std::net::TcpListener>(self).and_then(Self::try_from)
449  }
450}