use std::future::Future;
use std::io;
use monoio::io::{CancelHandle, CancelableAsyncReadRent, Canceller};
#[allow(unused)] pub struct IoCanceller {
inner: Canceller,
}
impl std::fmt::Debug for IoCanceller {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IoCanceller").finish()
}
}
impl IoCanceller {
pub fn new() -> Self {
Self {
inner: Canceller::new(),
}
}
#[allow(dead_code)]
pub fn handle(&self) -> CancelHandle {
self.inner.handle()
}
#[allow(dead_code)]
pub fn cancel(self) -> Self {
Self {
inner: self.inner.cancel(),
}
}
}
impl Default for IoCanceller {
fn default() -> Self {
Self::new()
}
}
#[allow(unused)] pub type CancellableReadResult<T> = (io::Result<usize>, T);
#[allow(unused)] pub async fn read_with_timeout<T>(
stream: &mut impl CancelableAsyncReadRent,
buf: T,
timeout: std::time::Duration,
) -> CancellableReadResult<T>
where
T: monoio::buf::IoBufMut,
{
let canceller = Canceller::new();
let handle = canceller.handle();
let recv_fut = stream.cancelable_read(buf, handle);
let mut recv_fut = std::pin::pin!(recv_fut);
monoio::select! {
result = &mut recv_fut => result,
_ = monoio::time::sleep(timeout) => {
let _ = canceller.cancel();
let (_, buf) = recv_fut.await;
(
Err(io::Error::new(
io::ErrorKind::TimedOut,
"read timeout",
)),
buf,
)
}
}
}
#[allow(dead_code)]
pub async fn timeout<F, T>(duration: std::time::Duration, future: F) -> Result<T, TimeoutError>
where
F: Future<Output = T>,
{
monoio::select! {
result = future => Ok(result),
_ = monoio::time::sleep(duration) => Err(TimeoutError),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub struct TimeoutError;
impl std::fmt::Display for TimeoutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "operation timed out")
}
}
impl std::error::Error for TimeoutError {}
#[cfg(test)]
mod tests {
use super::*;
use monoio::io::{AsyncWriteRentExt, CancelableAsyncReadRent};
use monoio::net::{TcpListener, TcpStream};
#[monoio::test]
async fn test_canceller_basic() {
let canceller = IoCanceller::new();
let handle = canceller.handle();
let _handle2 = handle.clone();
let _new_canceller = canceller.cancel();
}
#[monoio::test(enable_timer = true)]
async fn test_timeout_success() {
let result = timeout(std::time::Duration::from_secs(1), async { 42 }).await;
assert_eq!(result, Ok(42));
}
#[monoio::test(enable_timer = true)]
async fn test_timeout_elapsed() {
let start = std::time::Instant::now();
let result = timeout(
std::time::Duration::from_millis(10),
monoio::time::sleep(std::time::Duration::from_secs(1)),
)
.await;
assert!(result.is_err());
assert!(start.elapsed() < std::time::Duration::from_millis(100));
}
#[monoio::test(enable_timer = true)]
async fn test_cancel_read() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let connect_fut = async move { TcpStream::connect(addr).await };
let connect_handle = monoio::spawn(connect_fut);
monoio::time::sleep(std::time::Duration::from_millis(10)).await;
let (_server_stream, _) = listener.accept().await.unwrap();
let mut client_stream = connect_handle.await.unwrap();
let buf = vec![0u8; 1024];
let canceller = Canceller::new();
let handle = canceller.handle();
monoio::spawn(async move {
monoio::time::sleep(std::time::Duration::from_millis(50)).await;
canceller.cancel();
});
let (res, buf) = client_stream.cancelable_read(buf, handle).await;
assert!(res.is_err());
assert_eq!(buf.len(), 1024);
}
#[monoio::test(enable_timer = true)]
async fn test_read_with_timeout_cancelled() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let connect_fut = async move { TcpStream::connect(addr).await };
let connect_handle = monoio::spawn(connect_fut);
monoio::time::sleep(std::time::Duration::from_millis(10)).await;
let (_server_stream, _) = listener.accept().await.unwrap();
let mut client_stream = connect_handle.await.unwrap();
let buf = vec![0u8; 1024];
let start = std::time::Instant::now();
let (res, buf) = read_with_timeout(
&mut client_stream,
buf,
std::time::Duration::from_millis(50),
)
.await;
assert!(res.is_err());
assert_eq!(res.unwrap_err().kind(), io::ErrorKind::TimedOut);
assert_eq!(buf.len(), 1024);
assert!(start.elapsed() >= std::time::Duration::from_millis(40));
}
#[monoio::test(enable_timer = true)]
async fn test_read_with_timeout_success() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
monoio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let _ = stream.write_all(b"hello world").await;
});
let mut client_stream = TcpStream::connect(addr).await.unwrap();
let buf = vec![0u8; 1024];
let (res, buf) =
read_with_timeout(&mut client_stream, buf, std::time::Duration::from_secs(1)).await;
let n = res.expect("read should succeed");
assert_eq!(&buf[..n], b"hello world");
}
}