use core::future::{poll_fn, Future};
use core::pin::pin;
use core::task::{Context, Poll};
use std::io::{self, Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket};
use std::os::fd::FromRawFd;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
use super::reactor::{Event, REACTOR};
use super::sys;
use super::{ready, syscall, syscall_los, syscall_los_eagain};
#[derive(Debug)]
pub struct Async<T: AsFd> {
io: Option<T>,
}
impl<T: AsFd> Unpin for Async<T> {}
impl<T: AsFd> Async<T> {
pub fn new(io: T) -> io::Result<Self> {
set_nonblocking(io.as_fd())?;
Self::new_nonblocking(io)
}
pub fn new_nonblocking(io: T) -> io::Result<Self> {
REACTOR.start()?;
REACTOR.register(io.as_fd().as_raw_fd())?;
Ok(Self { io: Some(io) })
}
}
impl<T: AsFd + AsRawFd> AsRawFd for Async<T> {
fn as_raw_fd(&self) -> RawFd {
self.get_ref().as_raw_fd()
}
}
impl<T: AsFd> AsFd for Async<T> {
fn as_fd(&self) -> BorrowedFd<'_> {
self.get_ref().as_fd()
}
}
impl<T: AsFd + From<OwnedFd>> TryFrom<OwnedFd> for Async<T> {
type Error = io::Error;
fn try_from(value: OwnedFd) -> Result<Self, Self::Error> {
Async::new(value.into())
}
}
impl<T: AsFd + Into<OwnedFd>> TryFrom<Async<T>> for OwnedFd {
type Error = io::Error;
fn try_from(value: Async<T>) -> Result<Self, Self::Error> {
value.into_inner().map(Into::into)
}
}
impl<T: AsFd> Async<T> {
pub fn get_ref(&self) -> &T {
self.io.as_ref().unwrap()
}
pub unsafe fn get_mut(&mut self) -> &mut T {
self.io.as_mut().unwrap()
}
pub fn into_inner(mut self) -> io::Result<T> {
REACTOR.deregister(self.as_fd().as_raw_fd())?;
Ok(self.io.take().unwrap())
}
pub async fn readable(&self) -> io::Result<()> {
poll_fn(|cx| self.poll_readable(cx)).await
}
pub async fn writable(&self) -> io::Result<()> {
poll_fn(|cx| self.poll_writable(cx)).await
}
pub fn poll_readable(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if REACTOR.fetch_or_set(self.as_fd().as_raw_fd(), Event::Read, cx.waker())? {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
pub fn poll_writable(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if REACTOR.fetch_or_set(self.as_fd().as_raw_fd(), Event::Write, cx.waker())? {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
pub async fn read_with<R>(&self, op: impl FnMut(&T) -> io::Result<R>) -> io::Result<R> {
REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Read)?;
let mut op = op;
loop {
match op(self.get_ref()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
}
optimistic(self.readable()).await?;
}
}
pub async unsafe fn read_with_mut<R>(
&mut self,
op: impl FnMut(&mut T) -> io::Result<R>,
) -> io::Result<R> {
REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Read)?;
let mut op = op;
loop {
match op(self.get_mut()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
}
optimistic(self.readable()).await?;
}
}
pub async fn write_with<R>(&self, op: impl FnMut(&T) -> io::Result<R>) -> io::Result<R> {
REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Write)?;
let mut op = op;
loop {
match op(self.get_ref()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
}
optimistic(self.writable()).await?;
}
}
pub async unsafe fn write_with_mut<R>(
&mut self,
op: impl FnMut(&mut T) -> io::Result<R>,
) -> io::Result<R> {
REACTOR.fetch(self.as_fd().as_raw_fd(), Event::Write)?;
let mut op = op;
loop {
match op(self.get_mut()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
}
optimistic(self.writable()).await?;
}
}
}
impl<T: AsFd> AsRef<T> for Async<T> {
fn as_ref(&self) -> &T {
self.io.as_ref().unwrap()
}
}
impl<T: AsFd> Drop for Async<T> {
fn drop(&mut self) {
if let Some(io) = &self.io {
REACTOR.deregister(io.as_fd().as_raw_fd()).ok();
}
}
}
pub unsafe trait IoSafe {}
unsafe impl<T: ?Sized> IoSafe for &T {}
unsafe impl IoSafe for std::fs::File {}
unsafe impl IoSafe for std::io::Stderr {}
unsafe impl IoSafe for std::io::Stdin {}
unsafe impl IoSafe for std::io::Stdout {}
unsafe impl IoSafe for std::io::StderrLock<'_> {}
unsafe impl IoSafe for std::io::StdinLock<'_> {}
unsafe impl IoSafe for std::io::StdoutLock<'_> {}
unsafe impl IoSafe for std::net::TcpStream {}
unsafe impl IoSafe for std::process::ChildStdin {}
unsafe impl IoSafe for std::process::ChildStdout {}
unsafe impl IoSafe for std::process::ChildStderr {}
unsafe impl<T: IoSafe + Read> IoSafe for std::io::BufReader<T> {}
unsafe impl<T: IoSafe + Write> IoSafe for std::io::BufWriter<T> {}
unsafe impl<T: IoSafe + Write> IoSafe for std::io::LineWriter<T> {}
unsafe impl<T: IoSafe + ?Sized> IoSafe for &mut T {}
unsafe impl<T: Clone + IoSafe + ?Sized> IoSafe for std::borrow::Cow<'_, T> {}
#[cfg(feature = "futures-io")]
impl<T: AsFd + IoSafe + Read> futures_io::AsyncRead for Async<T> {
fn poll_read(
mut self: core::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
loop {
match unsafe { (*self).get_mut() }.read(buf) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
ready!(self.poll_readable(cx))?;
}
}
fn poll_read_vectored(
mut self: core::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
loop {
match unsafe { (*self).get_mut() }.read_vectored(bufs) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
ready!(self.poll_readable(cx))?;
}
}
}
#[cfg(feature = "futures-io")]
impl<T: AsFd> futures_io::AsyncRead for &Async<T>
where
for<'a> &'a T: Read,
{
fn poll_read(
self: core::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
loop {
match (*self).get_ref().read(buf) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
ready!(self.poll_readable(cx))?;
}
}
fn poll_read_vectored(
self: core::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [std::io::IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
loop {
match (*self).get_ref().read_vectored(bufs) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
ready!(self.poll_readable(cx))?;
}
}
}
#[cfg(feature = "futures-io")]
impl<T: AsFd + IoSafe + Write> futures_io::AsyncWrite for Async<T> {
fn poll_write(
mut self: core::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
match unsafe { (*self).get_mut() }.write(buf) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
ready!(self.poll_writable(cx))?;
}
}
fn poll_write_vectored(
mut self: core::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
loop {
match unsafe { (*self).get_mut() }.write_vectored(bufs) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
ready!(self.poll_writable(cx))?;
}
}
fn poll_flush(
mut self: core::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
loop {
match unsafe { (*self).get_mut() }.flush() {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
ready!(self.poll_writable(cx))?;
}
}
fn poll_close(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.poll_flush(cx)
}
}
#[cfg(feature = "futures-io")]
impl<T: AsFd> futures_io::AsyncWrite for &Async<T>
where
for<'a> &'a T: Write,
{
fn poll_write(
self: core::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
match (*self).get_ref().write(buf) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
ready!(self.poll_writable(cx))?;
}
}
fn poll_write_vectored(
self: core::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
loop {
match (*self).get_ref().write_vectored(bufs) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
ready!(self.poll_writable(cx))?;
}
}
fn poll_flush(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
match (*self).get_ref().flush() {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
ready!(self.poll_writable(cx))?;
}
}
fn poll_close(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.poll_flush(cx)
}
}
impl Async<TcpListener> {
pub fn bind<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<TcpListener>> {
let addr = addr.into();
Async::new(TcpListener::bind(addr)?)
}
pub async fn accept(&self) -> io::Result<(Async<TcpStream>, SocketAddr)> {
let (stream, addr) = self.read_with(|io| io.accept()).await?;
Ok((Async::new(stream)?, addr))
}
#[cfg(feature = "futures-lite")]
pub fn incoming(
&self,
) -> impl futures_lite::Stream<Item = io::Result<Async<TcpStream>>> + Send + '_ {
futures_lite::stream::unfold(self, |listener| async move {
let res = listener.accept().await.map(|(stream, _)| stream);
Some((res, listener))
})
}
}
impl TryFrom<std::net::TcpListener> for Async<std::net::TcpListener> {
type Error = io::Error;
fn try_from(listener: std::net::TcpListener) -> io::Result<Self> {
Async::new(listener)
}
}
impl Async<TcpStream> {
pub async fn connect<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<TcpStream>> {
let addr = addr.into();
let socket = match addr {
SocketAddr::V4(v4) => {
let addr = sys::sockaddr_in {
sin_family: sys::AF_INET as _,
sin_port: u16::to_be(v4.port()),
sin_addr: sys::in_addr {
s_addr: u32::from_ne_bytes(v4.ip().octets()),
},
#[cfg(target_os = "espidf")]
sin_len: Default::default(),
sin_zero: Default::default(),
};
connect(
&addr as *const _ as *const _,
core::mem::size_of_val(&addr),
sys::AF_INET,
sys::SOCK_STREAM,
0,
)
}
SocketAddr::V6(v6) => {
let addr = sys::sockaddr_in6 {
sin6_family: sys::AF_INET6 as _,
sin6_port: u16::to_be(v6.port()),
sin6_flowinfo: 0,
sin6_addr: sys::in6_addr {
s6_addr: v6.ip().octets(),
},
sin6_scope_id: 0,
#[cfg(target_os = "espidf")]
sin6_len: Default::default(),
};
connect(
&addr as *const _ as *const _,
core::mem::size_of_val(&addr),
sys::AF_INET6,
sys::SOCK_STREAM,
6,
)
}
}?;
let stream = Async::new_nonblocking(TcpStream::from(socket))?;
stream.writable().await?;
match stream.get_ref().take_error()? {
None => Ok(stream),
Some(err) => Err(err),
}
}
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.read_with(|io| io.peek(buf)).await
}
}
impl TryFrom<std::net::TcpStream> for Async<std::net::TcpStream> {
type Error = io::Error;
fn try_from(stream: std::net::TcpStream) -> io::Result<Self> {
Async::new(stream)
}
}
impl Async<UdpSocket> {
pub fn bind<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<UdpSocket>> {
let addr = addr.into();
Async::new(UdpSocket::bind(addr)?)
}
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.read_with(|io| io.recv_from(buf)).await
}
pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.read_with(|io| io.peek_from(buf)).await
}
pub async fn send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<usize> {
let addr = addr.into();
self.write_with(|io| io.send_to(buf, addr)).await
}
pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
self.read_with(|io| io.recv(buf)).await
}
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.read_with(|io| io.peek(buf)).await
}
pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
self.write_with(|io| io.send(buf)).await
}
}
impl TryFrom<std::net::UdpSocket> for Async<std::net::UdpSocket> {
type Error = io::Error;
fn try_from(socket: std::net::UdpSocket) -> io::Result<Self> {
Async::new(socket)
}
}
async fn optimistic(fut: impl Future<Output = io::Result<()>>) -> io::Result<()> {
let mut polled = false;
let mut fut = pin!(fut);
poll_fn(move |cx| {
if !polled {
polled = true;
fut.as_mut().poll(cx)
} else {
Poll::Ready(Ok(()))
}
})
.await
}
fn connect(
addr: *const sys::sockaddr,
addr_len: usize,
domain: sys::c_int,
ty: sys::c_int,
protocol: sys::c_int,
) -> io::Result<OwnedFd> {
let socket = unsafe { OwnedFd::from_raw_fd(syscall_los!(sys::socket(domain, ty, protocol))?) };
set_nonblocking(socket.as_fd())?;
syscall_los_eagain!(unsafe { sys::connect(socket.as_raw_fd(), addr, addr_len as _) })?;
Ok(socket)
}
fn set_nonblocking(fd: BorrowedFd) -> io::Result<()> {
let previous = unsafe { sys::fcntl(fd.as_raw_fd(), sys::F_GETFL) };
let new = previous | sys::O_NONBLOCK;
if new != previous {
syscall!(unsafe { sys::fcntl(fd.as_raw_fd(), sys::F_SETFL, new) })?;
}
Ok(())
}