Skip to main content

hotmint_network/
codec.rs

1//! Wire codec for hotmint P2P messages.
2//!
3//! Applies optional zstd compression based on payload size:
4//!
5//! ```text
6//! [0x00][raw postcard]     — uncompressed (small messages)
7//! [0x01][zstd bytes]       — zstd-compressed postcard
8//! ```
9//!
10//! This is part of the hotmint wire protocol — all node implementations
11//! (regardless of P2P library) must support this format.
12
13use std::error::Error;
14use std::fmt;
15use std::io::{self, Read};
16
17use serde::{Deserialize, Serialize};
18
19/// Payloads larger than this threshold are zstd-compressed.
20const COMPRESS_THRESHOLD: usize = 256;
21
22/// Zstd compression level (3 = good balance of speed and ratio).
23const ZSTD_LEVEL: i32 = 3;
24
25/// Maximum allowed decompressed payload size (matches MAX_NOTIFICATION_SIZE in service.rs).
26/// Prevents decompression-bomb attacks on compressed frames.
27const MAX_DECOMPRESSED_SIZE: usize = 16 * 1024 * 1024;
28
29const TAG_RAW: u8 = 0x00;
30const TAG_ZSTD: u8 = 0x01;
31
32/// Serialize a value to postcard, then conditionally zstd-compress.
33pub fn encode<T: Serialize>(value: &T) -> Result<Vec<u8>, EncodeError> {
34    let raw = postcard::to_allocvec(value).map_err(EncodeError::Postcard)?;
35    if raw.len() > COMPRESS_THRESHOLD {
36        let compressed = zstd::encode_all(raw.as_slice(), ZSTD_LEVEL).map_err(EncodeError::Zstd)?;
37        let mut out = Vec::with_capacity(1 + compressed.len());
38        out.push(TAG_ZSTD);
39        out.extend_from_slice(&compressed);
40        Ok(out)
41    } else {
42        let mut out = Vec::with_capacity(1 + raw.len());
43        out.push(TAG_RAW);
44        out.extend_from_slice(&raw);
45        Ok(out)
46    }
47}
48
49/// Decode a wire frame: check tag byte, optionally decompress, then postcard-decode.
50pub fn decode<T: for<'de> Deserialize<'de>>(data: &[u8]) -> Result<T, DecodeError> {
51    if data.is_empty() {
52        return Err(DecodeError::EmptyFrame);
53    }
54    match data[0] {
55        TAG_RAW => postcard::from_bytes(&data[1..]).map_err(DecodeError::Postcard),
56        TAG_ZSTD => {
57            let decoder = zstd::stream::read::Decoder::new(&data[1..])
58                .map_err(|e| DecodeError::Zstd(e.to_string()))?;
59            let mut decompressed = Vec::with_capacity(data.len().min(MAX_DECOMPRESSED_SIZE));
60            decoder
61                .take(MAX_DECOMPRESSED_SIZE as u64 + 1)
62                .read_to_end(&mut decompressed)
63                .map_err(|e| DecodeError::Zstd(e.to_string()))?;
64            if decompressed.len() > MAX_DECOMPRESSED_SIZE {
65                return Err(DecodeError::DecompressedTooLarge);
66            }
67            postcard::from_bytes(&decompressed).map_err(DecodeError::Postcard)
68        }
69        tag => Err(DecodeError::UnknownTag(tag)),
70    }
71}
72
73#[derive(Debug)]
74pub enum DecodeError {
75    EmptyFrame,
76    UnknownTag(u8),
77    Postcard(postcard::Error),
78    Zstd(String),
79    DecompressedTooLarge,
80}
81
82impl fmt::Display for DecodeError {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        match self {
85            Self::EmptyFrame => write!(f, "empty frame"),
86            Self::UnknownTag(tag) => write!(f, "unknown codec tag: 0x{tag:02x}"),
87            Self::Postcard(e) => write!(f, "postcard: {e}"),
88            Self::Zstd(e) => write!(f, "zstd: {e}"),
89            Self::DecompressedTooLarge => write!(
90                f,
91                "decompressed payload exceeds {} bytes",
92                MAX_DECOMPRESSED_SIZE
93            ),
94        }
95    }
96}
97
98impl Error for DecodeError {}
99
100#[derive(Debug)]
101pub enum EncodeError {
102    Postcard(postcard::Error),
103    Zstd(io::Error),
104}
105
106impl fmt::Display for EncodeError {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        match self {
109            Self::Postcard(e) => write!(f, "postcard: {e}"),
110            Self::Zstd(e) => write!(f, "zstd: {e}"),
111        }
112    }
113}
114
115impl Error for EncodeError {}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn small_message_not_compressed() {
123        let data = vec![1u8, 2, 3];
124        let encoded = encode(&data).unwrap();
125        assert_eq!(encoded[0], TAG_RAW);
126        let decoded: Vec<u8> = decode(&encoded).unwrap();
127        assert_eq!(decoded, data);
128    }
129
130    #[test]
131    fn large_message_compressed() {
132        let data = vec![42u8; 1024];
133        let encoded = encode(&data).unwrap();
134        assert_eq!(encoded[0], TAG_ZSTD);
135        // Compressed should be smaller than raw postcard
136        let raw_len = postcard::to_allocvec(&data).unwrap().len();
137        assert!(encoded.len() < raw_len);
138        let decoded: Vec<u8> = decode(&encoded).unwrap();
139        assert_eq!(decoded, data);
140    }
141
142    #[test]
143    fn empty_frame_error() {
144        let result: Result<Vec<u8>, _> = decode(&[]);
145        assert!(result.is_err());
146    }
147
148    #[test]
149    fn unknown_tag_error() {
150        let result: Result<Vec<u8>, _> = decode(&[0xFF, 0x00]);
151        assert!(result.is_err());
152    }
153
154    #[test]
155    fn decompressed_too_large_rejected() {
156        let oversized: Vec<u8> = vec![0xAAu8; MAX_DECOMPRESSED_SIZE + 1];
157        let mut compressed = zstd::encode_all(oversized.as_slice(), ZSTD_LEVEL).unwrap();
158        compressed.insert(0, TAG_ZSTD);
159        let result: Result<Vec<u8>, _> = decode(&compressed);
160        assert!(
161            matches!(result, Err(DecodeError::DecompressedTooLarge)),
162            "expected DecompressedTooLarge, got: {:?}",
163            result.err()
164        );
165    }
166}