1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use serde::{Deserialize, Serialize};
3
4use crate::error::{Error, Result};
5
6const ENTRY_LEN_SIZE: usize = 4;
7const FORMAT_VERSION: u16 = 1;
8const COMPRESSION_TYPE_SIZE: usize = 1;
9const ENTRIES_COUNT_SIZE: usize = 4;
10const VERSION_SIZE: usize = 2;
11const FOOTER_SIZE: usize = COMPRESSION_TYPE_SIZE + ENTRIES_COUNT_SIZE + VERSION_SIZE;
12
13const ZSTD_LEVEL: i32 = 3;
15
16#[repr(u8)]
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
19#[serde(rename_all = "lowercase")]
20pub enum CompressionType {
21 #[default]
22 None = 0,
23 Zstd = 1,
24}
25
26impl TryFrom<u8> for CompressionType {
27 type Error = Error;
28
29 fn try_from(value: u8) -> Result<Self> {
30 match value {
31 0 => Ok(CompressionType::None),
32 1 => Ok(CompressionType::Zstd),
33 other => Err(Error::Serialization(format!(
34 "unsupported compression type: {other}"
35 ))),
36 }
37 }
38}
39
40pub(crate) fn encode_batch(entries: &[Bytes], compression: CompressionType) -> Result<Bytes> {
41 let data_size: usize = entries.iter().map(|e| ENTRY_LEN_SIZE + e.len()).sum();
42 let mut entry_buf = BytesMut::with_capacity(data_size);
43
44 for entry in entries {
45 debug_assert!(entry.len() <= u32::MAX as usize);
46 entry_buf.put_u32_le(entry.len() as u32);
47 entry_buf.put_slice(entry);
48 }
49
50 let compressed = match compression {
51 CompressionType::None => entry_buf.freeze(),
52 CompressionType::Zstd => {
53 let compressed = zstd::bulk::compress(&entry_buf, ZSTD_LEVEL)
54 .map_err(|e| Error::Serialization(format!("zstd compression failed: {e}")))?;
55 Bytes::from(compressed)
56 }
57 };
58
59 let mut buf = BytesMut::with_capacity(compressed.len() + FOOTER_SIZE);
60 buf.put_slice(&compressed);
61 buf.put_u8(compression as u8);
62
63 debug_assert!(entries.len() <= u32::MAX as usize);
64 buf.put_u32_le(entries.len() as u32);
65 buf.put_u16_le(FORMAT_VERSION);
66
67 Ok(buf.freeze())
68}
69
70pub(crate) fn decode_batch(mut data: Bytes) -> Result<Vec<Bytes>> {
71 if data.len() < FOOTER_SIZE {
72 return Err(Error::Serialization(
73 "batch too small for footer".to_string(),
74 ));
75 }
76
77 let footer_start = data.len() - FOOTER_SIZE;
78 let mut footer = data.split_off(footer_start);
79
80 let compression_type = CompressionType::try_from(footer.get_u8())?;
81 let record_count = footer.get_u32_le() as usize;
82 let version = footer.get_u16_le();
83
84 if version != FORMAT_VERSION {
85 return Err(Error::Serialization(format!(
86 "unsupported batch version: {}",
87 version
88 )));
89 }
90
91 let mut entry_data = match compression_type {
92 CompressionType::None => data,
93 CompressionType::Zstd => {
94 let decompressed = zstd::stream::decode_all(data.as_ref())
95 .map_err(|e| Error::Serialization(format!("zstd decompression failed: {e}")))?;
96 Bytes::from(decompressed)
97 }
98 };
99
100 let mut entries = Vec::with_capacity(record_count);
101 for _ in 0..record_count {
102 if entry_data.remaining() < ENTRY_LEN_SIZE {
103 return Err(Error::Serialization("truncated record length".to_string()));
104 }
105 let len = entry_data.get_u32_le() as usize;
106 if entry_data.remaining() < len {
107 return Err(Error::Serialization("truncated record data".to_string()));
108 }
109 entries.push(entry_data.split_to(len));
110 }
111
112 Ok(entries)
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn should_roundtrip_batch() {
121 let entries = vec![
122 Bytes::from("hello"),
123 Bytes::from("world"),
124 Bytes::from("foo"),
125 ];
126 let encoded = encode_batch(&entries, CompressionType::None).unwrap();
127 let decoded = decode_batch(encoded).unwrap();
128 assert_eq!(decoded, entries);
129 }
130
131 #[test]
132 fn should_roundtrip_empty_batch() {
133 let entries: Vec<Bytes> = vec![];
134 let encoded = encode_batch(&entries, CompressionType::None).unwrap();
135 assert_eq!(encoded.len(), FOOTER_SIZE);
136 let decoded = decode_batch(encoded).unwrap();
137 assert!(decoded.is_empty());
138 }
139
140 #[test]
141 fn should_roundtrip_empty_record() {
142 let entries = vec![Bytes::new()];
143 let encoded = encode_batch(&entries, CompressionType::None).unwrap();
144 let decoded = decode_batch(encoded).unwrap();
145 assert_eq!(decoded, entries);
146 }
147
148 #[test]
149 fn should_reject_truncated_data() {
150 let entries = vec![Bytes::from("hello")];
151 let mut encoded = BytesMut::from(
152 encode_batch(&entries, CompressionType::None)
153 .unwrap()
154 .as_ref(),
155 );
156 encoded.truncate(encoded.len() - FOOTER_SIZE - 1);
157 encoded.put_u8(CompressionType::None as u8);
158 encoded.put_u32_le(1);
159 encoded.put_u16_le(FORMAT_VERSION);
160 let result = decode_batch(encoded.freeze());
161 assert!(result.is_err());
162 }
163
164 #[test]
165 fn should_reject_unsupported_version() {
166 let mut buf = BytesMut::new();
167 buf.put_u8(0);
168 buf.put_u32_le(0);
169 buf.put_u16_le(99);
170 let result = decode_batch(buf.freeze());
171 assert!(result.is_err());
172 }
173
174 #[test]
175 fn should_roundtrip_batch_with_zstd() {
176 let entries = vec![
177 Bytes::from("hello"),
178 Bytes::from("world"),
179 Bytes::from("foo"),
180 ];
181 let encoded = encode_batch(&entries, CompressionType::Zstd).unwrap();
182 let decoded = decode_batch(encoded).unwrap();
183 assert_eq!(decoded, entries);
184 }
185
186 #[test]
187 fn should_roundtrip_empty_batch_with_zstd() {
188 let entries: Vec<Bytes> = vec![];
189 let encoded = encode_batch(&entries, CompressionType::Zstd).unwrap();
190 let decoded = decode_batch(encoded).unwrap();
191 assert!(decoded.is_empty());
192 }
193
194 #[test]
195 fn should_roundtrip_large_batch_with_zstd() {
196 let entries: Vec<Bytes> = (0..1000)
197 .map(|i| Bytes::from(format!("entry-{:04}", i)))
198 .collect();
199 let encoded = encode_batch(&entries, CompressionType::Zstd).unwrap();
200 let decoded = decode_batch(encoded).unwrap();
201 assert_eq!(decoded, entries);
202 }
203
204 #[test]
205 fn should_compress_smaller_than_uncompressed_for_repetitive_data() {
206 let entries: Vec<Bytes> = (0..100)
207 .map(|_| Bytes::from("repeated-data-that-compresses-well"))
208 .collect();
209 let uncompressed = encode_batch(&entries, CompressionType::None).unwrap();
210 let compressed = encode_batch(&entries, CompressionType::Zstd).unwrap();
211 assert!(
212 compressed.len() < uncompressed.len(),
213 "compressed ({}) should be smaller than uncompressed ({})",
214 compressed.len(),
215 uncompressed.len()
216 );
217 }
218
219 #[test]
220 fn should_reject_unsupported_compression_type() {
221 let mut buf = BytesMut::new();
222 buf.put_u8(0xFF); buf.put_u32_le(0);
224 buf.put_u16_le(FORMAT_VERSION);
225 let result = decode_batch(buf.freeze());
226 assert!(
227 matches!(result, Err(Error::Serialization(msg)) if msg.contains("unsupported compression type"))
228 );
229 }
230}