1use std::{
2 convert::TryInto,
3 io::{Read, Seek, SeekFrom, Write},
4};
5
6use crate::{
7 constants::{HEADER_SIZE, MAGIC, SPEC_MAJOR, SPEC_MINOR, WAL_OFFSET},
8 error::{MemvidError, Result},
9 types::Header,
10};
11
12const VERSION_OFFSET: usize = 4;
13const SPEC_BYTES_OFFSET: usize = 6;
14const FOOTER_OFFSET_POS: usize = 8;
15const WAL_OFFSET_POS: usize = 16;
16const WAL_SIZE_POS: usize = 24;
17const WAL_CHECKPOINT_POS: usize = 32;
18const WAL_SEQUENCE_POS: usize = 40;
19const TOC_CHECKSUM_POS: usize = 48;
20const TOC_CHECKSUM_END: usize = 80;
21const LEGACY_LOCK_REGION_START: usize = TOC_CHECKSUM_END;
23const LEGACY_LOCK_REGION_END: usize = LEGACY_LOCK_REGION_START + 60;
24const EXPECTED_VERSION: u16 = ((SPEC_MAJOR as u16) << 8) | SPEC_MINOR as u16;
25
26pub struct HeaderCodec;
28
29impl HeaderCodec {
30 pub fn write<W: Write + Seek>(mut writer: W, header: &Header) -> Result<()> {
32 let bytes = Self::encode(header)?;
33 writer.seek(SeekFrom::Start(0))?;
34 writer.write_all(&bytes)?;
35 Ok(())
36 }
37
38 pub fn read<R: Read + Write + Seek>(mut reader: R) -> Result<Header> {
42 let mut buf = [0u8; HEADER_SIZE];
43 reader.seek(SeekFrom::Start(0))?;
44 reader.read_exact(&mut buf)?;
45 if clear_legacy_lock_metadata(&mut buf) {
46 reader.seek(SeekFrom::Start(0))?;
47 reader.write_all(&buf)?;
48 reader.flush()?;
49 }
50 Self::decode(&buf)
51 }
52
53 pub fn encode(header: &Header) -> Result<[u8; HEADER_SIZE]> {
55 if header.magic != MAGIC {
56 return Err(MemvidError::InvalidHeader {
57 reason: "magic mismatch".into(),
58 });
59 }
60 if header.version != EXPECTED_VERSION {
61 return Err(MemvidError::InvalidHeader {
62 reason: "unsupported version".into(),
63 });
64 }
65 if header.wal_offset < WAL_OFFSET {
66 return Err(MemvidError::InvalidHeader {
67 reason: "wal_offset precedes data region".into(),
68 });
69 }
70 if header.wal_size == 0 {
71 return Err(MemvidError::InvalidHeader {
72 reason: "wal_size must be non-zero".into(),
73 });
74 }
75
76 let mut buf = [0u8; HEADER_SIZE];
77 buf[..MAGIC.len()].copy_from_slice(&header.magic);
78 buf[VERSION_OFFSET..VERSION_OFFSET + 2].copy_from_slice(&header.version.to_le_bytes());
79 buf[SPEC_BYTES_OFFSET] = SPEC_MAJOR;
80 buf[SPEC_BYTES_OFFSET + 1] = SPEC_MINOR;
81 buf[FOOTER_OFFSET_POS..FOOTER_OFFSET_POS + 8]
82 .copy_from_slice(&header.footer_offset.to_le_bytes());
83 buf[WAL_OFFSET_POS..WAL_OFFSET_POS + 8].copy_from_slice(&header.wal_offset.to_le_bytes());
84 buf[WAL_SIZE_POS..WAL_SIZE_POS + 8].copy_from_slice(&header.wal_size.to_le_bytes());
85 buf[WAL_CHECKPOINT_POS..WAL_CHECKPOINT_POS + 8]
86 .copy_from_slice(&header.wal_checkpoint_pos.to_le_bytes());
87 buf[WAL_SEQUENCE_POS..WAL_SEQUENCE_POS + 8]
88 .copy_from_slice(&header.wal_sequence.to_le_bytes());
89 buf[TOC_CHECKSUM_POS..TOC_CHECKSUM_END].copy_from_slice(&header.toc_checksum);
90 Ok(buf)
91 }
92
93 pub fn decode(bytes: &[u8; HEADER_SIZE]) -> Result<Header> {
95 let magic: [u8; 4] = extract_array(bytes, 0)?;
98 if magic != MAGIC {
99 return Err(MemvidError::InvalidHeader {
100 reason: "magic mismatch".into(),
101 });
102 }
103
104 let version = u16::from_le_bytes(extract_array(bytes, VERSION_OFFSET)?);
105 if version != EXPECTED_VERSION {
106 return Err(MemvidError::InvalidHeader {
107 reason: "unsupported version".into(),
108 });
109 }
110
111 if bytes[SPEC_BYTES_OFFSET] != SPEC_MAJOR || bytes[SPEC_BYTES_OFFSET + 1] != SPEC_MINOR {
112 return Err(MemvidError::InvalidHeader {
113 reason: "spec byte mismatch".into(),
114 });
115 }
116
117 let footer_offset = u64::from_le_bytes(extract_array(bytes, FOOTER_OFFSET_POS)?);
118 let wal_offset = u64::from_le_bytes(extract_array(bytes, WAL_OFFSET_POS)?);
119 if wal_offset < WAL_OFFSET {
120 return Err(MemvidError::InvalidHeader {
121 reason: "wal_offset precedes data region".into(),
122 });
123 }
124 let wal_size = u64::from_le_bytes(extract_array(bytes, WAL_SIZE_POS)?);
125 if wal_size == 0 {
126 return Err(MemvidError::InvalidHeader {
127 reason: "wal_size must be non-zero".into(),
128 });
129 }
130 let wal_checkpoint_pos = u64::from_le_bytes(extract_array(bytes, WAL_CHECKPOINT_POS)?);
131 let wal_sequence = u64::from_le_bytes(extract_array(bytes, WAL_SEQUENCE_POS)?);
132 let toc_checksum: [u8; 32] = extract_array(bytes, TOC_CHECKSUM_POS)?;
133
134 Ok(Header {
135 magic,
136 version,
137 footer_offset,
138 wal_offset,
139 wal_size,
140 wal_checkpoint_pos,
141 wal_sequence,
142 toc_checksum,
143 })
144 }
145}
146
147#[inline]
150fn extract_array<const N: usize>(bytes: &[u8], offset: usize) -> Result<[u8; N]> {
151 bytes
152 .get(offset..offset + N)
153 .and_then(|s| s.try_into().ok())
154 .ok_or_else(|| MemvidError::InvalidHeader {
155 reason: "header truncated".into(),
156 })
157}
158
159fn clear_legacy_lock_metadata(buf: &mut [u8; HEADER_SIZE]) -> bool {
160 let region = &mut buf[LEGACY_LOCK_REGION_START..LEGACY_LOCK_REGION_END];
161 if region.iter().any(|byte| *byte != 0) {
162 region.fill(0);
163 true
164 } else {
165 false
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use std::io::Cursor;
173
174 fn sample_header() -> Header {
175 Header {
176 magic: MAGIC,
177 version: EXPECTED_VERSION,
178 footer_offset: 1_048_576,
179 wal_offset: WAL_OFFSET,
180 wal_size: 4 * 1024 * 1024,
181 wal_checkpoint_pos: 0,
182 wal_sequence: 42,
183 toc_checksum: [0xAB; 32],
184 }
185 }
186
187 #[test]
188 fn roundtrip_encode_decode() {
189 let header = sample_header();
190 let encoded = HeaderCodec::encode(&header).expect("encode header");
191 let decoded = HeaderCodec::decode(&encoded).expect("decode header");
192 assert_eq!(decoded.magic, MAGIC);
193 assert_eq!(decoded.version, EXPECTED_VERSION);
194 assert_eq!(decoded.footer_offset, header.footer_offset);
195 assert_eq!(decoded.wal_offset, WAL_OFFSET);
196 assert_eq!(decoded.toc_checksum, header.toc_checksum);
197 }
198
199 #[test]
200 fn read_write_from_cursor() {
201 let header = sample_header();
202 let mut cursor = Cursor::new(vec![0u8; HEADER_SIZE]);
203 HeaderCodec::write(&mut cursor, &header).expect("write header");
204 cursor.set_position(0);
205 let decoded = HeaderCodec::read(&mut cursor).expect("read header");
206 assert_eq!(decoded.wal_size, header.wal_size);
207 assert_eq!(decoded.wal_sequence, header.wal_sequence);
208 }
209
210 #[test]
211 fn clears_legacy_lock_metadata() {
212 let header = sample_header();
213 let mut encoded = HeaderCodec::encode(&header).expect("encode header");
214 encoded[LEGACY_LOCK_REGION_START..LEGACY_LOCK_REGION_END].fill(0xAA);
215 let mut cursor = Cursor::new(encoded.to_vec());
216 HeaderCodec::read(&mut cursor).expect("read header with legacy metadata");
217 let sanitized = cursor.into_inner();
218 assert!(
219 sanitized[LEGACY_LOCK_REGION_START..LEGACY_LOCK_REGION_END]
220 .iter()
221 .all(|byte| *byte == 0)
222 );
223 }
224
225 #[test]
226 fn reject_invalid_magic() {
227 let mut header = sample_header();
228 header.magic = *b"BAD!";
229 let err = HeaderCodec::encode(&header).expect_err("should fail");
230 matches!(err, MemvidError::InvalidHeader { .. });
231 }
232
233 #[test]
234 fn reject_short_wal_size() {
235 let mut header = sample_header();
236 header.wal_size = 0;
237 let err = HeaderCodec::encode(&header).expect_err("should fail");
238 matches!(err, MemvidError::InvalidHeader { .. });
239 }
240
241 #[test]
242 fn reject_decoding_with_bad_version() {
243 let header = sample_header();
244 let mut encoded = HeaderCodec::encode(&header).expect("encode header");
245 encoded[VERSION_OFFSET] = 0xFF;
246 let err = HeaderCodec::decode(&encoded).expect_err("decode should fail");
247 matches!(err, MemvidError::InvalidHeader { .. });
248 }
249}