use serde::{Serialize, de::DeserializeOwned};
use std::io::{Cursor, Read, Write};
use thiserror::Error;
use zstd::bulk;
use crate::ChannelKind;
pub const MAGIC_HEADER: [u8; 2] = [0x52, 0x50];
pub const VERSION_BYTE: u8 = 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, serde::Deserialize)]
pub struct BinaryFlags {
pub compressed: bool,
pub fragmented: bool,
pub ack_required: bool,
}
impl BinaryFlags {
pub fn to_byte(self) -> u8 {
(self.compressed as u8) | ((self.fragmented as u8) << 1) | ((self.ack_required as u8) << 2)
}
pub fn from_byte(byte: u8) -> Self {
Self {
compressed: byte & 0b0000_0001 != 0,
fragmented: byte & 0b0000_0010 != 0,
ack_required: byte & 0b0000_0100 != 0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PayloadEncoding {
MessagePack,
Cbor,
}
#[derive(Debug, Clone)]
pub struct CompressionConfig {
pub threshold: usize,
pub dictionary: Option<CompressionDictionary>,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
threshold: 512,
dictionary: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompressionDictionary {
pub id: u32,
pub bytes: Vec<u8>,
}
impl CompressionDictionary {
pub fn new(id: u32, bytes: Vec<u8>) -> Self {
Self { id, bytes }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BinaryFrame<C: ChannelKind, T = Vec<u8>> {
pub channel: C,
pub flags: BinaryFlags,
pub sequence: u32,
pub payload: T,
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum BinaryError {
#[error("invalid magic header")]
InvalidMagic,
#[error("unsupported version byte {0}")]
UnsupportedVersion(u8),
#[error("unknown channel id {0}")]
UnknownChannel(u8),
#[error("serialization error")]
Serialization,
#[error("deserialization error")]
Deserialization,
#[error("frame too short")]
FrameTooShort,
#[error("compression error")]
Compression,
#[error("decompression error")]
Decompression,
#[error("missing compression dictionary {0}")]
MissingDictionary(u32),
}
pub fn encode_frame<C: ChannelKind, T: Serialize>(
frame: &BinaryFrame<C, T>,
encoding: PayloadEncoding,
) -> Result<Vec<u8>, BinaryError> {
encode_frame_with_compression(frame, encoding, &CompressionConfig::default())
}
pub fn encode_frame_with_compression<C: ChannelKind, T: Serialize>(
frame: &BinaryFrame<C, T>,
encoding: PayloadEncoding,
compression: &CompressionConfig,
) -> Result<Vec<u8>, BinaryError> {
let mut flags = frame.flags;
let mut out = Vec::with_capacity(16);
out.extend_from_slice(&MAGIC_HEADER);
out.push(VERSION_BYTE);
let payload_bytes = serialize_payload(&frame.payload, encoding)?;
let payload_len = payload_bytes.len();
let compressed_attempt: Option<(Vec<u8>, Option<u32>)> = if payload_len < compression.threshold
{
None
} else if let Some(dict) = &compression.dictionary {
compress_with_dictionary(&payload_bytes, dict)
.ok()
.map(|c| (prepend_dict_id(c, dict.id), Some(dict.id)))
} else {
bulk::compress(&payload_bytes, 3)
.ok()
.map(|c| (prepend_dict_id(c, 0), None))
};
let (body, _dict_used) = match compressed_attempt {
Some((c, id)) if c.len() < payload_len => (c, id),
_ => (payload_bytes.clone(), None),
};
if body.len() < payload_len {
flags.compressed = true;
out.push(flags.to_byte());
out.push(frame.channel.wire_id());
out.extend_from_slice(&frame.sequence.to_be_bytes());
out.extend_from_slice(&body);
Ok(out)
} else {
flags.compressed = false;
out.push(flags.to_byte());
out.push(frame.channel.wire_id());
out.extend_from_slice(&frame.sequence.to_be_bytes());
out.extend_from_slice(&payload_bytes);
Ok(out)
}
}
pub fn decode_frame<C: ChannelKind, T: DeserializeOwned>(
bytes: &[u8],
encoding: PayloadEncoding,
) -> Result<BinaryFrame<C, T>, BinaryError> {
decode_frame_with_dictionaries(bytes, encoding, &[])
}
pub fn decode_frame_with_dictionaries<C: ChannelKind, T: DeserializeOwned>(
bytes: &[u8],
encoding: PayloadEncoding,
dictionaries: &[CompressionDictionary],
) -> Result<BinaryFrame<C, T>, BinaryError> {
if bytes.len() < 9 {
return Err(BinaryError::FrameTooShort);
}
if bytes[0..2] != MAGIC_HEADER {
return Err(BinaryError::InvalidMagic);
}
let version = bytes[2];
if version != VERSION_BYTE {
return Err(BinaryError::UnsupportedVersion(version));
}
let flags = BinaryFlags::from_byte(bytes[3]);
let channel = C::from_wire_id(bytes[4]).ok_or(BinaryError::UnknownChannel(bytes[4]))?;
let sequence = u32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]);
let payload_bytes = &bytes[9..];
let payload = deserialize_payload(payload_bytes, encoding, flags, dictionaries)?;
Ok(BinaryFrame {
channel,
flags,
sequence,
payload,
})
}
pub fn train_dictionary(
samples: &[&[u8]],
dict_size: usize,
id: u32,
) -> Result<CompressionDictionary, BinaryError> {
let dict =
zstd::dict::from_samples(samples, dict_size).map_err(|_| BinaryError::Compression)?;
Ok(CompressionDictionary::new(id, dict))
}
fn serialize_payload<T: Serialize>(
payload: &T,
encoding: PayloadEncoding,
) -> Result<Vec<u8>, BinaryError> {
match encoding {
PayloadEncoding::MessagePack => {
rmp_serde::to_vec(payload).map_err(|_| BinaryError::Serialization)
}
PayloadEncoding::Cbor => {
serde_cbor::to_vec(payload).map_err(|_| BinaryError::Serialization)
}
}
}
fn deserialize_payload<T: DeserializeOwned>(
bytes: &[u8],
encoding: PayloadEncoding,
flags: BinaryFlags,
dictionaries: &[CompressionDictionary],
) -> Result<T, BinaryError> {
let data = if flags.compressed {
let (dict_id, start) = extract_dict_id(bytes);
let compressed = &bytes[start..];
if let Some(id) = dict_id {
let dict = dictionaries
.iter()
.find(|d| d.id == id)
.ok_or(BinaryError::MissingDictionary(id))?;
let mut decoder =
zstd::stream::Decoder::with_dictionary(Cursor::new(compressed), &dict.bytes)
.map_err(|_| BinaryError::Decompression)?;
let mut buf = Vec::new();
decoder
.read_to_end(&mut buf)
.map_err(|_| BinaryError::Decompression)?;
buf
} else {
zstd::stream::decode_all(Cursor::new(compressed))
.map_err(|_| BinaryError::Decompression)?
}
} else {
bytes.to_vec()
};
match encoding {
PayloadEncoding::MessagePack => {
rmp_serde::from_slice(&data).map_err(|_| BinaryError::Deserialization)
}
PayloadEncoding::Cbor => {
serde_cbor::from_slice(&data).map_err(|_| BinaryError::Deserialization)
}
}
}
fn prepend_dict_id(mut data: Vec<u8>, id: u32) -> Vec<u8> {
let mut out = Vec::with_capacity(data.len() + 4);
out.extend_from_slice(&id.to_be_bytes());
out.append(&mut data);
out
}
fn extract_dict_id(bytes: &[u8]) -> (Option<u32>, usize) {
if bytes.len() < 4 {
return (None, 0);
}
let id = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
if id == 0 { (None, 4) } else { (Some(id), 4) }
}
fn compress_with_dictionary(
payload_bytes: &[u8],
dict: &CompressionDictionary,
) -> Result<Vec<u8>, BinaryError> {
let mut encoder = zstd::stream::Encoder::with_dictionary(Vec::new(), 3, &dict.bytes)
.map_err(|_| BinaryError::Compression)?;
encoder
.write_all(payload_bytes)
.map_err(|_| BinaryError::Compression)?;
encoder.finish().map_err(|_| BinaryError::Compression)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ChannelKind;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
enum Ch {
Data,
Ui,
}
impl ChannelKind for Ch {
fn priority(&self) -> u8 {
0
}
fn wire_id(&self) -> u8 {
match self {
Ch::Data => 0x07,
Ch::Ui => 0x01,
}
}
fn from_wire_id(id: u8) -> Option<Self> {
match id {
0x07 => Some(Ch::Data),
0x01 => Some(Ch::Ui),
_ => None,
}
}
fn from_name(s: &str) -> Option<Self> {
match s {
"data" => Some(Ch::Data),
"ui" => Some(Ch::Ui),
_ => None,
}
}
fn name(&self) -> &'static str {
match self {
Ch::Data => "data",
Ch::Ui => "ui",
}
}
fn is_system(&self) -> bool {
false
}
fn all() -> &'static [Self] {
&[Self::Data, Self::Ui]
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
struct Payload {
id: u32,
msg: String,
}
fn base_frame() -> BinaryFrame<Ch, Payload> {
BinaryFrame {
channel: Ch::Data,
flags: BinaryFlags {
compressed: false,
fragmented: false,
ack_required: true,
},
sequence: 42,
payload: Payload {
id: 1,
msg: "hello".into(),
},
}
}
#[test]
fn flags_roundtrip_bits() {
let flags = BinaryFlags {
compressed: true,
fragmented: true,
ack_required: false,
};
let byte = flags.to_byte();
assert_eq!(BinaryFlags::from_byte(byte), flags);
}
#[test]
fn messagepack_roundtrip() {
let frame = base_frame();
let bytes = encode_frame(&frame, PayloadEncoding::MessagePack).unwrap();
let decoded: BinaryFrame<Ch, Payload> =
decode_frame(&bytes, PayloadEncoding::MessagePack).unwrap();
assert_eq!(decoded, frame);
}
#[test]
fn cbor_roundtrip() {
let frame = base_frame();
let bytes = encode_frame(&frame, PayloadEncoding::Cbor).unwrap();
let decoded: BinaryFrame<Ch, Payload> =
decode_frame(&bytes, PayloadEncoding::Cbor).unwrap();
assert_eq!(decoded, frame);
}
#[test]
fn compresses_when_beneficial() {
let frame = BinaryFrame {
channel: Ch::Ui,
flags: BinaryFlags {
compressed: false,
fragmented: false,
ack_required: false,
},
sequence: 1,
payload: Payload {
id: 1,
msg: "x".repeat(2048),
},
};
let cfg = CompressionConfig {
threshold: 256,
dictionary: None,
};
let bytes =
encode_frame_with_compression(&frame, PayloadEncoding::MessagePack, &cfg).unwrap();
assert!(BinaryFlags::from_byte(bytes[3]).compressed);
let decoded: BinaryFrame<Ch, Payload> =
decode_frame_with_dictionaries(&bytes, PayloadEncoding::MessagePack, &[]).unwrap();
assert_eq!(decoded.payload.msg.len(), 2048);
}
#[test]
fn dictionary_training_and_use() {
let samples_raw: Vec<Vec<u8>> = (0..10)
.map(|i| format!("{{\"content\":\"sample_{i}_payload_data\"}}").into_bytes())
.collect();
let sample_refs: Vec<&[u8]> = samples_raw.iter().map(|b| b.as_slice()).collect();
let dict = train_dictionary(&sample_refs, 256, 7).unwrap();
let cfg = CompressionConfig {
threshold: 1,
dictionary: Some(dict.clone()),
};
let frame = base_frame();
let bytes =
encode_frame_with_compression(&frame, PayloadEncoding::MessagePack, &cfg).unwrap();
let decoded: BinaryFrame<Ch, Payload> =
decode_frame_with_dictionaries(&bytes, PayloadEncoding::MessagePack, &[dict]).unwrap();
assert_eq!(decoded, frame);
}
#[test]
fn rejects_bad_magic() {
let mut bytes = encode_frame(&base_frame(), PayloadEncoding::MessagePack).unwrap();
bytes[0] = 0x00;
let err = decode_frame::<Ch, Payload>(&bytes, PayloadEncoding::MessagePack).unwrap_err();
assert_eq!(err, BinaryError::InvalidMagic);
}
#[test]
fn rejects_unknown_channel() {
let mut bytes = encode_frame(&base_frame(), PayloadEncoding::MessagePack).unwrap();
bytes[4] = 0xFF;
let err = decode_frame::<Ch, Payload>(&bytes, PayloadEncoding::MessagePack).unwrap_err();
assert_eq!(err, BinaryError::UnknownChannel(0xFF));
}
#[test]
fn wire_id_preserved_in_encoding() {
let frame = base_frame();
let bytes = encode_frame(&frame, PayloadEncoding::MessagePack).unwrap();
assert_eq!(bytes[4], 0x07);
}
}