gstthreadshare/runtime/executor/async_wrapper.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This is based on https://github.com/smol-rs/async-io
3// with adaptations by:
4//
5// Copyright (C) 2021 François Laignel <fengalin@free.fr>
6
7use futures::io::{AsyncRead, AsyncWrite};
8use futures::stream::{self, Stream};
9use futures::{future, pin_mut, ready};
10
11use std::future::Future;
12use std::io::{self, IoSlice, IoSliceMut, Read, Write};
13use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18#[cfg(unix)]
19use std::{
20 os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd},
21 os::unix::net::{SocketAddr as UnixSocketAddr, UnixDatagram, UnixListener, UnixStream},
22 path::Path,
23};
24
25#[cfg(windows)]
26use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, OwnedSocket, RawSocket};
27
28use rustix::io as rio;
29use rustix::net as rn;
30use rustix::net::addr::SocketAddrArg;
31
32use crate::runtime::RUNTIME_CAT;
33
34use super::scheduler;
35use super::{Reactor, Readable, ReadableOwned, Registration, Source, Writable, WritableOwned};
36
37/// Async adapter for I/O types.
38///
39/// This type puts an I/O handle into non-blocking mode, registers it in
40/// [epoll]/[kqueue]/[event ports]/[wepoll], and then provides an async interface for it.
41///
42/// [epoll]: https://en.wikipedia.org/wiki/Epoll
43/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
44/// [event ports]: https://illumos.org/man/port_create
45/// [wepoll]: https://github.com/piscisaureus/wepoll
46///
47/// # Caveats
48///
49/// The [`Async`] implementation is specific to the threadshare implementation.
50/// Neither [`async-net`] nor [`async-process`] (on Unix) can be used.
51///
52/// [`async-net`]: https://github.com/smol-rs/async-net
53/// [`async-process`]: https://github.com/smol-rs/async-process
54///
55/// ### Supported types
56///
57/// [`Async`] supports all networking types, as well as some OS-specific file descriptors like
58/// [timerfd] and [inotify].
59///
60/// However, do not use [`Async`] with types like [`File`][`std::fs::File`],
61/// [`Stdin`][`std::io::Stdin`], [`Stdout`][`std::io::Stdout`], or [`Stderr`][`std::io::Stderr`]
62/// because all operating systems have issues with them when put in non-blocking mode.
63///
64/// [timerfd]: https://github.com/smol-rs/async-io/blob/master/examples/linux-timerfd.rs
65/// [inotify]: https://github.com/smol-rs/async-io/blob/master/examples/linux-inotify.rs
66///
67/// ### Concurrent I/O
68///
69/// Note that [`&Async<T>`][`Async`] implements [`AsyncRead`] and [`AsyncWrite`] if `&T`
70/// implements those traits, which means tasks can concurrently read and write using shared
71/// references.
72///
73/// But there is a catch: only one task can read a time, and only one task can write at a time. It
74/// is okay to have two tasks where one is reading and the other is writing at the same time, but
75/// it is not okay to have two tasks reading at the same time or writing at the same time. If you
76/// try to do that, conflicting tasks will just keep waking each other in turn, thus wasting CPU
77/// time.
78///
79/// Besides [`AsyncRead`] and [`AsyncWrite`], this caveat also applies to
80/// [`poll_readable()`][`Async::poll_readable()`] and
81/// [`poll_writable()`][`Async::poll_writable()`].
82///
83/// However, any number of tasks can be concurrently calling other methods like
84/// [`readable()`][`Async::readable()`] or [`read_with()`][`Async::read_with()`].
85///
86/// ### Closing
87///
88/// Closing the write side of [`Async`] with [`close()`][`futures::AsyncWriteExt::close()`]
89/// simply flushes. If you want to shutdown a TCP or Unix socket, use
90/// [`Shutdown`][`std::net::Shutdown`].
91///
92#[derive(Debug)]
93pub struct Async<T: Send + 'static> {
94 /// A source registered in the reactor.
95 pub(super) source: Arc<Source>,
96
97 /// The inner I/O handle.
98 pub(super) io: Option<T>,
99
100 // The [`ThrottlingHandle`] on the [`scheduler::Throttling`] on which this Async wrapper is registered.
101 pub(super) throttling_sched_hdl: Option<scheduler::ThrottlingHandleWeak>,
102}
103
104impl<T: Send + 'static> Unpin for Async<T> {}
105
106#[cfg(unix)]
107impl<T: AsFd + Send + 'static> Async<T> {
108 /// Creates an async I/O handle.
109 ///
110 /// This method will put the handle in non-blocking mode and register it in
111 /// [epoll]/[kqueue]/[event ports]/[IOCP].
112 ///
113 /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
114 /// `AsSocket`.
115 ///
116 /// [epoll]: https://en.wikipedia.org/wiki/Epoll
117 /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
118 /// [event ports]: https://illumos.org/man/port_create
119 /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
120 pub fn new(io: T) -> io::Result<Async<T>> {
121 // Put the file descriptor in non-blocking mode.
122 set_nonblocking(io.as_fd())?;
123
124 Self::new_nonblocking(io)
125 }
126
127 /// Creates an async I/O handle without setting it to non-blocking mode.
128 ///
129 /// This method will register the handle in [epoll]/[kqueue]/[event ports]/[IOCP].
130 ///
131 /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
132 /// `AsSocket`.
133 ///
134 /// [epoll]: https://en.wikipedia.org/wiki/Epoll
135 /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
136 /// [event ports]: https://illumos.org/man/port_create
137 /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
138 ///
139 /// # Caveats
140 ///
141 /// The caller should ensure that the handle is set to non-blocking mode or that it is okay if
142 /// it is not set. If not set to non-blocking mode, I/O operations may block the current thread
143 /// and cause a deadlock in an asynchronous context.
144 pub fn new_nonblocking(io: T) -> io::Result<Async<T>> {
145 // SAFETY: It is impossible to drop the I/O source while it is registered through
146 // this type.
147 let registration = unsafe { Registration::new(io.as_fd()) };
148
149 let source = Reactor::with_mut(|reactor| reactor.insert_io(registration))?;
150 Ok(Async {
151 source,
152 io: Some(io),
153 throttling_sched_hdl: scheduler::Throttling::current()
154 .as_ref()
155 .map(scheduler::ThrottlingHandle::downgrade),
156 })
157 }
158}
159
160#[cfg(unix)]
161impl<T: AsRawFd + Send + 'static> AsRawFd for Async<T> {
162 fn as_raw_fd(&self) -> RawFd {
163 self.get_ref().as_raw_fd()
164 }
165}
166
167#[cfg(unix)]
168impl<T: AsFd + Send + 'static> AsFd for Async<T> {
169 fn as_fd(&self) -> BorrowedFd<'_> {
170 self.get_ref().as_fd()
171 }
172}
173
174#[cfg(unix)]
175impl<T: AsFd + From<OwnedFd> + Send + 'static> TryFrom<OwnedFd> for Async<T> {
176 type Error = io::Error;
177
178 fn try_from(value: OwnedFd) -> Result<Self, Self::Error> {
179 Async::new(value.into())
180 }
181}
182
183#[cfg(unix)]
184impl<T: Into<OwnedFd> + Send + 'static> TryFrom<Async<T>> for OwnedFd {
185 type Error = io::Error;
186
187 fn try_from(value: Async<T>) -> Result<Self, Self::Error> {
188 value.into_inner().map(Into::into)
189 }
190}
191
192#[cfg(windows)]
193impl<T: AsSocket + Send + 'static> Async<T> {
194 /// Creates an async I/O handle.
195 ///
196 /// This method will put the handle in non-blocking mode and register it in
197 /// [epoll]/[kqueue]/[event ports]/[IOCP].
198 ///
199 /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
200 /// `AsSocket`.
201 ///
202 /// [epoll]: https://en.wikipedia.org/wiki/Epoll
203 /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
204 /// [event ports]: https://illumos.org/man/port_create
205 /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
206 pub fn new(io: T) -> io::Result<Async<T>> {
207 // Put the socket in non-blocking mode.
208 set_nonblocking(io.as_socket())?;
209
210 Self::new_nonblocking(io)
211 }
212
213 /// Creates an async I/O handle without setting it to non-blocking mode.
214 ///
215 /// This method will register the handle in [epoll]/[kqueue]/[event ports]/[IOCP].
216 ///
217 /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
218 /// `AsSocket`.
219 ///
220 /// [epoll]: https://en.wikipedia.org/wiki/Epoll
221 /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
222 /// [event ports]: https://illumos.org/man/port_create
223 /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
224 ///
225 /// # Caveats
226 ///
227 /// The caller should ensure that the handle is set to non-blocking mode or that it is okay if
228 /// it is not set. If not set to non-blocking mode, I/O operations may block the current thread
229 /// and cause a deadlock in an asynchronous context.
230 pub fn new_nonblocking(io: T) -> io::Result<Async<T>> {
231 // Create the registration.
232 //
233 // SAFETY: It is impossible to drop the I/O source while it is registered through
234 // this type.
235 let registration = unsafe { Registration::new(io.as_socket()) };
236
237 let source = Reactor::with_mut(|reactor| reactor.insert_io(registration))?;
238 Ok(Async {
239 source,
240 io: Some(io),
241 throttling_sched_hdl: scheduler::Throttling::current()
242 .as_ref()
243 .map(scheduler::ThrottlingHandle::downgrade),
244 })
245 }
246}
247
248#[cfg(windows)]
249impl<T: AsRawSocket + Send + 'static> AsRawSocket for Async<T> {
250 fn as_raw_socket(&self) -> RawSocket {
251 self.get_ref().as_raw_socket()
252 }
253}
254
255#[cfg(windows)]
256impl<T: AsSocket + Send + 'static> AsSocket for Async<T> {
257 fn as_socket(&self) -> BorrowedSocket<'_> {
258 self.get_ref().as_socket()
259 }
260}
261
262#[cfg(windows)]
263impl<T: AsSocket + From<OwnedSocket> + Send + 'static> TryFrom<OwnedSocket> for Async<T> {
264 type Error = io::Error;
265
266 fn try_from(value: OwnedSocket) -> Result<Self, Self::Error> {
267 Async::new(value.into())
268 }
269}
270
271#[cfg(windows)]
272impl<T: Into<OwnedSocket> + Send + 'static> TryFrom<Async<T>> for OwnedSocket {
273 type Error = io::Error;
274
275 fn try_from(value: Async<T>) -> Result<Self, Self::Error> {
276 value.into_inner().map(Into::into)
277 }
278}
279
280impl<T: Send + 'static> Async<T> {
281 /// Gets a reference to the inner I/O handle.
282 pub fn get_ref(&self) -> &T {
283 self.io.as_ref().unwrap()
284 }
285
286 /// Gets a mutable reference to the inner I/O handle.
287 ///
288 /// # Safety
289 ///
290 /// The underlying I/O source must not be dropped using this function.
291 pub unsafe fn get_mut(&mut self) -> &mut T {
292 self.io.as_mut().unwrap()
293 }
294
295 /// Unwraps the inner I/O handle.
296 pub fn into_inner(mut self) -> io::Result<T> {
297 let io = self.io.take().unwrap();
298 Reactor::with_mut(|reactor| reactor.remove_io(&self.source))?;
299 Ok(io)
300 }
301
302 /// Waits until the I/O handle is readable.
303 ///
304 /// This method completes when a read operation on this I/O handle wouldn't block.
305 pub fn readable(&self) -> Readable<'_, T> {
306 Source::readable(self)
307 }
308
309 /// Waits until the I/O handle is readable.
310 ///
311 /// This method completes when a read operation on this I/O handle wouldn't block.
312 pub fn readable_owned(self: Arc<Self>) -> ReadableOwned<T> {
313 Source::readable_owned(self)
314 }
315
316 /// Waits until the I/O handle is writable.
317 ///
318 /// This method completes when a write operation on this I/O handle wouldn't block.
319 pub fn writable(&self) -> Writable<'_, T> {
320 Source::writable(self)
321 }
322
323 /// Waits until the I/O handle is writable.
324 ///
325 /// This method completes when a write operation on this I/O handle wouldn't block.
326 pub fn writable_owned(self: Arc<Self>) -> WritableOwned<T> {
327 Source::writable_owned(self)
328 }
329
330 /// Polls the I/O handle for readability.
331 ///
332 /// When this method returns [`Poll::Ready`], that means the OS has delivered an event
333 /// indicating readability since the last time this task has called the method and received
334 /// [`Poll::Pending`].
335 ///
336 /// # Caveats
337 ///
338 /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks
339 /// will just keep waking each other in turn, thus wasting CPU time.
340 ///
341 /// Note that the [`AsyncRead`] implementation for [`Async`] also uses this method.
342 pub fn poll_readable(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
343 self.source.poll_readable(cx)
344 }
345
346 /// Polls the I/O handle for writability.
347 ///
348 /// When this method returns [`Poll::Ready`], that means the OS has delivered an event
349 /// indicating writability since the last time this task has called the method and received
350 /// [`Poll::Pending`].
351 ///
352 /// # Caveats
353 ///
354 /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks
355 /// will just keep waking each other in turn, thus wasting CPU time.
356 ///
357 /// Note that the [`AsyncWrite`] implementation for [`Async`] also uses this method.
358 pub fn poll_writable(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
359 self.source.poll_writable(cx)
360 }
361
362 /// Performs a read operation asynchronously.
363 ///
364 /// The I/O handle is registered in the reactor and put in non-blocking mode. This method
365 /// invokes the `op` closure in a loop until it succeeds or returns an error other than
366 /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS
367 /// sends a notification that the I/O handle is readable.
368 ///
369 /// The closure receives a shared reference to the I/O handle.
370 pub async fn read_with<R>(&self, op: impl FnMut(&T) -> io::Result<R>) -> io::Result<R> {
371 let mut op = op;
372 loop {
373 match op(self.get_ref()) {
374 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
375 res => return res,
376 }
377 optimistic(self.readable()).await?;
378 }
379 }
380
381 /// Performs a read operation asynchronously.
382 ///
383 /// The I/O handle is registered in the reactor and put in non-blocking mode. This method
384 /// invokes the `op` closure in a loop until it succeeds or returns an error other than
385 /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS
386 /// sends a notification that the I/O handle is readable.
387 ///
388 /// The closure receives a mutable reference to the I/O handle.
389 ///
390 /// # Safety
391 ///
392 /// In the closure, the underlying I/O source must not be dropped.
393 pub async unsafe fn read_with_mut<R>(
394 &mut self,
395 op: impl FnMut(&mut T) -> io::Result<R>,
396 ) -> io::Result<R> {
397 let mut op = op;
398 loop {
399 match op(self.get_mut()) {
400 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
401 res => return res,
402 }
403 optimistic(self.readable()).await?;
404 }
405 }
406
407 /// Performs a write operation asynchronously.
408 ///
409 /// The I/O handle is registered in the reactor and put in non-blocking mode. This method
410 /// invokes the `op` closure in a loop until it succeeds or returns an error other than
411 /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS
412 /// sends a notification that the I/O handle is writable.
413 ///
414 /// The closure receives a shared reference to the I/O handle.
415 pub async fn write_with<R>(&self, op: impl FnMut(&T) -> io::Result<R>) -> io::Result<R> {
416 let mut op = op;
417 loop {
418 match op(self.get_ref()) {
419 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
420 res => return res,
421 }
422 optimistic(self.writable()).await?;
423 }
424 }
425
426 /// Performs a write operation asynchronously.
427 ///
428 /// The I/O handle is registered in the reactor and put in non-blocking mode. This method
429 /// invokes the `op` closure in a loop until it succeeds or returns an error other than
430 /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS
431 /// sends a notification that the I/O handle is writable.
432 ///
433 /// The closure receives a mutable reference to the I/O handle.
434 ///
435 /// # Safety
436 ///
437 /// The closure receives a mutable reference to the I/O handle. In the closure, the underlying
438 /// I/O source must not be dropped.
439 pub async unsafe fn write_with_mut<R>(
440 &mut self,
441 op: impl FnMut(&mut T) -> io::Result<R>,
442 ) -> io::Result<R> {
443 let mut op = op;
444 loop {
445 match op(self.get_mut()) {
446 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
447 res => return res,
448 }
449 optimistic(self.writable()).await?;
450 }
451 }
452}
453
454impl<T: Send + 'static> AsRef<T> for Async<T> {
455 fn as_ref(&self) -> &T {
456 self.get_ref()
457 }
458}
459
460impl<T: Send + 'static> Drop for Async<T> {
461 fn drop(&mut self) {
462 if let Some(io) = self.io.take() {
463 if let Some(throttling_sched_hdl) = self.throttling_sched_hdl.take() {
464 if let Some(sched) = throttling_sched_hdl.upgrade() {
465 let source = Arc::clone(&self.source);
466 sched.spawn_and_unpark(async move {
467 Reactor::with_mut(|reactor| {
468 if let Err(err) = reactor.remove_io(&source) {
469 gst::error!(
470 RUNTIME_CAT,
471 "Failed to remove fd {:?}: {err}",
472 source.registration,
473 );
474 }
475 });
476 drop(io);
477 });
478 }
479 } else {
480 Reactor::with_mut(|reactor| {
481 if let Err(err) = reactor.remove_io(&self.source) {
482 gst::error!(
483 RUNTIME_CAT,
484 "Failed to remove fd {:?}: {err}",
485 self.source.registration,
486 );
487 }
488 });
489 }
490 }
491 }
492}
493
494/// Types whose I/O trait implementations do not drop the underlying I/O source.
495///
496/// The resource contained inside of the [`Async`] cannot be invalidated. This invalidation can
497/// happen if the inner resource (the [`TcpStream`], [`UnixListener`] or other `T`) is moved out
498/// and dropped before the [`Async`]. Because of this, functions that grant mutable access to
499/// the inner type are unsafe, as there is no way to guarantee that the source won't be dropped
500/// and a dangling handle won't be left behind.
501///
502/// Unfortunately this extends to implementations of [`Read`] and [`Write`]. Since methods on those
503/// traits take `&mut`, there is no guarantee that the implementor of those traits won't move the
504/// source out while the method is being run.
505///
506/// This trait is an antidote to this predicament. By implementing this trait, the user pledges
507/// that using any I/O traits won't destroy the source. This way, [`Async`] can implement the
508/// `async` version of these I/O traits, like [`AsyncRead`] and [`AsyncWrite`].
509///
510/// # Safety
511///
512/// Any I/O trait implementations for this type must not drop the underlying I/O source. Traits
513/// affected by this trait include [`Read`], [`Write`], [`Seek`] and [`BufRead`].
514///
515/// This trait is implemented by default on top of `libstd` types. In addition, it is implemented
516/// for immutable reference types, as it is impossible to invalidate any outstanding references
517/// while holding an immutable reference, even with interior mutability. As Rust's current pinning
518/// system relies on similar guarantees, I believe that this approach is robust.
519///
520/// [`BufRead`]: https://doc.rust-lang.org/std/io/trait.BufRead.html
521/// [`Read`]: https://doc.rust-lang.org/std/io/trait.Read.html
522/// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html
523/// [`Write`]: https://doc.rust-lang.org/std/io/trait.Write.html
524///
525/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html
526/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html
527pub unsafe trait IoSafe {}
528
529/// Reference types can't be mutated.
530///
531/// The worst thing that can happen is that external state is used to change what kind of pointer
532/// `as_fd()` returns. For instance:
533///
534/// ```
535/// # #[cfg(unix)] {
536/// use std::cell::Cell;
537/// use std::net::TcpStream;
538/// use std::os::unix::io::{AsFd, BorrowedFd};
539///
540/// struct Bar {
541/// flag: Cell<bool>,
542/// a: TcpStream,
543/// b: TcpStream
544/// }
545///
546/// impl AsFd for Bar {
547/// fn as_fd(&self) -> BorrowedFd<'_> {
548/// if self.flag.replace(!self.flag.get()) {
549/// self.a.as_fd()
550/// } else {
551/// self.b.as_fd()
552/// }
553/// }
554/// }
555/// # }
556/// ```
557///
558/// We solve this problem by only calling `as_fd()` once to get the original source. Implementations
559/// like this are considered buggy (but not unsound) and are thus not really supported by `async-io`.
560unsafe impl<T: ?Sized> IoSafe for &T {}
561
562// Can be implemented on top of libstd types.
563unsafe impl IoSafe for std::fs::File {}
564unsafe impl IoSafe for std::io::Stderr {}
565unsafe impl IoSafe for std::io::Stdin {}
566unsafe impl IoSafe for std::io::Stdout {}
567unsafe impl IoSafe for std::io::StderrLock<'_> {}
568unsafe impl IoSafe for std::io::StdinLock<'_> {}
569unsafe impl IoSafe for std::io::StdoutLock<'_> {}
570unsafe impl IoSafe for std::net::TcpStream {}
571
572#[cfg(unix)]
573unsafe impl IoSafe for std::os::unix::net::UnixStream {}
574
575unsafe impl<T: IoSafe + Read> IoSafe for std::io::BufReader<T> {}
576unsafe impl<T: IoSafe + Write> IoSafe for std::io::BufWriter<T> {}
577unsafe impl<T: IoSafe + Write> IoSafe for std::io::LineWriter<T> {}
578unsafe impl<T: IoSafe + ?Sized> IoSafe for &mut T {}
579unsafe impl<T: IoSafe + ?Sized> IoSafe for Box<T> {}
580unsafe impl<T: Clone + IoSafe> IoSafe for std::borrow::Cow<'_, T> {}
581
582impl<T: IoSafe + Read + Send + 'static> AsyncRead for Async<T> {
583 fn poll_read(
584 mut self: Pin<&mut Self>,
585 cx: &mut Context<'_>,
586 buf: &mut [u8],
587 ) -> Poll<io::Result<usize>> {
588 loop {
589 match unsafe { (*self).get_mut() }.read(buf) {
590 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
591 res => return Poll::Ready(res),
592 }
593 ready!(self.poll_readable(cx))?;
594 }
595 }
596
597 fn poll_read_vectored(
598 mut self: Pin<&mut Self>,
599 cx: &mut Context<'_>,
600 bufs: &mut [IoSliceMut<'_>],
601 ) -> Poll<io::Result<usize>> {
602 loop {
603 match unsafe { (*self).get_mut() }.read_vectored(bufs) {
604 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
605 res => return Poll::Ready(res),
606 }
607 ready!(self.poll_readable(cx))?;
608 }
609 }
610}
611
612// Since this is through a reference, we can't mutate the inner I/O source.
613// Therefore this is safe!
614impl<T: Send + 'static> AsyncRead for &Async<T>
615where
616 for<'a> &'a T: Read,
617{
618 fn poll_read(
619 self: Pin<&mut Self>,
620 cx: &mut Context<'_>,
621 buf: &mut [u8],
622 ) -> Poll<io::Result<usize>> {
623 loop {
624 match (*self).get_ref().read(buf) {
625 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
626 res => return Poll::Ready(res),
627 }
628 ready!(self.poll_readable(cx))?;
629 }
630 }
631
632 fn poll_read_vectored(
633 self: Pin<&mut Self>,
634 cx: &mut Context<'_>,
635 bufs: &mut [IoSliceMut<'_>],
636 ) -> Poll<io::Result<usize>> {
637 loop {
638 match (*self).get_ref().read_vectored(bufs) {
639 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
640 res => return Poll::Ready(res),
641 }
642 ready!(self.poll_readable(cx))?;
643 }
644 }
645}
646
647impl<T: IoSafe + Write + Send + 'static> AsyncWrite for Async<T> {
648 fn poll_write(
649 mut self: Pin<&mut Self>,
650 cx: &mut Context<'_>,
651 buf: &[u8],
652 ) -> Poll<io::Result<usize>> {
653 loop {
654 match unsafe { (*self).get_mut() }.write(buf) {
655 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
656 res => return Poll::Ready(res),
657 }
658 ready!(self.poll_writable(cx))?;
659 }
660 }
661
662 fn poll_write_vectored(
663 mut self: Pin<&mut Self>,
664 cx: &mut Context<'_>,
665 bufs: &[IoSlice<'_>],
666 ) -> Poll<io::Result<usize>> {
667 loop {
668 match unsafe { (*self).get_mut() }.write_vectored(bufs) {
669 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
670 res => return Poll::Ready(res),
671 }
672 ready!(self.poll_writable(cx))?;
673 }
674 }
675
676 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
677 loop {
678 match unsafe { (*self).get_mut() }.flush() {
679 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
680 res => return Poll::Ready(res),
681 }
682 ready!(self.poll_writable(cx))?;
683 }
684 }
685
686 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
687 self.poll_flush(cx)
688 }
689}
690
691impl<T: Send + 'static> AsyncWrite for &Async<T>
692where
693 for<'a> &'a T: Write,
694{
695 fn poll_write(
696 self: Pin<&mut Self>,
697 cx: &mut Context<'_>,
698 buf: &[u8],
699 ) -> Poll<io::Result<usize>> {
700 loop {
701 match (*self).get_ref().write(buf) {
702 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
703 res => return Poll::Ready(res),
704 }
705 ready!(self.poll_writable(cx))?;
706 }
707 }
708
709 fn poll_write_vectored(
710 self: Pin<&mut Self>,
711 cx: &mut Context<'_>,
712 bufs: &[IoSlice<'_>],
713 ) -> Poll<io::Result<usize>> {
714 loop {
715 match (*self).get_ref().write_vectored(bufs) {
716 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
717 res => return Poll::Ready(res),
718 }
719 ready!(self.poll_writable(cx))?;
720 }
721 }
722
723 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
724 loop {
725 match (*self).get_ref().flush() {
726 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
727 res => return Poll::Ready(res),
728 }
729 ready!(self.poll_writable(cx))?;
730 }
731 }
732
733 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
734 self.poll_flush(cx)
735 }
736}
737
738impl Async<TcpListener> {
739 /// Creates a TCP listener bound to the specified address.
740 ///
741 /// Binding with port number 0 will request an available port from the OS.
742 pub fn bind<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<TcpListener>> {
743 let addr = addr.into();
744 Async::new(TcpListener::bind(addr)?)
745 }
746
747 /// Accepts a new incoming TCP connection.
748 ///
749 /// When a connection is established, it will be returned as a TCP stream together with its
750 /// remote address.
751 pub async fn accept(&self) -> io::Result<(Async<TcpStream>, SocketAddr)> {
752 let (stream, addr) = self.read_with(|io| io.accept()).await?;
753 Ok((Async::new(stream)?, addr))
754 }
755
756 /// Returns a stream of incoming TCP connections.
757 ///
758 /// The stream is infinite, i.e. it never stops with a [`None`].
759 pub fn incoming(&self) -> impl Stream<Item = io::Result<Async<TcpStream>>> + Send + '_ {
760 stream::unfold(self, |listener| async move {
761 let res = listener.accept().await.map(|(stream, _)| stream);
762 Some((res, listener))
763 })
764 }
765}
766
767impl TryFrom<std::net::TcpListener> for Async<std::net::TcpListener> {
768 type Error = io::Error;
769
770 fn try_from(listener: std::net::TcpListener) -> io::Result<Self> {
771 Async::new(listener)
772 }
773}
774
775impl Async<TcpStream> {
776 /// Creates a TCP connection to the specified address.
777 pub async fn connect<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<TcpStream>> {
778 // Figure out how to handle this address.
779 let addr = addr.into();
780 let (domain, sock_addr) = match addr {
781 SocketAddr::V4(v4) => (rn::AddressFamily::INET, v4.as_any()),
782 SocketAddr::V6(v6) => (rn::AddressFamily::INET6, v6.as_any()),
783 };
784
785 // Begin async connect.
786 let socket = connect(sock_addr, domain, Some(rn::ipproto::TCP))?;
787 // Use new_nonblocking because connect already sets socket to non-blocking mode.
788 let stream = Async::new_nonblocking(TcpStream::from(socket))?;
789
790 // The stream becomes writable when connected.
791 stream.writable().await?;
792
793 // Check if there was an error while connecting.
794 match stream.get_ref().take_error()? {
795 None => Ok(stream),
796 Some(err) => Err(err),
797 }
798 }
799
800 /// Reads data from the stream without removing it from the buffer.
801 ///
802 /// Returns the number of bytes read. Successive calls of this method read the same data.
803 pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
804 self.read_with(|io| io.peek(buf)).await
805 }
806}
807
808impl TryFrom<std::net::TcpStream> for Async<std::net::TcpStream> {
809 type Error = io::Error;
810
811 fn try_from(stream: std::net::TcpStream) -> io::Result<Self> {
812 Async::new(stream)
813 }
814}
815
816impl Async<UdpSocket> {
817 /// Creates a UDP socket bound to the specified address.
818 ///
819 /// Binding with port number 0 will request an available port from the OS.
820 pub fn bind<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<UdpSocket>> {
821 let addr = addr.into();
822 Async::new(UdpSocket::bind(addr)?)
823 }
824
825 /// Receives a single datagram message.
826 ///
827 /// Returns the number of bytes read and the address the message came from.
828 ///
829 /// This method must be called with a valid byte slice of sufficient size to hold the message.
830 /// If the message is too long to fit, excess bytes may get discarded.
831 pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
832 self.read_with(|io| io.recv_from(buf)).await
833 }
834
835 /// Receives a single datagram message without removing it from the queue.
836 ///
837 /// Returns the number of bytes read and the address the message came from.
838 ///
839 /// This method must be called with a valid byte slice of sufficient size to hold the message.
840 /// If the message is too long to fit, excess bytes may get discarded.
841 pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
842 self.read_with(|io| io.peek_from(buf)).await
843 }
844
845 /// Sends data to the specified address.
846 ///
847 /// Returns the number of bytes written.
848 pub async fn send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<usize> {
849 let addr = addr.into();
850 self.write_with(|io| io.send_to(buf, addr)).await
851 }
852
853 /// Receives a single datagram message from the connected peer.
854 ///
855 /// Returns the number of bytes read.
856 ///
857 /// This method must be called with a valid byte slice of sufficient size to hold the message.
858 /// If the message is too long to fit, excess bytes may get discarded.
859 ///
860 /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address.
861 /// This method will fail if the socket is not connected.
862 pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
863 self.read_with(|io| io.recv(buf)).await
864 }
865
866 /// Receives a single datagram message from the connected peer without removing it from the
867 /// queue.
868 ///
869 /// Returns the number of bytes read and the address the message came from.
870 ///
871 /// This method must be called with a valid byte slice of sufficient size to hold the message.
872 /// If the message is too long to fit, excess bytes may get discarded.
873 ///
874 /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address.
875 /// This method will fail if the socket is not connected.
876 pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
877 self.read_with(|io| io.peek(buf)).await
878 }
879
880 /// Sends data to the connected peer.
881 ///
882 /// Returns the number of bytes written.
883 ///
884 /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address.
885 /// This method will fail if the socket is not connected.
886 pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
887 self.write_with(|io| io.send(buf)).await
888 }
889}
890
891impl TryFrom<std::net::UdpSocket> for Async<std::net::UdpSocket> {
892 type Error = io::Error;
893
894 fn try_from(socket: std::net::UdpSocket) -> io::Result<Self> {
895 Async::new(socket)
896 }
897}
898
899impl TryFrom<socket2::Socket> for Async<std::net::UdpSocket> {
900 type Error = io::Error;
901
902 fn try_from(socket: socket2::Socket) -> io::Result<Self> {
903 Async::new(std::net::UdpSocket::from(socket))
904 }
905}
906
907#[cfg(unix)]
908impl Async<UnixListener> {
909 /// Creates a UDS listener bound to the specified path.
910 pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixListener>> {
911 let path = path.as_ref().to_owned();
912 Async::new(UnixListener::bind(path)?)
913 }
914
915 /// Accepts a new incoming UDS stream connection.
916 pub async fn accept(&self) -> io::Result<(Async<UnixStream>, UnixSocketAddr)> {
917 let (stream, addr) = self.read_with(|io| io.accept()).await?;
918 Ok((Async::new(stream)?, addr))
919 }
920
921 /// Returns a stream of incoming UDS connections.
922 ///
923 /// The stream is infinite, i.e. it never stops with a [`None`] item.
924 pub fn incoming(&self) -> impl Stream<Item = io::Result<Async<UnixStream>>> + Send + '_ {
925 stream::unfold(self, |listener| async move {
926 let res = listener.accept().await.map(|(stream, _)| stream);
927 Some((res, listener))
928 })
929 }
930}
931
932#[cfg(unix)]
933impl TryFrom<std::os::unix::net::UnixListener> for Async<std::os::unix::net::UnixListener> {
934 type Error = io::Error;
935
936 fn try_from(listener: std::os::unix::net::UnixListener) -> io::Result<Self> {
937 Async::new(listener)
938 }
939}
940
941#[cfg(unix)]
942impl Async<UnixStream> {
943 /// Creates a UDS stream connected to the specified path.
944 pub async fn connect<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixStream>> {
945 let address = convert_path_to_socket_address(path.as_ref())?;
946
947 // Begin async connect.
948 let socket = connect(address.into(), rn::AddressFamily::UNIX, None)?;
949 // Use new_nonblocking because connect already sets socket to non-blocking mode.
950 let stream = Async::new_nonblocking(UnixStream::from(socket))?;
951
952 // The stream becomes writable when connected.
953 stream.writable().await?;
954
955 // On Linux, it appears the socket may become writable even when connecting fails, so we
956 // must do an extra check here and see if the peer address is retrievable.
957 stream.get_ref().peer_addr()?;
958 Ok(stream)
959 }
960
961 /// Creates an unnamed pair of connected UDS stream sockets.
962 pub fn pair() -> io::Result<(Async<UnixStream>, Async<UnixStream>)> {
963 let (stream1, stream2) = UnixStream::pair()?;
964 Ok((Async::new(stream1)?, Async::new(stream2)?))
965 }
966}
967
968#[cfg(unix)]
969impl TryFrom<std::os::unix::net::UnixStream> for Async<std::os::unix::net::UnixStream> {
970 type Error = io::Error;
971
972 fn try_from(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
973 Async::new(stream)
974 }
975}
976
977#[cfg(unix)]
978impl Async<UnixDatagram> {
979 /// Creates a UDS datagram socket bound to the specified path.
980 pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixDatagram>> {
981 let path = path.as_ref().to_owned();
982 Async::new(UnixDatagram::bind(path)?)
983 }
984
985 /// Creates a UDS datagram socket not bound to any address.
986 pub fn unbound() -> io::Result<Async<UnixDatagram>> {
987 Async::new(UnixDatagram::unbound()?)
988 }
989
990 /// Creates an unnamed pair of connected Unix datagram sockets.
991 pub fn pair() -> io::Result<(Async<UnixDatagram>, Async<UnixDatagram>)> {
992 let (socket1, socket2) = UnixDatagram::pair()?;
993 Ok((Async::new(socket1)?, Async::new(socket2)?))
994 }
995
996 /// Receives data from the socket.
997 ///
998 /// Returns the number of bytes read and the address the message came from.
999 pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, UnixSocketAddr)> {
1000 self.read_with(|io| io.recv_from(buf)).await
1001 }
1002
1003 /// Sends data to the specified address.
1004 ///
1005 /// Returns the number of bytes written.
1006 pub async fn send_to<P: AsRef<Path>>(&self, buf: &[u8], path: P) -> io::Result<usize> {
1007 self.write_with(|io| io.send_to(buf, &path)).await
1008 }
1009
1010 /// Receives data from the connected peer.
1011 ///
1012 /// Returns the number of bytes read and the address the message came from.
1013 ///
1014 /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address.
1015 /// This method will fail if the socket is not connected.
1016 pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
1017 self.read_with(|io| io.recv(buf)).await
1018 }
1019
1020 /// Sends data to the connected peer.
1021 ///
1022 /// Returns the number of bytes written.
1023 ///
1024 /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address.
1025 /// This method will fail if the socket is not connected.
1026 pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
1027 self.write_with(|io| io.send(buf)).await
1028 }
1029}
1030
1031#[cfg(unix)]
1032impl TryFrom<std::os::unix::net::UnixDatagram> for Async<std::os::unix::net::UnixDatagram> {
1033 type Error = io::Error;
1034
1035 fn try_from(socket: std::os::unix::net::UnixDatagram) -> io::Result<Self> {
1036 Async::new(socket)
1037 }
1038}
1039
1040/// Polls a future once, waits for a wakeup, and then optimistically assumes the future is ready.
1041async fn optimistic(fut: impl Future<Output = io::Result<()>>) -> io::Result<()> {
1042 let mut polled = false;
1043 pin_mut!(fut);
1044
1045 future::poll_fn(|cx| {
1046 if !polled {
1047 polled = true;
1048 fut.as_mut().poll(cx)
1049 } else {
1050 Poll::Ready(Ok(()))
1051 }
1052 })
1053 .await
1054}
1055
1056fn connect(
1057 addr: rn::SocketAddrAny,
1058 domain: rn::AddressFamily,
1059 protocol: Option<rn::Protocol>,
1060) -> io::Result<rustix::fd::OwnedFd> {
1061 #[cfg(windows)]
1062 use rustix::fd::AsFd;
1063
1064 setup_networking();
1065
1066 #[cfg(any(
1067 target_os = "android",
1068 target_os = "dragonfly",
1069 target_os = "freebsd",
1070 target_os = "fuchsia",
1071 target_os = "illumos",
1072 target_os = "linux",
1073 target_os = "netbsd",
1074 target_os = "openbsd"
1075 ))]
1076 let socket = rn::socket_with(
1077 domain,
1078 rn::SocketType::STREAM,
1079 rn::SocketFlags::CLOEXEC | rn::SocketFlags::NONBLOCK,
1080 protocol,
1081 )?;
1082
1083 #[cfg(not(any(
1084 target_os = "android",
1085 target_os = "dragonfly",
1086 target_os = "freebsd",
1087 target_os = "fuchsia",
1088 target_os = "illumos",
1089 target_os = "linux",
1090 target_os = "netbsd",
1091 target_os = "openbsd"
1092 )))]
1093 let socket = {
1094 #[cfg(not(any(
1095 target_os = "aix",
1096 target_vendor = "apple",
1097 target_os = "espidf",
1098 windows,
1099 )))]
1100 let flags = rn::SocketFlags::CLOEXEC;
1101 #[cfg(any(
1102 target_os = "aix",
1103 target_vendor = "apple",
1104 target_os = "espidf",
1105 windows,
1106 ))]
1107 let flags = rn::SocketFlags::empty();
1108
1109 // Create the socket.
1110 let socket = rn::socket_with(domain, rn::SocketType::STREAM, flags, protocol)?;
1111
1112 // Set cloexec if necessary.
1113 #[cfg(any(target_os = "aix", target_vendor = "apple"))]
1114 rio::fcntl_setfd(&socket, rio::fcntl_getfd(&socket)? | rio::FdFlags::CLOEXEC)?;
1115
1116 // Set non-blocking mode.
1117 set_nonblocking(socket.as_fd())?;
1118
1119 socket
1120 };
1121
1122 // Set nosigpipe if necessary.
1123 #[cfg(any(
1124 target_vendor = "apple",
1125 target_os = "freebsd",
1126 target_os = "netbsd",
1127 target_os = "dragonfly",
1128 ))]
1129 rn::sockopt::set_socket_nosigpipe(&socket, true)?;
1130
1131 // Set the handle information to HANDLE_FLAG_INHERIT.
1132 #[cfg(windows)]
1133 unsafe {
1134 if windows_sys::Win32::Foundation::SetHandleInformation(
1135 socket.as_raw_socket() as _,
1136 windows_sys::Win32::Foundation::HANDLE_FLAG_INHERIT,
1137 windows_sys::Win32::Foundation::HANDLE_FLAG_INHERIT,
1138 ) == 0
1139 {
1140 return Err(io::Error::last_os_error());
1141 }
1142 }
1143
1144 #[allow(unreachable_patterns)]
1145 match rn::connect(&socket, &addr) {
1146 Ok(_) => {}
1147 #[cfg(unix)]
1148 Err(rio::Errno::INPROGRESS) => {}
1149 Err(rio::Errno::AGAIN) | Err(rio::Errno::WOULDBLOCK) => {}
1150 Err(err) => return Err(err.into()),
1151 }
1152 Ok(socket)
1153}
1154
1155#[inline]
1156fn setup_networking() {
1157 #[cfg(windows)]
1158 {
1159 // On Windows, we need to call WSAStartup before calling any networking code.
1160 // Make sure to call it at least once.
1161 static INIT: std::sync::Once = std::sync::Once::new();
1162
1163 INIT.call_once(|| {
1164 let _ = rustix::net::wsa_startup();
1165 });
1166 }
1167}
1168
1169#[inline]
1170fn set_nonblocking(
1171 #[cfg(unix)] fd: BorrowedFd<'_>,
1172 #[cfg(windows)] fd: BorrowedSocket<'_>,
1173) -> io::Result<()> {
1174 cfg_if::cfg_if! {
1175 // ioctl(FIONBIO) sets the flag atomically, but we use this only on Linux
1176 // for now, as with the standard library, because it seems to behave
1177 // differently depending on the platform.
1178 // https://github.com/rust-lang/rust/commit/efeb42be2837842d1beb47b51bb693c7474aba3d
1179 // https://github.com/libuv/libuv/blob/e9d91fccfc3e5ff772d5da90e1c4a24061198ca0/src/unix/poll.c#L78-L80
1180 // https://github.com/tokio-rs/mio/commit/0db49f6d5caf54b12176821363d154384357e70a
1181 if #[cfg(any(windows, target_os = "linux"))] {
1182 rustix::io::ioctl_fionbio(fd, true)?;
1183 } else {
1184 let previous = rustix::fs::fcntl_getfl(fd)?;
1185 let new = previous | rustix::fs::OFlags::NONBLOCK;
1186 if new != previous {
1187 rustix::fs::fcntl_setfl(fd, new)?;
1188 }
1189 }
1190 }
1191
1192 Ok(())
1193}
1194
1195/// Converts a `Path` to its socket address representation.
1196///
1197/// This function is abstract socket-aware.
1198#[cfg(unix)]
1199#[inline]
1200fn convert_path_to_socket_address(path: &Path) -> io::Result<rn::SocketAddrUnix> {
1201 // SocketAddrUnix::new() will throw EINVAL when a path with a zero in it is passed in.
1202 // However, some users expect to be able to pass in paths to abstract sockets, which
1203 // triggers this error as it has a zero in it. Therefore, if a path starts with a zero,
1204 // make it an abstract socket.
1205 #[cfg(any(target_os = "linux", target_os = "android"))]
1206 let address = {
1207 use std::os::unix::ffi::OsStrExt;
1208
1209 let path = path.as_os_str();
1210 match path.as_bytes().first() {
1211 Some(0) => rn::SocketAddrUnix::new_abstract_name(path.as_bytes().get(1..).unwrap())?,
1212 _ => rn::SocketAddrUnix::new(path)?,
1213 }
1214 };
1215
1216 // Only Linux and Android support abstract sockets.
1217 #[cfg(not(any(target_os = "linux", target_os = "android")))]
1218 let address = rn::SocketAddrUnix::new(path)?;
1219
1220 Ok(address)
1221}