Skip to main content

hotmint_network/
codec.rs

1//! Wire codec for hotmint P2P messages.
2//!
3//! Applies optional zstd compression based on payload size:
4//!
5//! ```text
6//! [0x00][raw CBOR]     — uncompressed (small messages)
7//! [0x01][zstd bytes]   — zstd-compressed CBOR
8//! ```
9//!
10//! This is part of the hotmint wire protocol — all node implementations
11//! (regardless of P2P library) must support this format.
12
13use std::error::Error;
14use std::fmt;
15use std::io::{self, Read};
16
17use serde::{Deserialize, Serialize};
18
19/// Payloads larger than this threshold are zstd-compressed.
20const COMPRESS_THRESHOLD: usize = 256;
21
22/// Zstd compression level (3 = good balance of speed and ratio).
23const ZSTD_LEVEL: i32 = 3;
24
25/// Maximum allowed decompressed payload size (matches MAX_NOTIFICATION_SIZE in service.rs).
26/// Prevents decompression-bomb attacks on compressed frames.
27const MAX_DECOMPRESSED_SIZE: usize = 16 * 1024 * 1024;
28
29const TAG_RAW: u8 = 0x00;
30const TAG_ZSTD: u8 = 0x01;
31
32/// Serialize a value to CBOR, then conditionally zstd-compress.
33pub fn encode<T: Serialize>(value: &T) -> Result<Vec<u8>, EncodeError> {
34    let cbor = serde_cbor_2::to_vec(value).map_err(EncodeError::Cbor)?;
35    if cbor.len() > COMPRESS_THRESHOLD {
36        let compressed =
37            zstd::encode_all(cbor.as_slice(), ZSTD_LEVEL).map_err(EncodeError::Zstd)?;
38        let mut out = Vec::with_capacity(1 + compressed.len());
39        out.push(TAG_ZSTD);
40        out.extend_from_slice(&compressed);
41        Ok(out)
42    } else {
43        let mut out = Vec::with_capacity(1 + cbor.len());
44        out.push(TAG_RAW);
45        out.extend_from_slice(&cbor);
46        Ok(out)
47    }
48}
49
50/// Decode a wire frame: check tag byte, optionally decompress, then CBOR-decode.
51pub fn decode<T: for<'de> Deserialize<'de>>(data: &[u8]) -> Result<T, DecodeError> {
52    if data.is_empty() {
53        return Err(DecodeError::EmptyFrame);
54    }
55    match data[0] {
56        TAG_RAW => serde_cbor_2::from_slice(&data[1..]).map_err(DecodeError::Cbor),
57        TAG_ZSTD => {
58            let decoder = zstd::stream::read::Decoder::new(&data[1..])
59                .map_err(|e| DecodeError::Zstd(e.to_string()))?;
60            let mut decompressed = Vec::with_capacity(data.len().min(MAX_DECOMPRESSED_SIZE));
61            decoder
62                .take(MAX_DECOMPRESSED_SIZE as u64 + 1)
63                .read_to_end(&mut decompressed)
64                .map_err(|e| DecodeError::Zstd(e.to_string()))?;
65            if decompressed.len() > MAX_DECOMPRESSED_SIZE {
66                return Err(DecodeError::DecompressedTooLarge);
67            }
68            serde_cbor_2::from_slice(&decompressed).map_err(DecodeError::Cbor)
69        }
70        tag => Err(DecodeError::UnknownTag(tag)),
71    }
72}
73
74#[derive(Debug)]
75pub enum DecodeError {
76    EmptyFrame,
77    UnknownTag(u8),
78    Cbor(serde_cbor_2::Error),
79    Zstd(String),
80    DecompressedTooLarge,
81}
82
83impl fmt::Display for DecodeError {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        match self {
86            Self::EmptyFrame => write!(f, "empty frame"),
87            Self::UnknownTag(tag) => write!(f, "unknown codec tag: 0x{tag:02x}"),
88            Self::Cbor(e) => write!(f, "cbor: {e}"),
89            Self::Zstd(e) => write!(f, "zstd: {e}"),
90            Self::DecompressedTooLarge => write!(
91                f,
92                "decompressed payload exceeds {} bytes",
93                MAX_DECOMPRESSED_SIZE
94            ),
95        }
96    }
97}
98
99impl Error for DecodeError {}
100
101#[derive(Debug)]
102pub enum EncodeError {
103    Cbor(serde_cbor_2::Error),
104    Zstd(io::Error),
105}
106
107impl fmt::Display for EncodeError {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        match self {
110            Self::Cbor(e) => write!(f, "cbor: {e}"),
111            Self::Zstd(e) => write!(f, "zstd: {e}"),
112        }
113    }
114}
115
116impl Error for EncodeError {}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn small_message_not_compressed() {
124        let data = vec![1u8, 2, 3];
125        let encoded = encode(&data).unwrap();
126        assert_eq!(encoded[0], TAG_RAW);
127        let decoded: Vec<u8> = decode(&encoded).unwrap();
128        assert_eq!(decoded, data);
129    }
130
131    #[test]
132    fn large_message_compressed() {
133        let data = vec![42u8; 1024];
134        let encoded = encode(&data).unwrap();
135        assert_eq!(encoded[0], TAG_ZSTD);
136        // Compressed should be smaller than raw CBOR
137        let cbor_len = serde_cbor_2::to_vec(&data).unwrap().len();
138        assert!(encoded.len() < cbor_len);
139        let decoded: Vec<u8> = decode(&encoded).unwrap();
140        assert_eq!(decoded, data);
141    }
142
143    #[test]
144    fn empty_frame_error() {
145        let result: Result<Vec<u8>, _> = decode(&[]);
146        assert!(result.is_err());
147    }
148
149    #[test]
150    fn unknown_tag_error() {
151        let result: Result<Vec<u8>, _> = decode(&[0xFF, 0x00]);
152        assert!(result.is_err());
153    }
154
155    #[test]
156    fn decompressed_too_large_rejected() {
157        // Build a zstd frame that decompresses to just over the limit.
158        // Use a highly-compressible byte pattern so the compressed size stays small.
159        let oversized: Vec<u8> = vec![0xAAu8; MAX_DECOMPRESSED_SIZE + 1];
160        let mut compressed = zstd::encode_all(oversized.as_slice(), ZSTD_LEVEL).unwrap();
161        // Prepend the zstd tag byte to form a valid-looking wire frame
162        compressed.insert(0, TAG_ZSTD);
163        let result: Result<Vec<u8>, _> = decode(&compressed);
164        assert!(
165            matches!(result, Err(DecodeError::DecompressedTooLarge)),
166            "expected DecompressedTooLarge, got: {:?}",
167            result.err()
168        );
169    }
170}