use std::io;
use std::sync::Arc;
use futures_util::lock::Mutex as AsyncMutex;
use super::{AsyncRead, AsyncWrite};
use crate::io::{IoBuf, IoBufMut, IoBufWithCursor};
pub async fn copy<R, W>(reader: &mut R, writer: &mut W) -> Result<u64, io::Error>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
let mut buffer = vec![0u8; 8192];
buffer.clear();
let mut copied = 0u64;
loop {
let (read, returned_buf) = reader.read(buffer).await;
let read = read?;
if read == 0 {
break;
}
let mut cursor_buf = IoBufWithCursor::new(returned_buf);
while cursor_buf.buf_len() > 0 {
let (w, mut returned_buf) = writer.write(cursor_buf).await;
let w = w?;
if w == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write whole buffer",
));
}
returned_buf.advance(w);
cursor_buf = returned_buf;
}
buffer = cursor_buf.into_inner();
unsafe {
buffer.set_buf_init(0);
} copied = copied.saturating_add(read as u64);
}
writer.flush().await?;
Ok(copied)
}
pub struct ReadHalf<T> {
inner: Arc<AsyncMutex<T>>,
}
pub struct WriteHalf<T> {
inner: Arc<AsyncMutex<T>>,
}
pub fn split<T>(io: T) -> (ReadHalf<T>, WriteHalf<T>)
where
T: AsyncRead + AsyncWrite + 'static,
{
let inner = Arc::new(AsyncMutex::new(io));
(
ReadHalf {
inner: inner.clone(),
},
WriteHalf { inner },
)
}
impl<T> ReadHalf<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
pub fn into_inner(self) -> Arc<AsyncMutex<T>> {
self.inner
}
}
impl<T> WriteHalf<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
pub fn into_inner(self) -> Arc<AsyncMutex<T>> {
self.inner
}
}
impl<T> AsyncRead for ReadHalf<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
async fn read<B: crate::io::IoBufMut>(&mut self, buf: B) -> (Result<usize, io::Error>, B) {
let mut guard = self.inner.lock().await;
(*guard).read(buf).await
}
}
impl<T> AsyncWrite for WriteHalf<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
async fn write<B: crate::io::IoBuf>(&mut self, buf: B) -> (Result<usize, io::Error>, B) {
let mut guard = self.inner.lock().await;
(*guard).write(buf).await
}
async fn flush(&mut self) -> Result<(), io::Error> {
let mut guard = self.inner.lock().await;
(*guard).flush().await
}
}
impl<R: AsyncRead + ?Sized> AsyncRead for Box<R> {
#[inline]
async fn read<B: crate::io::IoBufMut>(&mut self, buf: B) -> (Result<usize, std::io::Error>, B) {
(**self).read(buf).await
}
}
impl<R: AsyncRead + ?Sized> AsyncRead for &mut R {
#[inline]
async fn read<B: crate::io::IoBufMut>(&mut self, buf: B) -> (Result<usize, std::io::Error>, B) {
(**self).read(buf).await
}
}
impl<W: AsyncWrite + ?Sized> AsyncWrite for Box<W> {
#[inline]
async fn write<B: crate::io::IoBuf>(&mut self, buf: B) -> (Result<usize, std::io::Error>, B) {
(**self).write(buf).await
}
#[inline]
async fn flush(&mut self) -> Result<(), std::io::Error> {
(**self).flush().await
}
}
impl<W: AsyncWrite + ?Sized> AsyncWrite for &mut W {
#[inline]
async fn write<B: crate::io::IoBuf>(&mut self, buf: B) -> (Result<usize, std::io::Error>, B) {
(**self).write(buf).await
}
#[inline]
async fn flush(&mut self) -> Result<(), std::io::Error> {
(**self).flush().await
}
}
pub async fn copy_bidirectional<A, B>(a: A, b: B) -> Result<(u64, u64), io::Error>
where
A: AsyncRead + AsyncWrite + 'static,
B: AsyncRead + AsyncWrite + 'static,
{
let (mut a_r, mut a_w) = split(a);
let (mut b_r, mut b_w) = split(b);
let f1 = copy(&mut a_r, &mut b_w);
let f2 = copy(&mut b_r, &mut a_w);
let (res1, res2) = futures_util::future::join(f1, f2).await;
let n1 = res1?;
let n2 = res2?;
Ok((n1, n2))
}