use futures::prelude::*;
use std::{error, fmt, io};
pub async fn write_one(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<[u8]>)
-> Result<(), io::Error>
{
write_varint(socket, data.as_ref().len()).await?;
socket.write_all(data.as_ref()).await?;
socket.close().await?;
Ok(())
}
pub async fn write_with_len_prefix(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<[u8]>)
-> Result<(), io::Error>
{
write_varint(socket, data.as_ref().len()).await?;
socket.write_all(data.as_ref()).await?;
socket.flush().await?;
Ok(())
}
pub async fn write_varint(socket: &mut (impl AsyncWrite + Unpin), len: usize)
-> Result<(), io::Error>
{
let mut len_data = unsigned_varint::encode::usize_buffer();
let encoded_len = unsigned_varint::encode::usize(len, &mut len_data).len();
socket.write_all(&len_data[..encoded_len]).await?;
Ok(())
}
pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result<usize, io::Error> {
let mut buffer = unsigned_varint::encode::usize_buffer();
let mut buffer_len = 0;
loop {
match socket.read(&mut buffer[buffer_len..buffer_len+1]).await? {
0 => {
if buffer_len == 0 {
return Ok(0);
} else {
return Err(io::ErrorKind::UnexpectedEof.into());
}
}
n => debug_assert_eq!(n, 1),
}
buffer_len += 1;
match unsigned_varint::decode::usize(&buffer[..buffer_len]) {
Ok((len, _)) => return Ok(len),
Err(unsigned_varint::decode::Error::Overflow) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"overflow in variable-length integer"
));
}
Err(_) => {}
}
}
}
pub async fn read_one(socket: &mut (impl AsyncRead + Unpin), max_size: usize)
-> Result<Vec<u8>, ReadOneError>
{
let len = read_varint(socket).await?;
if len > max_size {
return Err(ReadOneError::TooLarge {
requested: len,
max: max_size,
});
}
let mut buf = vec![0; len];
socket.read_exact(&mut buf).await?;
Ok(buf)
}
#[derive(Debug)]
pub enum ReadOneError {
Io(std::io::Error),
TooLarge {
requested: usize,
max: usize,
},
}
impl From<std::io::Error> for ReadOneError {
fn from(err: std::io::Error) -> ReadOneError {
ReadOneError::Io(err)
}
}
impl fmt::Display for ReadOneError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
ReadOneError::Io(ref err) => write!(f, "{}", err),
ReadOneError::TooLarge { .. } => write!(f, "Received data size over maximum"),
}
}
}
impl error::Error for ReadOneError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match *self {
ReadOneError::Io(ref err) => Some(err),
ReadOneError::TooLarge { .. } => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn write_one_works() {
let data = (0..rand::random::<usize>() % 10_000)
.map(|_| rand::random::<u8>())
.collect::<Vec<_>>();
let mut out = vec![0; 10_000];
futures::executor::block_on(
write_one(&mut futures::io::Cursor::new(&mut out[..]), data.clone())
).unwrap();
let (out_len, out_data) = unsigned_varint::decode::usize(&out).unwrap();
assert_eq!(out_len, data.len());
assert_eq!(&out_data[..out_len], &data[..]);
}
}