use std::io::Result;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tokio::io::{split, AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
pub struct PrefixedRead<S>
{
prefix: Option<Vec<u8>>,
read: ReadHalf<S>,
}
pub struct PrefixedStream<S>
{
read: PrefixedRead<S>,
write: WriteHalf<S>,
}
impl<S> PrefixedStream<S>
where
S: AsyncWrite + AsyncRead,
{
pub fn new(prefix: Vec<u8>, stream: S) -> Self
{
let (read, write) = split(stream);
log::trace!("Prefix: {:?}", prefix);
Self {
read: PrefixedRead {
prefix: Some(prefix),
read,
},
write,
}
}
pub fn into_split(self) -> (PrefixedRead<S>, WriteHalf<S>)
{
(self.read, self.write)
}
}
impl<S: AsyncWrite + AsyncRead + Unpin> PrefixedStream<S> {}
impl<S: AsyncRead> AsyncRead for PrefixedStream<S>
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<Result<()>>
{
let inner_pin = unsafe { self.map_unchecked_mut(|s| &mut s.read) };
inner_pin.poll_read(cx, buf)
}
}
impl<S: AsyncRead> AsyncRead for PrefixedRead<S>
{
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf)
-> Poll<Result<()>>
{
if let Some(p) = &mut self.prefix {
if p.is_empty() {
self.prefix = None;
} else if p.len() <= buf.remaining() {
buf.put_slice(p);
self.prefix = None;
return Poll::Ready(Ok(()));
} else {
let mut taken = p.split_off(buf.remaining());
std::mem::swap(&mut taken, p);
buf.put_slice(&taken);
return Poll::Ready(Ok(()));
}
}
let inner_pin = unsafe { self.map_unchecked_mut(|s| &mut s.read) };
inner_pin.poll_read(cx, buf)
}
}
impl<S: AsyncWrite> AsyncWrite for PrefixedStream<S>
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize>>
{
let inner_pin = unsafe { self.map_unchecked_mut(|s| &mut s.write) };
inner_pin.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>>
{
let inner_pin = unsafe { self.map_unchecked_mut(|s| &mut s.write) };
inner_pin.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>>
{
let inner_pin = unsafe { self.map_unchecked_mut(|s| &mut s.write) };
inner_pin.poll_shutdown(cx)
}
}