#[cfg(unix)]
use std::os::{
fd::{AsFd, AsRawFd, BorrowedFd},
unix::net::UnixStream,
};
use std::{
fs::File,
future::poll_fn,
io::{
self, BufRead, BufReader, BufWriter, ErrorKind, LineWriter, Read, Stderr, StderrLock,
Stdin, StdinLock, Stdout, StdoutLock, Write,
},
marker::PhantomData,
net::{SocketAddr, TcpListener, TcpStream, UdpSocket},
pin::Pin,
process::{ChildStderr, ChildStdin, ChildStdout},
task::{Context, Poll},
};
use futures_core::Stream;
use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
use crate::{
reactor::{Interest, Source},
REACTOR,
};
pub unsafe trait IoSafe {}
unsafe impl IoSafe for File {}
unsafe impl IoSafe for Stderr {}
unsafe impl IoSafe for Stdin {}
unsafe impl IoSafe for Stdout {}
unsafe impl IoSafe for StderrLock<'_> {}
unsafe impl IoSafe for StdinLock<'_> {}
unsafe impl IoSafe for StdoutLock<'_> {}
unsafe impl IoSafe for TcpStream {}
unsafe impl IoSafe for UdpSocket {}
#[cfg(unix)]
unsafe impl IoSafe for UnixStream {}
unsafe impl IoSafe for ChildStdin {}
unsafe impl IoSafe for ChildStderr {}
unsafe impl IoSafe for ChildStdout {}
unsafe impl<T: IoSafe> IoSafe for BufReader<T> {}
unsafe impl<T: IoSafe + Write> IoSafe for BufWriter<T> {}
unsafe impl<T: IoSafe + Write> IoSafe for LineWriter<T> {}
unsafe impl<T: IoSafe + ?Sized> IoSafe for &mut T {}
unsafe impl<T: IoSafe + ?Sized> IoSafe for Box<T> {}
unsafe impl<T: IoSafe + ?Sized> IoSafe for &T {}
struct GuardedSource(Source);
impl Drop for GuardedSource {
fn drop(&mut self) {
if let Err(err) = REACTOR.with(|r| r.deregister_event(self.0)) {
log::error!("Drop failed due to deregistration failure: {err}");
}
}
}
pub struct Async<T> {
source: GuardedSource,
inner: T,
_phantom: PhantomData<*const ()>,
}
impl<T> Unpin for Async<T> {}
#[cfg(unix)]
impl<T: AsFd> Async<T> {
pub fn without_nonblocking(inner: T) -> io::Result<Self> {
let source = inner.as_fd().as_raw_fd();
unsafe { REACTOR.with(|r| r.register_event(source))? }
Ok(Self {
inner,
source: GuardedSource(source),
_phantom: PhantomData,
})
}
pub fn new(inner: T) -> io::Result<Self> {
set_nonblocking(inner.as_fd())?;
Self::without_nonblocking(inner)
}
}
#[cfg(unix)]
pub(crate) fn set_nonblocking(fd: BorrowedFd) -> io::Result<()> {
#[cfg(any(target_os = "linux", target_os = "android"))]
rustix::io::ioctl_fionbio(fd, true)?;
#[cfg(not(any(target_os = "linux", target_os = "android")))]
{
let previous = rustix::fs::fcntl_getfl(fd)?;
let new = previous | rustix::fs::OFlags::NONBLOCK;
if new != previous {
rustix::fs::fcntl_setfl(fd, new)?;
}
}
Ok(())
}
impl<T> Async<T> {
pub fn get_ref(&self) -> &T {
&self.inner
}
pub fn into_inner(self) -> T {
self.inner
}
unsafe fn poll_event<'a, P, F>(
&'a self,
interest: Interest,
cx: &mut Context,
f: F,
) -> Poll<io::Result<P>>
where
F: FnOnce(&'a T) -> io::Result<P>,
{
match f(&self.inner) {
Ok(n) => return Poll::Ready(Ok(n)),
Err(err) if err.kind() == ErrorKind::WouldBlock => {}
Err(err) => return Poll::Ready(Err(err)),
}
REACTOR.with(|r| r.enable_event(self.source.0, interest, cx.waker()))?;
Poll::Pending
}
unsafe fn poll_event_mut<'a, P, F>(
&'a mut self,
interest: Interest,
cx: &mut Context,
f: F,
) -> Poll<io::Result<P>>
where
F: FnOnce(&'a mut T) -> io::Result<P>,
{
match f(&mut self.inner) {
Ok(n) => return Poll::Ready(Ok(n)),
Err(err) if err.kind() == ErrorKind::WouldBlock => {}
Err(err) => return Poll::Ready(Err(err)),
}
REACTOR.with(|r| r.enable_event(self.source.0, interest, cx.waker()))?;
Poll::Pending
}
pub unsafe fn poll_read_with<'a, P, F>(&'a self, cx: &mut Context, f: F) -> Poll<io::Result<P>>
where
F: FnOnce(&'a T) -> io::Result<P>,
{
self.poll_event(Interest::Read, cx, f)
}
pub unsafe fn poll_read_with_mut<'a, P, F>(
&'a mut self,
cx: &mut Context,
f: F,
) -> Poll<io::Result<P>>
where
F: FnOnce(&'a mut T) -> io::Result<P>,
{
self.poll_event_mut(Interest::Read, cx, f)
}
pub unsafe fn poll_write_with<'a, P, F>(&'a self, cx: &mut Context, f: F) -> Poll<io::Result<P>>
where
F: FnOnce(&'a T) -> io::Result<P>,
{
self.poll_event(Interest::Write, cx, f)
}
pub unsafe fn poll_write_with_mut<'a, P, F>(
&'a mut self,
cx: &mut Context,
f: F,
) -> Poll<io::Result<P>>
where
F: FnOnce(&'a mut T) -> io::Result<P>,
{
self.poll_event_mut(Interest::Write, cx, f)
}
async fn wait_for_event_ready(&self, interest: Interest) -> io::Result<()> {
let mut first_call = true;
poll_fn(|cx| {
if first_call {
first_call = false;
REACTOR.with(|r| r.enable_event(self.source.0, interest, cx.waker()))?;
Poll::Pending
} else {
match REACTOR.with(|r| r.is_event_ready(self.source.0, interest)) {
true => Poll::Ready(Ok(())),
false => {
REACTOR.with(|r| r.enable_event(self.source.0, interest, cx.waker()))?;
Poll::Pending
}
}
}
})
.await
}
pub async fn writable(&self) -> io::Result<()> {
self.wait_for_event_ready(Interest::Write).await
}
pub async fn readable(&self) -> io::Result<()> {
self.wait_for_event_ready(Interest::Read).await
}
}
impl<T: Read + IoSafe> AsyncRead for Async<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
unsafe { self.poll_event_mut(Interest::Read, cx, |inner| inner.read(buf)) }
}
}
impl<'a, T> AsyncRead for &'a Async<T>
where
&'a T: Read + IoSafe,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
unsafe { self.poll_event(Interest::Read, cx, |mut inner| inner.read(buf)) }
}
}
impl<T: Write + IoSafe> AsyncWrite for Async<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
unsafe { self.poll_event_mut(Interest::Write, cx, |inner| inner.write(buf)) }
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
unsafe { self.poll_event_mut(Interest::Write, cx, |inner| inner.flush()) }
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.poll_flush(cx)
}
}
impl<'a, T> AsyncWrite for &'a Async<T>
where
&'a T: Write + IoSafe,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
unsafe { self.poll_event(Interest::Write, cx, |mut inner| inner.write(buf)) }
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
unsafe { self.poll_event(Interest::Write, cx, |mut inner| inner.flush()) }
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.poll_flush(cx)
}
}
impl<T: BufRead + IoSafe> AsyncBufRead for Async<T> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
let this = self.get_mut();
unsafe { this.poll_event_mut(Interest::Read, cx, |inner| inner.fill_buf()) }
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
BufRead::consume(&mut self.inner, amt);
}
}
impl Async<TcpListener> {
pub fn bind<A: Into<SocketAddr>>(addr: A) -> io::Result<Self> {
Async::new(TcpListener::bind(addr.into())?)
}
fn poll_accept(&self, cx: &mut Context) -> Poll<io::Result<(Async<TcpStream>, SocketAddr)>> {
unsafe {
self.poll_event(Interest::Read, cx, |inner| {
inner
.accept()
.and_then(|(st, addr)| Async::new(st).map(|st| (st, addr)))
})
}
}
pub async fn accept(&self) -> io::Result<(Async<TcpStream>, SocketAddr)> {
poll_fn(|cx| self.poll_accept(cx)).await
}
pub fn incoming(&self) -> IncomingTcp {
IncomingTcp { listener: self }
}
}
#[must_use = "Streams do nothing unless polled"]
pub struct IncomingTcp<'a> {
listener: &'a Async<TcpListener>,
}
impl Stream for IncomingTcp<'_> {
type Item = io::Result<Async<TcpStream>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.listener
.poll_accept(cx)
.map(|pair| pair.map(|(st, _)| st))
.map(Some)
}
}
impl Async<TcpStream> {
pub async fn connect<A: Into<SocketAddr>>(addr: A) -> io::Result<Self> {
let addr = addr.into();
let stream = Async::without_nonblocking(tcp_socket(&addr)?)?;
connect(&stream.inner, &addr)?;
stream.wait_for_event_ready(Interest::Write).await?;
stream.inner.peer_addr()?;
Ok(stream)
}
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
unsafe { poll_fn(|cx| self.poll_event(Interest::Read, cx, |inner| inner.peek(buf))).await }
}
}
#[cfg(unix)]
fn tcp_socket(addr: &SocketAddr) -> io::Result<TcpStream> {
use rustix::net::*;
let af = match addr {
SocketAddr::V4(_) => AddressFamily::INET,
SocketAddr::V6(_) => AddressFamily::INET6,
};
let type_ = SocketType::STREAM;
#[cfg(any(target_os = "linux", target_os = "android"))]
let socket = socket_with(
af,
type_,
SocketFlags::NONBLOCK | SocketFlags::CLOEXEC,
None,
)?;
#[cfg(not(any(target_os = "linux", target_os = "android")))]
let socket = {
let socket = socket_with(af, type_, SocketFlags::empty(), None)?;
let previous = rustix::fs::fcntl_getfl(&socket)?;
let new = previous | rustix::fs::OFlags::NONBLOCK | rustix::fs::OFlags::CLOEXEC;
if new != previous {
rustix::fs::fcntl_setfl(&socket, new)?;
}
socket
};
Ok(socket.into())
}
#[cfg(unix)]
fn connect(tcp: &TcpStream, addr: &SocketAddr) -> io::Result<()> {
match rustix::net::connect(tcp.as_fd(), addr) {
Ok(()) => Ok(()),
Err(rustix::io::Errno::INPROGRESS | rustix::io::Errno::WOULDBLOCK) => Ok(()),
Err(err) => Err(err.into()),
}
}
impl Async<UdpSocket> {
pub fn bind<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<UdpSocket>> {
Async::new(UdpSocket::bind(addr.into())?)
}
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
poll_fn(|cx| unsafe { self.poll_read_with(cx, |inner| inner.recv_from(buf)) }).await
}
pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
poll_fn(|cx| unsafe { self.poll_read_with(cx, |inner| inner.peek_from(buf)) }).await
}
pub async fn send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<usize> {
let addr = addr.into();
poll_fn(|cx| unsafe { self.poll_write_with(cx, |inner| inner.send_to(buf, addr)) }).await
}
pub fn connect<A: Into<SocketAddr>>(&self, addr: A) -> io::Result<()> {
self.get_ref().connect(addr.into())
}
pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
poll_fn(|cx| unsafe { self.poll_read_with(cx, |inner| inner.recv(buf)) }).await
}
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
poll_fn(|cx| unsafe { self.poll_read_with(cx, |inner| inner.peek(buf)) }).await
}
pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
poll_fn(|cx| unsafe { self.poll_write_with(cx, |inner| inner.send(buf)) }).await
}
}
#[cfg(test)]
mod tests {
use std::{future::Future, io::stderr, pin::pin, sync::Arc};
use rustix::pipe::pipe;
use crate::{block_on, test::MockWaker};
use super::*;
#[test]
fn deregister_on_drop() {
let io = Async::without_nonblocking(stderr());
assert!(!REACTOR.with(|r| r.is_empty()));
drop(io);
assert!(REACTOR.with(|r| r.is_empty()));
}
#[test]
fn deregister_into_inner() {
let io = Async::without_nonblocking(stderr()).unwrap();
assert!(!REACTOR.with(|r| r.is_empty()));
let _inner = io.into_inner();
assert!(REACTOR.with(|r| r.is_empty()));
}
#[test]
fn tcp() {
let accept_waker = Arc::new(MockWaker::default());
let connect_waker = Arc::new(MockWaker::default());
let listener = Async::<TcpListener>::bind(([127, 0, 0, 1], 0)).unwrap();
let addr = listener.get_ref().local_addr().unwrap();
let mut accept = pin!(listener.accept());
assert!(accept
.as_mut()
.poll(&mut Context::from_waker(&accept_waker.clone().into()))
.is_pending());
let mut connect = pin!(Async::<TcpStream>::connect(addr));
assert!(connect
.as_mut()
.poll(&mut Context::from_waker(&connect_waker.clone().into()))
.is_pending());
block_on(async {
let _accepted = accept.await.unwrap();
let _conneted = connect.await.unwrap();
});
let mut connect = pin!(Async::<TcpStream>::connect(addr));
assert!(connect
.as_mut()
.poll(&mut Context::from_waker(&connect_waker.into()))
.is_pending());
}
#[test]
fn writable_readable() {
let wr_waker = Arc::new(MockWaker::default());
let rd_waker = Arc::new(MockWaker::default());
let (read, write) = pipe().unwrap();
set_nonblocking(read.as_fd()).unwrap();
set_nonblocking(write.as_fd()).unwrap();
let reader = Async::new(read).unwrap();
let writer = Async::new(write).unwrap();
let mut writable = pin!(writer.writable());
assert!(writable
.as_mut()
.poll(&mut Context::from_waker(&wr_waker.clone().into()))
.is_pending());
REACTOR.with(|r| r.wait()).unwrap();
assert!(wr_waker.get());
assert!(writable
.as_mut()
.poll(&mut Context::from_waker(&wr_waker.clone().into()))
.is_ready());
let mut readable = pin!(reader.readable());
assert!(readable
.as_mut()
.poll(&mut Context::from_waker(&rd_waker.clone().into()))
.is_pending());
unsafe {
assert!(writer
.poll_write_with(&mut Context::from_waker(&wr_waker.clone().into()), |w| {
rustix::io::write(w, &[0]).map_err(Into::into)
})
.is_ready());
};
REACTOR.with(|r| r.wait()).unwrap();
assert!(rd_waker.get());
assert!(readable
.as_mut()
.poll(&mut Context::from_waker(&rd_waker.clone().into()))
.is_ready());
}
}