use bincode::{
BorrowDecode,
Encode,
borrow_decode_from_slice,
config,
encode_to_vec,
error::{DecodeError, EncodeError},
};
use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const MAX_PREAMBLE_LEN: usize = 1024;
pub trait Preamble: for<'de> BorrowDecode<'de, ()> + Send + Sync + 'static {}
impl<T> Preamble for T where for<'de> T: BorrowDecode<'de, ()> + Send + Sync + 'static {}
async fn read_more<R>(
reader: &mut R,
buf: &mut Vec<u8>,
additional: usize,
) -> Result<(), DecodeError>
where
R: AsyncRead + Unpin,
{
let start = buf.len();
if start + additional > MAX_PREAMBLE_LEN {
return Err(DecodeError::Other("preamble too long"));
}
buf.resize(start + additional, 0);
let mut read = 0;
while read < additional {
let range_start = start + read;
let range_end = start + additional;
let chunk = buf
.get_mut(range_start..range_end)
.ok_or(DecodeError::Other("preamble buffer range invalid"))?;
match reader.read(chunk).await {
Ok(0) => {
return Err(DecodeError::Io {
inner: io::Error::from(io::ErrorKind::UnexpectedEof),
additional: additional - read,
});
}
Ok(n) => read += n,
Err(inner) => {
return Err(DecodeError::Io {
inner,
additional: additional - read,
});
}
}
}
Ok(())
}
pub async fn read_preamble<R, T>(reader: &mut R) -> Result<(T, Vec<u8>), DecodeError>
where
R: AsyncRead + Unpin,
for<'de> T: BorrowDecode<'de, ()>,
{
let mut buf = Vec::new();
read_more(reader, &mut buf, 1).await?;
let config = config::standard()
.with_big_endian()
.with_fixed_int_encoding();
loop {
match borrow_decode_from_slice::<T, _>(&buf, config) {
Ok((value, consumed)) => {
let leftover = buf.split_off(consumed);
return Ok((value, leftover));
}
Err(DecodeError::UnexpectedEnd { additional }) => {
read_more(reader, &mut buf, additional).await?;
}
Err(e) => return Err(e),
}
}
}
pub async fn write_preamble<W, T>(writer: &mut W, preamble: &T) -> Result<(), EncodeError>
where
W: AsyncWrite + Unpin,
T: Encode,
{
let config = config::standard()
.with_big_endian()
.with_fixed_int_encoding();
let bytes = encode_to_vec(preamble, config)?;
writer
.write_all(&bytes)
.await
.map_err(|inner| EncodeError::Io { inner, index: 0 })?;
writer.flush().await.map_err(|inner| EncodeError::Io {
inner,
index: bytes.len(),
})?;
Ok(())
}