pub mod command;
pub mod reply;
use bytes::Bytes;
use std::{
ascii,
error::Error,
ffi::{CStr, CString},
fmt::{self, Debug, Display, Formatter, Write},
io::{self, ErrorKind},
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub type Version = u32;
pub const PROTOCOL_VERSION: Version = 6;
#[derive(Clone, Eq, Hash, PartialEq)]
pub struct Message {
pub kind: u8,
pub buffer: Bytes,
}
impl Message {
const MAX_BUFFER_LEN: usize = 1024 * 1024 - 1;
pub fn new(kind: impl Into<u8>, buffer: impl Into<Bytes>) -> Self {
Self {
kind: kind.into(),
buffer: buffer.into(),
}
}
}
impl Debug for Message {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Message")
.field("kind", &Byte(self.kind))
.field("buffer", &self.buffer)
.finish()
}
}
pub(crate) struct Byte(pub u8);
impl Debug for Byte {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str("b'")?;
match self.0 {
b'\0' => f.write_str("\\0")?,
b'"' => f.write_str("\"")?,
byte => {
for c in ascii::escape_default(byte) {
f.write_char(c.into())?;
}
}
}
f.write_str("'")?;
Ok(())
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct TryFromByteError(u8);
impl TryFromByteError {
pub(crate) fn new(byte: u8) -> Self {
Self(byte)
}
pub fn byte(&self) -> u8 {
self.0
}
}
impl Display for TryFromByteError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "failed to convert byte {:?}", Byte(self.0))
}
}
impl Error for TryFromByteError {}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct TryFromIndexError(i32);
impl TryFromIndexError {
pub(crate) fn new(index: i32) -> Self {
Self(index)
}
pub fn index(&self) -> i32 {
self.0
}
}
impl Display for TryFromIndexError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "failed to convert index {}", self.0)
}
}
impl Error for TryFromIndexError {}
pub async fn read<S>(stream: &mut S) -> io::Result<Message>
where
S: AsyncRead + Unpin,
{
let len = stream.read_u32().await?;
let kind = stream.read_u8().await?;
let len = usize::try_from(len)
.expect("unsupported pointer size")
.saturating_sub(1);
if len > Message::MAX_BUFFER_LEN {
return Err(ErrorKind::InvalidData.into());
}
let mut buffer = vec![0; len];
stream.read_exact(&mut buffer).await?;
Ok(Message::new(kind, buffer))
}
pub async fn write<S>(stream: &mut S, msg: Message) -> io::Result<()>
where
S: AsyncWrite + Unpin,
{
let len = msg.buffer.len();
if len > Message::MAX_BUFFER_LEN {
return Err(ErrorKind::InvalidData.into());
}
let len = u32::try_from(len).unwrap().checked_add(1).unwrap();
stream.write_u32(len).await?;
stream.write_u8(msg.kind).await?;
stream.write_all(msg.buffer.as_ref()).await?;
stream.flush().await?;
Ok(())
}
struct NoCStringFoundError;
fn get_c_string(buf: &mut Bytes) -> Result<CString, NoCStringFoundError> {
if let Some(i) = buf.iter().position(|&x| x == 0) {
let b = buf.split_to(i + 1);
return Ok(CStr::from_bytes_with_nul(b.as_ref()).unwrap().into());
}
Err(NoCStringFoundError)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn byte_debug_ok() {
assert_eq!(format!("{:?}", Byte(b'\0')), r"b'\0'");
assert_eq!(format!("{:?}", Byte(b'\n')), r"b'\n'");
assert_eq!(format!("{:?}", Byte(b'\'')), r"b'\''");
assert_eq!(format!("{:?}", Byte(b'"')), r#"b'"'"#);
assert_eq!(format!("{:?}", Byte(b'a')), r"b'a'");
assert_eq!(format!("{:?}", Byte(b' ')), r"b' '");
assert_eq!(format!("{:?}", Byte(b'\x07')), r"b'\x07'");
assert_eq!(format!("{:?}", Byte(b'\xef')), r"b'\xef'");
}
#[test]
fn message_debug_ok() {
assert_eq!(
format!("{:?}", Message::new(b'A', "abc\0")),
r#"Message { kind: b'A', buffer: b"abc\0" }"#
);
assert_eq!(
format!("{:?}", Message::new(b'\0', "")),
r#"Message { kind: b'\0', buffer: b"" }"#
);
}
#[test]
fn try_from_byte_error_debug() {
assert_eq!(
TryFromByteError(b'x').to_string(),
"failed to convert byte b'x'"
);
}
#[tokio::test]
async fn read_message() {
let mut stream = {
let (mut client, stream) = tokio::io::duplex(100);
client.write_all(b"\0\0\0\x02xyz").await.unwrap();
stream
};
let msg = read(&mut stream).await.unwrap();
assert_eq!(msg, Message::new(b'x', "y"));
let error = read(&mut stream).await.unwrap_err();
assert_eq!(error.kind(), ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn write_message() {
let mut client = {
let (client, mut stream) = tokio::io::duplex(100);
let msg = Message::new(b'x', "abc");
write(&mut stream, msg).await.unwrap();
stream.write_u8(1).await.unwrap();
client
};
let mut buffer = vec![0; 8];
client.read_exact(&mut buffer).await.unwrap();
assert_eq!(buffer, b"\0\0\0\x04xabc");
let byte = client.read_u8().await.unwrap();
assert_eq!(byte, 1);
let error = client.read_u8().await.unwrap_err();
assert_eq!(error.kind(), ErrorKind::UnexpectedEof);
}
}