use crate::util::io::read_available_non_blocking;
use anyhow::Error;
use bytes::BytesMut;
use serde_bare::error::Error as sbError;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub trait SendingStream: AsyncWrite + Send + Unpin {}
impl SendingStream for quinn::SendStream {}
#[cfg(test)]
impl SendingStream for tokio_test::io::Mock {}
pub trait ReceivingStream: AsyncRead + Send + Unpin {}
impl ReceivingStream for quinn::RecvStream {}
#[cfg(test)]
impl ReceivingStream for tokio_test::io::Mock {}
#[derive(Debug)]
pub struct SendReceivePair<S: SendingStream, R: ReceivingStream> {
pub send: S,
pub recv: R,
}
impl<S: SendingStream, R: ReceivingStream> From<(S, R)> for SendReceivePair<S, R> {
fn from(value: (S, R)) -> Self {
Self {
send: value.0,
recv: value.1,
}
}
}
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Eq, Debug, Default, Clone, Copy)]
pub struct MessageHeader {
pub size: u32,
}
impl MessageHeader {
pub const SIZE: u32 = 4;
}
impl ProtocolMessage for MessageHeader {}
pub trait ProtocolMessage
where
Self: serde::Serialize + serde::de::DeserializeOwned + Sync,
{
const WIRE_ENCODING_LIMIT: u32 = 1_048_576;
fn check_size(size: u32) -> Result<(), Error> {
Self::check_size_usize(size as usize)
}
fn check_size_usize(size: usize) -> Result<(), Error> {
anyhow::ensure!(
size <= Self::WIRE_ENCODING_LIMIT as usize,
format!(
"Wire message size {} was too long for {} (limit: {})",
size,
std::any::type_name::<Self>(),
Self::WIRE_ENCODING_LIMIT
)
);
Ok(())
}
fn from_slice(slice: &[u8]) -> Result<Self, sbError> {
serde_bare::from_slice(slice)
}
fn from_reader<R>(reader: &mut R, size: u32) -> Result<Self, Error>
where
R: std::io::Read,
{
let mut buffer = BytesMut::zeroed(size.try_into().unwrap());
reader.read_exact(&mut buffer)?;
Ok(serde_bare::from_slice(&buffer)?)
}
fn from_reader_async<R>(
reader: &mut R,
size: u32,
) -> impl Future<Output = Result<Self, Error>> + Send
where
R: AsyncReadExt + std::marker::Unpin + Send,
{
async move {
let mut buffer = BytesMut::zeroed(size.try_into().unwrap());
let _ = reader.read_exact(&mut buffer).await?;
Ok(serde_bare::from_slice(&buffer)?)
}
}
fn to_vec(&self) -> Result<Vec<u8>, sbError> {
serde_bare::to_vec(&self)
}
fn from_reader_framed<R>(reader: &mut R) -> Result<Self, Error>
where
R: std::io::Read,
{
let header = MessageHeader::from_reader(reader, MessageHeader::SIZE)?;
Self::check_size(header.size)?;
Self::from_reader(reader, header.size)
}
fn from_reader_async_framed<R>(
reader: &mut R,
) -> impl Future<Output = Result<Self, Error>> + Send
where
R: AsyncReadExt + std::marker::Unpin + Send,
{
async {
let header = MessageHeader::from_reader_async(reader, MessageHeader::SIZE).await?;
if let Err(e) = Self::check_size(header.size) {
let mut raw = BytesMut::zeroed(256);
let mut buf = tokio::io::ReadBuf::new(&mut raw);
if let Ok(hdr) = header.to_vec() {
buf.put_slice(&hdr);
if read_available_non_blocking(reader, &mut buf).await.is_ok()
&& let Ok(s) = str::from_utf8(&raw)
{
tracing::warn!("Received protocol garbage: {}", s.trim());
}
}
return Err(e);
}
Self::from_reader_async(reader, header.size).await
}
}
fn to_writer_framed<W>(&self, writer: &mut W) -> Result<(), Error>
where
W: std::io::Write,
{
let vec = self.to_vec()?;
Self::check_size_usize(vec.len())?;
#[allow(clippy::cast_possible_truncation)] let header = MessageHeader {
size: vec.len() as u32,
}
.to_vec()?;
writer.write_all(&header)?;
Ok(writer.write_all(&vec)?)
}
fn to_writer_async_framed<W>(
&self,
writer: &mut W,
) -> impl Future<Output = Result<(), Error>> + Send
where
W: AsyncWriteExt + std::marker::Unpin + Send,
{
async {
let vec = self.to_vec()?;
Self::check_size_usize(vec.len())?;
#[allow(clippy::cast_possible_truncation)] let header = MessageHeader {
size: vec.len() as u32,
}
.to_vec()?;
writer.write_all(&header).await?;
Ok(writer.write_all(&vec).await?)
}
}
fn encoded_size(&self) -> Result<usize, Error> {
Ok(self.to_vec()?.len())
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use crate::protocol::common::MessageHeader;
use super::{Error, ProtocolMessage};
use pretty_assertions::assert_eq;
use serde::{Deserialize, Serialize};
use std::io::Cursor;
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct TestMessage {
data: Vec<u8>,
}
impl ProtocolMessage for TestMessage {
const WIRE_ENCODING_LIMIT: u32 = 16;
}
#[test]
fn test_sync_framed_roundtrip() -> Result<(), Error> {
let msg = TestMessage {
data: vec![1, 2, 3],
};
let mut buf = Vec::new();
msg.to_writer_framed(&mut buf)?;
let decoded = TestMessage::from_reader_framed(&mut Cursor::new(buf))?;
assert_eq!(msg, decoded);
Ok(())
}
#[tokio::test]
async fn test_async_framed_roundtrip() -> Result<(), Error> {
let msg = TestMessage {
data: vec![1, 2, 3],
};
let mut buf = Vec::new();
msg.to_writer_async_framed(&mut buf).await?;
let decoded = TestMessage::from_reader_async_framed(&mut Cursor::new(buf)).await?;
assert_eq!(msg, decoded);
Ok(())
}
#[test]
fn test_slicing() {
let msg = TestMessage {
data: vec![4, 5, 6],
};
let vec = msg.to_vec().unwrap();
let decoded = TestMessage::from_slice(&vec).unwrap();
assert_eq!(msg, decoded);
}
#[test]
fn deserialize_limit() {
let buf = [
18, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
];
let _ = TestMessage::from_reader_framed(&mut Cursor::new(buf))
.expect_err("an error was expected");
}
#[test]
fn serialize_limit() {
#[allow(clippy::cast_possible_truncation)]
let msg = TestMessage {
data: vec![0u8; (TestMessage::WIRE_ENCODING_LIMIT + 1) as usize],
};
let mut buf = Vec::new();
let _ = msg
.to_writer_framed(&mut buf)
.expect_err("an error was expected");
}
#[tokio::test]
async fn deserialize_limit_async() {
let buf = [
18, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
];
let _ = TestMessage::from_reader_async_framed(&mut Cursor::new(buf))
.await
.expect_err("an error was expected");
}
#[tokio::test]
async fn serialize_limit_async() {
#[allow(clippy::cast_possible_truncation)]
let msg = TestMessage {
data: vec![0u8; (TestMessage::WIRE_ENCODING_LIMIT + 1) as usize],
};
let mut buf = Vec::new();
let _ = msg
.to_writer_framed(&mut buf)
.expect_err("an error was expected");
}
#[test]
fn deserialize_junk_over_long() {
for testcase in &[1u32 << 31, 4_294_967_295 ] {
let buf = MessageHeader { size: *testcase }.to_vec().unwrap();
let _ = TestMessage::from_reader_framed(&mut Cursor::new(buf))
.expect_err("an error was expected");
}
}
#[test]
fn deserialize_junk_zero_data() {
let buf = MessageHeader { size: 0 }.to_vec().unwrap();
let _ = TestMessage::from_reader_framed(&mut Cursor::new(buf))
.expect_err("an error was expected");
}
#[test]
fn deserialize_junk_insufficient_data() {
#![allow(clippy::cast_possible_truncation)]
let mut bogus_payload = vec![10u8 , 1, 2, 3 ];
let mut buf = MessageHeader {
size: bogus_payload.len() as u32,
}
.to_vec()
.unwrap();
buf.append(&mut bogus_payload);
let _ = TestMessage::from_reader_framed(&mut Cursor::new(buf))
.expect_err("an error was expected");
}
}