Skip to main content

buffer/
model.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use serde::{Deserialize, Serialize};
3
4use crate::error::{Error, Result};
5
6const ENTRY_LEN_SIZE: usize = 4;
7const FORMAT_VERSION: u16 = 1;
8const COMPRESSION_TYPE_SIZE: usize = 1;
9const ENTRIES_COUNT_SIZE: usize = 4;
10const VERSION_SIZE: usize = 2;
11const FOOTER_SIZE: usize = COMPRESSION_TYPE_SIZE + ENTRIES_COUNT_SIZE + VERSION_SIZE;
12
13/// The default ZSTD compression level used when `CompressionType::Zstd` is selected.
14const ZSTD_LEVEL: i32 = 3;
15
16/// Compression algorithm applied to the record block of a data batch.
17#[repr(u8)]
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
19#[serde(rename_all = "lowercase")]
20pub enum CompressionType {
21    #[default]
22    None = 0,
23    Zstd = 1,
24}
25
26impl TryFrom<u8> for CompressionType {
27    type Error = Error;
28
29    fn try_from(value: u8) -> Result<Self> {
30        match value {
31            0 => Ok(CompressionType::None),
32            1 => Ok(CompressionType::Zstd),
33            other => Err(Error::Serialization(format!(
34                "unsupported compression type: {other}"
35            ))),
36        }
37    }
38}
39
40pub(crate) fn encode_batch(entries: &[Bytes], compression: CompressionType) -> Result<Bytes> {
41    let data_size: usize = entries.iter().map(|e| ENTRY_LEN_SIZE + e.len()).sum();
42    let mut entry_buf = BytesMut::with_capacity(data_size);
43
44    for entry in entries {
45        debug_assert!(entry.len() <= u32::MAX as usize);
46        entry_buf.put_u32_le(entry.len() as u32);
47        entry_buf.put_slice(entry);
48    }
49
50    let compressed = match compression {
51        CompressionType::None => entry_buf.freeze(),
52        CompressionType::Zstd => {
53            let compressed = zstd::bulk::compress(&entry_buf, ZSTD_LEVEL)
54                .map_err(|e| Error::Serialization(format!("zstd compression failed: {e}")))?;
55            Bytes::from(compressed)
56        }
57    };
58
59    let mut buf = BytesMut::with_capacity(compressed.len() + FOOTER_SIZE);
60    buf.put_slice(&compressed);
61    buf.put_u8(compression as u8);
62
63    debug_assert!(entries.len() <= u32::MAX as usize);
64    buf.put_u32_le(entries.len() as u32);
65    buf.put_u16_le(FORMAT_VERSION);
66
67    Ok(buf.freeze())
68}
69
70pub(crate) fn decode_batch(mut data: Bytes) -> Result<Vec<Bytes>> {
71    if data.len() < FOOTER_SIZE {
72        return Err(Error::Serialization(
73            "batch too small for footer".to_string(),
74        ));
75    }
76
77    let footer_start = data.len() - FOOTER_SIZE;
78    let mut footer = data.split_off(footer_start);
79
80    let compression_type = CompressionType::try_from(footer.get_u8())?;
81    let record_count = footer.get_u32_le() as usize;
82    let version = footer.get_u16_le();
83
84    if version != FORMAT_VERSION {
85        return Err(Error::Serialization(format!(
86            "unsupported batch version: {}",
87            version
88        )));
89    }
90
91    let mut entry_data = match compression_type {
92        CompressionType::None => data,
93        CompressionType::Zstd => {
94            let decompressed = zstd::stream::decode_all(data.as_ref())
95                .map_err(|e| Error::Serialization(format!("zstd decompression failed: {e}")))?;
96            Bytes::from(decompressed)
97        }
98    };
99
100    let mut entries = Vec::with_capacity(record_count);
101    for _ in 0..record_count {
102        if entry_data.remaining() < ENTRY_LEN_SIZE {
103            return Err(Error::Serialization("truncated record length".to_string()));
104        }
105        let len = entry_data.get_u32_le() as usize;
106        if entry_data.remaining() < len {
107            return Err(Error::Serialization("truncated record data".to_string()));
108        }
109        entries.push(entry_data.split_to(len));
110    }
111
112    Ok(entries)
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn should_roundtrip_batch() {
121        let entries = vec![
122            Bytes::from("hello"),
123            Bytes::from("world"),
124            Bytes::from("foo"),
125        ];
126        let encoded = encode_batch(&entries, CompressionType::None).unwrap();
127        let decoded = decode_batch(encoded).unwrap();
128        assert_eq!(decoded, entries);
129    }
130
131    #[test]
132    fn should_roundtrip_empty_batch() {
133        let entries: Vec<Bytes> = vec![];
134        let encoded = encode_batch(&entries, CompressionType::None).unwrap();
135        assert_eq!(encoded.len(), FOOTER_SIZE);
136        let decoded = decode_batch(encoded).unwrap();
137        assert!(decoded.is_empty());
138    }
139
140    #[test]
141    fn should_roundtrip_empty_record() {
142        let entries = vec![Bytes::new()];
143        let encoded = encode_batch(&entries, CompressionType::None).unwrap();
144        let decoded = decode_batch(encoded).unwrap();
145        assert_eq!(decoded, entries);
146    }
147
148    #[test]
149    fn should_reject_truncated_data() {
150        let entries = vec![Bytes::from("hello")];
151        let mut encoded = BytesMut::from(
152            encode_batch(&entries, CompressionType::None)
153                .unwrap()
154                .as_ref(),
155        );
156        encoded.truncate(encoded.len() - FOOTER_SIZE - 1);
157        encoded.put_u8(CompressionType::None as u8);
158        encoded.put_u32_le(1);
159        encoded.put_u16_le(FORMAT_VERSION);
160        let result = decode_batch(encoded.freeze());
161        assert!(result.is_err());
162    }
163
164    #[test]
165    fn should_reject_unsupported_version() {
166        let mut buf = BytesMut::new();
167        buf.put_u8(0);
168        buf.put_u32_le(0);
169        buf.put_u16_le(99);
170        let result = decode_batch(buf.freeze());
171        assert!(result.is_err());
172    }
173
174    #[test]
175    fn should_roundtrip_batch_with_zstd() {
176        let entries = vec![
177            Bytes::from("hello"),
178            Bytes::from("world"),
179            Bytes::from("foo"),
180        ];
181        let encoded = encode_batch(&entries, CompressionType::Zstd).unwrap();
182        let decoded = decode_batch(encoded).unwrap();
183        assert_eq!(decoded, entries);
184    }
185
186    #[test]
187    fn should_roundtrip_empty_batch_with_zstd() {
188        let entries: Vec<Bytes> = vec![];
189        let encoded = encode_batch(&entries, CompressionType::Zstd).unwrap();
190        let decoded = decode_batch(encoded).unwrap();
191        assert!(decoded.is_empty());
192    }
193
194    #[test]
195    fn should_roundtrip_large_batch_with_zstd() {
196        let entries: Vec<Bytes> = (0..1000)
197            .map(|i| Bytes::from(format!("entry-{:04}", i)))
198            .collect();
199        let encoded = encode_batch(&entries, CompressionType::Zstd).unwrap();
200        let decoded = decode_batch(encoded).unwrap();
201        assert_eq!(decoded, entries);
202    }
203
204    #[test]
205    fn should_compress_smaller_than_uncompressed_for_repetitive_data() {
206        let entries: Vec<Bytes> = (0..100)
207            .map(|_| Bytes::from("repeated-data-that-compresses-well"))
208            .collect();
209        let uncompressed = encode_batch(&entries, CompressionType::None).unwrap();
210        let compressed = encode_batch(&entries, CompressionType::Zstd).unwrap();
211        assert!(
212            compressed.len() < uncompressed.len(),
213            "compressed ({}) should be smaller than uncompressed ({})",
214            compressed.len(),
215            uncompressed.len()
216        );
217    }
218
219    #[test]
220    fn should_reject_unsupported_compression_type() {
221        let mut buf = BytesMut::new();
222        buf.put_u8(0xFF); // unsupported compression type
223        buf.put_u32_le(0);
224        buf.put_u16_le(FORMAT_VERSION);
225        let result = decode_batch(buf.freeze());
226        assert!(
227            matches!(result, Err(Error::Serialization(msg)) if msg.contains("unsupported compression type"))
228        );
229    }
230}