async_io_mini/
io.rs

1use core::future::{poll_fn, Future};
2use core::pin::pin;
3use core::task::{Context, Poll};
4
5use std::io::{self, Read, Write};
6use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket};
7use std::os::fd::FromRawFd;
8use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
9
10use super::reactor::{Event, REACTOR};
11use super::sys;
12use super::{ready, syscall, syscall_los, syscall_los_eagain};
13
14/// Async adapter for I/O types.
15///
16/// This type puts an I/O handle into non-blocking mode, registers it in
17/// [epoll]/[kqueue]/[event ports]/[IOCP], and then provides an async interface for it.
18///
19/// [epoll]: https://en.wikipedia.org/wiki/Epoll
20/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
21/// [event ports]: https://illumos.org/man/port_create
22/// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
23///
24/// # Caveats
25///
26/// [`Async`] is a low-level primitive, and as such it comes with some caveats.
27///
28/// For higher-level primitives built on top of [`Async`], look into [`async-net`] or
29/// [`async-process`] (on Unix).
30///
31/// The most notable caveat is that it is unsafe to access the inner I/O source mutably
32/// using this primitive. Traits likes [`AsyncRead`] and [`AsyncWrite`] are not implemented by
33/// default unless it is guaranteed that the resource won't be invalidated by reading or writing.
34/// See the [`IoSafe`] trait for more information.
35///
36/// [`async-net`]: https://github.com/smol-rs/async-net
37/// [`async-process`]: https://github.com/smol-rs/async-process
38/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html
39/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html
40///
41/// ### Supported types
42///
43/// [`Async`] supports all networking types, as well as some OS-specific file descriptors like
44/// [timerfd] and [inotify].
45///
46/// However, do not use [`Async`] with types like [`File`][`std::fs::File`],
47/// [`Stdin`][`std::io::Stdin`], [`Stdout`][`std::io::Stdout`], or [`Stderr`][`std::io::Stderr`]
48/// because all operating systems have issues with them when put in non-blocking mode.
49///
50/// [timerfd]: https://github.com/smol-rs/async-io/blob/master/examples/linux-timerfd.rs
51/// [inotify]: https://github.com/smol-rs/async-io/blob/master/examples/linux-inotify.rs
52///
53/// ### Concurrent I/O
54///
55/// Note that [`&Async<T>`][`Async`] implements [`AsyncRead`] and [`AsyncWrite`] if `&T`
56/// implements those traits, which means tasks can concurrently read and write using shared
57/// references.
58///
59/// But there is a catch: only one task can read a time, and only one task can write at a time. It
60/// is okay to have two tasks where one is reading and the other is writing at the same time, but
61/// it is not okay to have two tasks reading at the same time or writing at the same time. If you
62/// try to do that, conflicting tasks will just keep waking each other in turn, thus wasting CPU
63/// time.
64///
65/// Besides [`AsyncRead`] and [`AsyncWrite`], this caveat also applies to
66/// [`poll_readable()`][`Async::poll_readable()`] and
67/// [`poll_writable()`][`Async::poll_writable()`].
68///
69/// However, any number of tasks can be concurrently calling other methods like
70/// [`readable()`][`Async::readable()`] or [`read_with()`][`Async::read_with()`].
71///
72/// ### Closing
73///
74/// Closing the write side of [`Async`] with [`close()`][`futures_lite::AsyncWriteExt::close()`]
75/// simply flushes. If you want to shutdown a TCP or Unix socket, use
76/// [`Shutdown`][`std::net::Shutdown`].
77///
78/// # Examples
79///
80/// Connect to a server and echo incoming messages back to the server:
81///
82/// ```no_run
83/// use async_io_mini::Async;
84/// use futures_lite::io;
85/// use std::net::TcpStream;
86///
87/// # futures_lite::future::block_on(async {
88/// // Connect to a local server.
89/// let stream = Async::<TcpStream>::connect(([127, 0, 0, 1], 8000)).await?;
90///
91/// // Echo all messages from the read side of the stream into the write side.
92/// io::copy(&stream, &stream).await?;
93/// # std::io::Result::Ok(()) });
94/// ```
95///
96/// You can use either predefined async methods or wrap blocking I/O operations in
97/// [`Async::read_with()`], [`Async::read_with_mut()`], [`Async::write_with()`], and
98/// [`Async::write_with_mut()`]:
99///
100/// ```no_run
101/// use async_io_mini::Async;
102/// use std::net::TcpListener;
103///
104/// # futures_lite::future::block_on(async {
105/// let listener = Async::<TcpListener>::bind(([127, 0, 0, 1], 0))?;
106///
107/// // These two lines are equivalent:
108/// let (stream, addr) = listener.accept().await?;
109/// let (stream, addr) = listener.read_with(|inner| inner.accept()).await?;
110/// # std::io::Result::Ok(()) });
111/// ```
112#[derive(Debug)]
113pub struct Async<T: AsFd> {
114    io: Option<T>,
115}
116
117impl<T: AsFd> Unpin for Async<T> {}
118
119impl<T: AsFd> Async<T> {
120    /// Creates an async I/O handle.
121    ///
122    /// This method will put the handle in non-blocking mode and register it in
123    /// [epoll]/[kqueue]/[event ports]/[IOCP].
124    ///
125    /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
126    /// `AsSocket`.
127    ///
128    /// [epoll]: https://en.wikipedia.org/wiki/Epoll
129    /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
130    /// [event ports]: https://illumos.org/man/port_create
131    /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
132    ///
133    /// # Examples
134    ///
135    /// ```
136    /// use async_io_mini::Async;
137    /// use std::net::{SocketAddr, TcpListener};
138    ///
139    /// # futures_lite::future::block_on(async {
140    /// let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?;
141    /// let listener = Async::new(listener)?;
142    /// # std::io::Result::Ok(()) });
143    /// ```
144    pub fn new(io: T) -> io::Result<Self> {
145        // Put the file descriptor in non-blocking mode.
146        set_nonblocking(io.as_fd())?;
147
148        Self::new_nonblocking(io)
149    }
150
151    /// Creates an async I/O handle without setting it to non-blocking mode.
152    ///
153    /// This method will register the handle in [epoll]/[kqueue]/[event ports]/[IOCP].
154    ///
155    /// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
156    /// `AsSocket`.
157    ///
158    /// [epoll]: https://en.wikipedia.org/wiki/Epoll
159    /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
160    /// [event ports]: https://illumos.org/man/port_create
161    /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
162    ///
163    /// # Caveats
164    ///
165    /// The caller should ensure that the handle is set to non-blocking mode or that it is okay if
166    /// it is not set. If not set to non-blocking mode, I/O operations may block the current thread
167    /// and cause a deadlock in an asynchronous context.
168    pub fn new_nonblocking(io: T) -> io::Result<Self> {
169        REACTOR.start()?;
170        // SAFETY: It is impossible to drop the I/O source while it is registered.
171        REACTOR.register(io.as_fd().as_raw_fd())?;
172
173        Ok(Self { io: Some(io) })
174    }
175}
176
177impl<T: AsFd + AsRawFd> AsRawFd for Async<T> {
178    fn as_raw_fd(&self) -> RawFd {
179        self.get_ref().as_raw_fd()
180    }
181}
182
183impl<T: AsFd> AsFd for Async<T> {
184    fn as_fd(&self) -> BorrowedFd<'_> {
185        self.get_ref().as_fd()
186    }
187}
188
189impl<T: AsFd + From<OwnedFd>> TryFrom<OwnedFd> for Async<T> {
190    type Error = io::Error;
191
192    fn try_from(value: OwnedFd) -> Result<Self, Self::Error> {
193        Async::new(value.into())
194    }
195}
196
197impl<T: AsFd + Into<OwnedFd>> TryFrom<Async<T>> for OwnedFd {
198    type Error = io::Error;
199
200    fn try_from(value: Async<T>) -> Result<Self, Self::Error> {
201        value.into_inner().map(Into::into)
202    }
203}
204
205impl<T: AsFd> Async<T> {
206    /// Gets a reference to the inner I/O handle.
207    ///
208    /// # Examples
209    ///
210    /// ```
211    /// use async_io_mini::Async;
212    /// use std::net::TcpListener;
213    ///
214    /// # futures_lite::future::block_on(async {
215    /// let listener = Async::<TcpListener>::bind(([127, 0, 0, 1], 0))?;
216    /// let inner = listener.get_ref();
217    /// # std::io::Result::Ok(()) });
218    /// ```
219    pub fn get_ref(&self) -> &T {
220        self.io.as_ref().unwrap()
221    }
222
223    /// Gets a mutable reference to the inner I/O handle.
224    ///
225    /// # Safety
226    ///
227    /// The underlying I/O source must not be dropped using this function.
228    ///
229    /// # Examples
230    ///
231    /// ```
232    /// use async_io_mini::Async;
233    /// use std::net::TcpListener;
234    ///
235    /// # futures_lite::future::block_on(async {
236    /// let mut listener = Async::<TcpListener>::bind(([127, 0, 0, 1], 0))?;
237    /// let inner = unsafe { listener.get_mut() };
238    /// # std::io::Result::Ok(()) });
239    /// ```
240    pub unsafe fn get_mut(&mut self) -> &mut T {
241        self.io.as_mut().unwrap()
242    }
243
244    /// Unwraps the inner I/O handle.
245    ///
246    /// This method will **not** put the I/O handle back into blocking mode.
247    ///
248    /// # Examples
249    ///
250    /// ```
251    /// use async_io_mini::Async;
252    /// use std::net::TcpListener;
253    ///
254    /// # futures_lite::future::block_on(async {
255    /// let listener = Async::<TcpListener>::bind(([127, 0, 0, 1], 0))?;
256    /// let inner = listener.into_inner()?;
257    ///
258    /// // Put the listener back into blocking mode.
259    /// inner.set_nonblocking(false)?;
260    /// # std::io::Result::Ok(()) });
261    /// ```
262    pub fn into_inner(mut self) -> io::Result<T> {
263        REACTOR.deregister(self.as_fd().as_raw_fd())?;
264        Ok(self.io.take().unwrap())
265    }
266
267    /// Waits until the I/O handle is readable.
268    ///
269    /// This method completes when a read operation on this I/O handle wouldn't block.
270    ///
271    /// # Examples
272    ///
273    /// ```no_run
274    /// use async_io_mini::Async;
275    /// use std::net::TcpListener;
276    ///
277    /// # futures_lite::future::block_on(async {
278    /// let mut listener = Async::<TcpListener>::bind(([127, 0, 0, 1], 0))?;
279    ///
280    /// // Wait until a client can be accepted.
281    /// listener.readable().await?;
282    /// # std::io::Result::Ok(()) });
283    /// ```
284    pub async fn readable(&self) -> io::Result<()> {
285        poll_fn(|cx| self.poll_readable(cx)).await
286    }
287
288    /// Waits until the I/O handle is writable.
289    ///
290    /// This method completes when a write operation on this I/O handle wouldn't block.
291    ///
292    /// # Examples
293    ///
294    /// ```
295    /// use async_io_mini::Async;
296    /// use std::net::{TcpStream, ToSocketAddrs};
297    ///
298    /// # futures_lite::future::block_on(async {
299    /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap();
300    /// let stream = Async::<TcpStream>::connect(addr).await?;
301    ///
302    /// // Wait until the stream is writable.
303    /// stream.writable().await?;
304    /// # std::io::Result::Ok(()) });
305    /// ```
306    pub async fn writable(&self) -> io::Result<()> {
307        poll_fn(|cx| self.poll_writable(cx)).await
308    }
309
310    /// Polls the I/O handle for readability.
311    ///
312    /// When this method returns [`Poll::Ready`], that means the OS has delivered an event
313    /// indicating readability since the last time this task has called the method and received
314    /// [`Poll::Pending`].
315    ///
316    /// # Caveats
317    ///
318    /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks
319    /// will just keep waking each other in turn, thus wasting CPU time.
320    ///
321    /// Note that the [`AsyncRead`] implementation for [`Async`] also uses this method.
322    ///
323    /// # Examples
324    ///
325    /// ```no_run
326    /// use async_io_mini::Async;
327    /// use futures_lite::future;
328    /// use std::net::TcpListener;
329    ///
330    /// # futures_lite::future::block_on(async {
331    /// let mut listener = Async::<TcpListener>::bind(([127, 0, 0, 1], 0))?;
332    ///
333    /// // Wait until a client can be accepted.
334    /// future::poll_fn(|cx| listener.poll_readable(cx)).await?;
335    /// # std::io::Result::Ok(()) });
336    /// ```
337    pub fn poll_readable(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
338        if REACTOR.fetch_or_set(self.as_fd().as_raw_fd(), Event::Read, cx.waker())? {
339            Poll::Ready(Ok(()))
340        } else {
341            Poll::Pending
342        }
343    }
344
345    /// Polls the I/O handle for writability.
346    ///
347    /// When this method returns [`Poll::Ready`], that means the OS has delivered an event
348    /// indicating writability since the last time this task has called the method and received
349    /// [`Poll::Pending`].
350    ///
351    /// # Caveats
352    ///
353    /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks
354    /// will just keep waking each other in turn, thus wasting CPU time.
355    ///
356    /// Note that the [`AsyncWrite`] implementation for [`Async`] also uses this method.
357    ///
358    /// # Examples
359    ///
360    /// ```
361    /// use async_io_mini::Async;
362    /// use futures_lite::future;
363    /// use std::net::{TcpStream, ToSocketAddrs};
364    ///
365    /// # futures_lite::future::block_on(async {
366    /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap();
367    /// let stream = Async::<TcpStream>::connect(addr).await?;
368    ///
369    /// // Wait until the stream is writable.
370    /// future::poll_fn(|cx| stream.poll_writable(cx)).await?;
371    /// # std::io::Result::Ok(()) });
372    /// ```
373    pub fn poll_writable(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
374        if REACTOR.fetch_or_set(self.as_fd().as_raw_fd(), Event::Write, cx.waker())? {
375            Poll::Ready(Ok(()))
376        } else {
377            Poll::Pending
378        }
379    }
380
381    /// Performs a read operation asynchronously.
382    ///
383    /// The I/O handle is registered in the reactor and put in non-blocking mode. This method
384    /// invokes the `op` closure in a loop until it succeeds or returns an error other than
385    /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS
386    /// sends a notification that the I/O handle is readable.
387    ///
388    /// The closure receives a shared reference to the I/O handle.
389    ///
390    /// # Examples
391    ///
392    /// ```no_run
393    /// use async_io_mini::Async;
394    /// use std::net::TcpListener;
395    ///
396    /// # futures_lite::future::block_on(async {
397    /// let listener = Async::<TcpListener>::bind(([127, 0, 0, 1], 0))?;
398    ///
399    /// // Accept a new client asynchronously.
400    /// let (stream, addr) = listener.read_with(|l| l.accept()).await?;
401    /// # std::io::Result::Ok(()) });
402    /// ```
403    pub async fn read_with<R>(&self, op: impl FnMut(&T) -> io::Result<R>) -> io::Result<R> {
404        REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Read)?;
405
406        let mut op = op;
407        loop {
408            match op(self.get_ref()) {
409                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
410                res => return res,
411            }
412            optimistic(self.readable()).await?;
413        }
414    }
415
416    /// Performs a read operation asynchronously.
417    ///
418    /// The I/O handle is registered in the reactor and put in non-blocking mode. This method
419    /// invokes the `op` closure in a loop until it succeeds or returns an error other than
420    /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS
421    /// sends a notification that the I/O handle is readable.
422    ///
423    /// The closure receives a mutable reference to the I/O handle.
424    ///
425    /// # Safety
426    ///
427    /// In the closure, the underlying I/O source must not be dropped.
428    ///
429    /// # Examples
430    ///
431    /// ```no_run
432    /// use async_io_mini::Async;
433    /// use std::net::TcpListener;
434    ///
435    /// # futures_lite::future::block_on(async {
436    /// let mut listener = Async::<TcpListener>::bind(([127, 0, 0, 1], 0))?;
437    ///
438    /// // Accept a new client asynchronously.
439    /// let (stream, addr) = unsafe { listener.read_with_mut(|l| l.accept()).await? };
440    /// # std::io::Result::Ok(()) });
441    /// ```
442    pub async unsafe fn read_with_mut<R>(
443        &mut self,
444        op: impl FnMut(&mut T) -> io::Result<R>,
445    ) -> io::Result<R> {
446        REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Read)?;
447
448        let mut op = op;
449        loop {
450            match op(self.get_mut()) {
451                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
452                res => return res,
453            }
454            optimistic(self.readable()).await?;
455        }
456    }
457
458    /// Performs a write operation asynchronously.
459    ///
460    /// The I/O handle is registered in the reactor and put in non-blocking mode. This method
461    /// invokes the `op` closure in a loop until it succeeds or returns an error other than
462    /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS
463    /// sends a notification that the I/O handle is writable.
464    ///
465    /// The closure receives a shared reference to the I/O handle.
466    ///
467    /// # Examples
468    ///
469    /// ```no_run
470    /// use async_io_mini::Async;
471    /// use std::net::UdpSocket;
472    ///
473    /// # futures_lite::future::block_on(async {
474    /// let socket = Async::<UdpSocket>::bind(([127, 0, 0, 1], 8000))?;
475    /// socket.get_ref().connect("127.0.0.1:9000")?;
476    ///
477    /// let msg = b"hello";
478    /// let len = socket.write_with(|s| s.send(msg)).await?;
479    /// # std::io::Result::Ok(()) });
480    /// ```
481    pub async fn write_with<R>(&self, op: impl FnMut(&T) -> io::Result<R>) -> io::Result<R> {
482        REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Write)?;
483
484        let mut op = op;
485        loop {
486            match op(self.get_ref()) {
487                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
488                res => return res,
489            }
490            optimistic(self.writable()).await?;
491        }
492    }
493
494    /// Performs a write operation asynchronously.
495    ///
496    /// The I/O handle is registered in the reactor and put in non-blocking mode. This method
497    /// invokes the `op` closure in a loop until it succeeds or returns an error other than
498    /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS
499    /// sends a notification that the I/O handle is writable.
500    ///
501    /// # Safety
502    ///
503    /// The closure receives a mutable reference to the I/O handle. In the closure, the underlying
504    /// I/O source must not be dropped.
505    ///
506    /// # Examples
507    ///
508    /// ```no_run
509    /// use async_io_mini::Async;
510    /// use std::net::UdpSocket;
511    ///
512    /// # futures_lite::future::block_on(async {
513    /// let mut socket = Async::<UdpSocket>::bind(([127, 0, 0, 1], 8000))?;
514    /// socket.get_ref().connect("127.0.0.1:9000")?;
515    ///
516    /// let msg = b"hello";
517    /// let len = unsafe { socket.write_with_mut(|s| s.send(msg)).await? };
518    /// # std::io::Result::Ok(()) });
519    /// ```
520    pub async unsafe fn write_with_mut<R>(
521        &mut self,
522        op: impl FnMut(&mut T) -> io::Result<R>,
523    ) -> io::Result<R> {
524        REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Write)?;
525
526        let mut op = op;
527        loop {
528            match op(self.get_mut()) {
529                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
530                res => return res,
531            }
532            optimistic(self.writable()).await?;
533        }
534    }
535}
536
537impl<T: AsFd> AsRef<T> for Async<T> {
538    fn as_ref(&self) -> &T {
539        self.io.as_ref().unwrap()
540    }
541}
542
543impl<T: AsFd> Drop for Async<T> {
544    fn drop(&mut self) {
545        if let Some(io) = &self.io {
546            REACTOR.deregister(io.as_fd().as_raw_fd()).ok();
547        }
548    }
549}
550
551/// Types whose I/O trait implementations do not drop the underlying I/O source.
552///
553/// The resource contained inside of the [`Async`] cannot be invalidated. This invalidation can
554/// happen if the inner resource (the [`TcpStream`], [`UnixListener`] or other `T`) is moved out
555/// and dropped before the [`Async`]. Because of this, functions that grant mutable access to
556/// the inner type are unsafe, as there is no way to guarantee that the source won't be dropped
557/// and a dangling handle won't be left behind.
558///
559/// Unfortunately this extends to implementations of [`Read`] and [`Write`]. Since methods on those
560/// traits take `&mut`, there is no guarantee that the implementor of those traits won't move the
561/// source out while the method is being run.
562///
563/// This trait is an antidote to this predicament. By implementing this trait, the user pledges
564/// that using any I/O traits won't destroy the source. This way, [`Async`] can implement the
565/// `async` version of these I/O traits, like [`AsyncRead`] and [`AsyncWrite`].
566///
567/// # Safety
568///
569/// Any I/O trait implementations for this type must not drop the underlying I/O source. Traits
570/// affected by this trait include [`Read`], [`Write`], [`Seek`] and [`BufRead`].
571///
572/// This trait is implemented by default on top of `libstd` types. In addition, it is implemented
573/// for immutable reference types, as it is impossible to invalidate any outstanding references
574/// while holding an immutable reference, even with interior mutability. As Rust's current pinning
575/// system relies on similar guarantees, I believe that this approach is robust.
576///
577/// [`BufRead`]: https://doc.rust-lang.org/std/io/trait.BufRead.html
578/// [`Read`]: https://doc.rust-lang.org/std/io/trait.Read.html
579/// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html
580/// [`Write`]: https://doc.rust-lang.org/std/io/trait.Write.html
581///
582/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html
583/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html
584pub unsafe trait IoSafe {}
585
586/// Reference types can't be mutated.
587///
588/// The worst thing that can happen is that external state is used to change what kind of pointer
589/// `as_fd()` returns. For instance:
590///
591/// ```
592/// # #[cfg(unix)] {
593/// use std::cell::Cell;
594/// use std::net::TcpStream;
595/// use std::os::unix::io::{AsFd, BorrowedFd};
596///
597/// struct Bar {
598///     flag: Cell<bool>,
599///     a: TcpStream,
600///     b: TcpStream
601/// }
602///
603/// impl AsFd for Bar {
604///     fn as_fd(&self) -> BorrowedFd<'_> {
605///         if self.flag.replace(!self.flag.get()) {
606///             self.a.as_fd()
607///         } else {
608///             self.b.as_fd()
609///         }
610///     }
611/// }
612/// # }
613/// ```
614///
615/// We solve this problem by only calling `as_fd()` once to get the original source. Implementations
616/// like this are considered buggy (but not unsound) and are thus not really supported by `async-io`.
617unsafe impl<T: ?Sized> IoSafe for &T {}
618
619// Can be implemented on top of libstd types.
620unsafe impl IoSafe for std::fs::File {}
621unsafe impl IoSafe for std::io::Stderr {}
622unsafe impl IoSafe for std::io::Stdin {}
623unsafe impl IoSafe for std::io::Stdout {}
624unsafe impl IoSafe for std::io::StderrLock<'_> {}
625unsafe impl IoSafe for std::io::StdinLock<'_> {}
626unsafe impl IoSafe for std::io::StdoutLock<'_> {}
627unsafe impl IoSafe for std::net::TcpStream {}
628unsafe impl IoSafe for std::process::ChildStdin {}
629unsafe impl IoSafe for std::process::ChildStdout {}
630unsafe impl IoSafe for std::process::ChildStderr {}
631
632unsafe impl<T: IoSafe + Read> IoSafe for std::io::BufReader<T> {}
633unsafe impl<T: IoSafe + Write> IoSafe for std::io::BufWriter<T> {}
634unsafe impl<T: IoSafe + Write> IoSafe for std::io::LineWriter<T> {}
635unsafe impl<T: IoSafe + ?Sized> IoSafe for &mut T {}
636//unsafe impl<T: IoSafe + ?Sized> IoSafe for alloc::boxed::Box<T> {}
637unsafe impl<T: Clone + IoSafe + ?Sized> IoSafe for std::borrow::Cow<'_, T> {}
638
639#[cfg(feature = "futures-io")]
640impl<T: AsFd + IoSafe + Read> futures_io::AsyncRead for Async<T> {
641    fn poll_read(
642        mut self: core::pin::Pin<&mut Self>,
643        cx: &mut Context<'_>,
644        buf: &mut [u8],
645    ) -> Poll<io::Result<usize>> {
646        loop {
647            match unsafe { (*self).get_mut() }.read(buf) {
648                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
649                res => return Poll::Ready(res),
650            }
651            ready!(self.poll_readable(cx))?;
652        }
653    }
654
655    fn poll_read_vectored(
656        mut self: core::pin::Pin<&mut Self>,
657        cx: &mut Context<'_>,
658        bufs: &mut [std::io::IoSliceMut<'_>],
659    ) -> Poll<io::Result<usize>> {
660        loop {
661            match unsafe { (*self).get_mut() }.read_vectored(bufs) {
662                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
663                res => return Poll::Ready(res),
664            }
665            ready!(self.poll_readable(cx))?;
666        }
667    }
668}
669
670// Since this is through a reference, we can't mutate the inner I/O source.
671// Therefore this is safe!
672#[cfg(feature = "futures-io")]
673impl<T: AsFd> futures_io::AsyncRead for &Async<T>
674where
675    for<'a> &'a T: Read,
676{
677    fn poll_read(
678        self: core::pin::Pin<&mut Self>,
679        cx: &mut Context<'_>,
680        buf: &mut [u8],
681    ) -> Poll<io::Result<usize>> {
682        loop {
683            match (*self).get_ref().read(buf) {
684                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
685                res => return Poll::Ready(res),
686            }
687            ready!(self.poll_readable(cx))?;
688        }
689    }
690
691    fn poll_read_vectored(
692        self: core::pin::Pin<&mut Self>,
693        cx: &mut Context<'_>,
694        bufs: &mut [std::io::IoSliceMut<'_>],
695    ) -> Poll<io::Result<usize>> {
696        loop {
697            match (*self).get_ref().read_vectored(bufs) {
698                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
699                res => return Poll::Ready(res),
700            }
701            ready!(self.poll_readable(cx))?;
702        }
703    }
704}
705
706#[cfg(feature = "futures-io")]
707impl<T: AsFd + IoSafe + Write> futures_io::AsyncWrite for Async<T> {
708    fn poll_write(
709        mut self: core::pin::Pin<&mut Self>,
710        cx: &mut Context<'_>,
711        buf: &[u8],
712    ) -> Poll<io::Result<usize>> {
713        loop {
714            match unsafe { (*self).get_mut() }.write(buf) {
715                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
716                res => return Poll::Ready(res),
717            }
718            ready!(self.poll_writable(cx))?;
719        }
720    }
721
722    fn poll_write_vectored(
723        mut self: core::pin::Pin<&mut Self>,
724        cx: &mut Context<'_>,
725        bufs: &[std::io::IoSlice<'_>],
726    ) -> Poll<io::Result<usize>> {
727        loop {
728            match unsafe { (*self).get_mut() }.write_vectored(bufs) {
729                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
730                res => return Poll::Ready(res),
731            }
732            ready!(self.poll_writable(cx))?;
733        }
734    }
735
736    fn poll_flush(
737        mut self: core::pin::Pin<&mut Self>,
738        cx: &mut Context<'_>,
739    ) -> Poll<io::Result<()>> {
740        loop {
741            match unsafe { (*self).get_mut() }.flush() {
742                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
743                res => return Poll::Ready(res),
744            }
745            ready!(self.poll_writable(cx))?;
746        }
747    }
748
749    fn poll_close(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
750        self.poll_flush(cx)
751    }
752}
753
754#[cfg(feature = "futures-io")]
755impl<T: AsFd> futures_io::AsyncWrite for &Async<T>
756where
757    for<'a> &'a T: Write,
758{
759    fn poll_write(
760        self: core::pin::Pin<&mut Self>,
761        cx: &mut Context<'_>,
762        buf: &[u8],
763    ) -> Poll<io::Result<usize>> {
764        loop {
765            match (*self).get_ref().write(buf) {
766                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
767                res => return Poll::Ready(res),
768            }
769            ready!(self.poll_writable(cx))?;
770        }
771    }
772
773    fn poll_write_vectored(
774        self: core::pin::Pin<&mut Self>,
775        cx: &mut Context<'_>,
776        bufs: &[std::io::IoSlice<'_>],
777    ) -> Poll<io::Result<usize>> {
778        loop {
779            match (*self).get_ref().write_vectored(bufs) {
780                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
781                res => return Poll::Ready(res),
782            }
783            ready!(self.poll_writable(cx))?;
784        }
785    }
786
787    fn poll_flush(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
788        loop {
789            match (*self).get_ref().flush() {
790                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
791                res => return Poll::Ready(res),
792            }
793            ready!(self.poll_writable(cx))?;
794        }
795    }
796
797    fn poll_close(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
798        self.poll_flush(cx)
799    }
800}
801
802impl Async<TcpListener> {
803    /// Creates a TCP listener bound to the specified address.
804    ///
805    /// Binding with port number 0 will request an available port from the OS.
806    ///
807    /// # Examples
808    ///
809    /// ```
810    /// use async_io_mini::Async;
811    /// use std::net::TcpListener;
812    ///
813    /// # futures_lite::future::block_on(async {
814    /// let listener = Async::<TcpListener>::bind(([127, 0, 0, 1], 0))?;
815    /// println!("Listening on {}", listener.get_ref().local_addr()?);
816    /// # std::io::Result::Ok(()) });
817    /// ```
818    pub fn bind<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<TcpListener>> {
819        let addr = addr.into();
820        Async::new(TcpListener::bind(addr)?)
821    }
822
823    /// Accepts a new incoming TCP connection.
824    ///
825    /// When a connection is established, it will be returned as a TCP stream together with its
826    /// remote address.
827    ///
828    /// # Examples
829    ///
830    /// ```no_run
831    /// use async_io_mini::Async;
832    /// use std::net::TcpListener;
833    ///
834    /// # futures_lite::future::block_on(async {
835    /// let listener = Async::<TcpListener>::bind(([127, 0, 0, 1], 8000))?;
836    /// let (stream, addr) = listener.accept().await?;
837    /// println!("Accepted client: {}", addr);
838    /// # std::io::Result::Ok(()) });
839    /// ```
840    pub async fn accept(&self) -> io::Result<(Async<TcpStream>, SocketAddr)> {
841        let (stream, addr) = self.read_with(|io| io.accept()).await?;
842        Ok((Async::new(stream)?, addr))
843    }
844
845    /// Returns a stream of incoming TCP connections.
846    ///
847    /// The stream is infinite, i.e. it never stops with a [`None`].
848    ///
849    /// # Examples
850    ///
851    /// ```no_run
852    /// use async_io_mini::Async;
853    /// use futures_lite::{pin, stream::StreamExt};
854    /// use std::net::TcpListener;
855    ///
856    /// # futures_lite::future::block_on(async {
857    /// let listener = Async::<TcpListener>::bind(([127, 0, 0, 1], 8000))?;
858    /// let incoming = listener.incoming();
859    /// pin!(incoming);
860    ///
861    /// while let Some(stream) = incoming.next().await {
862    ///     let stream = stream?;
863    ///     println!("Accepted client: {}", stream.get_ref().peer_addr()?);
864    /// }
865    /// # std::io::Result::Ok(()) });
866    /// ```
867    #[cfg(feature = "futures-lite")]
868    pub fn incoming(
869        &self,
870    ) -> impl futures_lite::Stream<Item = io::Result<Async<TcpStream>>> + Send + '_ {
871        futures_lite::stream::unfold(self, |listener| async move {
872            let res = listener.accept().await.map(|(stream, _)| stream);
873            Some((res, listener))
874        })
875    }
876}
877
878impl TryFrom<std::net::TcpListener> for Async<std::net::TcpListener> {
879    type Error = io::Error;
880
881    fn try_from(listener: std::net::TcpListener) -> io::Result<Self> {
882        Async::new(listener)
883    }
884}
885
886impl Async<TcpStream> {
887    /// Creates a TCP connection to the specified address.
888    ///
889    /// # Examples
890    ///
891    /// ```
892    /// use async_io_mini::Async;
893    /// use std::net::{TcpStream, ToSocketAddrs};
894    ///
895    /// # futures_lite::future::block_on(async {
896    /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap();
897    /// let stream = Async::<TcpStream>::connect(addr).await?;
898    /// # std::io::Result::Ok(()) });
899    /// ```
900    pub async fn connect<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<TcpStream>> {
901        // Figure out how to handle this address.
902        let addr = addr.into();
903
904        let socket = match addr {
905            SocketAddr::V4(v4) => {
906                let addr = sys::sockaddr_in {
907                    sin_family: sys::AF_INET as _,
908                    sin_port: u16::to_be(v4.port()),
909                    sin_addr: sys::in_addr {
910                        s_addr: u32::from_ne_bytes(v4.ip().octets()),
911                    },
912                    #[cfg(target_os = "espidf")]
913                    sin_len: Default::default(),
914                    sin_zero: Default::default(),
915                };
916
917                connect(
918                    &addr as *const _ as *const _,
919                    core::mem::size_of_val(&addr),
920                    sys::AF_INET,
921                    sys::SOCK_STREAM,
922                    0,
923                )
924            }
925            SocketAddr::V6(v6) => {
926                let addr = sys::sockaddr_in6 {
927                    sin6_family: sys::AF_INET6 as _,
928                    sin6_port: u16::to_be(v6.port()),
929                    sin6_flowinfo: 0,
930                    sin6_addr: sys::in6_addr {
931                        s6_addr: v6.ip().octets(),
932                    },
933                    sin6_scope_id: 0,
934                    #[cfg(target_os = "espidf")]
935                    sin6_len: Default::default(),
936                };
937
938                connect(
939                    &addr as *const _ as *const _,
940                    core::mem::size_of_val(&addr),
941                    sys::AF_INET6,
942                    sys::SOCK_STREAM,
943                    6,
944                )
945            }
946        }?;
947
948        // Use new_nonblocking because connect already sets socket to non-blocking mode.
949        let stream = Async::new_nonblocking(TcpStream::from(socket))?;
950
951        // The stream becomes writable when connected.
952        stream.writable().await?;
953
954        // Check if there was an error while connecting.
955        match stream.get_ref().take_error()? {
956            None => Ok(stream),
957            Some(err) => Err(err),
958        }
959    }
960
961    /// Reads data from the stream without removing it from the buffer.
962    ///
963    /// Returns the number of bytes read. Successive calls of this method read the same data.
964    ///
965    /// # Examples
966    ///
967    /// ```
968    /// use async_io_mini::Async;
969    /// use futures_lite::{io::AsyncWriteExt, stream::StreamExt};
970    /// use std::net::{TcpStream, ToSocketAddrs};
971    ///
972    /// # futures_lite::future::block_on(async {
973    /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap();
974    /// let mut stream = Async::<TcpStream>::connect(addr).await?;
975    ///
976    /// stream
977    ///     .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
978    ///     .await?;
979    ///
980    /// let mut buf = [0u8; 1024];
981    /// let len = stream.peek(&mut buf).await?;
982    /// # std::io::Result::Ok(()) });
983    /// ```
984    pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
985        self.read_with(|io| io.peek(buf)).await
986    }
987}
988
989impl TryFrom<std::net::TcpStream> for Async<std::net::TcpStream> {
990    type Error = io::Error;
991
992    fn try_from(stream: std::net::TcpStream) -> io::Result<Self> {
993        Async::new(stream)
994    }
995}
996
997impl Async<UdpSocket> {
998    /// Creates a UDP socket bound to the specified address.
999    ///
1000    /// Binding with port number 0 will request an available port from the OS.
1001    ///
1002    /// # Examples
1003    ///
1004    /// ```
1005    /// use async_io_mini::Async;
1006    /// use std::net::UdpSocket;
1007    ///
1008    /// # futures_lite::future::block_on(async {
1009    /// let socket = Async::<UdpSocket>::bind(([127, 0, 0, 1], 0))?;
1010    /// println!("Bound to {}", socket.get_ref().local_addr()?);
1011    /// # std::io::Result::Ok(()) });
1012    /// ```
1013    pub fn bind<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<UdpSocket>> {
1014        let addr = addr.into();
1015        Async::new(UdpSocket::bind(addr)?)
1016    }
1017
1018    /// Receives a single datagram message.
1019    ///
1020    /// Returns the number of bytes read and the address the message came from.
1021    ///
1022    /// This method must be called with a valid byte slice of sufficient size to hold the message.
1023    /// If the message is too long to fit, excess bytes may get discarded.
1024    ///
1025    /// # Examples
1026    ///
1027    /// ```no_run
1028    /// use async_io_mini::Async;
1029    /// use std::net::UdpSocket;
1030    ///
1031    /// # futures_lite::future::block_on(async {
1032    /// let socket = Async::<UdpSocket>::bind(([127, 0, 0, 1], 8000))?;
1033    ///
1034    /// let mut buf = [0u8; 1024];
1035    /// let (len, addr) = socket.recv_from(&mut buf).await?;
1036    /// # std::io::Result::Ok(()) });
1037    /// ```
1038    pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
1039        self.read_with(|io| io.recv_from(buf)).await
1040    }
1041
1042    /// Receives a single datagram message without removing it from the queue.
1043    ///
1044    /// Returns the number of bytes read and the address the message came from.
1045    ///
1046    /// This method must be called with a valid byte slice of sufficient size to hold the message.
1047    /// If the message is too long to fit, excess bytes may get discarded.
1048    ///
1049    /// # Examples
1050    ///
1051    /// ```no_run
1052    /// use async_io_mini::Async;
1053    /// use std::net::UdpSocket;
1054    ///
1055    /// # futures_lite::future::block_on(async {
1056    /// let socket = Async::<UdpSocket>::bind(([127, 0, 0, 1], 8000))?;
1057    ///
1058    /// let mut buf = [0u8; 1024];
1059    /// let (len, addr) = socket.peek_from(&mut buf).await?;
1060    /// # std::io::Result::Ok(()) });
1061    /// ```
1062    pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
1063        self.read_with(|io| io.peek_from(buf)).await
1064    }
1065
1066    /// Sends data to the specified address.
1067    ///
1068    /// Returns the number of bytes writen.
1069    ///
1070    /// # Examples
1071    ///
1072    /// ```no_run
1073    /// use async_io_mini::Async;
1074    /// use std::net::UdpSocket;
1075    ///
1076    /// # futures_lite::future::block_on(async {
1077    /// let socket = Async::<UdpSocket>::bind(([127, 0, 0, 1], 0))?;
1078    /// let addr = socket.get_ref().local_addr()?;
1079    ///
1080    /// let msg = b"hello";
1081    /// let len = socket.send_to(msg, addr).await?;
1082    /// # std::io::Result::Ok(()) });
1083    /// ```
1084    pub async fn send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<usize> {
1085        let addr = addr.into();
1086        self.write_with(|io| io.send_to(buf, addr)).await
1087    }
1088
1089    /// Receives a single datagram message from the connected peer.
1090    ///
1091    /// Returns the number of bytes read.
1092    ///
1093    /// This method must be called with a valid byte slice of sufficient size to hold the message.
1094    /// If the message is too long to fit, excess bytes may get discarded.
1095    ///
1096    /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address.
1097    /// This method will fail if the socket is not connected.
1098    ///
1099    /// # Examples
1100    ///
1101    /// ```no_run
1102    /// use async_io_mini::Async;
1103    /// use std::net::UdpSocket;
1104    ///
1105    /// # futures_lite::future::block_on(async {
1106    /// let socket = Async::<UdpSocket>::bind(([127, 0, 0, 1], 8000))?;
1107    /// socket.get_ref().connect("127.0.0.1:9000")?;
1108    ///
1109    /// let mut buf = [0u8; 1024];
1110    /// let len = socket.recv(&mut buf).await?;
1111    /// # std::io::Result::Ok(()) });
1112    /// ```
1113    pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
1114        self.read_with(|io| io.recv(buf)).await
1115    }
1116
1117    /// Receives a single datagram message from the connected peer without removing it from the
1118    /// queue.
1119    ///
1120    /// Returns the number of bytes read and the address the message came from.
1121    ///
1122    /// This method must be called with a valid byte slice of sufficient size to hold the message.
1123    /// If the message is too long to fit, excess bytes may get discarded.
1124    ///
1125    /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address.
1126    /// This method will fail if the socket is not connected.
1127    ///
1128    /// # Examples
1129    ///
1130    /// ```no_run
1131    /// use async_io_mini::Async;
1132    /// use std::net::UdpSocket;
1133    ///
1134    /// # futures_lite::future::block_on(async {
1135    /// let socket = Async::<UdpSocket>::bind(([127, 0, 0, 1], 8000))?;
1136    /// socket.get_ref().connect("127.0.0.1:9000")?;
1137    ///
1138    /// let mut buf = [0u8; 1024];
1139    /// let len = socket.peek(&mut buf).await?;
1140    /// # std::io::Result::Ok(()) });
1141    /// ```
1142    pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
1143        self.read_with(|io| io.peek(buf)).await
1144    }
1145
1146    /// Sends data to the connected peer.
1147    ///
1148    /// Returns the number of bytes written.
1149    ///
1150    /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address.
1151    /// This method will fail if the socket is not connected.
1152    ///
1153    /// # Examples
1154    ///
1155    /// ```no_run
1156    /// use async_io_mini::Async;
1157    /// use std::net::UdpSocket;
1158    ///
1159    /// # futures_lite::future::block_on(async {
1160    /// let socket = Async::<UdpSocket>::bind(([127, 0, 0, 1], 8000))?;
1161    /// socket.get_ref().connect("127.0.0.1:9000")?;
1162    ///
1163    /// let msg = b"hello";
1164    /// let len = socket.send(msg).await?;
1165    /// # std::io::Result::Ok(()) });
1166    /// ```
1167    pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
1168        self.write_with(|io| io.send(buf)).await
1169    }
1170}
1171
1172impl TryFrom<std::net::UdpSocket> for Async<std::net::UdpSocket> {
1173    type Error = io::Error;
1174
1175    fn try_from(socket: std::net::UdpSocket) -> io::Result<Self> {
1176        Async::new(socket)
1177    }
1178}
1179
1180/// Polls a future once, waits for a wakeup, and then optimistically assumes the future is ready.
1181async fn optimistic(fut: impl Future<Output = io::Result<()>>) -> io::Result<()> {
1182    let mut polled = false;
1183    let mut fut = pin!(fut);
1184
1185    poll_fn(move |cx| {
1186        if !polled {
1187            polled = true;
1188            fut.as_mut().poll(cx)
1189        } else {
1190            Poll::Ready(Ok(()))
1191        }
1192    })
1193    .await
1194}
1195
1196fn connect(
1197    addr: *const sys::sockaddr,
1198    addr_len: usize,
1199    domain: sys::c_int,
1200    ty: sys::c_int,
1201    protocol: sys::c_int,
1202) -> io::Result<OwnedFd> {
1203    // Create the socket.
1204    let socket = unsafe { OwnedFd::from_raw_fd(syscall_los!(sys::socket(domain, ty, protocol))?) };
1205
1206    // Set non-blocking mode.
1207    set_nonblocking(socket.as_fd())?;
1208
1209    syscall_los_eagain!(unsafe { sys::connect(socket.as_raw_fd(), addr, addr_len as _) })?;
1210
1211    Ok(socket)
1212}
1213
1214fn set_nonblocking(fd: BorrowedFd) -> io::Result<()> {
1215    let previous = unsafe { sys::fcntl(fd.as_raw_fd(), sys::F_GETFL) };
1216    let new = previous | sys::O_NONBLOCK;
1217    if new != previous {
1218        syscall!(unsafe { sys::fcntl(fd.as_raw_fd(), sys::F_SETFL, new) })?;
1219    }
1220
1221    Ok(())
1222}