use std::io;
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub async fn read_length_prefixed(
socket: &mut (impl AsyncRead + Unpin),
max_size: usize,
) -> io::Result<Vec<u8>> {
let len = read_varint(socket).await?;
if len > max_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Received data size ({} bytes) exceeds maximum ({} bytes)",
len, max_size
),
));
}
let mut buf = vec![0; len];
socket.read_exact(&mut buf).await?;
Ok(buf)
}
pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result<usize, io::Error> {
let mut buffer = unsigned_varint::encode::usize_buffer();
let mut buffer_len = 0;
loop {
match socket.read(&mut buffer[buffer_len..buffer_len + 1]).await? {
0 => {
if buffer_len == 0 {
return Ok(0);
} else {
return Err(io::ErrorKind::UnexpectedEof.into());
}
}
n => debug_assert_eq!(n, 1),
}
buffer_len += 1;
match unsigned_varint::decode::usize(&buffer[..buffer_len]) {
Ok((len, _)) => return Ok(len),
Err(unsigned_varint::decode::Error::Overflow) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"overflow in variable-length integer",
));
}
Err(_) => {}
}
}
}
pub async fn write_length_prefixed(
socket: &mut (impl AsyncWrite + Unpin),
data: impl AsRef<[u8]>,
) -> Result<(), io::Error> {
write_varint(socket, data.as_ref().len()).await?;
socket.write_all(data.as_ref()).await?;
socket.flush().await?;
Ok(())
}
pub async fn write_varint(
socket: &mut (impl AsyncWrite + Unpin),
len: usize,
) -> Result<(), io::Error> {
let mut len_data = unsigned_varint::encode::usize_buffer();
let encoded_len = unsigned_varint::encode::usize(len, &mut len_data).len();
socket.write_all(&len_data[..encoded_len]).await?;
Ok(())
}