1use std::error::Error;
14use std::fmt;
15use std::io::{self, Read};
16
17use serde::{Deserialize, Serialize};
18
19const COMPRESS_THRESHOLD: usize = 256;
21
22const ZSTD_LEVEL: i32 = 3;
24
25const MAX_DECOMPRESSED_SIZE: usize = 16 * 1024 * 1024;
28
29const TAG_RAW: u8 = 0x00;
30const TAG_ZSTD: u8 = 0x01;
31
32pub fn encode<T: Serialize>(value: &T) -> Result<Vec<u8>, EncodeError> {
34 let raw = postcard::to_allocvec(value).map_err(EncodeError::Postcard)?;
35 if raw.len() > COMPRESS_THRESHOLD {
36 let compressed = zstd::encode_all(raw.as_slice(), ZSTD_LEVEL).map_err(EncodeError::Zstd)?;
37 let mut out = Vec::with_capacity(1 + compressed.len());
38 out.push(TAG_ZSTD);
39 out.extend_from_slice(&compressed);
40 Ok(out)
41 } else {
42 let mut out = Vec::with_capacity(1 + raw.len());
43 out.push(TAG_RAW);
44 out.extend_from_slice(&raw);
45 Ok(out)
46 }
47}
48
49pub fn decode<T: for<'de> Deserialize<'de>>(data: &[u8]) -> Result<T, DecodeError> {
51 if data.is_empty() {
52 return Err(DecodeError::EmptyFrame);
53 }
54 match data[0] {
55 TAG_RAW => postcard::from_bytes(&data[1..]).map_err(DecodeError::Postcard),
56 TAG_ZSTD => {
57 let decoder = zstd::stream::read::Decoder::new(&data[1..])
58 .map_err(|e| DecodeError::Zstd(e.to_string()))?;
59 let mut decompressed = Vec::with_capacity(data.len().min(MAX_DECOMPRESSED_SIZE));
60 decoder
61 .take(MAX_DECOMPRESSED_SIZE as u64 + 1)
62 .read_to_end(&mut decompressed)
63 .map_err(|e| DecodeError::Zstd(e.to_string()))?;
64 if decompressed.len() > MAX_DECOMPRESSED_SIZE {
65 return Err(DecodeError::DecompressedTooLarge);
66 }
67 postcard::from_bytes(&decompressed).map_err(DecodeError::Postcard)
68 }
69 tag => Err(DecodeError::UnknownTag(tag)),
70 }
71}
72
73#[derive(Debug)]
74pub enum DecodeError {
75 EmptyFrame,
76 UnknownTag(u8),
77 Postcard(postcard::Error),
78 Zstd(String),
79 DecompressedTooLarge,
80}
81
82impl fmt::Display for DecodeError {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 match self {
85 Self::EmptyFrame => write!(f, "empty frame"),
86 Self::UnknownTag(tag) => write!(f, "unknown codec tag: 0x{tag:02x}"),
87 Self::Postcard(e) => write!(f, "postcard: {e}"),
88 Self::Zstd(e) => write!(f, "zstd: {e}"),
89 Self::DecompressedTooLarge => write!(
90 f,
91 "decompressed payload exceeds {} bytes",
92 MAX_DECOMPRESSED_SIZE
93 ),
94 }
95 }
96}
97
98impl Error for DecodeError {}
99
100#[derive(Debug)]
101pub enum EncodeError {
102 Postcard(postcard::Error),
103 Zstd(io::Error),
104}
105
106impl fmt::Display for EncodeError {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 match self {
109 Self::Postcard(e) => write!(f, "postcard: {e}"),
110 Self::Zstd(e) => write!(f, "zstd: {e}"),
111 }
112 }
113}
114
115impl Error for EncodeError {}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn small_message_not_compressed() {
123 let data = vec![1u8, 2, 3];
124 let encoded = encode(&data).unwrap();
125 assert_eq!(encoded[0], TAG_RAW);
126 let decoded: Vec<u8> = decode(&encoded).unwrap();
127 assert_eq!(decoded, data);
128 }
129
130 #[test]
131 fn large_message_compressed() {
132 let data = vec![42u8; 1024];
133 let encoded = encode(&data).unwrap();
134 assert_eq!(encoded[0], TAG_ZSTD);
135 let raw_len = postcard::to_allocvec(&data).unwrap().len();
137 assert!(encoded.len() < raw_len);
138 let decoded: Vec<u8> = decode(&encoded).unwrap();
139 assert_eq!(decoded, data);
140 }
141
142 #[test]
143 fn empty_frame_error() {
144 let result: Result<Vec<u8>, _> = decode(&[]);
145 assert!(result.is_err());
146 }
147
148 #[test]
149 fn unknown_tag_error() {
150 let result: Result<Vec<u8>, _> = decode(&[0xFF, 0x00]);
151 assert!(result.is_err());
152 }
153
154 #[test]
155 fn decompressed_too_large_rejected() {
156 let oversized: Vec<u8> = vec![0xAAu8; MAX_DECOMPRESSED_SIZE + 1];
157 let mut compressed = zstd::encode_all(oversized.as_slice(), ZSTD_LEVEL).unwrap();
158 compressed.insert(0, TAG_ZSTD);
159 let result: Result<Vec<u8>, _> = decode(&compressed);
160 assert!(
161 matches!(result, Err(DecodeError::DecompressedTooLarge)),
162 "expected DecompressedTooLarge, got: {:?}",
163 result.err()
164 );
165 }
166}