1use std::io::{self, Read, Write};
7
8use crc32fast::Hasher;
9use thiserror::Error;
10
11pub const AOF_MAGIC: &[u8; 4] = b"EAOF";
13
14pub const SNAP_MAGIC: &[u8; 4] = b"ESNP";
16
17pub const FORMAT_VERSION: u8 = 2;
22
23pub const FORMAT_VERSION_ENCRYPTED: u8 = 3;
27
28#[derive(Debug, Error)]
30pub enum FormatError {
31 #[error("unexpected end of file")]
32 UnexpectedEof,
33
34 #[error("invalid magic bytes")]
35 InvalidMagic,
36
37 #[error("unsupported format version: {0}")]
38 UnsupportedVersion(u8),
39
40 #[error("crc32 mismatch (expected {expected:#010x}, got {actual:#010x})")]
41 ChecksumMismatch { expected: u32, actual: u32 },
42
43 #[error("unknown record tag: {0}")]
44 UnknownTag(u8),
45
46 #[error("invalid data: {0}")]
47 InvalidData(String),
48
49 #[error("file is encrypted but no encryption key was provided")]
50 EncryptionRequired,
51
52 #[error("decryption failed (wrong key or tampered data)")]
53 DecryptionFailed,
54
55 #[error("io error: {0}")]
56 Io(#[from] io::Error),
57}
58
59pub fn crc32(data: &[u8]) -> u32 {
61 let mut h = Hasher::new();
62 h.update(data);
63 h.finalize()
64}
65
66pub fn write_u8(w: &mut impl Write, val: u8) -> io::Result<()> {
72 w.write_all(&[val])
73}
74
75pub fn write_u16(w: &mut impl Write, val: u16) -> io::Result<()> {
77 w.write_all(&val.to_le_bytes())
78}
79
80pub fn write_u32(w: &mut impl Write, val: u32) -> io::Result<()> {
82 w.write_all(&val.to_le_bytes())
83}
84
85pub fn write_i64(w: &mut impl Write, val: i64) -> io::Result<()> {
87 w.write_all(&val.to_le_bytes())
88}
89
90pub fn write_f32(w: &mut impl Write, val: f32) -> io::Result<()> {
92 w.write_all(&val.to_le_bytes())
93}
94
95pub fn write_f64(w: &mut impl Write, val: f64) -> io::Result<()> {
97 w.write_all(&val.to_le_bytes())
98}
99
100pub fn write_len(w: &mut impl Write, len: usize) -> io::Result<()> {
102 let len = u32::try_from(len).map_err(|_| {
103 io::Error::new(
104 io::ErrorKind::InvalidInput,
105 format!("collection length {len} exceeds u32::MAX"),
106 )
107 })?;
108 write_u32(w, len)
109}
110
111pub fn write_bytes(w: &mut impl Write, data: &[u8]) -> io::Result<()> {
115 let len = u32::try_from(data.len()).map_err(|_| {
116 io::Error::new(
117 io::ErrorKind::InvalidInput,
118 format!("data length {} exceeds u32::MAX", data.len()),
119 )
120 })?;
121 write_u32(w, len)?;
122 w.write_all(data)
123}
124
125pub fn read_u8(r: &mut impl Read) -> Result<u8, FormatError> {
131 let mut buf = [0u8; 1];
132 read_exact(r, &mut buf)?;
133 Ok(buf[0])
134}
135
136pub fn read_u16(r: &mut impl Read) -> Result<u16, FormatError> {
138 let mut buf = [0u8; 2];
139 read_exact(r, &mut buf)?;
140 Ok(u16::from_le_bytes(buf))
141}
142
143pub fn read_u32(r: &mut impl Read) -> Result<u32, FormatError> {
145 let mut buf = [0u8; 4];
146 read_exact(r, &mut buf)?;
147 Ok(u32::from_le_bytes(buf))
148}
149
150pub fn read_i64(r: &mut impl Read) -> Result<i64, FormatError> {
152 let mut buf = [0u8; 8];
153 read_exact(r, &mut buf)?;
154 Ok(i64::from_le_bytes(buf))
155}
156
157pub fn read_f32(r: &mut impl Read) -> Result<f32, FormatError> {
159 let mut buf = [0u8; 4];
160 read_exact(r, &mut buf)?;
161 Ok(f32::from_le_bytes(buf))
162}
163
164pub fn read_f64(r: &mut impl Read) -> Result<f64, FormatError> {
166 let mut buf = [0u8; 8];
167 read_exact(r, &mut buf)?;
168 Ok(f64::from_le_bytes(buf))
169}
170
171pub const MAX_FIELD_LEN: usize = 512 * 1024 * 1024;
175
176pub fn read_bytes(r: &mut impl Read) -> Result<Vec<u8>, FormatError> {
181 let len = read_u32(r)? as usize;
182 if len > MAX_FIELD_LEN {
183 return Err(FormatError::Io(io::Error::new(
184 io::ErrorKind::InvalidData,
185 format!("field length {len} exceeds maximum of {MAX_FIELD_LEN}"),
186 )));
187 }
188 let mut buf = vec![0u8; len];
189 read_exact(r, &mut buf)?;
190 Ok(buf)
191}
192
193fn read_exact(r: &mut impl Read, buf: &mut [u8]) -> Result<(), FormatError> {
195 r.read_exact(buf).map_err(|e| {
196 if e.kind() == io::ErrorKind::UnexpectedEof {
197 FormatError::UnexpectedEof
198 } else {
199 FormatError::Io(e)
200 }
201 })
202}
203
204pub fn write_header(w: &mut impl Write, magic: &[u8; 4]) -> io::Result<()> {
206 w.write_all(magic)?;
207 write_u8(w, FORMAT_VERSION)
208}
209
210pub fn write_header_versioned(w: &mut impl Write, magic: &[u8; 4], version: u8) -> io::Result<()> {
212 w.write_all(magic)?;
213 write_u8(w, version)
214}
215
216#[cfg(feature = "encryption")]
221const MAX_READABLE_VERSION: u8 = FORMAT_VERSION_ENCRYPTED;
222#[cfg(not(feature = "encryption"))]
223const MAX_READABLE_VERSION: u8 = FORMAT_VERSION;
224
225pub fn read_header(r: &mut impl Read, expected_magic: &[u8; 4]) -> Result<u8, FormatError> {
228 let mut magic = [0u8; 4];
229 read_exact(r, &mut magic)?;
230 if &magic != expected_magic {
231 return Err(FormatError::InvalidMagic);
232 }
233 let version = read_u8(r)?;
234 if version == 0 || version > MAX_READABLE_VERSION {
235 return Err(FormatError::UnsupportedVersion(version));
236 }
237 Ok(version)
238}
239
240pub fn verify_crc32(data: &[u8], expected: u32) -> Result<(), FormatError> {
242 let actual = crc32(data);
243 verify_crc32_values(actual, expected)
244}
245
246pub fn capped_capacity(count: u32) -> usize {
255 (count as usize).min(65_536)
256}
257
258pub const MAX_COLLECTION_COUNT: u32 = 100_000_000;
263
264pub fn validate_collection_count(count: u32, label: &str) -> Result<(), FormatError> {
267 if count > MAX_COLLECTION_COUNT {
268 return Err(FormatError::InvalidData(format!(
269 "{label} count {count} exceeds max {MAX_COLLECTION_COUNT}"
270 )));
271 }
272 Ok(())
273}
274
275pub const MAX_PERSISTED_VECTOR_DIMS: u32 = 65_536;
279
280pub const MAX_PERSISTED_VECTOR_COUNT: u32 = 10_000_000;
283
284pub const MAX_PERSISTED_VECTOR_TOTAL_FLOATS: u64 = 1_000_000_000;
288
289pub fn validate_vector_total(dim: u32, count: u32) -> Result<(), FormatError> {
292 let total = dim as u64 * count as u64;
293 if total > MAX_PERSISTED_VECTOR_TOTAL_FLOATS {
294 return Err(FormatError::InvalidData(format!(
295 "vector total elements ({dim} dims x {count} vectors = {total}) \
296 exceeds max {MAX_PERSISTED_VECTOR_TOTAL_FLOATS}"
297 )));
298 }
299 Ok(())
300}
301
302pub fn verify_crc32_values(computed: u32, stored: u32) -> Result<(), FormatError> {
304 if computed != stored {
305 return Err(FormatError::ChecksumMismatch {
306 expected: stored,
307 actual: computed,
308 });
309 }
310 Ok(())
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use std::io::Cursor;
317
318 #[test]
319 fn u8_round_trip() {
320 let mut buf = Vec::new();
321 write_u8(&mut buf, 42).unwrap();
322 assert_eq!(read_u8(&mut Cursor::new(&buf)).unwrap(), 42);
323 }
324
325 #[test]
326 fn u16_round_trip() {
327 let mut buf = Vec::new();
328 write_u16(&mut buf, 12345).unwrap();
329 assert_eq!(read_u16(&mut Cursor::new(&buf)).unwrap(), 12345);
330 }
331
332 #[test]
333 fn u32_round_trip() {
334 let mut buf = Vec::new();
335 write_u32(&mut buf, 0xDEAD_BEEF).unwrap();
336 assert_eq!(read_u32(&mut Cursor::new(&buf)).unwrap(), 0xDEAD_BEEF);
337 }
338
339 #[test]
340 fn i64_round_trip() {
341 let mut buf = Vec::new();
342 write_i64(&mut buf, -1).unwrap();
343 assert_eq!(read_i64(&mut Cursor::new(&buf)).unwrap(), -1);
344
345 let mut buf2 = Vec::new();
346 write_i64(&mut buf2, i64::MAX).unwrap();
347 assert_eq!(read_i64(&mut Cursor::new(&buf2)).unwrap(), i64::MAX);
348 }
349
350 #[test]
351 fn bytes_round_trip() {
352 let mut buf = Vec::new();
353 write_bytes(&mut buf, b"hello world").unwrap();
354 assert_eq!(read_bytes(&mut Cursor::new(&buf)).unwrap(), b"hello world");
355 }
356
357 #[test]
358 fn empty_bytes_round_trip() {
359 let mut buf = Vec::new();
360 write_bytes(&mut buf, b"").unwrap();
361 assert_eq!(read_bytes(&mut Cursor::new(&buf)).unwrap(), b"");
362 }
363
364 #[test]
365 fn header_round_trip() {
366 let mut buf = Vec::new();
367 write_header(&mut buf, AOF_MAGIC).unwrap();
368 read_header(&mut Cursor::new(&buf), AOF_MAGIC).unwrap();
369 }
370
371 #[test]
372 fn header_wrong_magic() {
373 let mut buf = Vec::new();
374 write_header(&mut buf, AOF_MAGIC).unwrap();
375 let err = read_header(&mut Cursor::new(&buf), SNAP_MAGIC).unwrap_err();
376 assert!(matches!(err, FormatError::InvalidMagic));
377 }
378
379 #[test]
380 fn header_wrong_version() {
381 let buf = vec![b'E', b'A', b'O', b'F', 99];
382 let err = read_header(&mut Cursor::new(&buf), AOF_MAGIC).unwrap_err();
383 assert!(matches!(err, FormatError::UnsupportedVersion(99)));
384 }
385
386 #[test]
387 fn crc32_deterministic() {
388 let a = crc32(b"test data");
389 let b = crc32(b"test data");
390 assert_eq!(a, b);
391 assert_ne!(a, crc32(b"different data"));
392 }
393
394 #[test]
395 fn verify_crc32_pass() {
396 let data = b"check me";
397 let checksum = crc32(data);
398 verify_crc32(data, checksum).unwrap();
399 }
400
401 #[test]
402 fn verify_crc32_fail() {
403 let err = verify_crc32(b"data", 0xBAD).unwrap_err();
404 assert!(matches!(err, FormatError::ChecksumMismatch { .. }));
405 }
406
407 #[test]
408 fn truncated_input_returns_eof() {
409 let buf = [0u8; 2]; let err = read_u32(&mut Cursor::new(&buf)).unwrap_err();
411 assert!(matches!(err, FormatError::UnexpectedEof));
412 }
413
414 #[test]
415 fn empty_input_returns_eof() {
416 let err = read_u8(&mut Cursor::new(&[])).unwrap_err();
417 assert!(matches!(err, FormatError::UnexpectedEof));
418 }
419
420 #[test]
421 fn read_bytes_rejects_oversized_length() {
422 let bogus_len = (MAX_FIELD_LEN as u32) + 1;
424 let mut buf = Vec::new();
425 write_u32(&mut buf, bogus_len).unwrap();
426 let err = read_bytes(&mut Cursor::new(&buf)).unwrap_err();
427 assert!(matches!(err, FormatError::Io(_)));
428 }
429}