memvid_core/io/
header.rs

1use std::{
2    convert::TryInto,
3    io::{Read, Seek, SeekFrom, Write},
4};
5
6use crate::{
7    constants::{HEADER_SIZE, MAGIC, SPEC_MAJOR, SPEC_MINOR, WAL_OFFSET},
8    error::{MemvidError, Result},
9    types::Header,
10};
11
12const VERSION_OFFSET: usize = 4;
13const SPEC_BYTES_OFFSET: usize = 6;
14const FOOTER_OFFSET_POS: usize = 8;
15const WAL_OFFSET_POS: usize = 16;
16const WAL_SIZE_POS: usize = 24;
17const WAL_CHECKPOINT_POS: usize = 32;
18const WAL_SEQUENCE_POS: usize = 40;
19const TOC_CHECKSUM_POS: usize = 48;
20const TOC_CHECKSUM_END: usize = 80;
21// Legacy lock metadata occupied bytes 80..140 within the header padding.
22const LEGACY_LOCK_REGION_START: usize = TOC_CHECKSUM_END;
23const LEGACY_LOCK_REGION_END: usize = LEGACY_LOCK_REGION_START + 60;
24const EXPECTED_VERSION: u16 = ((SPEC_MAJOR as u16) << 8) | SPEC_MINOR as u16;
25
26/// Deterministic encoder/decoder for the fixed-size header region.
27pub struct HeaderCodec;
28
29impl HeaderCodec {
30    /// Writes the header back to the beginning of the file, zero-filling any unused bytes.
31    pub fn write<W: Write + Seek>(mut writer: W, header: &Header) -> Result<()> {
32        let bytes = Self::encode(header)?;
33        writer.seek(SeekFrom::Start(0))?;
34        writer.write_all(&bytes)?;
35        Ok(())
36    }
37
38    /// Reads and decodes the header from the start of the file. If legacy lock metadata is present
39    /// in the reserved padding, it is cleared in-place before decoding to maintain forward
40    /// compatibility with older files.
41    pub fn read<R: Read + Write + Seek>(mut reader: R) -> Result<Header> {
42        let mut buf = [0u8; HEADER_SIZE];
43        reader.seek(SeekFrom::Start(0))?;
44        reader.read_exact(&mut buf)?;
45        if clear_legacy_lock_metadata(&mut buf) {
46            reader.seek(SeekFrom::Start(0))?;
47            reader.write_all(&buf)?;
48            reader.flush()?;
49        }
50        Self::decode(&buf)
51    }
52
53    /// Encodes a header into the canonical 4 KB byte representation.
54    pub fn encode(header: &Header) -> Result<[u8; HEADER_SIZE]> {
55        if header.magic != MAGIC {
56            return Err(MemvidError::InvalidHeader {
57                reason: "magic mismatch".into(),
58            });
59        }
60        if header.version != EXPECTED_VERSION {
61            return Err(MemvidError::InvalidHeader {
62                reason: "unsupported version".into(),
63            });
64        }
65        if header.wal_offset < WAL_OFFSET {
66            return Err(MemvidError::InvalidHeader {
67                reason: "wal_offset precedes data region".into(),
68            });
69        }
70        if header.wal_size == 0 {
71            return Err(MemvidError::InvalidHeader {
72                reason: "wal_size must be non-zero".into(),
73            });
74        }
75
76        let mut buf = [0u8; HEADER_SIZE];
77        buf[..MAGIC.len()].copy_from_slice(&header.magic);
78        buf[VERSION_OFFSET..VERSION_OFFSET + 2].copy_from_slice(&header.version.to_le_bytes());
79        buf[SPEC_BYTES_OFFSET] = SPEC_MAJOR;
80        buf[SPEC_BYTES_OFFSET + 1] = SPEC_MINOR;
81        buf[FOOTER_OFFSET_POS..FOOTER_OFFSET_POS + 8]
82            .copy_from_slice(&header.footer_offset.to_le_bytes());
83        buf[WAL_OFFSET_POS..WAL_OFFSET_POS + 8].copy_from_slice(&header.wal_offset.to_le_bytes());
84        buf[WAL_SIZE_POS..WAL_SIZE_POS + 8].copy_from_slice(&header.wal_size.to_le_bytes());
85        buf[WAL_CHECKPOINT_POS..WAL_CHECKPOINT_POS + 8]
86            .copy_from_slice(&header.wal_checkpoint_pos.to_le_bytes());
87        buf[WAL_SEQUENCE_POS..WAL_SEQUENCE_POS + 8]
88            .copy_from_slice(&header.wal_sequence.to_le_bytes());
89        buf[TOC_CHECKSUM_POS..TOC_CHECKSUM_END].copy_from_slice(&header.toc_checksum);
90        Ok(buf)
91    }
92
93    /// Decodes the canonical header bytes into a strongly typed struct after validation.
94    pub fn decode(bytes: &[u8; HEADER_SIZE]) -> Result<Header> {
95        let magic = bytes[..MAGIC.len()].try_into().unwrap();
96        if magic != MAGIC {
97            return Err(MemvidError::InvalidHeader {
98                reason: "magic mismatch".into(),
99            });
100        }
101
102        let version = u16::from_le_bytes(
103            bytes[VERSION_OFFSET..VERSION_OFFSET + 2]
104                .try_into()
105                .unwrap(),
106        );
107        if version != EXPECTED_VERSION {
108            return Err(MemvidError::InvalidHeader {
109                reason: "unsupported version".into(),
110            });
111        }
112
113        if bytes[SPEC_BYTES_OFFSET] != SPEC_MAJOR || bytes[SPEC_BYTES_OFFSET + 1] != SPEC_MINOR {
114            return Err(MemvidError::InvalidHeader {
115                reason: "spec byte mismatch".into(),
116            });
117        }
118
119        let footer_offset = u64::from_le_bytes(
120            bytes[FOOTER_OFFSET_POS..FOOTER_OFFSET_POS + 8]
121                .try_into()
122                .unwrap(),
123        );
124        let wal_offset = u64::from_le_bytes(
125            bytes[WAL_OFFSET_POS..WAL_OFFSET_POS + 8]
126                .try_into()
127                .unwrap(),
128        );
129        if wal_offset < WAL_OFFSET {
130            return Err(MemvidError::InvalidHeader {
131                reason: "wal_offset precedes data region".into(),
132            });
133        }
134        let wal_size =
135            u64::from_le_bytes(bytes[WAL_SIZE_POS..WAL_SIZE_POS + 8].try_into().unwrap());
136        if wal_size == 0 {
137            return Err(MemvidError::InvalidHeader {
138                reason: "wal_size must be non-zero".into(),
139            });
140        }
141        let wal_checkpoint_pos = u64::from_le_bytes(
142            bytes[WAL_CHECKPOINT_POS..WAL_CHECKPOINT_POS + 8]
143                .try_into()
144                .unwrap(),
145        );
146        let wal_sequence = u64::from_le_bytes(
147            bytes[WAL_SEQUENCE_POS..WAL_SEQUENCE_POS + 8]
148                .try_into()
149                .unwrap(),
150        );
151        let mut toc_checksum = [0u8; 32];
152        toc_checksum.copy_from_slice(&bytes[TOC_CHECKSUM_POS..TOC_CHECKSUM_END]);
153
154        Ok(Header {
155            magic,
156            version,
157            footer_offset,
158            wal_offset,
159            wal_size,
160            wal_checkpoint_pos,
161            wal_sequence,
162            toc_checksum,
163        })
164    }
165}
166
167fn clear_legacy_lock_metadata(buf: &mut [u8; HEADER_SIZE]) -> bool {
168    let region = &mut buf[LEGACY_LOCK_REGION_START..LEGACY_LOCK_REGION_END];
169    if region.iter().any(|byte| *byte != 0) {
170        region.fill(0);
171        true
172    } else {
173        false
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use std::io::Cursor;
181
182    fn sample_header() -> Header {
183        Header {
184            magic: MAGIC,
185            version: EXPECTED_VERSION,
186            footer_offset: 1_048_576,
187            wal_offset: WAL_OFFSET,
188            wal_size: 4 * 1024 * 1024,
189            wal_checkpoint_pos: 0,
190            wal_sequence: 42,
191            toc_checksum: [0xAB; 32],
192        }
193    }
194
195    #[test]
196    fn roundtrip_encode_decode() {
197        let header = sample_header();
198        let encoded = HeaderCodec::encode(&header).expect("encode header");
199        let decoded = HeaderCodec::decode(&encoded).expect("decode header");
200        assert_eq!(decoded.magic, MAGIC);
201        assert_eq!(decoded.version, EXPECTED_VERSION);
202        assert_eq!(decoded.footer_offset, header.footer_offset);
203        assert_eq!(decoded.wal_offset, WAL_OFFSET);
204        assert_eq!(decoded.toc_checksum, header.toc_checksum);
205    }
206
207    #[test]
208    fn read_write_from_cursor() {
209        let header = sample_header();
210        let mut cursor = Cursor::new(vec![0u8; HEADER_SIZE]);
211        HeaderCodec::write(&mut cursor, &header).expect("write header");
212        cursor.set_position(0);
213        let decoded = HeaderCodec::read(&mut cursor).expect("read header");
214        assert_eq!(decoded.wal_size, header.wal_size);
215        assert_eq!(decoded.wal_sequence, header.wal_sequence);
216    }
217
218    #[test]
219    fn clears_legacy_lock_metadata() {
220        let header = sample_header();
221        let mut encoded = HeaderCodec::encode(&header).expect("encode header");
222        encoded[LEGACY_LOCK_REGION_START..LEGACY_LOCK_REGION_END].fill(0xAA);
223        let mut cursor = Cursor::new(encoded.to_vec());
224        HeaderCodec::read(&mut cursor).expect("read header with legacy metadata");
225        let sanitized = cursor.into_inner();
226        assert!(
227            sanitized[LEGACY_LOCK_REGION_START..LEGACY_LOCK_REGION_END]
228                .iter()
229                .all(|byte| *byte == 0)
230        );
231    }
232
233    #[test]
234    fn reject_invalid_magic() {
235        let mut header = sample_header();
236        header.magic = *b"BAD!";
237        let err = HeaderCodec::encode(&header).expect_err("should fail");
238        matches!(err, MemvidError::InvalidHeader { .. });
239    }
240
241    #[test]
242    fn reject_short_wal_size() {
243        let mut header = sample_header();
244        header.wal_size = 0;
245        let err = HeaderCodec::encode(&header).expect_err("should fail");
246        matches!(err, MemvidError::InvalidHeader { .. });
247    }
248
249    #[test]
250    fn reject_decoding_with_bad_version() {
251        let header = sample_header();
252        let mut encoded = HeaderCodec::encode(&header).expect("encode header");
253        encoded[VERSION_OFFSET] = 0xFF;
254        let err = HeaderCodec::decode(&encoded).expect_err("decode should fail");
255        matches!(err, MemvidError::InvalidHeader { .. });
256    }
257}