use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[derive(Debug)]
pub struct PrefixedStream<S> {
prefix: Bytes,
pos: usize,
inner: S,
}
impl<S> PrefixedStream<S> {
pub fn new(prefix: Bytes, inner: S) -> Self {
Self {
prefix,
pos: 0,
inner,
}
}
pub fn prefix_remaining(&self) -> usize {
self.prefix.len().saturating_sub(self.pos)
}
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S: AsyncRead + Unpin> AsyncRead for PrefixedStream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if self.pos < self.prefix.len() {
let remaining = &self.prefix[self.pos..];
let to_copy = remaining.len().min(buf.remaining());
buf.put_slice(&remaining[..to_copy]);
self.pos += to_copy;
return Poll::Ready(Ok(()));
}
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for PrefixedStream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
data: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, data)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
#[tokio::test]
async fn test_prefixed_stream_read() {
let (mut client, server) = duplex(1024);
let prefix = Bytes::from_static(b"prefix:");
let mut prefixed = PrefixedStream::new(prefix, server);
client.write_all(b"suffix").await.unwrap();
drop(client);
let mut buf = vec![0u8; 1024];
let mut total = Vec::new();
loop {
let n = prefixed.read(&mut buf).await.unwrap();
if n == 0 {
break;
}
total.extend_from_slice(&buf[..n]);
}
assert_eq!(total, b"prefix:suffix");
}
#[tokio::test]
async fn test_prefixed_stream_partial_read() {
let (_client, server) = duplex(1024);
let prefix = Bytes::from_static(b"hello world");
let mut prefixed = PrefixedStream::new(prefix, server);
let mut buf = [0u8; 5];
let n = prefixed.read(&mut buf).await.unwrap();
assert_eq!(n, 5);
assert_eq!(&buf, b"hello");
let n = prefixed.read(&mut buf).await.unwrap();
assert_eq!(n, 5);
assert_eq!(&buf, b" worl");
let n = prefixed.read(&mut buf).await.unwrap();
assert_eq!(n, 1);
assert_eq!(&buf[..1], b"d");
}
#[tokio::test]
async fn test_prefixed_stream_write_passthrough() {
let (mut client, server) = duplex(1024);
let prefix = Bytes::from_static(b"prefix");
let mut prefixed = PrefixedStream::new(prefix, server);
prefixed.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 10];
let n = client.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"hello");
}
#[tokio::test]
async fn test_prefix_remaining() {
let (_client, server) = duplex(1024);
let prefix = Bytes::from_static(b"hello");
let mut prefixed = PrefixedStream::new(prefix, server);
assert_eq!(prefixed.prefix_remaining(), 5);
let mut buf = [0u8; 3];
let n = prefixed.read(&mut buf).await.unwrap();
assert_eq!(n, 3);
assert_eq!(prefixed.prefix_remaining(), 2);
}
}