Skip to main content

pushwire_core/
binary.rs

1use serde::{Serialize, de::DeserializeOwned};
2use std::io::{Cursor, Read, Write};
3use thiserror::Error;
4use zstd::bulk;
5
6use crate::ChannelKind;
7
8/// Magic bytes prefixing every binary frame ("RP").
9pub const MAGIC_HEADER: [u8; 2] = [0x52, 0x50];
10/// Version byte for future evolution.
11pub const VERSION_BYTE: u8 = 1;
12
13/// Flags describing frame properties.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, serde::Deserialize)]
15pub struct BinaryFlags {
16    pub compressed: bool,
17    pub fragmented: bool,
18    pub ack_required: bool,
19}
20
21impl BinaryFlags {
22    pub fn to_byte(self) -> u8 {
23        (self.compressed as u8) | ((self.fragmented as u8) << 1) | ((self.ack_required as u8) << 2)
24    }
25
26    pub fn from_byte(byte: u8) -> Self {
27        Self {
28            compressed: byte & 0b0000_0001 != 0,
29            fragmented: byte & 0b0000_0010 != 0,
30            ack_required: byte & 0b0000_0100 != 0,
31        }
32    }
33}
34
35/// Encoding used for the payload body.
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum PayloadEncoding {
38    MessagePack,
39    Cbor,
40}
41
42/// Compression options for encoding/decoding.
43#[derive(Debug, Clone)]
44pub struct CompressionConfig {
45    pub threshold: usize,
46    pub dictionary: Option<CompressionDictionary>,
47}
48
49impl Default for CompressionConfig {
50    fn default() -> Self {
51        Self {
52            threshold: 512,
53            dictionary: None,
54        }
55    }
56}
57
58/// Pre-trained zstd dictionary with a numeric ID.
59#[derive(Debug, Clone, PartialEq, Eq)]
60pub struct CompressionDictionary {
61    pub id: u32,
62    pub bytes: Vec<u8>,
63}
64
65impl CompressionDictionary {
66    pub fn new(id: u32, bytes: Vec<u8>) -> Self {
67        Self { id, bytes }
68    }
69}
70
71/// Binary frame format:
72/// `[magic:2][version:1][flags:1][channel:1][sequence:4][payload:var]`
73///
74/// The channel byte is determined by [`ChannelKind::wire_id()`].
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub struct BinaryFrame<C: ChannelKind, T = Vec<u8>> {
77    pub channel: C,
78    pub flags: BinaryFlags,
79    pub sequence: u32,
80    pub payload: T,
81}
82
83#[derive(Debug, Error, PartialEq, Eq)]
84pub enum BinaryError {
85    #[error("invalid magic header")]
86    InvalidMagic,
87    #[error("unsupported version byte {0}")]
88    UnsupportedVersion(u8),
89    #[error("unknown channel id {0}")]
90    UnknownChannel(u8),
91    #[error("serialization error")]
92    Serialization,
93    #[error("deserialization error")]
94    Deserialization,
95    #[error("frame too short")]
96    FrameTooShort,
97    #[error("compression error")]
98    Compression,
99    #[error("decompression error")]
100    Decompression,
101    #[error("missing compression dictionary {0}")]
102    MissingDictionary(u32),
103}
104
105/// Encode without compression using the default settings.
106pub fn encode_frame<C: ChannelKind, T: Serialize>(
107    frame: &BinaryFrame<C, T>,
108    encoding: PayloadEncoding,
109) -> Result<Vec<u8>, BinaryError> {
110    encode_frame_with_compression(frame, encoding, &CompressionConfig::default())
111}
112
113/// Encode with optional per-frame compression.
114pub fn encode_frame_with_compression<C: ChannelKind, T: Serialize>(
115    frame: &BinaryFrame<C, T>,
116    encoding: PayloadEncoding,
117    compression: &CompressionConfig,
118) -> Result<Vec<u8>, BinaryError> {
119    let mut flags = frame.flags;
120    let mut out = Vec::with_capacity(16);
121    out.extend_from_slice(&MAGIC_HEADER);
122    out.push(VERSION_BYTE);
123
124    let payload_bytes = serialize_payload(&frame.payload, encoding)?;
125    let payload_len = payload_bytes.len();
126    let compressed_attempt: Option<(Vec<u8>, Option<u32>)> = if payload_len < compression.threshold
127    {
128        None
129    } else if let Some(dict) = &compression.dictionary {
130        compress_with_dictionary(&payload_bytes, dict)
131            .ok()
132            .map(|c| (prepend_dict_id(c, dict.id), Some(dict.id)))
133    } else {
134        bulk::compress(&payload_bytes, 3)
135            .ok()
136            .map(|c| (prepend_dict_id(c, 0), None))
137    };
138
139    let (body, _dict_used) = match compressed_attempt {
140        Some((c, id)) if c.len() < payload_len => (c, id),
141        _ => (payload_bytes.clone(), None),
142    };
143
144    if body.len() < payload_len {
145        flags.compressed = true;
146        out.push(flags.to_byte());
147        out.push(frame.channel.wire_id());
148        out.extend_from_slice(&frame.sequence.to_be_bytes());
149        out.extend_from_slice(&body);
150        Ok(out)
151    } else {
152        flags.compressed = false;
153        out.push(flags.to_byte());
154        out.push(frame.channel.wire_id());
155        out.extend_from_slice(&frame.sequence.to_be_bytes());
156        out.extend_from_slice(&payload_bytes);
157        Ok(out)
158    }
159}
160
161/// Decode a frame without any pre-shared dictionaries.
162pub fn decode_frame<C: ChannelKind, T: DeserializeOwned>(
163    bytes: &[u8],
164    encoding: PayloadEncoding,
165) -> Result<BinaryFrame<C, T>, BinaryError> {
166    decode_frame_with_dictionaries(bytes, encoding, &[])
167}
168
169/// Decode a frame using a set of pre-shared dictionaries.
170pub fn decode_frame_with_dictionaries<C: ChannelKind, T: DeserializeOwned>(
171    bytes: &[u8],
172    encoding: PayloadEncoding,
173    dictionaries: &[CompressionDictionary],
174) -> Result<BinaryFrame<C, T>, BinaryError> {
175    if bytes.len() < 9 {
176        return Err(BinaryError::FrameTooShort);
177    }
178
179    if bytes[0..2] != MAGIC_HEADER {
180        return Err(BinaryError::InvalidMagic);
181    }
182    let version = bytes[2];
183    if version != VERSION_BYTE {
184        return Err(BinaryError::UnsupportedVersion(version));
185    }
186
187    let flags = BinaryFlags::from_byte(bytes[3]);
188    let channel = C::from_wire_id(bytes[4]).ok_or(BinaryError::UnknownChannel(bytes[4]))?;
189    let sequence = u32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]);
190
191    let payload_bytes = &bytes[9..];
192    let payload = deserialize_payload(payload_bytes, encoding, flags, dictionaries)?;
193
194    Ok(BinaryFrame {
195        channel,
196        flags,
197        sequence,
198        payload,
199    })
200}
201
202/// Train a zstd dictionary from samples.
203pub fn train_dictionary(
204    samples: &[&[u8]],
205    dict_size: usize,
206    id: u32,
207) -> Result<CompressionDictionary, BinaryError> {
208    let dict =
209        zstd::dict::from_samples(samples, dict_size).map_err(|_| BinaryError::Compression)?;
210    Ok(CompressionDictionary::new(id, dict))
211}
212
213fn serialize_payload<T: Serialize>(
214    payload: &T,
215    encoding: PayloadEncoding,
216) -> Result<Vec<u8>, BinaryError> {
217    match encoding {
218        PayloadEncoding::MessagePack => {
219            rmp_serde::to_vec(payload).map_err(|_| BinaryError::Serialization)
220        }
221        PayloadEncoding::Cbor => {
222            serde_cbor::to_vec(payload).map_err(|_| BinaryError::Serialization)
223        }
224    }
225}
226
227fn deserialize_payload<T: DeserializeOwned>(
228    bytes: &[u8],
229    encoding: PayloadEncoding,
230    flags: BinaryFlags,
231    dictionaries: &[CompressionDictionary],
232) -> Result<T, BinaryError> {
233    let data = if flags.compressed {
234        let (dict_id, start) = extract_dict_id(bytes);
235        let compressed = &bytes[start..];
236        if let Some(id) = dict_id {
237            let dict = dictionaries
238                .iter()
239                .find(|d| d.id == id)
240                .ok_or(BinaryError::MissingDictionary(id))?;
241            let mut decoder =
242                zstd::stream::Decoder::with_dictionary(Cursor::new(compressed), &dict.bytes)
243                    .map_err(|_| BinaryError::Decompression)?;
244            let mut buf = Vec::new();
245            decoder
246                .read_to_end(&mut buf)
247                .map_err(|_| BinaryError::Decompression)?;
248            buf
249        } else {
250            zstd::stream::decode_all(Cursor::new(compressed))
251                .map_err(|_| BinaryError::Decompression)?
252        }
253    } else {
254        bytes.to_vec()
255    };
256
257    match encoding {
258        PayloadEncoding::MessagePack => {
259            rmp_serde::from_slice(&data).map_err(|_| BinaryError::Deserialization)
260        }
261        PayloadEncoding::Cbor => {
262            serde_cbor::from_slice(&data).map_err(|_| BinaryError::Deserialization)
263        }
264    }
265}
266
267fn prepend_dict_id(mut data: Vec<u8>, id: u32) -> Vec<u8> {
268    let mut out = Vec::with_capacity(data.len() + 4);
269    out.extend_from_slice(&id.to_be_bytes());
270    out.append(&mut data);
271    out
272}
273
274fn extract_dict_id(bytes: &[u8]) -> (Option<u32>, usize) {
275    if bytes.len() < 4 {
276        return (None, 0);
277    }
278    let id = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
279    if id == 0 { (None, 4) } else { (Some(id), 4) }
280}
281
282fn compress_with_dictionary(
283    payload_bytes: &[u8],
284    dict: &CompressionDictionary,
285) -> Result<Vec<u8>, BinaryError> {
286    let mut encoder = zstd::stream::Encoder::with_dictionary(Vec::new(), 3, &dict.bytes)
287        .map_err(|_| BinaryError::Compression)?;
288    encoder
289        .write_all(payload_bytes)
290        .map_err(|_| BinaryError::Compression)?;
291    encoder.finish().map_err(|_| BinaryError::Compression)
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use crate::ChannelKind;
298    use serde::{Deserialize, Serialize};
299
300    /// Test channel used across binary codec tests.
301    #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq, Serialize, Deserialize)]
302    #[serde(rename_all = "lowercase")]
303    enum Ch {
304        Data,
305        Ui,
306    }
307
308    impl ChannelKind for Ch {
309        fn priority(&self) -> u8 {
310            0
311        }
312        fn wire_id(&self) -> u8 {
313            match self {
314                Ch::Data => 0x07,
315                Ch::Ui => 0x01,
316            }
317        }
318        fn from_wire_id(id: u8) -> Option<Self> {
319            match id {
320                0x07 => Some(Ch::Data),
321                0x01 => Some(Ch::Ui),
322                _ => None,
323            }
324        }
325        fn from_name(s: &str) -> Option<Self> {
326            match s {
327                "data" => Some(Ch::Data),
328                "ui" => Some(Ch::Ui),
329                _ => None,
330            }
331        }
332        fn name(&self) -> &'static str {
333            match self {
334                Ch::Data => "data",
335                Ch::Ui => "ui",
336            }
337        }
338        fn is_system(&self) -> bool {
339            false
340        }
341        fn all() -> &'static [Self] {
342            &[Self::Data, Self::Ui]
343        }
344    }
345
346    #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
347    struct Payload {
348        id: u32,
349        msg: String,
350    }
351
352    fn base_frame() -> BinaryFrame<Ch, Payload> {
353        BinaryFrame {
354            channel: Ch::Data,
355            flags: BinaryFlags {
356                compressed: false,
357                fragmented: false,
358                ack_required: true,
359            },
360            sequence: 42,
361            payload: Payload {
362                id: 1,
363                msg: "hello".into(),
364            },
365        }
366    }
367
368    #[test]
369    fn flags_roundtrip_bits() {
370        let flags = BinaryFlags {
371            compressed: true,
372            fragmented: true,
373            ack_required: false,
374        };
375        let byte = flags.to_byte();
376        assert_eq!(BinaryFlags::from_byte(byte), flags);
377    }
378
379    #[test]
380    fn messagepack_roundtrip() {
381        let frame = base_frame();
382        let bytes = encode_frame(&frame, PayloadEncoding::MessagePack).unwrap();
383        let decoded: BinaryFrame<Ch, Payload> =
384            decode_frame(&bytes, PayloadEncoding::MessagePack).unwrap();
385        assert_eq!(decoded, frame);
386    }
387
388    #[test]
389    fn cbor_roundtrip() {
390        let frame = base_frame();
391        let bytes = encode_frame(&frame, PayloadEncoding::Cbor).unwrap();
392        let decoded: BinaryFrame<Ch, Payload> =
393            decode_frame(&bytes, PayloadEncoding::Cbor).unwrap();
394        assert_eq!(decoded, frame);
395    }
396
397    #[test]
398    fn compresses_when_beneficial() {
399        let frame = BinaryFrame {
400            channel: Ch::Ui,
401            flags: BinaryFlags {
402                compressed: false,
403                fragmented: false,
404                ack_required: false,
405            },
406            sequence: 1,
407            payload: Payload {
408                id: 1,
409                msg: "x".repeat(2048),
410            },
411        };
412        let cfg = CompressionConfig {
413            threshold: 256,
414            dictionary: None,
415        };
416        let bytes =
417            encode_frame_with_compression(&frame, PayloadEncoding::MessagePack, &cfg).unwrap();
418        assert!(BinaryFlags::from_byte(bytes[3]).compressed);
419
420        let decoded: BinaryFrame<Ch, Payload> =
421            decode_frame_with_dictionaries(&bytes, PayloadEncoding::MessagePack, &[]).unwrap();
422        assert_eq!(decoded.payload.msg.len(), 2048);
423    }
424
425    #[test]
426    fn dictionary_training_and_use() {
427        let samples_raw: Vec<Vec<u8>> = (0..10)
428            .map(|i| format!("{{\"content\":\"sample_{i}_payload_data\"}}").into_bytes())
429            .collect();
430        let sample_refs: Vec<&[u8]> = samples_raw.iter().map(|b| b.as_slice()).collect();
431        let dict = train_dictionary(&sample_refs, 256, 7).unwrap();
432        let cfg = CompressionConfig {
433            threshold: 1,
434            dictionary: Some(dict.clone()),
435        };
436        let frame = base_frame();
437        let bytes =
438            encode_frame_with_compression(&frame, PayloadEncoding::MessagePack, &cfg).unwrap();
439        let decoded: BinaryFrame<Ch, Payload> =
440            decode_frame_with_dictionaries(&bytes, PayloadEncoding::MessagePack, &[dict]).unwrap();
441        assert_eq!(decoded, frame);
442    }
443
444    #[test]
445    fn rejects_bad_magic() {
446        let mut bytes = encode_frame(&base_frame(), PayloadEncoding::MessagePack).unwrap();
447        bytes[0] = 0x00;
448        let err = decode_frame::<Ch, Payload>(&bytes, PayloadEncoding::MessagePack).unwrap_err();
449        assert_eq!(err, BinaryError::InvalidMagic);
450    }
451
452    #[test]
453    fn rejects_unknown_channel() {
454        let mut bytes = encode_frame(&base_frame(), PayloadEncoding::MessagePack).unwrap();
455        bytes[4] = 0xFF;
456        let err = decode_frame::<Ch, Payload>(&bytes, PayloadEncoding::MessagePack).unwrap_err();
457        assert_eq!(err, BinaryError::UnknownChannel(0xFF));
458    }
459
460    #[test]
461    fn wire_id_preserved_in_encoding() {
462        let frame = base_frame();
463        let bytes = encode_frame(&frame, PayloadEncoding::MessagePack).unwrap();
464        // Channel byte at offset 4 should be Ch::Data wire_id = 0x07
465        assert_eq!(bytes[4], 0x07);
466    }
467}