use crate::encrypted::Error;
use commonware_codec::{
varint::{Decoder, UInt, MAX_U32_VARINT_SIZE},
Encode, EncodeSize, Write,
};
use commonware_runtime::{Buf, IoBuf, IoBufMut, IoBufs, Sink, Stream};
pub(crate) fn build_frame<T>(
payload_len: usize,
max_message_size: u32,
assemble: impl FnOnce(UInt<u32>) -> Result<T, Error>,
) -> Result<T, Error> {
if payload_len > max_message_size as usize {
return Err(Error::SendTooLarge(payload_len));
}
let prefix = UInt(payload_len as u32);
assemble(prefix)
}
pub(crate) fn framed_len(payload_len: usize, max_message_size: u32) -> Result<usize, Error> {
build_frame(payload_len, max_message_size, |prefix| {
Ok(prefix.encode_size() + payload_len)
})
}
pub(crate) fn append_frame(
frame: &mut IoBufMut,
payload_len: usize,
max_message_size: u32,
append_payload: impl FnOnce(&mut IoBufMut, usize) -> Result<(), Error>,
) -> Result<usize, Error> {
build_frame(payload_len, max_message_size, |prefix| {
let start = frame.len();
prefix.write(frame);
let payload_offset = frame.len();
append_payload(frame, payload_offset)?;
assert_eq!(frame.len() - payload_offset, payload_len);
Ok(frame.len() - start)
})
}
pub async fn send_frame<S: Sink>(
sink: &mut S,
bufs: impl Into<IoBufs> + Send,
max_message_size: u32,
) -> Result<(), Error> {
let mut bufs = bufs.into();
let frame = build_frame(bufs.len(), max_message_size, |prefix| {
bufs.prepend(IoBuf::from(prefix.encode()));
Ok(bufs)
})?;
sink.send(frame).await.map_err(Error::SendFailed)
}
pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<IoBufs, Error> {
let (len, skip) = recv_length(stream).await?;
if len > max_message_size as usize {
return Err(Error::RecvTooLarge(len));
}
stream
.recv(skip + len)
.await
.map(|mut bufs| {
bufs.advance(skip);
bufs
})
.map_err(Error::RecvFailed)
}
async fn recv_length<T: Stream>(stream: &mut T) -> Result<(usize, usize), Error> {
let mut decoder = Decoder::<u32>::new();
let peeked = {
let peeked = stream.peek(MAX_U32_VARINT_SIZE);
for (i, byte) in peeked.iter().enumerate() {
match decoder.feed(*byte) {
Ok(Some(len)) => return Ok((len as usize, i + 1)),
Ok(None) => continue,
Err(_) => return Err(Error::InvalidVarint),
}
}
peeked.len()
};
let mut buf = stream.recv(peeked + 1).await.map_err(Error::RecvFailed)?;
buf.advance(peeked);
loop {
match decoder.feed(buf.get_u8()) {
Ok(Some(len)) => return Ok((len as usize, 0)),
Ok(None) => {}
Err(_) => return Err(Error::InvalidVarint),
}
buf = stream.recv(1).await.map_err(Error::RecvFailed)?;
}
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_runtime::{deterministic, mocks, BufMut, IoBufMut, Runner};
use rand::Rng;
const MAX_MESSAGE_SIZE: u32 = 1024;
#[test]
fn test_send_recv_at_max_message_size() {
let (mut sink, mut stream) = mocks::Channel::init();
let executor = deterministic::Runner::default();
executor.start(|mut context| async move {
let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
context.fill(&mut buf);
let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
assert!(result.is_ok());
let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
assert_eq!(data.len(), buf.len());
assert_eq!(data.coalesce(), buf);
});
}
#[test]
fn test_send_recv_multiple() {
let (mut sink, mut stream) = mocks::Channel::init();
let executor = deterministic::Runner::default();
executor.start(|mut context| async move {
let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
context.fill(&mut buf1);
context.fill(&mut buf2);
let result = send_frame(&mut sink, buf1.to_vec(), MAX_MESSAGE_SIZE).await;
assert!(result.is_ok());
let result = send_frame(&mut sink, buf2.to_vec(), MAX_MESSAGE_SIZE).await;
assert!(result.is_ok());
let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
assert_eq!(data.len(), buf1.len());
assert_eq!(data.coalesce(), buf1);
let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
assert_eq!(data.len(), buf2.len());
assert_eq!(data.coalesce(), buf2);
});
}
#[test]
fn test_send_frame() {
let (mut sink, mut stream) = mocks::Channel::init();
let executor = deterministic::Runner::default();
executor.start(|mut context| async move {
let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
context.fill(&mut buf);
let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
assert!(result.is_ok());
let read = stream.recv(2).await.unwrap();
assert_eq!(read.coalesce(), &[0x80, 0x08]); let read = stream.recv(MAX_MESSAGE_SIZE as usize).await.unwrap();
assert_eq!(read.coalesce(), buf);
});
}
#[test]
fn test_build_frame_closure_error() {
let result: Result<IoBufs, _> = build_frame(10, MAX_MESSAGE_SIZE, |_prefix| {
Err(Error::HandshakeError(
commonware_cryptography::handshake::Error::EncryptionFailed,
))
});
assert!(matches!(&result, Err(Error::HandshakeError(_))));
}
#[test]
fn test_build_frame_too_large() {
let result: Result<IoBufs, _> = build_frame(
MAX_MESSAGE_SIZE as usize + 1,
MAX_MESSAGE_SIZE,
|_prefix| unreachable!(),
);
assert!(
matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize + 1)
);
}
#[test]
fn test_send_frame_too_large() {
let (mut sink, _) = mocks::Channel::init();
let executor = deterministic::Runner::default();
executor.start(|mut context| async move {
let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
context.fill(&mut buf);
let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE - 1).await;
assert!(
matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
);
});
}
#[test]
fn test_read_frame() {
let (mut sink, mut stream) = mocks::Channel::init();
let executor = deterministic::Runner::default();
executor.start(|mut context| async move {
let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
context.fill(&mut msg);
let mut buf = IoBufMut::with_capacity(2 + msg.len());
buf.put_u8(0x80);
buf.put_u8(0x08);
buf.put_slice(&msg);
sink.send(buf.freeze()).await.unwrap();
let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
assert_eq!(data.coalesce(), msg);
});
}
#[test]
fn test_read_frame_too_large() {
let (mut sink, mut stream) = mocks::Channel::init();
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let mut buf = IoBufMut::with_capacity(2);
buf.put_u8(0x80);
buf.put_u8(0x08);
sink.send(buf.freeze()).await.unwrap();
let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
assert!(
matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
);
});
}
#[test]
fn test_recv_frame_incomplete_varint() {
let (mut sink, mut stream) = mocks::Channel::init();
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let mut buf = IoBufMut::with_capacity(1);
buf.put_u8(0x80);
sink.send(buf.freeze()).await.unwrap();
drop(sink);
let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
assert!(matches!(&result, Err(Error::RecvFailed(_))));
});
}
#[test]
fn test_recv_frame_invalid_varint_overflow() {
let (mut sink, mut stream) = mocks::Channel::init();
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let mut buf = IoBufMut::with_capacity(6);
buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0xFF); buf.put_u8(0x01);
sink.send(buf.freeze()).await.unwrap();
let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
assert!(matches!(&result, Err(Error::InvalidVarint)));
});
}
#[test]
fn test_recv_frame_peek_paths() {
let executor = deterministic::Runner::default();
executor.start(|mut context| async move {
let mut payload = vec![0u8; 300];
context.fill(&mut payload[..]);
let (mut sink, mut stream) = mocks::Channel::init();
send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
.await
.unwrap();
let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
assert_eq!(data.coalesce(), &payload[..]);
let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(0);
send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
.await
.unwrap();
let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
assert_eq!(data.coalesce(), &payload[..]);
let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(1);
send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
.await
.unwrap();
let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
assert_eq!(data.coalesce(), &payload[..]);
});
}
}