use crate::{
authenticated::data::{Data, EncodedData},
Channel,
};
use commonware_codec::{EncodeSize, Error, Read, ReadExt, Write};
use commonware_runtime::{Buf, BufMut, BufferPool, IoBufs};
pub const MAX_PAYLOAD_DATA_OVERHEAD: u32 = 1 + 10 + 5;
pub const DATA_PREFIX: u8 = 0;
pub const PING_PREFIX: u8 = 1;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub enum Message {
Data(Data),
Ping,
}
impl Message {
pub(crate) fn encode_data(pool: &BufferPool, channel: Channel, message: IoBufs) -> EncodedData {
EncodedData::new(pool, DATA_PREFIX, channel, message)
}
}
impl From<Data> for Message {
fn from(data: Data) -> Self {
Self::Data(data)
}
}
impl EncodeSize for Message {
fn encode_size(&self) -> usize {
(match self {
Self::Data(data) => data.encode_size(),
Self::Ping => 0, }) + 1 }
}
impl Write for Message {
fn write(&self, buf: &mut impl BufMut) {
match self {
Self::Data(data) => {
DATA_PREFIX.write(buf); data.write(buf);
}
Self::Ping => {
PING_PREFIX.write(buf); }
}
}
}
impl Read for Message {
type Cfg = usize;
fn read_cfg(buf: &mut impl Buf, max_data_length: &Self::Cfg) -> Result<Self, Error> {
let message_type = <u8>::read(buf)?;
match message_type {
DATA_PREFIX => {
let data = Data::read_cfg(buf, &(..=*max_data_length).into())?;
Ok(Self::Data(data))
}
PING_PREFIX => Ok(Self::Ping),
other => Err(Error::InvalidEnum(other)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_codec::{Decode as _, Encode as _, Error};
use commonware_runtime::IoBuf;
#[test]
fn test_max_payload_overhead() {
let message = IoBuf::from(vec![0; 1 << 29]);
let message_len = message.len();
let payload = Message::Data(Data {
channel: u64::MAX,
message,
});
assert_eq!(
payload.encode_size(),
message_len + MAX_PAYLOAD_DATA_OVERHEAD as usize
);
}
#[test]
fn test_decode_data_within_limit() {
let payload = Message::Data(Data {
channel: 7,
message: IoBuf::from(b"ping"),
});
let encoded = payload.encode();
let decoded = Message::decode_cfg(encoded, &4).expect("within limit");
match decoded {
Message::Data(data) => {
assert_eq!(data.channel, 7);
assert_eq!(data.message, IoBuf::from(b"ping"));
}
other => panic!("unexpected message variant: {other:?}"),
}
}
#[test]
fn test_decode_data_exceeding_limit() {
let payload = Message::Data(Data {
channel: 9,
message: IoBuf::from(b"hello"),
});
let encoded = payload.encode();
let result = Message::decode_cfg(encoded, &4);
assert!(matches!(result, Err(Error::InvalidLength(5))));
}
#[cfg(feature = "arbitrary")]
mod conformance {
use super::*;
use commonware_codec::conformance::CodecConformance;
commonware_conformance::conformance_tests! {
CodecConformance<Message>,
}
}
}