use super::copy::CopyBuffer;
use crate::io::{AsyncRead, AsyncWrite};
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
enum TransferState {
Running(CopyBuffer),
ShuttingDown(u64),
Done(u64),
}
struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
a: &'a mut A,
b: &'a mut B,
a_to_b: TransferState,
b_to_a: TransferState,
}
fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
r: &mut A,
w: &mut B,
) -> Poll<io::Result<u64>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut r = Pin::new(r);
let mut w = Pin::new(w);
loop {
match state {
TransferState::Running(buf) => {
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
ready!(w.as_mut().poll_shutdown(cx))?;
*state = TransferState::Done(*count);
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
}
}
}
impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<(u64, u64)>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let CopyBidirectional {
a,
b,
a_to_b,
b_to_a,
} = &mut *self;
let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?;
let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?;
let a_to_b = ready!(a_to_b);
let b_to_a = ready!(b_to_a);
Poll::Ready(Ok((a_to_b, b_to_a)))
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
pub async fn copy_bidirectional<A, B>(a: &mut A, b: &mut B) -> Result<(u64, u64), std::io::Error>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
CopyBidirectional {
a,
b,
a_to_b: TransferState::Running(CopyBuffer::new()),
b_to_a: TransferState::Running(CopyBuffer::new()),
}
.await
}