Skip to main content

memvid_core/io/
header.rs

1use std::{
2    convert::TryInto,
3    io::{Read, Seek, SeekFrom, Write},
4};
5
6use crate::{
7    constants::{HEADER_SIZE, MAGIC, SPEC_MAJOR, SPEC_MINOR, WAL_OFFSET},
8    error::{MemvidError, Result},
9    types::Header,
10};
11
12const VERSION_OFFSET: usize = 4;
13const SPEC_BYTES_OFFSET: usize = 6;
14const FOOTER_OFFSET_POS: usize = 8;
15const WAL_OFFSET_POS: usize = 16;
16const WAL_SIZE_POS: usize = 24;
17const WAL_CHECKPOINT_POS: usize = 32;
18const WAL_SEQUENCE_POS: usize = 40;
19const TOC_CHECKSUM_POS: usize = 48;
20const TOC_CHECKSUM_END: usize = 80;
21// Legacy lock metadata occupied bytes 80..140 within the header padding.
22const LEGACY_LOCK_REGION_START: usize = TOC_CHECKSUM_END;
23const LEGACY_LOCK_REGION_END: usize = LEGACY_LOCK_REGION_START + 60;
24const EXPECTED_VERSION: u16 = ((SPEC_MAJOR as u16) << 8) | SPEC_MINOR as u16;
25
26/// Deterministic encoder/decoder for the fixed-size header region.
27pub struct HeaderCodec;
28
29impl HeaderCodec {
30    /// Writes the header back to the beginning of the file, zero-filling any unused bytes.
31    pub fn write<W: Write + Seek>(mut writer: W, header: &Header) -> Result<()> {
32        let bytes = Self::encode(header)?;
33        writer.seek(SeekFrom::Start(0))?;
34        writer.write_all(&bytes)?;
35        Ok(())
36    }
37
38    /// Reads and decodes the header from the start of the file. If legacy lock metadata is present
39    /// in the reserved padding, it is cleared in-place before decoding to maintain forward
40    /// compatibility with older files.
41    pub fn read<R: Read + Write + Seek>(mut reader: R) -> Result<Header> {
42        let mut buf = [0u8; HEADER_SIZE];
43        reader.seek(SeekFrom::Start(0))?;
44        reader.read_exact(&mut buf)?;
45        if clear_legacy_lock_metadata(&mut buf) {
46            reader.seek(SeekFrom::Start(0))?;
47            reader.write_all(&buf)?;
48            reader.flush()?;
49        }
50        Self::decode(&buf)
51    }
52
53    /// Encodes a header into the canonical 4 KB byte representation.
54    pub fn encode(header: &Header) -> Result<[u8; HEADER_SIZE]> {
55        if header.magic != MAGIC {
56            return Err(MemvidError::InvalidHeader {
57                reason: "magic mismatch".into(),
58            });
59        }
60        if header.version != EXPECTED_VERSION {
61            return Err(MemvidError::InvalidHeader {
62                reason: "unsupported version".into(),
63            });
64        }
65        if header.wal_offset < WAL_OFFSET {
66            return Err(MemvidError::InvalidHeader {
67                reason: "wal_offset precedes data region".into(),
68            });
69        }
70        if header.wal_size == 0 {
71            return Err(MemvidError::InvalidHeader {
72                reason: "wal_size must be non-zero".into(),
73            });
74        }
75
76        let mut buf = [0u8; HEADER_SIZE];
77        buf[..MAGIC.len()].copy_from_slice(&header.magic);
78        buf[VERSION_OFFSET..VERSION_OFFSET + 2].copy_from_slice(&header.version.to_le_bytes());
79        buf[SPEC_BYTES_OFFSET] = SPEC_MAJOR;
80        buf[SPEC_BYTES_OFFSET + 1] = SPEC_MINOR;
81        buf[FOOTER_OFFSET_POS..FOOTER_OFFSET_POS + 8]
82            .copy_from_slice(&header.footer_offset.to_le_bytes());
83        buf[WAL_OFFSET_POS..WAL_OFFSET_POS + 8].copy_from_slice(&header.wal_offset.to_le_bytes());
84        buf[WAL_SIZE_POS..WAL_SIZE_POS + 8].copy_from_slice(&header.wal_size.to_le_bytes());
85        buf[WAL_CHECKPOINT_POS..WAL_CHECKPOINT_POS + 8]
86            .copy_from_slice(&header.wal_checkpoint_pos.to_le_bytes());
87        buf[WAL_SEQUENCE_POS..WAL_SEQUENCE_POS + 8]
88            .copy_from_slice(&header.wal_sequence.to_le_bytes());
89        buf[TOC_CHECKSUM_POS..TOC_CHECKSUM_END].copy_from_slice(&header.toc_checksum);
90        Ok(buf)
91    }
92
93    /// Decodes the canonical header bytes into a strongly typed struct after validation.
94    pub fn decode(bytes: &[u8; HEADER_SIZE]) -> Result<Header> {
95        // Extract fixed-size arrays from the header buffer
96        // All indices are compile-time constants, so these slices are guaranteed to fit
97        let magic: [u8; 4] = extract_array(bytes, 0)?;
98        if magic != MAGIC {
99            return Err(MemvidError::InvalidHeader {
100                reason: "magic mismatch".into(),
101            });
102        }
103
104        let version = u16::from_le_bytes(extract_array(bytes, VERSION_OFFSET)?);
105        if version != EXPECTED_VERSION {
106            return Err(MemvidError::InvalidHeader {
107                reason: "unsupported version".into(),
108            });
109        }
110
111        if bytes[SPEC_BYTES_OFFSET] != SPEC_MAJOR || bytes[SPEC_BYTES_OFFSET + 1] != SPEC_MINOR {
112            return Err(MemvidError::InvalidHeader {
113                reason: "spec byte mismatch".into(),
114            });
115        }
116
117        let footer_offset = u64::from_le_bytes(extract_array(bytes, FOOTER_OFFSET_POS)?);
118        let wal_offset = u64::from_le_bytes(extract_array(bytes, WAL_OFFSET_POS)?);
119        if wal_offset < WAL_OFFSET {
120            return Err(MemvidError::InvalidHeader {
121                reason: "wal_offset precedes data region".into(),
122            });
123        }
124        let wal_size = u64::from_le_bytes(extract_array(bytes, WAL_SIZE_POS)?);
125        if wal_size == 0 {
126            return Err(MemvidError::InvalidHeader {
127                reason: "wal_size must be non-zero".into(),
128            });
129        }
130        let wal_checkpoint_pos = u64::from_le_bytes(extract_array(bytes, WAL_CHECKPOINT_POS)?);
131        let wal_sequence = u64::from_le_bytes(extract_array(bytes, WAL_SEQUENCE_POS)?);
132        let toc_checksum: [u8; 32] = extract_array(bytes, TOC_CHECKSUM_POS)?;
133
134        Ok(Header {
135            magic,
136            version,
137            footer_offset,
138            wal_offset,
139            wal_size,
140            wal_checkpoint_pos,
141            wal_sequence,
142            toc_checksum,
143        })
144    }
145}
146
147/// Extracts a fixed-size array from a byte slice at the given offset.
148/// Returns an error if the slice is too short (should never happen with valid headers).
149#[inline]
150fn extract_array<const N: usize>(bytes: &[u8], offset: usize) -> Result<[u8; N]> {
151    bytes
152        .get(offset..offset + N)
153        .and_then(|s| s.try_into().ok())
154        .ok_or_else(|| MemvidError::InvalidHeader {
155            reason: "header truncated".into(),
156        })
157}
158
159fn clear_legacy_lock_metadata(buf: &mut [u8; HEADER_SIZE]) -> bool {
160    let region = &mut buf[LEGACY_LOCK_REGION_START..LEGACY_LOCK_REGION_END];
161    if region.iter().any(|byte| *byte != 0) {
162        region.fill(0);
163        true
164    } else {
165        false
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use std::io::Cursor;
173
174    fn sample_header() -> Header {
175        Header {
176            magic: MAGIC,
177            version: EXPECTED_VERSION,
178            footer_offset: 1_048_576,
179            wal_offset: WAL_OFFSET,
180            wal_size: 4 * 1024 * 1024,
181            wal_checkpoint_pos: 0,
182            wal_sequence: 42,
183            toc_checksum: [0xAB; 32],
184        }
185    }
186
187    #[test]
188    fn roundtrip_encode_decode() {
189        let header = sample_header();
190        let encoded = HeaderCodec::encode(&header).expect("encode header");
191        let decoded = HeaderCodec::decode(&encoded).expect("decode header");
192        assert_eq!(decoded.magic, MAGIC);
193        assert_eq!(decoded.version, EXPECTED_VERSION);
194        assert_eq!(decoded.footer_offset, header.footer_offset);
195        assert_eq!(decoded.wal_offset, WAL_OFFSET);
196        assert_eq!(decoded.toc_checksum, header.toc_checksum);
197    }
198
199    #[test]
200    fn read_write_from_cursor() {
201        let header = sample_header();
202        let mut cursor = Cursor::new(vec![0u8; HEADER_SIZE]);
203        HeaderCodec::write(&mut cursor, &header).expect("write header");
204        cursor.set_position(0);
205        let decoded = HeaderCodec::read(&mut cursor).expect("read header");
206        assert_eq!(decoded.wal_size, header.wal_size);
207        assert_eq!(decoded.wal_sequence, header.wal_sequence);
208    }
209
210    #[test]
211    fn clears_legacy_lock_metadata() {
212        let header = sample_header();
213        let mut encoded = HeaderCodec::encode(&header).expect("encode header");
214        encoded[LEGACY_LOCK_REGION_START..LEGACY_LOCK_REGION_END].fill(0xAA);
215        let mut cursor = Cursor::new(encoded.to_vec());
216        HeaderCodec::read(&mut cursor).expect("read header with legacy metadata");
217        let sanitized = cursor.into_inner();
218        assert!(
219            sanitized[LEGACY_LOCK_REGION_START..LEGACY_LOCK_REGION_END]
220                .iter()
221                .all(|byte| *byte == 0)
222        );
223    }
224
225    #[test]
226    fn reject_invalid_magic() {
227        let mut header = sample_header();
228        header.magic = *b"BAD!";
229        let err = HeaderCodec::encode(&header).expect_err("should fail");
230        matches!(err, MemvidError::InvalidHeader { .. });
231    }
232
233    #[test]
234    fn reject_short_wal_size() {
235        let mut header = sample_header();
236        header.wal_size = 0;
237        let err = HeaderCodec::encode(&header).expect_err("should fail");
238        matches!(err, MemvidError::InvalidHeader { .. });
239    }
240
241    #[test]
242    fn reject_decoding_with_bad_version() {
243        let header = sample_header();
244        let mut encoded = HeaderCodec::encode(&header).expect("encode header");
245        encoded[VERSION_OFFSET] = 0xFF;
246        let err = HeaderCodec::decode(&encoded).expect_err("decode should fail");
247        matches!(err, MemvidError::InvalidHeader { .. });
248    }
249}