Skip to main content

ember_persistence/
format.rs

1//! Binary format helpers shared across AOF and snapshot files.
2//!
3//! Provides TLV-style encoding primitives, CRC32 checksums, and magic
4//! byte constants. All multi-byte integers are stored in little-endian.
5
6use std::io::{self, Read, Write};
7
8use crc32fast::Hasher;
9use thiserror::Error;
10
11/// Magic bytes for the AOF file header.
12pub const AOF_MAGIC: &[u8; 4] = b"EAOF";
13
14/// Magic bytes for the snapshot file header.
15pub const SNAP_MAGIC: &[u8; 4] = b"ESNP";
16
17/// Current unencrypted format version.
18///
19/// v1: original format (strings only)
20/// v2: type-tagged entries (string, list, sorted set, hash, set)
21pub const FORMAT_VERSION: u8 = 2;
22
23/// Format version for encrypted files.
24///
25/// v3: per-record AES-256-GCM encryption (requires `encryption` feature)
26pub const FORMAT_VERSION_ENCRYPTED: u8 = 3;
27
28/// Errors that can occur when reading or writing persistence formats.
29#[derive(Debug, Error)]
30pub enum FormatError {
31    #[error("unexpected end of file")]
32    UnexpectedEof,
33
34    #[error("invalid magic bytes")]
35    InvalidMagic,
36
37    #[error("unsupported format version: {0}")]
38    UnsupportedVersion(u8),
39
40    #[error("crc32 mismatch (expected {expected:#010x}, got {actual:#010x})")]
41    ChecksumMismatch { expected: u32, actual: u32 },
42
43    #[error("unknown record tag: {0}")]
44    UnknownTag(u8),
45
46    #[error("invalid data: {0}")]
47    InvalidData(String),
48
49    #[error("file is encrypted but no encryption key was provided")]
50    EncryptionRequired,
51
52    #[error("decryption failed (wrong key or tampered data)")]
53    DecryptionFailed,
54
55    #[error("io error: {0}")]
56    Io(#[from] io::Error),
57}
58
59/// Computes a CRC32 checksum over a byte slice.
60pub fn crc32(data: &[u8]) -> u32 {
61    let mut h = Hasher::new();
62    h.update(data);
63    h.finalize()
64}
65
66// ---------------------------------------------------------------------------
67// write helpers
68// ---------------------------------------------------------------------------
69
70/// Writes a `u8` to the writer.
71pub fn write_u8(w: &mut impl Write, val: u8) -> io::Result<()> {
72    w.write_all(&[val])
73}
74
75/// Writes a `u16` in little-endian.
76pub fn write_u16(w: &mut impl Write, val: u16) -> io::Result<()> {
77    w.write_all(&val.to_le_bytes())
78}
79
80/// Writes a `u32` in little-endian.
81pub fn write_u32(w: &mut impl Write, val: u32) -> io::Result<()> {
82    w.write_all(&val.to_le_bytes())
83}
84
85/// Writes an `i64` in little-endian.
86pub fn write_i64(w: &mut impl Write, val: i64) -> io::Result<()> {
87    w.write_all(&val.to_le_bytes())
88}
89
90/// Writes an `f32` in little-endian.
91pub fn write_f32(w: &mut impl Write, val: f32) -> io::Result<()> {
92    w.write_all(&val.to_le_bytes())
93}
94
95/// Writes an `f64` in little-endian.
96pub fn write_f64(w: &mut impl Write, val: f64) -> io::Result<()> {
97    w.write_all(&val.to_le_bytes())
98}
99
100/// Writes a collection length as u32, returning an error if it exceeds `u32::MAX`.
101pub fn write_len(w: &mut impl Write, len: usize) -> io::Result<()> {
102    let len = u32::try_from(len).map_err(|_| {
103        io::Error::new(
104            io::ErrorKind::InvalidInput,
105            format!("collection length {len} exceeds u32::MAX"),
106        )
107    })?;
108    write_u32(w, len)
109}
110
111/// Writes a length-prefixed byte slice: `[len: u32][data]`.
112///
113/// Returns an error if the data length exceeds `u32::MAX`.
114pub fn write_bytes(w: &mut impl Write, data: &[u8]) -> io::Result<()> {
115    let len = u32::try_from(data.len()).map_err(|_| {
116        io::Error::new(
117            io::ErrorKind::InvalidInput,
118            format!("data length {} exceeds u32::MAX", data.len()),
119        )
120    })?;
121    write_u32(w, len)?;
122    w.write_all(data)
123}
124
125// ---------------------------------------------------------------------------
126// read helpers
127// ---------------------------------------------------------------------------
128
129/// Reads a `u8` from the reader.
130pub fn read_u8(r: &mut impl Read) -> Result<u8, FormatError> {
131    let mut buf = [0u8; 1];
132    read_exact(r, &mut buf)?;
133    Ok(buf[0])
134}
135
136/// Reads a `u16` in little-endian.
137pub fn read_u16(r: &mut impl Read) -> Result<u16, FormatError> {
138    let mut buf = [0u8; 2];
139    read_exact(r, &mut buf)?;
140    Ok(u16::from_le_bytes(buf))
141}
142
143/// Reads a `u32` in little-endian.
144pub fn read_u32(r: &mut impl Read) -> Result<u32, FormatError> {
145    let mut buf = [0u8; 4];
146    read_exact(r, &mut buf)?;
147    Ok(u32::from_le_bytes(buf))
148}
149
150/// Reads an `i64` in little-endian.
151pub fn read_i64(r: &mut impl Read) -> Result<i64, FormatError> {
152    let mut buf = [0u8; 8];
153    read_exact(r, &mut buf)?;
154    Ok(i64::from_le_bytes(buf))
155}
156
157/// Reads an `f32` in little-endian.
158pub fn read_f32(r: &mut impl Read) -> Result<f32, FormatError> {
159    let mut buf = [0u8; 4];
160    read_exact(r, &mut buf)?;
161    Ok(f32::from_le_bytes(buf))
162}
163
164/// Reads an `f64` in little-endian.
165pub fn read_f64(r: &mut impl Read) -> Result<f64, FormatError> {
166    let mut buf = [0u8; 8];
167    read_exact(r, &mut buf)?;
168    Ok(f64::from_le_bytes(buf))
169}
170
171/// Maximum length we'll allocate when reading a length-prefixed field.
172/// 512 MB is generous for any realistic key or value — a corrupt or
173/// malicious length prefix won't cause a multi-gigabyte allocation.
174pub const MAX_FIELD_LEN: usize = 512 * 1024 * 1024;
175
176/// Reads a length-prefixed byte vector: `[len: u32][data]`.
177///
178/// Returns an error if the declared length exceeds [`MAX_FIELD_LEN`]
179/// to prevent unbounded allocations from corrupt data.
180pub fn read_bytes(r: &mut impl Read) -> Result<Vec<u8>, FormatError> {
181    let len = read_u32(r)? as usize;
182    if len > MAX_FIELD_LEN {
183        return Err(FormatError::Io(io::Error::new(
184            io::ErrorKind::InvalidData,
185            format!("field length {len} exceeds maximum of {MAX_FIELD_LEN}"),
186        )));
187    }
188    let mut buf = vec![0u8; len];
189    read_exact(r, &mut buf)?;
190    Ok(buf)
191}
192
193/// Reads exactly `buf.len()` bytes, returning `UnexpectedEof` on short read.
194fn read_exact(r: &mut impl Read, buf: &mut [u8]) -> Result<(), FormatError> {
195    r.read_exact(buf).map_err(|e| {
196        if e.kind() == io::ErrorKind::UnexpectedEof {
197            FormatError::UnexpectedEof
198        } else {
199            FormatError::Io(e)
200        }
201    })
202}
203
204/// Writes a file header: magic bytes + version byte.
205pub fn write_header(w: &mut impl Write, magic: &[u8; 4]) -> io::Result<()> {
206    w.write_all(magic)?;
207    write_u8(w, FORMAT_VERSION)
208}
209
210/// Writes a file header with an explicit version byte.
211pub fn write_header_versioned(w: &mut impl Write, magic: &[u8; 4], version: u8) -> io::Result<()> {
212    w.write_all(magic)?;
213    write_u8(w, version)
214}
215
216/// The maximum format version this build can read.
217///
218/// When the `encryption` feature is compiled in, v3 (encrypted) files
219/// are supported. Without the feature, only v1 and v2 are accepted.
220#[cfg(feature = "encryption")]
221const MAX_READABLE_VERSION: u8 = FORMAT_VERSION_ENCRYPTED;
222#[cfg(not(feature = "encryption"))]
223const MAX_READABLE_VERSION: u8 = FORMAT_VERSION;
224
225/// Reads and validates a file header. Returns an error if magic doesn't
226/// match or version is unsupported. Returns the format version.
227pub fn read_header(r: &mut impl Read, expected_magic: &[u8; 4]) -> Result<u8, FormatError> {
228    let mut magic = [0u8; 4];
229    read_exact(r, &mut magic)?;
230    if &magic != expected_magic {
231        return Err(FormatError::InvalidMagic);
232    }
233    let version = read_u8(r)?;
234    if version == 0 || version > MAX_READABLE_VERSION {
235        return Err(FormatError::UnsupportedVersion(version));
236    }
237    Ok(version)
238}
239
240/// Verifies that `data` matches the expected CRC32 checksum.
241pub fn verify_crc32(data: &[u8], expected: u32) -> Result<(), FormatError> {
242    let actual = crc32(data);
243    verify_crc32_values(actual, expected)
244}
245
246/// Caps pre-allocation to avoid huge allocations from corrupt count fields.
247///
248/// The loop will still iterate `count` times — this just limits the
249/// up-front `Vec::with_capacity` reservation so a bogus u32 (up to 4 billion)
250/// can't cause a multi-gigabyte allocation on the first call. The cap of
251/// 65,536 is a pragmatic choice: large enough that realistic collections
252/// never re-allocate during deserialization, small enough that even a
253/// corrupt count only wastes ~1 MB (65k × 16 bytes per element).
254pub fn capped_capacity(count: u32) -> usize {
255    (count as usize).min(65_536)
256}
257
258/// Maximum element count for collections (lists, sets, hashes, sorted sets)
259/// in persistence formats. Prevents corrupt count fields from causing
260/// unbounded iteration during deserialization. 100M is well beyond any
261/// realistic collection while catching obviously corrupt u32 values.
262pub const MAX_COLLECTION_COUNT: u32 = 100_000_000;
263
264/// Validates that a deserialized collection count is within bounds.
265/// Returns `InvalidData` if the count exceeds `MAX_COLLECTION_COUNT`.
266pub fn validate_collection_count(count: u32, label: &str) -> Result<(), FormatError> {
267    if count > MAX_COLLECTION_COUNT {
268        return Err(FormatError::InvalidData(format!(
269            "{label} count {count} exceeds max {MAX_COLLECTION_COUNT}"
270        )));
271    }
272    Ok(())
273}
274
275/// Maximum vector dimensions allowed in persistence formats.
276/// Matches the protocol-layer cap. Records exceeding this are rejected
277/// during deserialization to prevent OOM from corrupt files.
278pub const MAX_PERSISTED_VECTOR_DIMS: u32 = 65_536;
279
280/// Maximum element count per vector set in persistence formats.
281/// Prevents corrupt count fields from causing unbounded loops.
282pub const MAX_PERSISTED_VECTOR_COUNT: u32 = 10_000_000;
283
284/// Maximum total f32 elements (dim * count) for vector deserialization.
285/// Caps total allocation at ~4 GB. Without this, a crafted file with
286/// 65536 dims x 10M vectors would attempt ~2.6 TB.
287pub const MAX_PERSISTED_VECTOR_TOTAL_FLOATS: u64 = 1_000_000_000;
288
289/// Validates that the total vector element budget (dim * count) is within
290/// bounds. Call after validating dim and count individually.
291pub fn validate_vector_total(dim: u32, count: u32) -> Result<(), FormatError> {
292    let total = dim as u64 * count as u64;
293    if total > MAX_PERSISTED_VECTOR_TOTAL_FLOATS {
294        return Err(FormatError::InvalidData(format!(
295            "vector total elements ({dim} dims x {count} vectors = {total}) \
296             exceeds max {MAX_PERSISTED_VECTOR_TOTAL_FLOATS}"
297        )));
298    }
299    Ok(())
300}
301
302/// Verifies that two CRC32 values match.
303pub fn verify_crc32_values(computed: u32, stored: u32) -> Result<(), FormatError> {
304    if computed != stored {
305        return Err(FormatError::ChecksumMismatch {
306            expected: stored,
307            actual: computed,
308        });
309    }
310    Ok(())
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use std::io::Cursor;
317
318    #[test]
319    fn u8_round_trip() {
320        let mut buf = Vec::new();
321        write_u8(&mut buf, 42).unwrap();
322        assert_eq!(read_u8(&mut Cursor::new(&buf)).unwrap(), 42);
323    }
324
325    #[test]
326    fn u16_round_trip() {
327        let mut buf = Vec::new();
328        write_u16(&mut buf, 12345).unwrap();
329        assert_eq!(read_u16(&mut Cursor::new(&buf)).unwrap(), 12345);
330    }
331
332    #[test]
333    fn u32_round_trip() {
334        let mut buf = Vec::new();
335        write_u32(&mut buf, 0xDEAD_BEEF).unwrap();
336        assert_eq!(read_u32(&mut Cursor::new(&buf)).unwrap(), 0xDEAD_BEEF);
337    }
338
339    #[test]
340    fn i64_round_trip() {
341        let mut buf = Vec::new();
342        write_i64(&mut buf, -1).unwrap();
343        assert_eq!(read_i64(&mut Cursor::new(&buf)).unwrap(), -1);
344
345        let mut buf2 = Vec::new();
346        write_i64(&mut buf2, i64::MAX).unwrap();
347        assert_eq!(read_i64(&mut Cursor::new(&buf2)).unwrap(), i64::MAX);
348    }
349
350    #[test]
351    fn bytes_round_trip() {
352        let mut buf = Vec::new();
353        write_bytes(&mut buf, b"hello world").unwrap();
354        assert_eq!(read_bytes(&mut Cursor::new(&buf)).unwrap(), b"hello world");
355    }
356
357    #[test]
358    fn empty_bytes_round_trip() {
359        let mut buf = Vec::new();
360        write_bytes(&mut buf, b"").unwrap();
361        assert_eq!(read_bytes(&mut Cursor::new(&buf)).unwrap(), b"");
362    }
363
364    #[test]
365    fn header_round_trip() {
366        let mut buf = Vec::new();
367        write_header(&mut buf, AOF_MAGIC).unwrap();
368        read_header(&mut Cursor::new(&buf), AOF_MAGIC).unwrap();
369    }
370
371    #[test]
372    fn header_wrong_magic() {
373        let mut buf = Vec::new();
374        write_header(&mut buf, AOF_MAGIC).unwrap();
375        let err = read_header(&mut Cursor::new(&buf), SNAP_MAGIC).unwrap_err();
376        assert!(matches!(err, FormatError::InvalidMagic));
377    }
378
379    #[test]
380    fn header_wrong_version() {
381        let buf = vec![b'E', b'A', b'O', b'F', 99];
382        let err = read_header(&mut Cursor::new(&buf), AOF_MAGIC).unwrap_err();
383        assert!(matches!(err, FormatError::UnsupportedVersion(99)));
384    }
385
386    #[test]
387    fn crc32_deterministic() {
388        let a = crc32(b"test data");
389        let b = crc32(b"test data");
390        assert_eq!(a, b);
391        assert_ne!(a, crc32(b"different data"));
392    }
393
394    #[test]
395    fn verify_crc32_pass() {
396        let data = b"check me";
397        let checksum = crc32(data);
398        verify_crc32(data, checksum).unwrap();
399    }
400
401    #[test]
402    fn verify_crc32_fail() {
403        let err = verify_crc32(b"data", 0xBAD).unwrap_err();
404        assert!(matches!(err, FormatError::ChecksumMismatch { .. }));
405    }
406
407    #[test]
408    fn truncated_input_returns_eof() {
409        let buf = [0u8; 2]; // too short for u32
410        let err = read_u32(&mut Cursor::new(&buf)).unwrap_err();
411        assert!(matches!(err, FormatError::UnexpectedEof));
412    }
413
414    #[test]
415    fn empty_input_returns_eof() {
416        let err = read_u8(&mut Cursor::new(&[])).unwrap_err();
417        assert!(matches!(err, FormatError::UnexpectedEof));
418    }
419
420    #[test]
421    fn read_bytes_rejects_oversized_length() {
422        // encode a length that exceeds MAX_FIELD_LEN
423        let bogus_len = (MAX_FIELD_LEN as u32) + 1;
424        let mut buf = Vec::new();
425        write_u32(&mut buf, bogus_len).unwrap();
426        let err = read_bytes(&mut Cursor::new(&buf)).unwrap_err();
427        assert!(matches!(err, FormatError::Io(_)));
428    }
429}