use crate::Channel;
use commonware_codec::{varint::UInt, EncodeSize, Error, RangeCfg, Read, ReadExt as _, Write};
use commonware_runtime::{Buf, BufMut, BufferPool, IoBuf, IoBufs};
use std::collections::HashMap;
#[derive(Clone, Debug, PartialEq)]
pub struct Data {
pub channel: u64,
pub message: IoBuf,
}
impl EncodeSize for Data {
fn encode_size(&self) -> usize {
UInt(self.channel).encode_size() + self.message.encode_size()
}
}
impl Write for Data {
fn write(&self, buf: &mut impl BufMut) {
UInt(self.channel).write(buf);
self.message.write(buf);
}
}
impl Read for Data {
type Cfg = RangeCfg<usize>;
fn read_cfg(buf: &mut impl Buf, range: &Self::Cfg) -> Result<Self, Error> {
let channel = UInt::read(buf)?.into();
let message = IoBuf::read_cfg(buf, range)?;
Ok(Self { channel, message })
}
}
#[derive(Clone, Debug)]
pub struct EncodedData {
pub channel: Channel,
pub payload: IoBufs,
}
impl EncodedData {
pub fn validate_channel<V>(self, rate_limits: &HashMap<u64, V>) -> Self {
assert!(
rate_limits.contains_key(&self.channel),
"outbound message on invalid channel"
);
self
}
pub fn new(pool: &BufferPool, prefix: u8, channel: Channel, mut message: IoBufs) -> Self {
let payload_len = message.len();
let header_len =
prefix.encode_size() + UInt(channel).encode_size() + payload_len.encode_size();
let mut header = pool.alloc(header_len);
prefix.write(&mut header);
UInt(channel).write(&mut header);
payload_len.write(&mut header);
assert_eq!(header.len(), header_len, "data header size mismatch");
message.prepend(header.freeze());
Self {
channel,
payload: message,
}
}
}
#[cfg(feature = "arbitrary")]
impl arbitrary::Arbitrary<'_> for Data {
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
let channel = u.arbitrary::<u64>()?;
let message = {
let size = u.int_in_range(0..=1024)?;
let bytes = u.bytes(size)?;
IoBuf::copy_from_slice(bytes)
};
Ok(Self { channel, message })
}
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_codec::{Decode as _, Encode as _, Error};
use commonware_runtime::{deterministic, BufferPooler as _, Runner as _};
#[test]
fn test_data_codec() {
let original = Data {
channel: 12345,
message: IoBuf::from(b"Hello, world!"),
};
let encoded = original.encode();
let decoded = Data::decode_cfg(encoded, &(13..=13).into()).unwrap();
assert_eq!(original, decoded);
let too_short = Data::decode_cfg(original.encode(), &(0..13).into());
assert!(matches!(too_short, Err(Error::InvalidLength(13))));
let too_long = Data::decode_cfg(original.encode(), &(14..).into());
assert!(matches!(too_long, Err(Error::InvalidLength(13))));
}
#[test]
fn test_decode_invalid() {
let invalid_payload = [3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let result = Data::decode_cfg(&invalid_payload[..], &(..).into());
assert!(result.is_err());
}
#[test]
fn test_encoded_data_new_matches_data_encode() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let mut message = IoBufs::from(IoBuf::from(b"hello "));
message.append(IoBuf::from(b"world"));
message.append(IoBuf::from(b"!"));
let data = Data {
channel: 12345,
message: message.clone().coalesce(),
};
let mut expected = IoBufs::from(data.encode());
expected.prepend(IoBuf::from(vec![7]));
let encoded = EncodedData::new(context.network_buffer_pool(), 7, 12345, message);
assert_eq!(encoded.channel, 12345);
assert_eq!(encoded.payload.coalesce(), expected.coalesce());
});
}
#[cfg(feature = "arbitrary")]
mod conformance {
use super::*;
use commonware_codec::conformance::CodecConformance;
commonware_conformance::conformance_tests! {
CodecConformance<Data>,
}
}
}