use std::fmt::Debug;
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use crate::{AsyncRead, AsyncReadAt, AsyncWrite, AsyncWriteAt, IoResult, sync::bilock::BiLock};
pub fn split<T: AsyncRead + AsyncWrite>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) {
Split::new(stream).split()
}
pub trait Splittable {
type ReadHalf;
type WriteHalf;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf);
}
impl<R, W> Splittable for (R, W) {
type ReadHalf = R;
type WriteHalf = W;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
self
}
}
#[derive(Debug)]
pub struct Split<T>(BiLock<T>, BiLock<T>);
impl<T> Split<T> {
pub fn new(stream: T) -> Self {
let (l, r) = BiLock::new(stream);
Split(l, r)
}
}
impl<T: AsyncRead + AsyncWrite> Splittable for Split<T> {
type ReadHalf = ReadHalf<T>;
type WriteHalf = WriteHalf<T>;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
(ReadHalf(self.0), WriteHalf(self.1))
}
}
#[derive(Debug)]
pub struct ReadHalf<T>(BiLock<T>);
impl<T: Unpin> ReadHalf<T> {
#[track_caller]
pub fn unsplit(self, w: WriteHalf<T>) -> T {
self.0.try_join(w.0).expect("Not the same pair")
}
#[track_caller]
pub fn try_unsplit(self, w: WriteHalf<T>) -> Option<T> {
self.0.try_join(w.0)
}
}
impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.lock().await.read(buf).await
}
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
self.0.lock().await.read_vectored(buf).await
}
}
impl<T: AsyncReadAt> AsyncReadAt for ReadHalf<T> {
async fn read_at<B: IoBufMut>(&self, buf: B, pos: u64) -> BufResult<usize, B> {
self.0.lock().await.read_at(buf, pos).await
}
}
#[derive(Debug)]
pub struct WriteHalf<T>(BiLock<T>);
impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.lock().await.write(buf).await
}
async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.lock().await.write_vectored(buf).await
}
async fn flush(&mut self) -> IoResult<()> {
self.0.lock().await.flush().await
}
async fn shutdown(&mut self) -> IoResult<()> {
self.0.lock().await.shutdown().await
}
}
impl<T: AsyncWriteAt> AsyncWriteAt for WriteHalf<T> {
async fn write_at<B: IoBuf>(&mut self, buf: B, pos: u64) -> BufResult<usize, B> {
self.0.lock().await.write_at(buf, pos).await
}
async fn write_vectored_at<B: IoVectoredBuf>(
&mut self,
buf: B,
pos: u64,
) -> BufResult<usize, B> {
self.0.lock().await.write_vectored_at(buf, pos).await
}
}