use std::io::{self, Read, Write};
use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd};
#[cfg(feature = "io_timeout")]
use std::time::Duration;
use self::io_impl::co_io_err::Error;
use self::io_impl::net as net_impl;
use super::from_nix_error;
use crate::io as io_impl;
#[cfg(feature = "io_timeout")]
use crate::sync::atomic_dur::AtomicDuration;
use crate::yield_now::yield_with_io;
use nix::sys::socket::{recv, MsgFlags};
fn set_nonblocking<T: AsRawFd>(fd: &T, nb: bool) -> io::Result<()> {
unsafe {
let fd = fd.as_raw_fd();
let r = libc::fcntl(fd, libc::F_GETFL);
if r == -1 {
return Err(io::Error::last_os_error());
}
let r = if nb {
libc::fcntl(fd, libc::F_SETFL, r | libc::O_NONBLOCK)
} else {
libc::fcntl(fd, libc::F_SETFL, r & !libc::O_NONBLOCK)
};
if r == -1 {
return Err(io::Error::last_os_error());
}
Ok(())
}
}
#[derive(Debug)]
pub struct CoIo<T: AsRawFd> {
inner: T,
io: io_impl::IoData,
#[cfg(feature = "io_timeout")]
read_timeout: AtomicDuration,
#[cfg(feature = "io_timeout")]
write_timeout: AtomicDuration,
}
impl<T: AsRawFd> io_impl::AsIoData for CoIo<T> {
fn as_io_data(&self) -> &io_impl::IoData {
&self.io
}
}
impl<T: AsRawFd> AsRawFd for CoIo<T> {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
impl<T: AsRawFd + IntoRawFd> IntoRawFd for CoIo<T> {
fn into_raw_fd(self) -> RawFd {
self.inner.into_raw_fd()
}
}
impl<T: AsRawFd> CoIo<T> {
pub fn new(io: T) -> Result<Self, Error<T>> {
let io_data = match io_impl::add_socket(&io) {
Ok(o) => o,
Err(e) => return Err(Error::new(e, io)),
};
match set_nonblocking(&io, true) {
Ok(_) => {}
Err(e) => return Err(Error::new(e, io)),
}
Ok(CoIo {
inner: io,
io: io_data,
#[cfg(feature = "io_timeout")]
read_timeout: AtomicDuration::new(None),
#[cfg(feature = "io_timeout")]
write_timeout: AtomicDuration::new(None),
})
}
pub(crate) fn from_raw(io: T, io_data: io_impl::IoData) -> Self {
CoIo {
inner: io,
io: io_data,
#[cfg(feature = "io_timeout")]
read_timeout: AtomicDuration::new(None),
#[cfg(feature = "io_timeout")]
write_timeout: AtomicDuration::new(None),
}
}
pub(crate) fn io_reset(&self) {
self.io.reset();
}
#[inline]
pub fn inner(&self) -> &T {
&self.inner
}
#[inline]
pub fn inner_mut(&mut self) -> &mut T {
&mut self.inner
}
pub fn into_inner(self) -> T {
self.inner
}
#[cfg(feature = "io_timeout")]
pub fn read_timeout(&self) -> io::Result<Option<Duration>> {
Ok(self.read_timeout.get())
}
#[cfg(feature = "io_timeout")]
pub fn write_timeout(&self) -> io::Result<Option<Duration>> {
Ok(self.write_timeout.get())
}
#[cfg(feature = "io_timeout")]
pub fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.read_timeout.store(dur);
Ok(())
}
#[cfg(feature = "io_timeout")]
pub fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.write_timeout.store(dur);
Ok(())
}
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.io.reset();
match recv(self.io.fd, buf, MsgFlags::MSG_PEEK) {
Ok(n) => return Ok(n),
Err(e) => {
if e == nix::errno::Errno::EAGAIN {
} else {
return Err(from_nix_error(e));
}
}
}
let mut reader = net_impl::SocketPeek::new(
self,
buf,
#[cfg(feature = "io_timeout")]
self.read_timeout.get(),
);
yield_with_io(&reader, reader.is_coroutine);
reader.done()
}
}
impl<T: AsRawFd + Read> Read for CoIo<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.io.reset();
match self.inner.read(buf) {
Ok(n) => return Ok(n),
Err(e) => {
let raw_err = e.raw_os_error();
if raw_err == Some(libc::EAGAIN) || raw_err == Some(libc::EWOULDBLOCK) {
} else {
return Err(e);
}
}
}
let mut reader = net_impl::SocketRead::new(
self,
buf,
#[cfg(feature = "io_timeout")]
self.read_timeout.get(),
);
yield_with_io(&reader, reader.is_coroutine);
reader.done()
}
}
impl<T: AsRawFd + Write> Write for CoIo<T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.io.reset();
match self.inner.write(buf) {
Ok(n) => return Ok(n),
Err(e) => {
let raw_err = e.raw_os_error();
if raw_err == Some(libc::EAGAIN) || raw_err == Some(libc::EWOULDBLOCK) {
} else {
return Err(e);
}
}
}
let mut writer = net_impl::SocketWrite::new(
self,
buf,
#[cfg(feature = "io_timeout")]
self.write_timeout.get(),
);
yield_with_io(&writer, writer.is_coroutine);
writer.done()
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compile_co_io() {
#[derive(Debug)]
struct Fd {
file: std::net::UdpSocket,
}
impl Fd {
fn new() -> Self {
Fd {
file: std::net::UdpSocket::bind(("127.0.0.1", 9765)).unwrap(),
}
}
}
impl AsRawFd for Fd {
fn as_raw_fd(&self) -> RawFd {
self.file.as_raw_fd()
}
}
impl Read for Fd {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
buf.fill(0x55);
Ok(buf.len())
}
}
let a = Fd::new();
let mut io = CoIo::new(a).unwrap();
let mut buf = [0u8; 100];
io.read_exact(&mut buf).unwrap();
assert_eq!(buf, [0x55u8; 100]);
}
}