1use serde::{Serialize, de::DeserializeOwned};
2use std::io::{Cursor, Read, Write};
3use thiserror::Error;
4use zstd::bulk;
5
6use crate::ChannelKind;
7
8pub const MAGIC_HEADER: [u8; 2] = [0x52, 0x50];
10pub const VERSION_BYTE: u8 = 1;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, serde::Deserialize)]
15pub struct BinaryFlags {
16 pub compressed: bool,
17 pub fragmented: bool,
18 pub ack_required: bool,
19}
20
21impl BinaryFlags {
22 pub fn to_byte(self) -> u8 {
23 (self.compressed as u8) | ((self.fragmented as u8) << 1) | ((self.ack_required as u8) << 2)
24 }
25
26 pub fn from_byte(byte: u8) -> Self {
27 Self {
28 compressed: byte & 0b0000_0001 != 0,
29 fragmented: byte & 0b0000_0010 != 0,
30 ack_required: byte & 0b0000_0100 != 0,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum PayloadEncoding {
38 MessagePack,
39 Cbor,
40}
41
42#[derive(Debug, Clone)]
44pub struct CompressionConfig {
45 pub threshold: usize,
46 pub dictionary: Option<CompressionDictionary>,
47}
48
49impl Default for CompressionConfig {
50 fn default() -> Self {
51 Self {
52 threshold: 512,
53 dictionary: None,
54 }
55 }
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
60pub struct CompressionDictionary {
61 pub id: u32,
62 pub bytes: Vec<u8>,
63}
64
65impl CompressionDictionary {
66 pub fn new(id: u32, bytes: Vec<u8>) -> Self {
67 Self { id, bytes }
68 }
69}
70
71#[derive(Debug, Clone, PartialEq, Eq)]
76pub struct BinaryFrame<C: ChannelKind, T = Vec<u8>> {
77 pub channel: C,
78 pub flags: BinaryFlags,
79 pub sequence: u32,
80 pub payload: T,
81}
82
83#[derive(Debug, Error, PartialEq, Eq)]
84pub enum BinaryError {
85 #[error("invalid magic header")]
86 InvalidMagic,
87 #[error("unsupported version byte {0}")]
88 UnsupportedVersion(u8),
89 #[error("unknown channel id {0}")]
90 UnknownChannel(u8),
91 #[error("serialization error")]
92 Serialization,
93 #[error("deserialization error")]
94 Deserialization,
95 #[error("frame too short")]
96 FrameTooShort,
97 #[error("compression error")]
98 Compression,
99 #[error("decompression error")]
100 Decompression,
101 #[error("missing compression dictionary {0}")]
102 MissingDictionary(u32),
103}
104
105pub fn encode_frame<C: ChannelKind, T: Serialize>(
107 frame: &BinaryFrame<C, T>,
108 encoding: PayloadEncoding,
109) -> Result<Vec<u8>, BinaryError> {
110 encode_frame_with_compression(frame, encoding, &CompressionConfig::default())
111}
112
113pub fn encode_frame_with_compression<C: ChannelKind, T: Serialize>(
115 frame: &BinaryFrame<C, T>,
116 encoding: PayloadEncoding,
117 compression: &CompressionConfig,
118) -> Result<Vec<u8>, BinaryError> {
119 let mut flags = frame.flags;
120 let mut out = Vec::with_capacity(16);
121 out.extend_from_slice(&MAGIC_HEADER);
122 out.push(VERSION_BYTE);
123
124 let payload_bytes = serialize_payload(&frame.payload, encoding)?;
125 let payload_len = payload_bytes.len();
126 let compressed_attempt: Option<(Vec<u8>, Option<u32>)> = if payload_len < compression.threshold
127 {
128 None
129 } else if let Some(dict) = &compression.dictionary {
130 compress_with_dictionary(&payload_bytes, dict)
131 .ok()
132 .map(|c| (prepend_dict_id(c, dict.id), Some(dict.id)))
133 } else {
134 bulk::compress(&payload_bytes, 3)
135 .ok()
136 .map(|c| (prepend_dict_id(c, 0), None))
137 };
138
139 let (body, _dict_used) = match compressed_attempt {
140 Some((c, id)) if c.len() < payload_len => (c, id),
141 _ => (payload_bytes.clone(), None),
142 };
143
144 if body.len() < payload_len {
145 flags.compressed = true;
146 out.push(flags.to_byte());
147 out.push(frame.channel.wire_id());
148 out.extend_from_slice(&frame.sequence.to_be_bytes());
149 out.extend_from_slice(&body);
150 Ok(out)
151 } else {
152 flags.compressed = false;
153 out.push(flags.to_byte());
154 out.push(frame.channel.wire_id());
155 out.extend_from_slice(&frame.sequence.to_be_bytes());
156 out.extend_from_slice(&payload_bytes);
157 Ok(out)
158 }
159}
160
161pub fn decode_frame<C: ChannelKind, T: DeserializeOwned>(
163 bytes: &[u8],
164 encoding: PayloadEncoding,
165) -> Result<BinaryFrame<C, T>, BinaryError> {
166 decode_frame_with_dictionaries(bytes, encoding, &[])
167}
168
169pub fn decode_frame_with_dictionaries<C: ChannelKind, T: DeserializeOwned>(
171 bytes: &[u8],
172 encoding: PayloadEncoding,
173 dictionaries: &[CompressionDictionary],
174) -> Result<BinaryFrame<C, T>, BinaryError> {
175 if bytes.len() < 9 {
176 return Err(BinaryError::FrameTooShort);
177 }
178
179 if bytes[0..2] != MAGIC_HEADER {
180 return Err(BinaryError::InvalidMagic);
181 }
182 let version = bytes[2];
183 if version != VERSION_BYTE {
184 return Err(BinaryError::UnsupportedVersion(version));
185 }
186
187 let flags = BinaryFlags::from_byte(bytes[3]);
188 let channel = C::from_wire_id(bytes[4]).ok_or(BinaryError::UnknownChannel(bytes[4]))?;
189 let sequence = u32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]);
190
191 let payload_bytes = &bytes[9..];
192 let payload = deserialize_payload(payload_bytes, encoding, flags, dictionaries)?;
193
194 Ok(BinaryFrame {
195 channel,
196 flags,
197 sequence,
198 payload,
199 })
200}
201
202pub fn train_dictionary(
204 samples: &[&[u8]],
205 dict_size: usize,
206 id: u32,
207) -> Result<CompressionDictionary, BinaryError> {
208 let dict =
209 zstd::dict::from_samples(samples, dict_size).map_err(|_| BinaryError::Compression)?;
210 Ok(CompressionDictionary::new(id, dict))
211}
212
213fn serialize_payload<T: Serialize>(
214 payload: &T,
215 encoding: PayloadEncoding,
216) -> Result<Vec<u8>, BinaryError> {
217 match encoding {
218 PayloadEncoding::MessagePack => {
219 rmp_serde::to_vec(payload).map_err(|_| BinaryError::Serialization)
220 }
221 PayloadEncoding::Cbor => {
222 serde_cbor::to_vec(payload).map_err(|_| BinaryError::Serialization)
223 }
224 }
225}
226
227fn deserialize_payload<T: DeserializeOwned>(
228 bytes: &[u8],
229 encoding: PayloadEncoding,
230 flags: BinaryFlags,
231 dictionaries: &[CompressionDictionary],
232) -> Result<T, BinaryError> {
233 let data = if flags.compressed {
234 let (dict_id, start) = extract_dict_id(bytes);
235 let compressed = &bytes[start..];
236 if let Some(id) = dict_id {
237 let dict = dictionaries
238 .iter()
239 .find(|d| d.id == id)
240 .ok_or(BinaryError::MissingDictionary(id))?;
241 let mut decoder =
242 zstd::stream::Decoder::with_dictionary(Cursor::new(compressed), &dict.bytes)
243 .map_err(|_| BinaryError::Decompression)?;
244 let mut buf = Vec::new();
245 decoder
246 .read_to_end(&mut buf)
247 .map_err(|_| BinaryError::Decompression)?;
248 buf
249 } else {
250 zstd::stream::decode_all(Cursor::new(compressed))
251 .map_err(|_| BinaryError::Decompression)?
252 }
253 } else {
254 bytes.to_vec()
255 };
256
257 match encoding {
258 PayloadEncoding::MessagePack => {
259 rmp_serde::from_slice(&data).map_err(|_| BinaryError::Deserialization)
260 }
261 PayloadEncoding::Cbor => {
262 serde_cbor::from_slice(&data).map_err(|_| BinaryError::Deserialization)
263 }
264 }
265}
266
267fn prepend_dict_id(mut data: Vec<u8>, id: u32) -> Vec<u8> {
268 let mut out = Vec::with_capacity(data.len() + 4);
269 out.extend_from_slice(&id.to_be_bytes());
270 out.append(&mut data);
271 out
272}
273
274fn extract_dict_id(bytes: &[u8]) -> (Option<u32>, usize) {
275 if bytes.len() < 4 {
276 return (None, 0);
277 }
278 let id = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
279 if id == 0 { (None, 4) } else { (Some(id), 4) }
280}
281
282fn compress_with_dictionary(
283 payload_bytes: &[u8],
284 dict: &CompressionDictionary,
285) -> Result<Vec<u8>, BinaryError> {
286 let mut encoder = zstd::stream::Encoder::with_dictionary(Vec::new(), 3, &dict.bytes)
287 .map_err(|_| BinaryError::Compression)?;
288 encoder
289 .write_all(payload_bytes)
290 .map_err(|_| BinaryError::Compression)?;
291 encoder.finish().map_err(|_| BinaryError::Compression)
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::ChannelKind;
298 use serde::{Deserialize, Serialize};
299
300 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq, Serialize, Deserialize)]
302 #[serde(rename_all = "lowercase")]
303 enum Ch {
304 Data,
305 Ui,
306 }
307
308 impl ChannelKind for Ch {
309 fn priority(&self) -> u8 {
310 0
311 }
312 fn wire_id(&self) -> u8 {
313 match self {
314 Ch::Data => 0x07,
315 Ch::Ui => 0x01,
316 }
317 }
318 fn from_wire_id(id: u8) -> Option<Self> {
319 match id {
320 0x07 => Some(Ch::Data),
321 0x01 => Some(Ch::Ui),
322 _ => None,
323 }
324 }
325 fn from_name(s: &str) -> Option<Self> {
326 match s {
327 "data" => Some(Ch::Data),
328 "ui" => Some(Ch::Ui),
329 _ => None,
330 }
331 }
332 fn name(&self) -> &'static str {
333 match self {
334 Ch::Data => "data",
335 Ch::Ui => "ui",
336 }
337 }
338 fn is_system(&self) -> bool {
339 false
340 }
341 fn all() -> &'static [Self] {
342 &[Self::Data, Self::Ui]
343 }
344 }
345
346 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
347 struct Payload {
348 id: u32,
349 msg: String,
350 }
351
352 fn base_frame() -> BinaryFrame<Ch, Payload> {
353 BinaryFrame {
354 channel: Ch::Data,
355 flags: BinaryFlags {
356 compressed: false,
357 fragmented: false,
358 ack_required: true,
359 },
360 sequence: 42,
361 payload: Payload {
362 id: 1,
363 msg: "hello".into(),
364 },
365 }
366 }
367
368 #[test]
369 fn flags_roundtrip_bits() {
370 let flags = BinaryFlags {
371 compressed: true,
372 fragmented: true,
373 ack_required: false,
374 };
375 let byte = flags.to_byte();
376 assert_eq!(BinaryFlags::from_byte(byte), flags);
377 }
378
379 #[test]
380 fn messagepack_roundtrip() {
381 let frame = base_frame();
382 let bytes = encode_frame(&frame, PayloadEncoding::MessagePack).unwrap();
383 let decoded: BinaryFrame<Ch, Payload> =
384 decode_frame(&bytes, PayloadEncoding::MessagePack).unwrap();
385 assert_eq!(decoded, frame);
386 }
387
388 #[test]
389 fn cbor_roundtrip() {
390 let frame = base_frame();
391 let bytes = encode_frame(&frame, PayloadEncoding::Cbor).unwrap();
392 let decoded: BinaryFrame<Ch, Payload> =
393 decode_frame(&bytes, PayloadEncoding::Cbor).unwrap();
394 assert_eq!(decoded, frame);
395 }
396
397 #[test]
398 fn compresses_when_beneficial() {
399 let frame = BinaryFrame {
400 channel: Ch::Ui,
401 flags: BinaryFlags {
402 compressed: false,
403 fragmented: false,
404 ack_required: false,
405 },
406 sequence: 1,
407 payload: Payload {
408 id: 1,
409 msg: "x".repeat(2048),
410 },
411 };
412 let cfg = CompressionConfig {
413 threshold: 256,
414 dictionary: None,
415 };
416 let bytes =
417 encode_frame_with_compression(&frame, PayloadEncoding::MessagePack, &cfg).unwrap();
418 assert!(BinaryFlags::from_byte(bytes[3]).compressed);
419
420 let decoded: BinaryFrame<Ch, Payload> =
421 decode_frame_with_dictionaries(&bytes, PayloadEncoding::MessagePack, &[]).unwrap();
422 assert_eq!(decoded.payload.msg.len(), 2048);
423 }
424
425 #[test]
426 fn dictionary_training_and_use() {
427 let samples_raw: Vec<Vec<u8>> = (0..10)
428 .map(|i| format!("{{\"content\":\"sample_{i}_payload_data\"}}").into_bytes())
429 .collect();
430 let sample_refs: Vec<&[u8]> = samples_raw.iter().map(|b| b.as_slice()).collect();
431 let dict = train_dictionary(&sample_refs, 256, 7).unwrap();
432 let cfg = CompressionConfig {
433 threshold: 1,
434 dictionary: Some(dict.clone()),
435 };
436 let frame = base_frame();
437 let bytes =
438 encode_frame_with_compression(&frame, PayloadEncoding::MessagePack, &cfg).unwrap();
439 let decoded: BinaryFrame<Ch, Payload> =
440 decode_frame_with_dictionaries(&bytes, PayloadEncoding::MessagePack, &[dict]).unwrap();
441 assert_eq!(decoded, frame);
442 }
443
444 #[test]
445 fn rejects_bad_magic() {
446 let mut bytes = encode_frame(&base_frame(), PayloadEncoding::MessagePack).unwrap();
447 bytes[0] = 0x00;
448 let err = decode_frame::<Ch, Payload>(&bytes, PayloadEncoding::MessagePack).unwrap_err();
449 assert_eq!(err, BinaryError::InvalidMagic);
450 }
451
452 #[test]
453 fn rejects_unknown_channel() {
454 let mut bytes = encode_frame(&base_frame(), PayloadEncoding::MessagePack).unwrap();
455 bytes[4] = 0xFF;
456 let err = decode_frame::<Ch, Payload>(&bytes, PayloadEncoding::MessagePack).unwrap_err();
457 assert_eq!(err, BinaryError::UnknownChannel(0xFF));
458 }
459
460 #[test]
461 fn wire_id_preserved_in_encoding() {
462 let frame = base_frame();
463 let bytes = encode_frame(&frame, PayloadEncoding::MessagePack).unwrap();
464 assert_eq!(bytes[4], 0x07);
466 }
467}