use std::pin::Pin;
use std::task::{Context, Poll};
fn ret_reduce(
ret: Poll<std::io::Result<usize>>,
reached_limit: bool,
) -> Poll<std::io::Result<bool>> {
ret.map(|x| x.map(|_| reached_limit))
}
#[derive(Debug)]
pub struct PollResult {
pub delta: usize,
pub ret: Poll<std::io::Result<bool>>,
}
pub fn poll_read<I, O>(
mut input: Pin<&mut I>,
output: &mut O,
cx: &mut Context<'_>,
delta_limit: usize,
) -> PollResult
where
I: futures_io::AsyncRead,
O: bytes::BufMut,
{
let mut rdbuf = [0u8; 8192];
let start = output.remaining_mut();
loop {
let buflim = *[
rdbuf.len(),
output.remaining_mut(),
delta_limit - (start - output.remaining_mut()),
]
.iter()
.min()
.unwrap();
match input.as_mut().poll_read(cx, &mut rdbuf[..buflim]) {
Poll::Ready(Ok(n)) if n != 0 => output.put_slice(&rdbuf[..n]),
ret => {
return PollResult {
delta: start - output.remaining_mut(),
ret: ret_reduce(ret, buflim == 0),
}
}
}
}
}
pub fn poll_write<I, O>(input: &mut I, mut output: Pin<&mut O>, cx: &mut Context<'_>) -> PollResult
where
I: bytes::Buf,
O: futures_io::AsyncWrite,
{
let start = input.remaining();
loop {
match output.as_mut().poll_write(cx, input.bytes()) {
Poll::Ready(Ok(n)) if n != 0 => input.advance(n),
ret => {
return PollResult {
delta: start - input.remaining(),
ret: ret_reduce(ret, !input.has_remaining()),
}
}
}
}
}