1use std::error::Error;
14use std::fmt;
15use std::io::{self, Read};
16
17use serde::{Deserialize, Serialize};
18
19const COMPRESS_THRESHOLD: usize = 256;
21
22const ZSTD_LEVEL: i32 = 3;
24
25const MAX_DECOMPRESSED_SIZE: usize = 16 * 1024 * 1024;
28
29const TAG_RAW: u8 = 0x00;
30const TAG_ZSTD: u8 = 0x01;
31
32pub fn encode<T: Serialize>(value: &T) -> Result<Vec<u8>, EncodeError> {
34 let cbor = serde_cbor_2::to_vec(value).map_err(EncodeError::Cbor)?;
35 if cbor.len() > COMPRESS_THRESHOLD {
36 let compressed =
37 zstd::encode_all(cbor.as_slice(), ZSTD_LEVEL).map_err(EncodeError::Zstd)?;
38 let mut out = Vec::with_capacity(1 + compressed.len());
39 out.push(TAG_ZSTD);
40 out.extend_from_slice(&compressed);
41 Ok(out)
42 } else {
43 let mut out = Vec::with_capacity(1 + cbor.len());
44 out.push(TAG_RAW);
45 out.extend_from_slice(&cbor);
46 Ok(out)
47 }
48}
49
50pub fn decode<T: for<'de> Deserialize<'de>>(data: &[u8]) -> Result<T, DecodeError> {
52 if data.is_empty() {
53 return Err(DecodeError::EmptyFrame);
54 }
55 match data[0] {
56 TAG_RAW => serde_cbor_2::from_slice(&data[1..]).map_err(DecodeError::Cbor),
57 TAG_ZSTD => {
58 let decoder = zstd::stream::read::Decoder::new(&data[1..])
59 .map_err(|e| DecodeError::Zstd(e.to_string()))?;
60 let mut decompressed = Vec::with_capacity(data.len().min(MAX_DECOMPRESSED_SIZE));
61 decoder
62 .take(MAX_DECOMPRESSED_SIZE as u64 + 1)
63 .read_to_end(&mut decompressed)
64 .map_err(|e| DecodeError::Zstd(e.to_string()))?;
65 if decompressed.len() > MAX_DECOMPRESSED_SIZE {
66 return Err(DecodeError::DecompressedTooLarge);
67 }
68 serde_cbor_2::from_slice(&decompressed).map_err(DecodeError::Cbor)
69 }
70 tag => Err(DecodeError::UnknownTag(tag)),
71 }
72}
73
74#[derive(Debug)]
75pub enum DecodeError {
76 EmptyFrame,
77 UnknownTag(u8),
78 Cbor(serde_cbor_2::Error),
79 Zstd(String),
80 DecompressedTooLarge,
81}
82
83impl fmt::Display for DecodeError {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 match self {
86 Self::EmptyFrame => write!(f, "empty frame"),
87 Self::UnknownTag(tag) => write!(f, "unknown codec tag: 0x{tag:02x}"),
88 Self::Cbor(e) => write!(f, "cbor: {e}"),
89 Self::Zstd(e) => write!(f, "zstd: {e}"),
90 Self::DecompressedTooLarge => write!(
91 f,
92 "decompressed payload exceeds {} bytes",
93 MAX_DECOMPRESSED_SIZE
94 ),
95 }
96 }
97}
98
99impl Error for DecodeError {}
100
101#[derive(Debug)]
102pub enum EncodeError {
103 Cbor(serde_cbor_2::Error),
104 Zstd(io::Error),
105}
106
107impl fmt::Display for EncodeError {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 match self {
110 Self::Cbor(e) => write!(f, "cbor: {e}"),
111 Self::Zstd(e) => write!(f, "zstd: {e}"),
112 }
113 }
114}
115
116impl Error for EncodeError {}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn small_message_not_compressed() {
124 let data = vec![1u8, 2, 3];
125 let encoded = encode(&data).unwrap();
126 assert_eq!(encoded[0], TAG_RAW);
127 let decoded: Vec<u8> = decode(&encoded).unwrap();
128 assert_eq!(decoded, data);
129 }
130
131 #[test]
132 fn large_message_compressed() {
133 let data = vec![42u8; 1024];
134 let encoded = encode(&data).unwrap();
135 assert_eq!(encoded[0], TAG_ZSTD);
136 let cbor_len = serde_cbor_2::to_vec(&data).unwrap().len();
138 assert!(encoded.len() < cbor_len);
139 let decoded: Vec<u8> = decode(&encoded).unwrap();
140 assert_eq!(decoded, data);
141 }
142
143 #[test]
144 fn empty_frame_error() {
145 let result: Result<Vec<u8>, _> = decode(&[]);
146 assert!(result.is_err());
147 }
148
149 #[test]
150 fn unknown_tag_error() {
151 let result: Result<Vec<u8>, _> = decode(&[0xFF, 0x00]);
152 assert!(result.is_err());
153 }
154
155 #[test]
156 fn decompressed_too_large_rejected() {
157 let oversized: Vec<u8> = vec![0xAAu8; MAX_DECOMPRESSED_SIZE + 1];
160 let mut compressed = zstd::encode_all(oversized.as_slice(), ZSTD_LEVEL).unwrap();
161 compressed.insert(0, TAG_ZSTD);
163 let result: Result<Vec<u8>, _> = decode(&compressed);
164 assert!(
165 matches!(result, Err(DecodeError::DecompressedTooLarge)),
166 "expected DecompressedTooLarge, got: {:?}",
167 result.err()
168 );
169 }
170}