rasi/
net.rs

1//! Future-based TCP/IP manipulation operations.
2
3use std::{
4    fmt::Debug,
5    io::{ErrorKind, Result},
6    net::{Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs},
7    ops::Deref,
8    sync::{Arc, OnceLock},
9    task::{Context, Poll},
10};
11
12use futures::{future::poll_fn, AsyncRead, AsyncWrite, Stream};
13
14/// A network driver must implement the `Driver-*` traits in this module.
15pub mod syscall {
16    use super::*;
17
18    #[cfg(unix)]
19    pub mod unix {
20        use super::*;
21        pub trait DriverUnixListener: Sync + Send {
22            /// Returns the local socket address of this listener.
23            fn local_addr(&self) -> Result<std::os::unix::net::SocketAddr>;
24
25            /// Polls and accepts a new incoming connection to this listener.
26            ///
27            /// When a connection is established, the corresponding stream and address will be returned.
28            fn poll_next(
29                &self,
30                cx: &mut Context<'_>,
31            ) -> Poll<Result<(crate::net::unix::UnixStream, std::os::unix::net::SocketAddr)>>;
32        }
33
34        pub trait DriverUnixStream: Sync + Send {
35            /// Returns the local address that this stream is connected from.
36            fn local_addr(&self) -> Result<std::os::unix::net::SocketAddr>;
37
38            /// Returns the local address that this stream is connected to.
39            fn peer_addr(&self) -> Result<std::os::unix::net::SocketAddr>;
40
41            /// Shuts down the read, write, or both halves of this connection.
42            ///
43            /// This method will cause all pending and future I/O on the specified portions to return
44            /// immediately with an appropriate value (see the documentation of [`Shutdown`]).
45            ///
46            /// [`Shutdown`]: https://doc.rust-lang.org/std/net/enum.Shutdown.html
47            fn shutdown(&self, how: Shutdown) -> Result<()>;
48
49            /// poll and receives data from the socket.
50            ///
51            /// On success, returns the number of bytes read.
52            fn poll_read(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>>;
53
54            /// Sends data on the socket to the remote address
55            ///
56            /// On success, returns the number of bytes written.
57            fn poll_write(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>;
58
59            /// Poll and wait underlying tcp connection established event.
60            ///
61            /// This function is no effect for server side socket created
62            /// by [`poll_next`](DriverUnixListener::poll_next) function.
63            fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>>;
64        }
65    }
66
67    /// A driver is the main entry to access asynchronously network functions.
68    pub trait Driver: Send + Sync {
69        /// Create a new tcp listener socket with provided `laddrs`.
70        fn tcp_listen(&self, laddrs: &[SocketAddr]) -> Result<TcpListener>;
71
72        #[cfg(unix)]
73        unsafe fn tcp_listener_from_raw_fd(&self, fd: std::os::fd::RawFd) -> Result<TcpListener>;
74
75        #[cfg(windows)]
76        unsafe fn tcp_listener_from_raw_socket(
77            &self,
78            socket: std::os::windows::io::RawSocket,
79        ) -> Result<TcpListener>;
80
81        /// Create a new `TcpStream` and connect to `raddrs`.
82        ///  
83        /// When this function returns a [`TcpStream`] object the underlying
84        /// tcp connection may not be actually established, and the user
85        /// needs to manually call the poll_ready method to wait for the
86        /// connection to be established.
87        fn tcp_connect(&self, raddrs: &SocketAddr) -> Result<TcpStream>;
88
89        #[cfg(unix)]
90        unsafe fn tcp_stream_from_raw_fd(&self, fd: std::os::fd::RawFd) -> Result<TcpStream>;
91
92        #[cfg(windows)]
93        unsafe fn tcp_stream_from_raw_socket(
94            &self,
95            socket: std::os::windows::io::RawSocket,
96        ) -> Result<TcpStream>;
97
98        /// Create new `UdpSocket` which will be bound to the specified `laddrs`
99        fn udp_bind(&self, laddrs: &[SocketAddr]) -> Result<UdpSocket>;
100
101        #[cfg(unix)]
102        unsafe fn udp_from_raw_fd(&self, fd: std::os::fd::RawFd) -> Result<UdpSocket>;
103
104        #[cfg(windows)]
105        unsafe fn udp_from_raw_socket(
106            &self,
107            socket: std::os::windows::io::RawSocket,
108        ) -> Result<UdpSocket>;
109
110        #[cfg(unix)]
111        #[cfg_attr(docsrs, doc(cfg(unix)))]
112        fn unix_listen(&self, path: &std::path::Path) -> Result<crate::net::unix::UnixListener>;
113
114        #[cfg(unix)]
115        #[cfg_attr(docsrs, doc(cfg(unix)))]
116        fn unix_connect(&self, path: &std::path::Path) -> Result<crate::net::unix::UnixStream>;
117    }
118
119    /// Driver-specific `TcpListener` implementation must implement this trait.
120    ///
121    /// When this trait object is dropping, the implementition must close the internal tcp listener socket.
122    pub trait DriverTcpListener: Sync + Send {
123        /// Returns the local socket address of this listener.
124        fn local_addr(&self) -> Result<SocketAddr>;
125
126        /// Gets the value of the IP_TTL option for this socket.
127        /// For more information about this option, see [`set_ttl`](DriverTcpListener::set_ttl).
128        fn ttl(&self) -> Result<u32>;
129
130        /// Sets the value for the IP_TTL option on this socket.
131        /// This value sets the time-to-live field that is used in every packet sent from this socket.
132        fn set_ttl(&self, ttl: u32) -> Result<()>;
133
134        /// Polls and accepts a new incoming connection to this listener.
135        ///
136        /// When a connection is established, the corresponding stream and address will be returned.
137        fn poll_next(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>>;
138    }
139
140    /// Driver-specific `TcpStream` implementation must implement this trait.
141    ///
142    /// When this trait object is dropping, the implementition must close the internal tcp listener socket.
143    pub trait DriverTcpStream: Sync + Send + Debug {
144        /// Returns the local address that this stream is connected from.
145        fn local_addr(&self) -> Result<SocketAddr>;
146
147        /// Returns the local address that this stream is connected to.
148        fn peer_addr(&self) -> Result<SocketAddr>;
149
150        /// Gets the value of the IP_TTL option for this socket.
151        /// For more information about this option, see [`set_ttl`](DriverTcpStream::set_ttl).
152        fn ttl(&self) -> Result<u32>;
153
154        /// Sets the value for the IP_TTL option on this socket.
155        /// This value sets the time-to-live field that is used in every packet sent from this socket.
156        fn set_ttl(&self, ttl: u32) -> Result<()>;
157
158        /// Gets the value of the `TCP_NODELAY` option on this socket.
159        ///
160        /// For more information about this option, see [`set_nodelay`](DriverTcpStream::set_nodelay).
161        fn nodelay(&self) -> Result<bool>;
162
163        /// Sets the value of the `TCP_NODELAY` option on this socket.
164        ///
165        /// If set, this option disables the Nagle algorithm. This means that
166        /// segments are always sent as soon as possible, even if there is only a
167        /// small amount of data. When not set, data is buffered until there is a
168        /// sufficient amount to send out, thereby avoiding the frequent sending of
169        /// small packets.
170        fn set_nodelay(&self, nodelay: bool) -> Result<()>;
171
172        /// Shuts down the read, write, or both halves of this connection.
173        ///
174        /// This method will cause all pending and future I/O on the specified portions to return
175        /// immediately with an appropriate value (see the documentation of [`Shutdown`]).
176        ///
177        /// [`Shutdown`]: https://doc.rust-lang.org/std/net/enum.Shutdown.html
178        fn shutdown(&self, how: Shutdown) -> Result<()>;
179
180        /// poll and receives data from the socket.
181        ///
182        /// On success, returns the number of bytes read.
183        fn poll_read(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>>;
184
185        /// Sends data on the socket to the remote address
186        ///
187        /// On success, returns the number of bytes written.
188        fn poll_write(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>;
189
190        /// Poll and wait underlying tcp connection established event.
191        ///
192        /// This function is no effect for server side socket created
193        /// by [`poll_next`](DriverTcpListener::poll_next) function.
194        fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>>;
195    }
196
197    /// Driver-specific `UdpSocket` implementation
198    ///
199    /// When this trait object is dropping, the implementition must close the internal tcp socket.
200    pub trait DriverUdpSocket: Sync + Send {
201        /// Shuts down the read, write, or both halves of this connection.
202        ///
203        /// This method will cause all pending and future I/O on the specified portions to return
204        /// immediately with an appropriate value (see the documentation of [`Shutdown`]).
205        ///
206        /// [`Shutdown`]: https://doc.rust-lang.org/std/net/enum.Shutdown.html
207        fn shutdown(&self, how: Shutdown) -> Result<()>;
208
209        /// Returns the local address that this stream is connected from.
210        fn local_addr(&self) -> Result<SocketAddr>;
211
212        /// Returns the local address that this stream is connected to.
213        fn peer_addr(&self) -> Result<SocketAddr>;
214
215        /// Gets the value of the IP_TTL option for this socket.
216        /// For more information about this option, see [`set_ttl`](DriverUdpSocket::set_ttl).
217        fn ttl(&self) -> Result<u32>;
218
219        /// Sets the value for the IP_TTL option on this socket.
220        /// This value sets the time-to-live field that is used in every packet sent from this socket.
221        fn set_ttl(&self, ttl: u32) -> Result<()>;
222
223        /// Executes an operation of the IP_ADD_MEMBERSHIP type.
224        ///
225        /// This function specifies a new multicast group for this socket to join.
226        /// The address must be a valid multicast address, and interface is the
227        /// address of the local interface with which the system should join the
228        /// multicast group. If it’s equal to INADDR_ANY then an appropriate
229        /// interface is chosen by the system.
230        fn join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> Result<()>;
231
232        /// Executes an operation of the `IPV6_ADD_MEMBERSHIP` type.
233        ///
234        /// This function specifies a new multicast group for this socket to join.
235        /// The address must be a valid multicast address, and `interface` is the
236        /// index of the interface to join/leave (or 0 to indicate any interface).
237        fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> Result<()>;
238
239        /// Executes an operation of the `IP_DROP_MEMBERSHIP` type.
240        ///
241        /// For more information about this option, see
242        /// [`join_multicast_v4`][link].
243        ///
244        /// [link]: #method.join_multicast_v4
245        fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> Result<()>;
246
247        /// Executes an operation of the `IPV6_DROP_MEMBERSHIP` type.
248        ///
249        /// For more information about this option, see
250        /// [`join_multicast_v6`][link].
251        ///
252        /// [link]: #method.join_multicast_v6
253        fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> Result<()>;
254
255        /// Sets the value of the IP_MULTICAST_LOOP option for this socket.
256        ///
257        /// If enabled, multicast packets will be looped back to the local socket. Note that this might not have any effect on IPv6 sockets.
258        fn set_multicast_loop_v4(&self, on: bool) -> Result<()>;
259
260        /// Sets the value of the IPV6_MULTICAST_LOOP option for this socket.
261        ///
262        /// Controls whether this socket sees the multicast packets it sends itself. Note that this might not have any affect on IPv4 sockets.
263        fn set_multicast_loop_v6(&self, on: bool) -> Result<()>;
264
265        /// Gets the value of the IP_MULTICAST_LOOP option for this socket.
266        fn multicast_loop_v4(&self) -> Result<bool>;
267
268        /// Gets the value of the IPV6_MULTICAST_LOOP option for this socket.
269        fn multicast_loop_v6(&self) -> Result<bool>;
270
271        /// Sets the value of the SO_BROADCAST option for this socket.
272        /// When enabled, this socket is allowed to send packets to a broadcast address.
273        fn set_broadcast(&self, on: bool) -> Result<()>;
274
275        /// Gets the value of the SO_BROADCAST option for this socket.
276        /// For more information about this option, see [`set_broadcast`](DriverUdpSocket::set_broadcast).
277        fn broadcast(&self) -> Result<bool>;
278
279        /// Receives data from the socket.
280        ///
281        /// On success, returns the number of bytes read and the origin.
282        fn poll_recv_from(
283            &self,
284            cx: &mut Context<'_>,
285            buf: &mut [u8],
286        ) -> Poll<Result<(usize, SocketAddr)>>;
287
288        /// Sends data on the socket to the given `target` address.
289        ///
290        /// On success, returns the number of bytes written.
291        fn poll_send_to(
292            &self,
293            cx: &mut Context<'_>,
294            buf: &[u8],
295            peer: SocketAddr,
296        ) -> Poll<Result<usize>>;
297    }
298}
299
300/// A TCP socket server, listening for connections.
301/// After creating a TcpListener by binding it to a socket address,
302/// it listens for incoming TCP connections. These can be accepted
303/// by awaiting elements from the async stream of incoming connections.
304///
305/// The socket will be closed when the value is dropped.
306/// The Transmission Control Protocol is specified in IETF RFC 793.
307/// This type is an async version of std::net::TcpListener.
308pub struct TcpListener(Box<dyn syscall::DriverTcpListener>);
309
310impl<T: syscall::DriverTcpListener + 'static> From<T> for TcpListener {
311    fn from(value: T) -> Self {
312        Self(Box::new(value))
313    }
314}
315
316impl Deref for TcpListener {
317    type Target = dyn syscall::DriverTcpListener;
318    fn deref(&self) -> &Self::Target {
319        &*self.0
320    }
321}
322
323impl TcpListener {
324    /// Returns inner [`syscall::DriverTcpListener`] ptr.
325    pub fn as_raw_ptr(&self) -> &dyn syscall::DriverTcpListener {
326        &*self.0
327    }
328
329    /// See [`poll_next`](syscall::DriverTcpListener::poll_next) for more information.
330    pub async fn accept(&self) -> Result<(TcpStream, SocketAddr)> {
331        poll_fn(|cx| self.poll_next(cx)).await
332    }
333
334    /// Create new `TcpListener` which will be bound to the specified `laddrs`
335    pub async fn bind<L: ToSocketAddrs>(laddrs: L) -> Result<Self> {
336        Self::bind_with(laddrs, get_network_driver()).await
337    }
338
339    /// Use custom `NetworkDriver` to create new `TcpListener` which will be bound to the specified `laddrs`.
340    pub async fn bind_with<L: ToSocketAddrs>(
341        laddrs: L,
342        driver: &dyn syscall::Driver,
343    ) -> Result<Self> {
344        let laddrs = laddrs.to_socket_addrs()?.collect::<Vec<_>>();
345        driver.tcp_listen(&laddrs)
346    }
347
348    #[cfg(unix)]
349    pub unsafe fn from_raw_fd_with(
350        fd: std::os::fd::RawFd,
351        driver: &dyn syscall::Driver,
352    ) -> Result<Self> {
353        driver.tcp_listener_from_raw_fd(fd)
354    }
355
356    #[cfg(unix)]
357    pub unsafe fn from_raw_fd(fd: std::os::fd::RawFd) -> Result<Self> {
358        Self::from_raw_fd_with(fd, get_network_driver())
359    }
360
361    #[cfg(windows)]
362    pub unsafe fn from_raw_socket_with(
363        fd: std::os::windows::io::RawSocket,
364        driver: &dyn syscall::Driver,
365    ) -> Result<Self> {
366        driver.tcp_listener_from_raw_socket(fd)
367    }
368
369    #[cfg(windows)]
370    pub unsafe fn from_raw_socket(fd: std::os::windows::io::RawSocket) -> Result<Self> {
371        Self::from_raw_socket_with(fd, get_network_driver())
372    }
373}
374
375impl Stream for TcpListener {
376    type Item = Result<TcpStream>;
377
378    fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
379        match self.as_raw_ptr().poll_next(cx) {
380            Poll::Ready(Ok((stream, _))) => Poll::Ready(Some(Ok(stream))),
381            Poll::Ready(Err(err)) => {
382                if err.kind() == ErrorKind::BrokenPipe {
383                    Poll::Ready(None)
384                } else {
385                    Poll::Ready(Some(Err(err)))
386                }
387            }
388            Poll::Pending => Poll::Pending,
389        }
390    }
391}
392
393/// A TCP stream between a local and a remote socket.
394///
395/// A `TcpStream` can either be created by connecting to an endpoint, via the [`connect`] method,
396/// or by [accepting] a connection from a [listener].  It can be read or written to using the
397/// [`AsyncRead`], [`AsyncWrite`], and related extension traits in [`futures::io`].
398///
399/// The connection will be closed when the value is dropped. The reading and writing portions of
400/// the connection can also be shut down individually with the [`shutdown`] method.
401///
402/// This type is an async version of [`std::net::TcpStream`].
403///
404/// [`connect`]: struct.TcpStream.html#method.connect
405/// [accepting]: struct.TcpListener.html#method.accept
406/// [listener]: struct.TcpListener.html
407/// [`AsyncRead`]: https://docs.rs/futures/0.3/futures/io/trait.AsyncRead.html
408/// [`AsyncWrite`]: https://docs.rs/futures/0.3/futures/io/trait.AsyncWrite.html
409/// [`futures::io`]: https://docs.rs/futures/0.3/futures/io/index.html
410/// [`shutdown`]: struct.TcpStream.html#method.shutdown
411/// [`std::net::TcpStream`]: https://doc.rust-lang.org/std/net/struct.TcpStream.html
412#[derive(Debug, Clone)]
413pub struct TcpStream(Arc<Box<dyn syscall::DriverTcpStream>>);
414
415impl<T: syscall::DriverTcpStream + 'static> From<T> for TcpStream {
416    fn from(value: T) -> Self {
417        Self(Arc::new(Box::new(value)))
418    }
419}
420
421impl Deref for TcpStream {
422    type Target = dyn syscall::DriverTcpStream;
423    fn deref(&self) -> &Self::Target {
424        &**self.0
425    }
426}
427
428impl TcpStream {
429    /// Returns inner [`syscall::DriverTcpStream`] ptr.
430    pub fn as_raw_ptr(&self) -> &dyn syscall::DriverTcpStream {
431        &**self.0
432    }
433
434    /// Connect to peer by provided `raddrs`.
435    pub async fn connect<R: ToSocketAddrs>(raddrs: R) -> Result<Self> {
436        Self::connect_with(raddrs, get_network_driver()).await
437    }
438
439    /// Use custom `NetworkDriver` to connect to peer by provided `raddrs`.
440    pub async fn connect_with<R: ToSocketAddrs>(
441        raddrs: R,
442        driver: &dyn syscall::Driver,
443    ) -> Result<Self> {
444        let mut last_error = None;
445
446        for raddr in raddrs.to_socket_addrs()?.collect::<Vec<_>>() {
447            match driver.tcp_connect(&raddr) {
448                Ok(stream) => {
449                    // Wait for the asynchronously connecting to complete
450                    match poll_fn(|cx| stream.poll_ready(cx)).await {
451                        Ok(()) => {
452                            return Ok(stream);
453                        }
454                        Err(err) => {
455                            last_error = Some(err);
456                        }
457                    }
458                }
459                Err(err) => last_error = Some(err),
460            }
461        }
462
463        Err(last_error.unwrap())
464    }
465
466    #[cfg(unix)]
467    pub unsafe fn from_raw_fd_with(
468        fd: std::os::fd::RawFd,
469        driver: &dyn syscall::Driver,
470    ) -> Result<Self> {
471        driver.tcp_stream_from_raw_fd(fd)
472    }
473
474    #[cfg(unix)]
475    pub unsafe fn from_raw_fd(fd: std::os::fd::RawFd) -> Result<Self> {
476        Self::from_raw_fd_with(fd, get_network_driver())
477    }
478
479    #[cfg(windows)]
480    pub unsafe fn from_raw_socket_with(
481        fd: std::os::windows::io::RawSocket,
482        driver: &dyn syscall::Driver,
483    ) -> Result<Self> {
484        driver.tcp_stream_from_raw_socket(fd)
485    }
486
487    #[cfg(windows)]
488    pub unsafe fn from_raw_socket(fd: std::os::windows::io::RawSocket) -> Result<Self> {
489        Self::from_raw_socket_with(fd, get_network_driver())
490    }
491}
492
493impl AsyncRead for TcpStream {
494    fn poll_read(
495        self: std::pin::Pin<&mut Self>,
496        cx: &mut Context<'_>,
497        buf: &mut [u8],
498    ) -> Poll<Result<usize>> {
499        self.as_raw_ptr().poll_read(cx, buf)
500    }
501}
502
503impl AsyncWrite for TcpStream {
504    fn poll_write(
505        self: std::pin::Pin<&mut Self>,
506        cx: &mut Context<'_>,
507        buf: &[u8],
508    ) -> Poll<Result<usize>> {
509        self.as_raw_ptr().poll_write(cx, buf)
510    }
511
512    fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
513        Poll::Ready(Ok(()))
514    }
515
516    fn poll_close(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
517        self.shutdown(Shutdown::Both)?;
518
519        Poll::Ready(Ok(()))
520    }
521}
522
523/// A UDP socket.
524///
525/// After creating a `UdpSocket` by [`bind`]ing it to a socket address, data can be [sent to] and
526/// [received from] any other socket address.
527///
528/// As stated in the User Datagram Protocol's specification in [IETF RFC 768], UDP is an unordered,
529/// unreliable protocol. Refer to [`TcpListener`] and [`TcpStream`] for async TCP primitives.
530///
531/// This type is an async version of [`std::net::UdpSocket`].
532///
533/// [`bind`]: #method.bind
534/// [received from]: #method.recv_from
535/// [sent to]: #method.send_to
536/// [`TcpListener`]: struct.TcpListener.html
537/// [`TcpStream`]: struct.TcpStream.html
538/// [`std::net`]: https://doc.rust-lang.org/std/net/index.html
539/// [IETF RFC 768]: https://tools.ietf.org/html/rfc768
540/// [`std::net::UdpSocket`]: https://doc.rust-lang.org/std/net/struct.UdpSocket.html
541///
542#[derive(Clone)]
543pub struct UdpSocket(Arc<Box<dyn syscall::DriverUdpSocket>>);
544
545impl<T: syscall::DriverUdpSocket + 'static> From<T> for UdpSocket {
546    fn from(value: T) -> Self {
547        Self(Arc::new(Box::new(value)))
548    }
549}
550
551impl Deref for UdpSocket {
552    type Target = dyn syscall::DriverUdpSocket;
553    fn deref(&self) -> &Self::Target {
554        &**self.0
555    }
556}
557
558impl UdpSocket {
559    /// Returns inner driver-specific implementation.
560    pub fn as_raw_ptr(&self) -> &dyn syscall::DriverUdpSocket {
561        &**self.0
562    }
563
564    /// Create new udp socket which will be bound to the specified `laddrs`
565    pub async fn bind<L: ToSocketAddrs>(laddrs: L) -> Result<Self> {
566        Self::bind_with(laddrs, get_network_driver()).await
567    }
568
569    pub async fn bind_with<L: ToSocketAddrs>(
570        laddrs: L,
571        driver: &dyn syscall::Driver,
572    ) -> Result<Self> {
573        let laddrs = laddrs.to_socket_addrs()?.collect::<Vec<_>>();
574        driver.udp_bind(&laddrs)
575    }
576
577    #[cfg(unix)]
578    pub unsafe fn from_raw_fd_with(
579        fd: std::os::fd::RawFd,
580        driver: &dyn syscall::Driver,
581    ) -> Result<Self> {
582        driver.udp_from_raw_fd(fd)
583    }
584
585    #[cfg(unix)]
586    pub unsafe fn from_raw_fd(fd: std::os::fd::RawFd) -> Result<Self> {
587        Self::from_raw_fd_with(fd, get_network_driver())
588    }
589
590    #[cfg(windows)]
591    pub unsafe fn from_raw_socket_with(
592        fd: std::os::windows::io::RawSocket,
593        driver: &dyn syscall::Driver,
594    ) -> Result<Self> {
595        driver.udp_from_raw_socket(fd)
596    }
597
598    #[cfg(windows)]
599    pub unsafe fn from_raw_socket(fd: std::os::windows::io::RawSocket) -> Result<Self> {
600        Self::from_raw_socket_with(fd, get_network_driver())
601    }
602
603    /// Sends data on the socket to the given `target` address.
604    ///
605    /// On success, returns the number of bytes written.
606    pub async fn send_to<Buf: AsRef<[u8]>, A: ToSocketAddrs>(
607        &self,
608        buf: Buf,
609        target: A,
610    ) -> Result<usize> {
611        let mut last_error = None;
612
613        let buf = buf.as_ref();
614
615        for raddr in target.to_socket_addrs()? {
616            match poll_fn(|cx| self.poll_send_to(cx, buf, raddr)).await {
617                Ok(send_size) => return Ok(send_size),
618                Err(err) => {
619                    last_error = Some(err);
620                }
621            }
622        }
623
624        Err(last_error.unwrap())
625    }
626
627    /// Receives data from the socket.
628    ///
629    /// On success, returns the number of bytes read and the origin.
630    pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
631        poll_fn(|cx| self.poll_recv_from(cx, buf)).await
632    }
633}
634
635/// Unix-specific sockets implementation.
636#[cfg(unix)]
637#[cfg_attr(docsrs, doc(cfg(unix)))]
638pub mod unix {
639
640    use super::*;
641    use std::path::Path;
642
643    use super::syscall::unix::*;
644
645    /// A unix domain server-side socket.
646    pub struct UnixListener(Box<dyn DriverUnixListener>);
647
648    impl<T: DriverUnixListener + 'static> From<T> for UnixListener {
649        fn from(value: T) -> Self {
650            Self(Box::new(value))
651        }
652    }
653
654    impl Deref for UnixListener {
655        type Target = dyn DriverUnixListener;
656        fn deref(&self) -> &Self::Target {
657            &*self.0
658        }
659    }
660
661    impl UnixListener {
662        /// Returns inner driver-specific implementation.
663        pub fn as_raw_ptr(&self) -> &dyn DriverUnixListener {
664            &*self.0
665        }
666
667        /// See [`poll_next`](syscall::unix::DriverUnixListener::poll_next) for more information.
668        pub async fn accept(&self) -> Result<(UnixStream, std::os::unix::net::SocketAddr)> {
669            poll_fn(|cx| self.poll_next(cx)).await
670        }
671
672        /// Create new `TcpListener` which will be bound to the specified `laddrs`
673        pub async fn bind<P: AsRef<Path>>(path: P) -> Result<Self> {
674            Self::bind_with(path, get_network_driver()).await
675        }
676
677        /// Use custom `NetworkDriver` to create new `UnixListener` which will be bound to the specified `laddrs`.
678        pub async fn bind_with<P: AsRef<Path>>(
679            path: P,
680            driver: &dyn syscall::Driver,
681        ) -> Result<Self> {
682            driver.unix_listen(path.as_ref())
683        }
684    }
685
686    impl Stream for UnixListener {
687        type Item = Result<UnixStream>;
688
689        fn poll_next(
690            self: std::pin::Pin<&mut Self>,
691            cx: &mut Context<'_>,
692        ) -> Poll<Option<Self::Item>> {
693            match self.as_raw_ptr().poll_next(cx) {
694                Poll::Ready(Ok((stream, _))) => Poll::Ready(Some(Ok(stream))),
695                Poll::Ready(Err(err)) => {
696                    if err.kind() == ErrorKind::BrokenPipe {
697                        Poll::Ready(None)
698                    } else {
699                        Poll::Ready(Some(Err(err)))
700                    }
701                }
702                Poll::Pending => Poll::Pending,
703            }
704        }
705    }
706
707    /// A unix domain stream between a local and a remote socket.
708    #[derive(Clone)]
709    pub struct UnixStream(Arc<Box<dyn DriverUnixStream>>);
710
711    impl<T: DriverUnixStream + 'static> From<T> for UnixStream {
712        fn from(value: T) -> Self {
713            Self(Arc::new(Box::new(value)))
714        }
715    }
716
717    impl Deref for UnixStream {
718        type Target = dyn DriverUnixStream;
719        fn deref(&self) -> &Self::Target {
720            &**self.0
721        }
722    }
723
724    impl UnixStream {
725        /// Returns inner driver-specific implementation.
726        pub fn as_raw_ptr(&self) -> &dyn DriverUnixStream {
727            &**self.0
728        }
729
730        /// Connect to peer by provided `raddrs`.
731        pub async fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
732            Self::connect_with(path, get_network_driver()).await
733        }
734
735        /// Use custom `NetworkDriver` to connect to peer by provided `raddrs`.
736        pub async fn connect_with<P: AsRef<Path>>(
737            path: P,
738            driver: &dyn syscall::Driver,
739        ) -> Result<Self> {
740            let stream = driver.unix_connect(path.as_ref())?;
741
742            // Wait for the asynchronously connecting to complete
743            poll_fn(|cx| stream.poll_ready(cx)).await?;
744
745            Ok(stream)
746        }
747    }
748
749    impl AsyncRead for UnixStream {
750        fn poll_read(
751            self: std::pin::Pin<&mut Self>,
752            cx: &mut Context<'_>,
753            buf: &mut [u8],
754        ) -> Poll<Result<usize>> {
755            self.as_raw_ptr().poll_read(cx, buf)
756        }
757    }
758
759    impl AsyncWrite for UnixStream {
760        fn poll_write(
761            self: std::pin::Pin<&mut Self>,
762            cx: &mut Context<'_>,
763            buf: &[u8],
764        ) -> Poll<Result<usize>> {
765            self.as_raw_ptr().poll_write(cx, buf)
766        }
767
768        fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
769            Poll::Ready(Ok(()))
770        }
771
772        fn poll_close(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
773            self.shutdown(Shutdown::Both)?;
774
775            Poll::Ready(Ok(()))
776        }
777    }
778}
779
780static NETWORK_DRIVER: OnceLock<Box<dyn syscall::Driver>> = OnceLock::new();
781
782/// Get global register `NetworkDriver` instance.
783pub fn get_network_driver() -> &'static dyn syscall::Driver {
784    NETWORK_DRIVER
785        .get()
786        .expect("Call register_network_driver first.")
787        .as_ref()
788}
789
790/// Register provided [`syscall::Driver`] as global network implementation.
791///
792/// # Panic
793///
794/// Multiple calls to this function are not permitted!!!
795pub fn register_network_driver<E: syscall::Driver + 'static>(driver: E) {
796    if NETWORK_DRIVER.set(Box::new(driver)).is_err() {
797        panic!("Multiple calls to register_global_network are not permitted!!!");
798    }
799}