Skip to main content

gstthreadshare/runtime/executor/
async_wrapper.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This is based on https://github.com/smol-rs/async-io
3// with adaptations by:
4//
5// Copyright (C) 2021 François Laignel <fengalin@free.fr>
6
7use futures::io::{AsyncRead, AsyncWrite};
8use futures::stream::{self, Stream};
9use futures::{future, pin_mut, ready};
10
11use std::future::Future;
12use std::io::{self, IoSlice, IoSliceMut, Read, Write};
13use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18#[cfg(unix)]
19use std::{
20    os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd},
21    os::unix::net::{SocketAddr as UnixSocketAddr, UnixDatagram, UnixListener, UnixStream},
22    path::Path,
23};
24
25#[cfg(windows)]
26use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, OwnedSocket, RawSocket};
27
28use rustix::io as rio;
29use rustix::net as rn;
30use rustix::net::addr::SocketAddrArg;
31
32use crate::runtime::RUNTIME_CAT;
33
34use super::scheduler;
35use super::{Reactor, Readable, ReadableOwned, Registration, Source, Writable, WritableOwned};
36
37/// Async adapter for I/O types.
38///
39/// This type puts an I/O handle into non-blocking mode, registers it in
40/// [epoll]/[kqueue]/[event ports]/[wepoll], and then provides an async interface for it.
41///
42/// [epoll]: https://en.wikipedia.org/wiki/Epoll
43/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
44/// [event ports]: https://illumos.org/man/port_create
45/// [wepoll]: https://github.com/piscisaureus/wepoll
46///
47/// # Caveats
48///
49/// The [`Async`] implementation is specific to the threadshare implementation.
50/// Neither [`async-net`] nor [`async-process`] (on Unix) can be used.
51///
52/// [`async-net`]: https://github.com/smol-rs/async-net
53/// [`async-process`]: https://github.com/smol-rs/async-process
54///
55/// ### Supported types
56///
57/// [`Async`] supports all networking types, as well as some OS-specific file descriptors like
58/// [timerfd] and [inotify].
59///
60/// However, do not use [`Async`] with types like [`File`][`std::fs::File`],
61/// [`Stdin`][`std::io::Stdin`], [`Stdout`][`std::io::Stdout`], or [`Stderr`][`std::io::Stderr`]
62/// because all operating systems have issues with them when put in non-blocking mode.
63///
64/// [timerfd]: https://github.com/smol-rs/async-io/blob/master/examples/linux-timerfd.rs
65/// [inotify]: https://github.com/smol-rs/async-io/blob/master/examples/linux-inotify.rs
66///
67/// ### Concurrent I/O
68///
69/// Note that [`&Async<T>`][`Async`] implements [`AsyncRead`] and [`AsyncWrite`] if `&T`
70/// implements those traits, which means tasks can concurrently read and write using shared
71/// references.
72///
73/// But there is a catch: only one task can read a time, and only one task can write at a time. It
74/// is okay to have two tasks where one is reading and the other is writing at the same time, but
75/// it is not okay to have two tasks reading at the same time or writing at the same time. If you
76/// try to do that, conflicting tasks will just keep waking each other in turn, thus wasting CPU
77/// time.
78///
79/// Besides [`AsyncRead`] and [`AsyncWrite`], this caveat also applies to
80/// [`poll_readable()`][`Async::poll_readable()`] and
81/// [`poll_writable()`][`Async::poll_writable()`].
82///
83/// However, any number of tasks can be concurrently calling other methods like
84/// [`readable()`][`Async::readable()`] or [`read_with()`][`Async::read_with()`].
85///
86/// ### Closing
87///
88/// Closing the write side of [`Async`] with [`close()`][`futures::AsyncWriteExt::close()`]
89/// simply flushes. If you want to shutdown a TCP or Unix socket, use
90/// [`Shutdown`][`std::net::Shutdown`].
91///
92#[derive(Debug)]
93pub struct Async<T: Send + 'static> {
94    /// A source registered in the reactor.
95    pub(super) source: Arc<Source>,
96
97    /// The inner I/O handle.
98    pub(super) io: Option<T>,
99
100    // The [`ThrottlingHandle`] on the [`scheduler::Throttling`] on which this Async wrapper is registered.
101    pub(super) throttling_sched_hdl: Option<scheduler::ThrottlingHandleWeak>,
102}
103
104impl<T: Send + 'static> Unpin for Async<T> {}
105
106#[cfg(unix)]
107impl<T: AsFd + Send + 'static> Async<T> {
108    /// Creates an async I/O handle.
109    ///
110    /// This method will put the handle in non-blocking mode and register it in
111    /// [epoll]/[kqueue]/[event ports]/[IOCP].
112    ///
113    /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
114    /// `AsSocket`.
115    ///
116    /// [epoll]: https://en.wikipedia.org/wiki/Epoll
117    /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
118    /// [event ports]: https://illumos.org/man/port_create
119    /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
120    pub fn new(io: T) -> io::Result<Async<T>> {
121        // Put the file descriptor in non-blocking mode.
122        set_nonblocking(io.as_fd())?;
123
124        Self::new_nonblocking(io)
125    }
126
127    /// Creates an async I/O handle without setting it to non-blocking mode.
128    ///
129    /// This method will register the handle in [epoll]/[kqueue]/[event ports]/[IOCP].
130    ///
131    /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
132    /// `AsSocket`.
133    ///
134    /// [epoll]: https://en.wikipedia.org/wiki/Epoll
135    /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
136    /// [event ports]: https://illumos.org/man/port_create
137    /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
138    ///
139    /// # Caveats
140    ///
141    /// The caller should ensure that the handle is set to non-blocking mode or that it is okay if
142    /// it is not set. If not set to non-blocking mode, I/O operations may block the current thread
143    /// and cause a deadlock in an asynchronous context.
144    pub fn new_nonblocking(io: T) -> io::Result<Async<T>> {
145        // SAFETY: It is impossible to drop the I/O source while it is registered through
146        // this type.
147        let registration = unsafe { Registration::new(io.as_fd()) };
148
149        let source = Reactor::with_mut(|reactor| reactor.insert_io(registration))?;
150        Ok(Async {
151            source,
152            io: Some(io),
153            throttling_sched_hdl: scheduler::Throttling::current()
154                .as_ref()
155                .map(scheduler::ThrottlingHandle::downgrade),
156        })
157    }
158}
159
160#[cfg(unix)]
161impl<T: AsRawFd + Send + 'static> AsRawFd for Async<T> {
162    fn as_raw_fd(&self) -> RawFd {
163        self.get_ref().as_raw_fd()
164    }
165}
166
167#[cfg(unix)]
168impl<T: AsFd + Send + 'static> AsFd for Async<T> {
169    fn as_fd(&self) -> BorrowedFd<'_> {
170        self.get_ref().as_fd()
171    }
172}
173
174#[cfg(unix)]
175impl<T: AsFd + From<OwnedFd> + Send + 'static> TryFrom<OwnedFd> for Async<T> {
176    type Error = io::Error;
177
178    fn try_from(value: OwnedFd) -> Result<Self, Self::Error> {
179        Async::new(value.into())
180    }
181}
182
183#[cfg(unix)]
184impl<T: Into<OwnedFd> + Send + 'static> TryFrom<Async<T>> for OwnedFd {
185    type Error = io::Error;
186
187    fn try_from(value: Async<T>) -> Result<Self, Self::Error> {
188        value.into_inner().map(Into::into)
189    }
190}
191
192#[cfg(windows)]
193impl<T: AsSocket + Send + 'static> Async<T> {
194    /// Creates an async I/O handle.
195    ///
196    /// This method will put the handle in non-blocking mode and register it in
197    /// [epoll]/[kqueue]/[event ports]/[IOCP].
198    ///
199    /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
200    /// `AsSocket`.
201    ///
202    /// [epoll]: https://en.wikipedia.org/wiki/Epoll
203    /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
204    /// [event ports]: https://illumos.org/man/port_create
205    /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
206    pub fn new(io: T) -> io::Result<Async<T>> {
207        // Put the socket in non-blocking mode.
208        set_nonblocking(io.as_socket())?;
209
210        Self::new_nonblocking(io)
211    }
212
213    /// Creates an async I/O handle without setting it to non-blocking mode.
214    ///
215    /// This method will register the handle in [epoll]/[kqueue]/[event ports]/[IOCP].
216    ///
217    /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
218    /// `AsSocket`.
219    ///
220    /// [epoll]: https://en.wikipedia.org/wiki/Epoll
221    /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
222    /// [event ports]: https://illumos.org/man/port_create
223    /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
224    ///
225    /// # Caveats
226    ///
227    /// The caller should ensure that the handle is set to non-blocking mode or that it is okay if
228    /// it is not set. If not set to non-blocking mode, I/O operations may block the current thread
229    /// and cause a deadlock in an asynchronous context.
230    pub fn new_nonblocking(io: T) -> io::Result<Async<T>> {
231        // Create the registration.
232        //
233        // SAFETY: It is impossible to drop the I/O source while it is registered through
234        // this type.
235        let registration = unsafe { Registration::new(io.as_socket()) };
236
237        let source = Reactor::with_mut(|reactor| reactor.insert_io(registration))?;
238        Ok(Async {
239            source,
240            io: Some(io),
241            throttling_sched_hdl: scheduler::Throttling::current()
242                .as_ref()
243                .map(scheduler::ThrottlingHandle::downgrade),
244        })
245    }
246}
247
248#[cfg(windows)]
249impl<T: AsRawSocket + Send + 'static> AsRawSocket for Async<T> {
250    fn as_raw_socket(&self) -> RawSocket {
251        self.get_ref().as_raw_socket()
252    }
253}
254
255#[cfg(windows)]
256impl<T: AsSocket + Send + 'static> AsSocket for Async<T> {
257    fn as_socket(&self) -> BorrowedSocket<'_> {
258        self.get_ref().as_socket()
259    }
260}
261
262#[cfg(windows)]
263impl<T: AsSocket + From<OwnedSocket> + Send + 'static> TryFrom<OwnedSocket> for Async<T> {
264    type Error = io::Error;
265
266    fn try_from(value: OwnedSocket) -> Result<Self, Self::Error> {
267        Async::new(value.into())
268    }
269}
270
271#[cfg(windows)]
272impl<T: Into<OwnedSocket> + Send + 'static> TryFrom<Async<T>> for OwnedSocket {
273    type Error = io::Error;
274
275    fn try_from(value: Async<T>) -> Result<Self, Self::Error> {
276        value.into_inner().map(Into::into)
277    }
278}
279
280impl<T: Send + 'static> Async<T> {
281    /// Gets a reference to the inner I/O handle.
282    pub fn get_ref(&self) -> &T {
283        self.io.as_ref().unwrap()
284    }
285
286    /// Gets a mutable reference to the inner I/O handle.
287    ///
288    /// # Safety
289    ///
290    /// The underlying I/O source must not be dropped using this function.
291    pub unsafe fn get_mut(&mut self) -> &mut T {
292        self.io.as_mut().unwrap()
293    }
294
295    /// Unwraps the inner I/O handle.
296    pub fn into_inner(mut self) -> io::Result<T> {
297        let io = self.io.take().unwrap();
298        Reactor::with_mut(|reactor| reactor.remove_io(&self.source))?;
299        Ok(io)
300    }
301
302    /// Waits until the I/O handle is readable.
303    ///
304    /// This method completes when a read operation on this I/O handle wouldn't block.
305    pub fn readable(&self) -> Readable<'_, T> {
306        Source::readable(self)
307    }
308
309    /// Waits until the I/O handle is readable.
310    ///
311    /// This method completes when a read operation on this I/O handle wouldn't block.
312    pub fn readable_owned(self: Arc<Self>) -> ReadableOwned<T> {
313        Source::readable_owned(self)
314    }
315
316    /// Waits until the I/O handle is writable.
317    ///
318    /// This method completes when a write operation on this I/O handle wouldn't block.
319    pub fn writable(&self) -> Writable<'_, T> {
320        Source::writable(self)
321    }
322
323    /// Waits until the I/O handle is writable.
324    ///
325    /// This method completes when a write operation on this I/O handle wouldn't block.
326    pub fn writable_owned(self: Arc<Self>) -> WritableOwned<T> {
327        Source::writable_owned(self)
328    }
329
330    /// Polls the I/O handle for readability.
331    ///
332    /// When this method returns [`Poll::Ready`], that means the OS has delivered an event
333    /// indicating readability since the last time this task has called the method and received
334    /// [`Poll::Pending`].
335    ///
336    /// # Caveats
337    ///
338    /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks
339    /// will just keep waking each other in turn, thus wasting CPU time.
340    ///
341    /// Note that the [`AsyncRead`] implementation for [`Async`] also uses this method.
342    pub fn poll_readable(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
343        self.source.poll_readable(cx)
344    }
345
346    /// Polls the I/O handle for writability.
347    ///
348    /// When this method returns [`Poll::Ready`], that means the OS has delivered an event
349    /// indicating writability since the last time this task has called the method and received
350    /// [`Poll::Pending`].
351    ///
352    /// # Caveats
353    ///
354    /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks
355    /// will just keep waking each other in turn, thus wasting CPU time.
356    ///
357    /// Note that the [`AsyncWrite`] implementation for [`Async`] also uses this method.
358    pub fn poll_writable(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
359        self.source.poll_writable(cx)
360    }
361
362    /// Performs a read operation asynchronously.
363    ///
364    /// The I/O handle is registered in the reactor and put in non-blocking mode. This method
365    /// invokes the `op` closure in a loop until it succeeds or returns an error other than
366    /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS
367    /// sends a notification that the I/O handle is readable.
368    ///
369    /// The closure receives a shared reference to the I/O handle.
370    pub async fn read_with<R>(&self, op: impl FnMut(&T) -> io::Result<R>) -> io::Result<R> {
371        let mut op = op;
372        loop {
373            match op(self.get_ref()) {
374                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
375                res => return res,
376            }
377            optimistic(self.readable()).await?;
378        }
379    }
380
381    /// Performs a read operation asynchronously.
382    ///
383    /// The I/O handle is registered in the reactor and put in non-blocking mode. This method
384    /// invokes the `op` closure in a loop until it succeeds or returns an error other than
385    /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS
386    /// sends a notification that the I/O handle is readable.
387    ///
388    /// The closure receives a mutable reference to the I/O handle.
389    ///
390    /// # Safety
391    ///
392    /// In the closure, the underlying I/O source must not be dropped.
393    pub async unsafe fn read_with_mut<R>(
394        &mut self,
395        op: impl FnMut(&mut T) -> io::Result<R>,
396    ) -> io::Result<R> {
397        unsafe {
398            let mut op = op;
399            loop {
400                match op(self.get_mut()) {
401                    Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
402                    res => return res,
403                }
404                optimistic(self.readable()).await?;
405            }
406        }
407    }
408
409    /// Performs a write operation asynchronously.
410    ///
411    /// The I/O handle is registered in the reactor and put in non-blocking mode. This method
412    /// invokes the `op` closure in a loop until it succeeds or returns an error other than
413    /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS
414    /// sends a notification that the I/O handle is writable.
415    ///
416    /// The closure receives a shared reference to the I/O handle.
417    pub async fn write_with<R>(&self, op: impl FnMut(&T) -> io::Result<R>) -> io::Result<R> {
418        let mut op = op;
419        loop {
420            match op(self.get_ref()) {
421                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
422                res => return res,
423            }
424            optimistic(self.writable()).await?;
425        }
426    }
427
428    /// Performs a write operation asynchronously.
429    ///
430    /// The I/O handle is registered in the reactor and put in non-blocking mode. This method
431    /// invokes the `op` closure in a loop until it succeeds or returns an error other than
432    /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS
433    /// sends a notification that the I/O handle is writable.
434    ///
435    /// The closure receives a mutable reference to the I/O handle.
436    ///
437    /// # Safety
438    ///
439    /// The closure receives a mutable reference to the I/O handle. In the closure, the underlying
440    /// I/O source must not be dropped.
441    pub async unsafe fn write_with_mut<R>(
442        &mut self,
443        op: impl FnMut(&mut T) -> io::Result<R>,
444    ) -> io::Result<R> {
445        unsafe {
446            let mut op = op;
447            loop {
448                match op(self.get_mut()) {
449                    Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
450                    res => return res,
451                }
452                optimistic(self.writable()).await?;
453            }
454        }
455    }
456}
457
458impl<T: Send + 'static> AsRef<T> for Async<T> {
459    fn as_ref(&self) -> &T {
460        self.get_ref()
461    }
462}
463
464impl<T: Send + 'static> Drop for Async<T> {
465    fn drop(&mut self) {
466        if let Some(io) = self.io.take() {
467            match self.throttling_sched_hdl.take() {
468                Some(throttling_sched_hdl) => {
469                    if let Some(sched) = throttling_sched_hdl.upgrade() {
470                        let source = Arc::clone(&self.source);
471                        sched.spawn_and_unpark(async move {
472                            Reactor::with_mut(|reactor| {
473                                if let Err(err) = reactor.remove_io(&source) {
474                                    gst::error!(
475                                        RUNTIME_CAT,
476                                        "Failed to remove fd {:?}: {err}",
477                                        source.registration,
478                                    );
479                                }
480                            });
481                            drop(io);
482                        });
483                    }
484                }
485                _ => {
486                    Reactor::with_mut(|reactor| {
487                        if let Err(err) = reactor.remove_io(&self.source) {
488                            gst::error!(
489                                RUNTIME_CAT,
490                                "Failed to remove fd {:?}: {err}",
491                                self.source.registration,
492                            );
493                        }
494                    });
495                }
496            }
497        }
498    }
499}
500
501/// Types whose I/O trait implementations do not drop the underlying I/O source.
502///
503/// The resource contained inside of the [`Async`] cannot be invalidated. This invalidation can
504/// happen if the inner resource (the [`TcpStream`], [`UnixListener`] or other `T`) is moved out
505/// and dropped before the [`Async`]. Because of this, functions that grant mutable access to
506/// the inner type are unsafe, as there is no way to guarantee that the source won't be dropped
507/// and a dangling handle won't be left behind.
508///
509/// Unfortunately this extends to implementations of [`Read`] and [`Write`]. Since methods on those
510/// traits take `&mut`, there is no guarantee that the implementor of those traits won't move the
511/// source out while the method is being run.
512///
513/// This trait is an antidote to this predicament. By implementing this trait, the user pledges
514/// that using any I/O traits won't destroy the source. This way, [`Async`] can implement the
515/// `async` version of these I/O traits, like [`AsyncRead`] and [`AsyncWrite`].
516///
517/// # Safety
518///
519/// Any I/O trait implementations for this type must not drop the underlying I/O source. Traits
520/// affected by this trait include [`Read`], [`Write`], [`Seek`] and [`BufRead`].
521///
522/// This trait is implemented by default on top of `libstd` types. In addition, it is implemented
523/// for immutable reference types, as it is impossible to invalidate any outstanding references
524/// while holding an immutable reference, even with interior mutability. As Rust's current pinning
525/// system relies on similar guarantees, I believe that this approach is robust.
526///
527/// [`BufRead`]: https://doc.rust-lang.org/std/io/trait.BufRead.html
528/// [`Read`]: https://doc.rust-lang.org/std/io/trait.Read.html
529/// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html
530/// [`Write`]: https://doc.rust-lang.org/std/io/trait.Write.html
531///
532/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html
533/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html
534pub unsafe trait IoSafe {}
535
536/// Reference types can't be mutated.
537///
538/// The worst thing that can happen is that external state is used to change what kind of pointer
539/// `as_fd()` returns. For instance:
540///
541/// ```
542/// # #[cfg(unix)] {
543/// use std::cell::Cell;
544/// use std::net::TcpStream;
545/// use std::os::unix::io::{AsFd, BorrowedFd};
546///
547/// struct Bar {
548///     flag: Cell<bool>,
549///     a: TcpStream,
550///     b: TcpStream
551/// }
552///
553/// impl AsFd for Bar {
554///     fn as_fd(&self) -> BorrowedFd<'_> {
555///         if self.flag.replace(!self.flag.get()) {
556///             self.a.as_fd()
557///         } else {
558///             self.b.as_fd()
559///         }
560///     }
561/// }
562/// # }
563/// ```
564///
565/// We solve this problem by only calling `as_fd()` once to get the original source. Implementations
566/// like this are considered buggy (but not unsound) and are thus not really supported by `async-io`.
567unsafe impl<T: ?Sized> IoSafe for &T {}
568
569// Can be implemented on top of libstd types.
570unsafe impl IoSafe for std::fs::File {}
571unsafe impl IoSafe for std::io::Stderr {}
572unsafe impl IoSafe for std::io::Stdin {}
573unsafe impl IoSafe for std::io::Stdout {}
574unsafe impl IoSafe for std::io::StderrLock<'_> {}
575unsafe impl IoSafe for std::io::StdinLock<'_> {}
576unsafe impl IoSafe for std::io::StdoutLock<'_> {}
577unsafe impl IoSafe for std::net::TcpStream {}
578
579#[cfg(unix)]
580unsafe impl IoSafe for std::os::unix::net::UnixStream {}
581
582unsafe impl<T: IoSafe + Read> IoSafe for std::io::BufReader<T> {}
583unsafe impl<T: IoSafe + Write> IoSafe for std::io::BufWriter<T> {}
584unsafe impl<T: IoSafe + Write> IoSafe for std::io::LineWriter<T> {}
585unsafe impl<T: IoSafe + ?Sized> IoSafe for &mut T {}
586unsafe impl<T: IoSafe + ?Sized> IoSafe for Box<T> {}
587unsafe impl<T: Clone + IoSafe> IoSafe for std::borrow::Cow<'_, T> {}
588
589impl<T: IoSafe + Read + Send + 'static> AsyncRead for Async<T> {
590    fn poll_read(
591        mut self: Pin<&mut Self>,
592        cx: &mut Context<'_>,
593        buf: &mut [u8],
594    ) -> Poll<io::Result<usize>> {
595        loop {
596            match unsafe { (*self).get_mut() }.read(buf) {
597                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
598                res => return Poll::Ready(res),
599            }
600            ready!(self.poll_readable(cx))?;
601        }
602    }
603
604    fn poll_read_vectored(
605        mut self: Pin<&mut Self>,
606        cx: &mut Context<'_>,
607        bufs: &mut [IoSliceMut<'_>],
608    ) -> Poll<io::Result<usize>> {
609        loop {
610            match unsafe { (*self).get_mut() }.read_vectored(bufs) {
611                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
612                res => return Poll::Ready(res),
613            }
614            ready!(self.poll_readable(cx))?;
615        }
616    }
617}
618
619// Since this is through a reference, we can't mutate the inner I/O source.
620// Therefore this is safe!
621impl<T: Send + 'static> AsyncRead for &Async<T>
622where
623    for<'a> &'a T: Read,
624{
625    fn poll_read(
626        self: Pin<&mut Self>,
627        cx: &mut Context<'_>,
628        buf: &mut [u8],
629    ) -> Poll<io::Result<usize>> {
630        loop {
631            match (*self).get_ref().read(buf) {
632                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
633                res => return Poll::Ready(res),
634            }
635            ready!(self.poll_readable(cx))?;
636        }
637    }
638
639    fn poll_read_vectored(
640        self: Pin<&mut Self>,
641        cx: &mut Context<'_>,
642        bufs: &mut [IoSliceMut<'_>],
643    ) -> Poll<io::Result<usize>> {
644        loop {
645            match (*self).get_ref().read_vectored(bufs) {
646                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
647                res => return Poll::Ready(res),
648            }
649            ready!(self.poll_readable(cx))?;
650        }
651    }
652}
653
654impl<T: IoSafe + Write + Send + 'static> AsyncWrite for Async<T> {
655    fn poll_write(
656        mut self: Pin<&mut Self>,
657        cx: &mut Context<'_>,
658        buf: &[u8],
659    ) -> Poll<io::Result<usize>> {
660        loop {
661            match unsafe { (*self).get_mut() }.write(buf) {
662                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
663                res => return Poll::Ready(res),
664            }
665            ready!(self.poll_writable(cx))?;
666        }
667    }
668
669    fn poll_write_vectored(
670        mut self: Pin<&mut Self>,
671        cx: &mut Context<'_>,
672        bufs: &[IoSlice<'_>],
673    ) -> Poll<io::Result<usize>> {
674        loop {
675            match unsafe { (*self).get_mut() }.write_vectored(bufs) {
676                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
677                res => return Poll::Ready(res),
678            }
679            ready!(self.poll_writable(cx))?;
680        }
681    }
682
683    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
684        loop {
685            match unsafe { (*self).get_mut() }.flush() {
686                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
687                res => return Poll::Ready(res),
688            }
689            ready!(self.poll_writable(cx))?;
690        }
691    }
692
693    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
694        self.poll_flush(cx)
695    }
696}
697
698impl<T: Send + 'static> AsyncWrite for &Async<T>
699where
700    for<'a> &'a T: Write,
701{
702    fn poll_write(
703        self: Pin<&mut Self>,
704        cx: &mut Context<'_>,
705        buf: &[u8],
706    ) -> Poll<io::Result<usize>> {
707        loop {
708            match (*self).get_ref().write(buf) {
709                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
710                res => return Poll::Ready(res),
711            }
712            ready!(self.poll_writable(cx))?;
713        }
714    }
715
716    fn poll_write_vectored(
717        self: Pin<&mut Self>,
718        cx: &mut Context<'_>,
719        bufs: &[IoSlice<'_>],
720    ) -> Poll<io::Result<usize>> {
721        loop {
722            match (*self).get_ref().write_vectored(bufs) {
723                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
724                res => return Poll::Ready(res),
725            }
726            ready!(self.poll_writable(cx))?;
727        }
728    }
729
730    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
731        loop {
732            match (*self).get_ref().flush() {
733                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
734                res => return Poll::Ready(res),
735            }
736            ready!(self.poll_writable(cx))?;
737        }
738    }
739
740    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
741        self.poll_flush(cx)
742    }
743}
744
745impl Async<TcpListener> {
746    /// Creates a TCP listener bound to the specified address.
747    ///
748    /// Binding with port number 0 will request an available port from the OS.
749    pub fn bind<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<TcpListener>> {
750        let addr = addr.into();
751        Async::new(TcpListener::bind(addr)?)
752    }
753
754    /// Accepts a new incoming TCP connection.
755    ///
756    /// When a connection is established, it will be returned as a TCP stream together with its
757    /// remote address.
758    pub async fn accept(&self) -> io::Result<(Async<TcpStream>, SocketAddr)> {
759        let (stream, addr) = self.read_with(|io| io.accept()).await?;
760        Ok((Async::new(stream)?, addr))
761    }
762
763    /// Returns a stream of incoming TCP connections.
764    ///
765    /// The stream is infinite, i.e. it never stops with a [`None`].
766    pub fn incoming(&self) -> impl Stream<Item = io::Result<Async<TcpStream>>> + Send + '_ {
767        stream::unfold(self, |listener| async move {
768            let res = listener.accept().await.map(|(stream, _)| stream);
769            Some((res, listener))
770        })
771    }
772}
773
774impl TryFrom<std::net::TcpListener> for Async<std::net::TcpListener> {
775    type Error = io::Error;
776
777    fn try_from(listener: std::net::TcpListener) -> io::Result<Self> {
778        Async::new(listener)
779    }
780}
781
782impl Async<TcpStream> {
783    /// Creates a TCP connection to the specified address.
784    pub async fn connect<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<TcpStream>> {
785        // Figure out how to handle this address.
786        let addr = addr.into();
787        let (domain, sock_addr) = match addr {
788            SocketAddr::V4(v4) => (rn::AddressFamily::INET, v4.as_any()),
789            SocketAddr::V6(v6) => (rn::AddressFamily::INET6, v6.as_any()),
790        };
791
792        // Begin async connect.
793        let socket = connect(sock_addr, domain, Some(rn::ipproto::TCP))?;
794        // Use new_nonblocking because connect already sets socket to non-blocking mode.
795        let stream = Async::new_nonblocking(TcpStream::from(socket))?;
796
797        // The stream becomes writable when connected.
798        stream.writable().await?;
799
800        // Check if there was an error while connecting.
801        match stream.get_ref().take_error()? {
802            None => Ok(stream),
803            Some(err) => Err(err),
804        }
805    }
806
807    /// Reads data from the stream without removing it from the buffer.
808    ///
809    /// Returns the number of bytes read. Successive calls of this method read the same data.
810    pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
811        self.read_with(|io| io.peek(buf)).await
812    }
813}
814
815impl TryFrom<std::net::TcpStream> for Async<std::net::TcpStream> {
816    type Error = io::Error;
817
818    fn try_from(stream: std::net::TcpStream) -> io::Result<Self> {
819        Async::new(stream)
820    }
821}
822
823impl Async<UdpSocket> {
824    /// Creates a UDP socket bound to the specified address.
825    ///
826    /// Binding with port number 0 will request an available port from the OS.
827    pub fn bind<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<UdpSocket>> {
828        let addr = addr.into();
829        Async::new(UdpSocket::bind(addr)?)
830    }
831
832    /// Receives a single datagram message.
833    ///
834    /// Returns the number of bytes read and the address the message came from.
835    ///
836    /// This method must be called with a valid byte slice of sufficient size to hold the message.
837    /// If the message is too long to fit, excess bytes may get discarded.
838    pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
839        self.read_with(|io| io.recv_from(buf)).await
840    }
841
842    /// Receives a single datagram message without removing it from the queue.
843    ///
844    /// Returns the number of bytes read and the address the message came from.
845    ///
846    /// This method must be called with a valid byte slice of sufficient size to hold the message.
847    /// If the message is too long to fit, excess bytes may get discarded.
848    pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
849        self.read_with(|io| io.peek_from(buf)).await
850    }
851
852    /// Sends data to the specified address.
853    ///
854    /// Returns the number of bytes written.
855    pub async fn send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<usize> {
856        let addr = addr.into();
857        self.write_with(|io| io.send_to(buf, addr)).await
858    }
859
860    /// Receives a single datagram message from the connected peer.
861    ///
862    /// Returns the number of bytes read.
863    ///
864    /// This method must be called with a valid byte slice of sufficient size to hold the message.
865    /// If the message is too long to fit, excess bytes may get discarded.
866    ///
867    /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address.
868    /// This method will fail if the socket is not connected.
869    pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
870        self.read_with(|io| io.recv(buf)).await
871    }
872
873    /// Receives a single datagram message from the connected peer without removing it from the
874    /// queue.
875    ///
876    /// Returns the number of bytes read and the address the message came from.
877    ///
878    /// This method must be called with a valid byte slice of sufficient size to hold the message.
879    /// If the message is too long to fit, excess bytes may get discarded.
880    ///
881    /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address.
882    /// This method will fail if the socket is not connected.
883    pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
884        self.read_with(|io| io.peek(buf)).await
885    }
886
887    /// Sends data to the connected peer.
888    ///
889    /// Returns the number of bytes written.
890    ///
891    /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address.
892    /// This method will fail if the socket is not connected.
893    pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
894        self.write_with(|io| io.send(buf)).await
895    }
896}
897
898impl TryFrom<std::net::UdpSocket> for Async<std::net::UdpSocket> {
899    type Error = io::Error;
900
901    fn try_from(socket: std::net::UdpSocket) -> io::Result<Self> {
902        Async::new(socket)
903    }
904}
905
906impl TryFrom<socket2::Socket> for Async<std::net::UdpSocket> {
907    type Error = io::Error;
908
909    fn try_from(socket: socket2::Socket) -> io::Result<Self> {
910        Async::new(std::net::UdpSocket::from(socket))
911    }
912}
913
914#[cfg(unix)]
915impl Async<UnixListener> {
916    /// Creates a UDS listener bound to the specified path.
917    pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixListener>> {
918        let path = path.as_ref().to_owned();
919        Async::new(UnixListener::bind(path)?)
920    }
921
922    /// Accepts a new incoming UDS stream connection.
923    pub async fn accept(&self) -> io::Result<(Async<UnixStream>, UnixSocketAddr)> {
924        let (stream, addr) = self.read_with(|io| io.accept()).await?;
925        Ok((Async::new(stream)?, addr))
926    }
927
928    /// Returns a stream of incoming UDS connections.
929    ///
930    /// The stream is infinite, i.e. it never stops with a [`None`] item.
931    pub fn incoming(&self) -> impl Stream<Item = io::Result<Async<UnixStream>>> + Send + '_ {
932        stream::unfold(self, |listener| async move {
933            let res = listener.accept().await.map(|(stream, _)| stream);
934            Some((res, listener))
935        })
936    }
937}
938
939#[cfg(unix)]
940impl TryFrom<std::os::unix::net::UnixListener> for Async<std::os::unix::net::UnixListener> {
941    type Error = io::Error;
942
943    fn try_from(listener: std::os::unix::net::UnixListener) -> io::Result<Self> {
944        Async::new(listener)
945    }
946}
947
948#[cfg(unix)]
949impl Async<UnixStream> {
950    /// Creates a UDS stream connected to the specified path.
951    pub async fn connect<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixStream>> {
952        let address = convert_path_to_socket_address(path.as_ref())?;
953
954        // Begin async connect.
955        let socket = connect(address.into(), rn::AddressFamily::UNIX, None)?;
956        // Use new_nonblocking because connect already sets socket to non-blocking mode.
957        let stream = Async::new_nonblocking(UnixStream::from(socket))?;
958
959        // The stream becomes writable when connected.
960        stream.writable().await?;
961
962        // On Linux, it appears the socket may become writable even when connecting fails, so we
963        // must do an extra check here and see if the peer address is retrievable.
964        stream.get_ref().peer_addr()?;
965        Ok(stream)
966    }
967
968    /// Creates an unnamed pair of connected UDS stream sockets.
969    pub fn pair() -> io::Result<(Async<UnixStream>, Async<UnixStream>)> {
970        let (stream1, stream2) = UnixStream::pair()?;
971        Ok((Async::new(stream1)?, Async::new(stream2)?))
972    }
973}
974
975#[cfg(unix)]
976impl TryFrom<std::os::unix::net::UnixStream> for Async<std::os::unix::net::UnixStream> {
977    type Error = io::Error;
978
979    fn try_from(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
980        Async::new(stream)
981    }
982}
983
984#[cfg(unix)]
985impl Async<UnixDatagram> {
986    /// Creates a UDS datagram socket bound to the specified path.
987    pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixDatagram>> {
988        let path = path.as_ref().to_owned();
989        Async::new(UnixDatagram::bind(path)?)
990    }
991
992    /// Creates a UDS datagram socket not bound to any address.
993    pub fn unbound() -> io::Result<Async<UnixDatagram>> {
994        Async::new(UnixDatagram::unbound()?)
995    }
996
997    /// Creates an unnamed pair of connected Unix datagram sockets.
998    pub fn pair() -> io::Result<(Async<UnixDatagram>, Async<UnixDatagram>)> {
999        let (socket1, socket2) = UnixDatagram::pair()?;
1000        Ok((Async::new(socket1)?, Async::new(socket2)?))
1001    }
1002
1003    /// Receives data from the socket.
1004    ///
1005    /// Returns the number of bytes read and the address the message came from.
1006    pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, UnixSocketAddr)> {
1007        self.read_with(|io| io.recv_from(buf)).await
1008    }
1009
1010    /// Sends data to the specified address.
1011    ///
1012    /// Returns the number of bytes written.
1013    pub async fn send_to<P: AsRef<Path>>(&self, buf: &[u8], path: P) -> io::Result<usize> {
1014        self.write_with(|io| io.send_to(buf, &path)).await
1015    }
1016
1017    /// Receives data from the connected peer.
1018    ///
1019    /// Returns the number of bytes read and the address the message came from.
1020    ///
1021    /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address.
1022    /// This method will fail if the socket is not connected.
1023    pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
1024        self.read_with(|io| io.recv(buf)).await
1025    }
1026
1027    /// Sends data to the connected peer.
1028    ///
1029    /// Returns the number of bytes written.
1030    ///
1031    /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address.
1032    /// This method will fail if the socket is not connected.
1033    pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
1034        self.write_with(|io| io.send(buf)).await
1035    }
1036}
1037
1038#[cfg(unix)]
1039impl TryFrom<std::os::unix::net::UnixDatagram> for Async<std::os::unix::net::UnixDatagram> {
1040    type Error = io::Error;
1041
1042    fn try_from(socket: std::os::unix::net::UnixDatagram) -> io::Result<Self> {
1043        Async::new(socket)
1044    }
1045}
1046
1047/// Polls a future once, waits for a wakeup, and then optimistically assumes the future is ready.
1048async fn optimistic(fut: impl Future<Output = io::Result<()>>) -> io::Result<()> {
1049    let mut polled = false;
1050    pin_mut!(fut);
1051
1052    future::poll_fn(|cx| {
1053        if !polled {
1054            polled = true;
1055            fut.as_mut().poll(cx)
1056        } else {
1057            Poll::Ready(Ok(()))
1058        }
1059    })
1060    .await
1061}
1062
1063fn connect(
1064    addr: rn::SocketAddrAny,
1065    domain: rn::AddressFamily,
1066    protocol: Option<rn::Protocol>,
1067) -> io::Result<rustix::fd::OwnedFd> {
1068    #[cfg(windows)]
1069    use rustix::fd::AsFd;
1070
1071    setup_networking();
1072
1073    #[cfg(any(
1074        target_os = "android",
1075        target_os = "dragonfly",
1076        target_os = "freebsd",
1077        target_os = "fuchsia",
1078        target_os = "illumos",
1079        target_os = "linux",
1080        target_os = "netbsd",
1081        target_os = "openbsd"
1082    ))]
1083    let socket = rn::socket_with(
1084        domain,
1085        rn::SocketType::STREAM,
1086        rn::SocketFlags::CLOEXEC | rn::SocketFlags::NONBLOCK,
1087        protocol,
1088    )?;
1089
1090    #[cfg(not(any(
1091        target_os = "android",
1092        target_os = "dragonfly",
1093        target_os = "freebsd",
1094        target_os = "fuchsia",
1095        target_os = "illumos",
1096        target_os = "linux",
1097        target_os = "netbsd",
1098        target_os = "openbsd"
1099    )))]
1100    let socket = {
1101        #[cfg(not(any(
1102            target_os = "aix",
1103            target_vendor = "apple",
1104            target_os = "espidf",
1105            windows,
1106        )))]
1107        let flags = rn::SocketFlags::CLOEXEC;
1108        #[cfg(any(
1109            target_os = "aix",
1110            target_vendor = "apple",
1111            target_os = "espidf",
1112            windows,
1113        ))]
1114        let flags = rn::SocketFlags::empty();
1115
1116        // Create the socket.
1117        let socket = rn::socket_with(domain, rn::SocketType::STREAM, flags, protocol)?;
1118
1119        // Set cloexec if necessary.
1120        #[cfg(any(target_os = "aix", target_vendor = "apple"))]
1121        rio::fcntl_setfd(&socket, rio::fcntl_getfd(&socket)? | rio::FdFlags::CLOEXEC)?;
1122
1123        // Set non-blocking mode.
1124        set_nonblocking(socket.as_fd())?;
1125
1126        socket
1127    };
1128
1129    // Set nosigpipe if necessary.
1130    #[cfg(any(
1131        target_vendor = "apple",
1132        target_os = "freebsd",
1133        target_os = "netbsd",
1134        target_os = "dragonfly",
1135    ))]
1136    rn::sockopt::set_socket_nosigpipe(&socket, true)?;
1137
1138    // Set the handle information to HANDLE_FLAG_INHERIT.
1139    #[cfg(windows)]
1140    unsafe {
1141        if windows_sys::Win32::Foundation::SetHandleInformation(
1142            socket.as_raw_socket() as _,
1143            windows_sys::Win32::Foundation::HANDLE_FLAG_INHERIT,
1144            windows_sys::Win32::Foundation::HANDLE_FLAG_INHERIT,
1145        ) == 0
1146        {
1147            return Err(io::Error::last_os_error());
1148        }
1149    }
1150
1151    #[allow(unreachable_patterns)]
1152    match rn::connect(&socket, &addr) {
1153        Ok(_) => {}
1154        #[cfg(unix)]
1155        Err(rio::Errno::INPROGRESS) => {}
1156        Err(rio::Errno::AGAIN) | Err(rio::Errno::WOULDBLOCK) => {}
1157        Err(err) => return Err(err.into()),
1158    }
1159    Ok(socket)
1160}
1161
1162#[inline]
1163fn setup_networking() {
1164    #[cfg(windows)]
1165    {
1166        // On Windows, we need to call WSAStartup before calling any networking code.
1167        // Make sure to call it at least once.
1168        static INIT: std::sync::Once = std::sync::Once::new();
1169
1170        INIT.call_once(|| {
1171            let _ = rustix::net::wsa_startup();
1172        });
1173    }
1174}
1175
1176#[inline]
1177fn set_nonblocking(
1178    #[cfg(unix)] fd: BorrowedFd<'_>,
1179    #[cfg(windows)] fd: BorrowedSocket<'_>,
1180) -> io::Result<()> {
1181    cfg_if::cfg_if! {
1182        // ioctl(FIONBIO) sets the flag atomically, but we use this only on Linux
1183        // for now, as with the standard library, because it seems to behave
1184        // differently depending on the platform.
1185        // https://github.com/rust-lang/rust/commit/efeb42be2837842d1beb47b51bb693c7474aba3d
1186        // https://github.com/libuv/libuv/blob/e9d91fccfc3e5ff772d5da90e1c4a24061198ca0/src/unix/poll.c#L78-L80
1187        // https://github.com/tokio-rs/mio/commit/0db49f6d5caf54b12176821363d154384357e70a
1188        if #[cfg(any(windows, target_os = "linux"))] {
1189            rustix::io::ioctl_fionbio(fd, true)?;
1190        } else {
1191            let previous = rustix::fs::fcntl_getfl(fd)?;
1192            let new = previous | rustix::fs::OFlags::NONBLOCK;
1193            if new != previous {
1194                rustix::fs::fcntl_setfl(fd, new)?;
1195            }
1196        }
1197    }
1198
1199    Ok(())
1200}
1201
1202/// Converts a `Path` to its socket address representation.
1203///
1204/// This function is abstract socket-aware.
1205#[cfg(unix)]
1206#[inline]
1207fn convert_path_to_socket_address(path: &Path) -> io::Result<rn::SocketAddrUnix> {
1208    // SocketAddrUnix::new() will throw EINVAL when a path with a zero in it is passed in.
1209    // However, some users expect to be able to pass in paths to abstract sockets, which
1210    // triggers this error as it has a zero in it. Therefore, if a path starts with a zero,
1211    // make it an abstract socket.
1212    #[cfg(any(target_os = "linux", target_os = "android"))]
1213    let address = {
1214        use std::os::unix::ffi::OsStrExt;
1215
1216        let path = path.as_os_str();
1217        match path.as_bytes().first() {
1218            Some(0) => rn::SocketAddrUnix::new_abstract_name(path.as_bytes().get(1..).unwrap())?,
1219            _ => rn::SocketAddrUnix::new(path)?,
1220        }
1221    };
1222
1223    // Only Linux and Android support abstract sockets.
1224    #[cfg(not(any(target_os = "linux", target_os = "android")))]
1225    let address = rn::SocketAddrUnix::new(path)?;
1226
1227    Ok(address)
1228}