Skip to main content

ember_persistence/
format.rs

1//! Binary format helpers shared across AOF and snapshot files.
2//!
3//! Provides TLV-style encoding primitives, CRC32 checksums, and magic
4//! byte constants. All multi-byte integers are stored in little-endian.
5
6use std::io::{self, Read, Write};
7
8use crc32fast::Hasher;
9use thiserror::Error;
10
11/// Magic bytes for the AOF file header.
12pub const AOF_MAGIC: &[u8; 4] = b"EAOF";
13
14/// Magic bytes for the snapshot file header.
15pub const SNAP_MAGIC: &[u8; 4] = b"ESNP";
16
17/// Current format version for both AOF and snapshot files.
18///
19/// v1: original format (strings only)
20/// v2: type-tagged entries (string, list, sorted set)
21pub const FORMAT_VERSION: u8 = 2;
22
23/// Errors that can occur when reading or writing persistence formats.
24#[derive(Debug, Error)]
25pub enum FormatError {
26    #[error("unexpected end of file")]
27    UnexpectedEof,
28
29    #[error("invalid magic bytes")]
30    InvalidMagic,
31
32    #[error("unsupported format version: {0}")]
33    UnsupportedVersion(u8),
34
35    #[error("crc32 mismatch (expected {expected:#010x}, got {actual:#010x})")]
36    ChecksumMismatch { expected: u32, actual: u32 },
37
38    #[error("unknown record tag: {0}")]
39    UnknownTag(u8),
40
41    #[error("io error: {0}")]
42    Io(#[from] io::Error),
43}
44
45/// Computes a CRC32 checksum over a byte slice.
46pub fn crc32(data: &[u8]) -> u32 {
47    let mut h = Hasher::new();
48    h.update(data);
49    h.finalize()
50}
51
52// ---------------------------------------------------------------------------
53// write helpers
54// ---------------------------------------------------------------------------
55
56/// Writes a `u8` to the writer.
57pub fn write_u8(w: &mut impl Write, val: u8) -> io::Result<()> {
58    w.write_all(&[val])
59}
60
61/// Writes a `u16` in little-endian.
62pub fn write_u16(w: &mut impl Write, val: u16) -> io::Result<()> {
63    w.write_all(&val.to_le_bytes())
64}
65
66/// Writes a `u32` in little-endian.
67pub fn write_u32(w: &mut impl Write, val: u32) -> io::Result<()> {
68    w.write_all(&val.to_le_bytes())
69}
70
71/// Writes an `i64` in little-endian.
72pub fn write_i64(w: &mut impl Write, val: i64) -> io::Result<()> {
73    w.write_all(&val.to_le_bytes())
74}
75
76/// Writes an `f64` in little-endian.
77pub fn write_f64(w: &mut impl Write, val: f64) -> io::Result<()> {
78    w.write_all(&val.to_le_bytes())
79}
80
81/// Writes a length-prefixed byte slice: `[len: u32][data]`.
82pub fn write_bytes(w: &mut impl Write, data: &[u8]) -> io::Result<()> {
83    write_u32(w, data.len() as u32)?;
84    w.write_all(data)
85}
86
87// ---------------------------------------------------------------------------
88// read helpers
89// ---------------------------------------------------------------------------
90
91/// Reads a `u8` from the reader.
92pub fn read_u8(r: &mut impl Read) -> Result<u8, FormatError> {
93    let mut buf = [0u8; 1];
94    read_exact(r, &mut buf)?;
95    Ok(buf[0])
96}
97
98/// Reads a `u16` in little-endian.
99pub fn read_u16(r: &mut impl Read) -> Result<u16, FormatError> {
100    let mut buf = [0u8; 2];
101    read_exact(r, &mut buf)?;
102    Ok(u16::from_le_bytes(buf))
103}
104
105/// Reads a `u32` in little-endian.
106pub fn read_u32(r: &mut impl Read) -> Result<u32, FormatError> {
107    let mut buf = [0u8; 4];
108    read_exact(r, &mut buf)?;
109    Ok(u32::from_le_bytes(buf))
110}
111
112/// Reads an `i64` in little-endian.
113pub fn read_i64(r: &mut impl Read) -> Result<i64, FormatError> {
114    let mut buf = [0u8; 8];
115    read_exact(r, &mut buf)?;
116    Ok(i64::from_le_bytes(buf))
117}
118
119/// Reads an `f64` in little-endian.
120pub fn read_f64(r: &mut impl Read) -> Result<f64, FormatError> {
121    let mut buf = [0u8; 8];
122    read_exact(r, &mut buf)?;
123    Ok(f64::from_le_bytes(buf))
124}
125
126/// Maximum length we'll allocate when reading a length-prefixed field.
127/// 512 MB is generous for any realistic key or value — a corrupt or
128/// malicious length prefix won't cause a multi-gigabyte allocation.
129pub const MAX_FIELD_LEN: usize = 512 * 1024 * 1024;
130
131/// Reads a length-prefixed byte vector: `[len: u32][data]`.
132///
133/// Returns an error if the declared length exceeds [`MAX_FIELD_LEN`]
134/// to prevent unbounded allocations from corrupt data.
135pub fn read_bytes(r: &mut impl Read) -> Result<Vec<u8>, FormatError> {
136    let len = read_u32(r)? as usize;
137    if len > MAX_FIELD_LEN {
138        return Err(FormatError::Io(io::Error::new(
139            io::ErrorKind::InvalidData,
140            format!("field length {len} exceeds maximum of {MAX_FIELD_LEN}"),
141        )));
142    }
143    let mut buf = vec![0u8; len];
144    read_exact(r, &mut buf)?;
145    Ok(buf)
146}
147
148/// Reads exactly `buf.len()` bytes, returning `UnexpectedEof` on short read.
149fn read_exact(r: &mut impl Read, buf: &mut [u8]) -> Result<(), FormatError> {
150    r.read_exact(buf).map_err(|e| {
151        if e.kind() == io::ErrorKind::UnexpectedEof {
152            FormatError::UnexpectedEof
153        } else {
154            FormatError::Io(e)
155        }
156    })
157}
158
159/// Writes a file header: magic bytes + version byte.
160pub fn write_header(w: &mut impl Write, magic: &[u8; 4]) -> io::Result<()> {
161    w.write_all(magic)?;
162    write_u8(w, FORMAT_VERSION)
163}
164
165/// Reads and validates a file header. Returns an error if magic doesn't
166/// match or version is unsupported. Returns the format version.
167pub fn read_header(r: &mut impl Read, expected_magic: &[u8; 4]) -> Result<u8, FormatError> {
168    let mut magic = [0u8; 4];
169    read_exact(r, &mut magic)?;
170    if &magic != expected_magic {
171        return Err(FormatError::InvalidMagic);
172    }
173    let version = read_u8(r)?;
174    if version == 0 || version > FORMAT_VERSION {
175        return Err(FormatError::UnsupportedVersion(version));
176    }
177    Ok(version)
178}
179
180/// Verifies that `data` matches the expected CRC32 checksum.
181pub fn verify_crc32(data: &[u8], expected: u32) -> Result<(), FormatError> {
182    let actual = crc32(data);
183    verify_crc32_values(actual, expected)
184}
185
186/// Verifies that two CRC32 values match.
187pub fn verify_crc32_values(computed: u32, stored: u32) -> Result<(), FormatError> {
188    if computed != stored {
189        return Err(FormatError::ChecksumMismatch {
190            expected: stored,
191            actual: computed,
192        });
193    }
194    Ok(())
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use std::io::Cursor;
201
202    #[test]
203    fn u8_round_trip() {
204        let mut buf = Vec::new();
205        write_u8(&mut buf, 42).unwrap();
206        assert_eq!(read_u8(&mut Cursor::new(&buf)).unwrap(), 42);
207    }
208
209    #[test]
210    fn u16_round_trip() {
211        let mut buf = Vec::new();
212        write_u16(&mut buf, 12345).unwrap();
213        assert_eq!(read_u16(&mut Cursor::new(&buf)).unwrap(), 12345);
214    }
215
216    #[test]
217    fn u32_round_trip() {
218        let mut buf = Vec::new();
219        write_u32(&mut buf, 0xDEAD_BEEF).unwrap();
220        assert_eq!(read_u32(&mut Cursor::new(&buf)).unwrap(), 0xDEAD_BEEF);
221    }
222
223    #[test]
224    fn i64_round_trip() {
225        let mut buf = Vec::new();
226        write_i64(&mut buf, -1).unwrap();
227        assert_eq!(read_i64(&mut Cursor::new(&buf)).unwrap(), -1);
228
229        let mut buf2 = Vec::new();
230        write_i64(&mut buf2, i64::MAX).unwrap();
231        assert_eq!(read_i64(&mut Cursor::new(&buf2)).unwrap(), i64::MAX);
232    }
233
234    #[test]
235    fn bytes_round_trip() {
236        let mut buf = Vec::new();
237        write_bytes(&mut buf, b"hello world").unwrap();
238        assert_eq!(read_bytes(&mut Cursor::new(&buf)).unwrap(), b"hello world");
239    }
240
241    #[test]
242    fn empty_bytes_round_trip() {
243        let mut buf = Vec::new();
244        write_bytes(&mut buf, b"").unwrap();
245        assert_eq!(read_bytes(&mut Cursor::new(&buf)).unwrap(), b"");
246    }
247
248    #[test]
249    fn header_round_trip() {
250        let mut buf = Vec::new();
251        write_header(&mut buf, AOF_MAGIC).unwrap();
252        read_header(&mut Cursor::new(&buf), AOF_MAGIC).unwrap();
253    }
254
255    #[test]
256    fn header_wrong_magic() {
257        let mut buf = Vec::new();
258        write_header(&mut buf, AOF_MAGIC).unwrap();
259        let err = read_header(&mut Cursor::new(&buf), SNAP_MAGIC).unwrap_err();
260        assert!(matches!(err, FormatError::InvalidMagic));
261    }
262
263    #[test]
264    fn header_wrong_version() {
265        let buf = vec![b'E', b'A', b'O', b'F', 99];
266        let err = read_header(&mut Cursor::new(&buf), AOF_MAGIC).unwrap_err();
267        assert!(matches!(err, FormatError::UnsupportedVersion(99)));
268    }
269
270    #[test]
271    fn crc32_deterministic() {
272        let a = crc32(b"test data");
273        let b = crc32(b"test data");
274        assert_eq!(a, b);
275        assert_ne!(a, crc32(b"different data"));
276    }
277
278    #[test]
279    fn verify_crc32_pass() {
280        let data = b"check me";
281        let checksum = crc32(data);
282        verify_crc32(data, checksum).unwrap();
283    }
284
285    #[test]
286    fn verify_crc32_fail() {
287        let err = verify_crc32(b"data", 0xBAD).unwrap_err();
288        assert!(matches!(err, FormatError::ChecksumMismatch { .. }));
289    }
290
291    #[test]
292    fn truncated_input_returns_eof() {
293        let buf = [0u8; 2]; // too short for u32
294        let err = read_u32(&mut Cursor::new(&buf)).unwrap_err();
295        assert!(matches!(err, FormatError::UnexpectedEof));
296    }
297
298    #[test]
299    fn empty_input_returns_eof() {
300        let err = read_u8(&mut Cursor::new(&[])).unwrap_err();
301        assert!(matches!(err, FormatError::UnexpectedEof));
302    }
303
304    #[test]
305    fn read_bytes_rejects_oversized_length() {
306        // encode a length that exceeds MAX_FIELD_LEN
307        let bogus_len = (MAX_FIELD_LEN as u32) + 1;
308        let mut buf = Vec::new();
309        write_u32(&mut buf, bogus_len).unwrap();
310        let err = read_bytes(&mut Cursor::new(&buf)).unwrap_err();
311        assert!(matches!(err, FormatError::Io(_)));
312    }
313}