1use std::{
2 convert::TryInto,
3 io::{Read, Seek, SeekFrom, Write},
4};
5
6use crate::{
7 constants::{HEADER_SIZE, MAGIC, SPEC_MAJOR, SPEC_MINOR, WAL_OFFSET},
8 error::{MemvidError, Result},
9 types::Header,
10};
11
12const VERSION_OFFSET: usize = 4;
13const SPEC_BYTES_OFFSET: usize = 6;
14const FOOTER_OFFSET_POS: usize = 8;
15const WAL_OFFSET_POS: usize = 16;
16const WAL_SIZE_POS: usize = 24;
17const WAL_CHECKPOINT_POS: usize = 32;
18const WAL_SEQUENCE_POS: usize = 40;
19const TOC_CHECKSUM_POS: usize = 48;
20const TOC_CHECKSUM_END: usize = 80;
21const LEGACY_LOCK_REGION_START: usize = TOC_CHECKSUM_END;
23const LEGACY_LOCK_REGION_END: usize = LEGACY_LOCK_REGION_START + 60;
24const EXPECTED_VERSION: u16 = ((SPEC_MAJOR as u16) << 8) | SPEC_MINOR as u16;
25
26pub struct HeaderCodec;
28
29impl HeaderCodec {
30 pub fn write<W: Write + Seek>(mut writer: W, header: &Header) -> Result<()> {
32 let bytes = Self::encode(header)?;
33 writer.seek(SeekFrom::Start(0))?;
34 writer.write_all(&bytes)?;
35 Ok(())
36 }
37
38 pub fn read<R: Read + Write + Seek>(mut reader: R) -> Result<Header> {
42 let mut buf = [0u8; HEADER_SIZE];
43 reader.seek(SeekFrom::Start(0))?;
44 reader.read_exact(&mut buf)?;
45 if clear_legacy_lock_metadata(&mut buf) {
46 reader.seek(SeekFrom::Start(0))?;
47 reader.write_all(&buf)?;
48 reader.flush()?;
49 }
50 Self::decode(&buf)
51 }
52
53 pub fn encode(header: &Header) -> Result<[u8; HEADER_SIZE]> {
55 if header.magic != MAGIC {
56 return Err(MemvidError::InvalidHeader {
57 reason: "magic mismatch".into(),
58 });
59 }
60 if header.version != EXPECTED_VERSION {
61 return Err(MemvidError::InvalidHeader {
62 reason: "unsupported version".into(),
63 });
64 }
65 if header.wal_offset < WAL_OFFSET {
66 return Err(MemvidError::InvalidHeader {
67 reason: "wal_offset precedes data region".into(),
68 });
69 }
70 if header.wal_size == 0 {
71 return Err(MemvidError::InvalidHeader {
72 reason: "wal_size must be non-zero".into(),
73 });
74 }
75
76 let mut buf = [0u8; HEADER_SIZE];
77 buf[..MAGIC.len()].copy_from_slice(&header.magic);
78 buf[VERSION_OFFSET..VERSION_OFFSET + 2].copy_from_slice(&header.version.to_le_bytes());
79 buf[SPEC_BYTES_OFFSET] = SPEC_MAJOR;
80 buf[SPEC_BYTES_OFFSET + 1] = SPEC_MINOR;
81 buf[FOOTER_OFFSET_POS..FOOTER_OFFSET_POS + 8]
82 .copy_from_slice(&header.footer_offset.to_le_bytes());
83 buf[WAL_OFFSET_POS..WAL_OFFSET_POS + 8].copy_from_slice(&header.wal_offset.to_le_bytes());
84 buf[WAL_SIZE_POS..WAL_SIZE_POS + 8].copy_from_slice(&header.wal_size.to_le_bytes());
85 buf[WAL_CHECKPOINT_POS..WAL_CHECKPOINT_POS + 8]
86 .copy_from_slice(&header.wal_checkpoint_pos.to_le_bytes());
87 buf[WAL_SEQUENCE_POS..WAL_SEQUENCE_POS + 8]
88 .copy_from_slice(&header.wal_sequence.to_le_bytes());
89 buf[TOC_CHECKSUM_POS..TOC_CHECKSUM_END].copy_from_slice(&header.toc_checksum);
90 Ok(buf)
91 }
92
93 pub fn decode(bytes: &[u8; HEADER_SIZE]) -> Result<Header> {
95 let magic = bytes[..MAGIC.len()].try_into().unwrap();
96 if magic != MAGIC {
97 return Err(MemvidError::InvalidHeader {
98 reason: "magic mismatch".into(),
99 });
100 }
101
102 let version = u16::from_le_bytes(
103 bytes[VERSION_OFFSET..VERSION_OFFSET + 2]
104 .try_into()
105 .unwrap(),
106 );
107 if version != EXPECTED_VERSION {
108 return Err(MemvidError::InvalidHeader {
109 reason: "unsupported version".into(),
110 });
111 }
112
113 if bytes[SPEC_BYTES_OFFSET] != SPEC_MAJOR || bytes[SPEC_BYTES_OFFSET + 1] != SPEC_MINOR {
114 return Err(MemvidError::InvalidHeader {
115 reason: "spec byte mismatch".into(),
116 });
117 }
118
119 let footer_offset = u64::from_le_bytes(
120 bytes[FOOTER_OFFSET_POS..FOOTER_OFFSET_POS + 8]
121 .try_into()
122 .unwrap(),
123 );
124 let wal_offset = u64::from_le_bytes(
125 bytes[WAL_OFFSET_POS..WAL_OFFSET_POS + 8]
126 .try_into()
127 .unwrap(),
128 );
129 if wal_offset < WAL_OFFSET {
130 return Err(MemvidError::InvalidHeader {
131 reason: "wal_offset precedes data region".into(),
132 });
133 }
134 let wal_size =
135 u64::from_le_bytes(bytes[WAL_SIZE_POS..WAL_SIZE_POS + 8].try_into().unwrap());
136 if wal_size == 0 {
137 return Err(MemvidError::InvalidHeader {
138 reason: "wal_size must be non-zero".into(),
139 });
140 }
141 let wal_checkpoint_pos = u64::from_le_bytes(
142 bytes[WAL_CHECKPOINT_POS..WAL_CHECKPOINT_POS + 8]
143 .try_into()
144 .unwrap(),
145 );
146 let wal_sequence = u64::from_le_bytes(
147 bytes[WAL_SEQUENCE_POS..WAL_SEQUENCE_POS + 8]
148 .try_into()
149 .unwrap(),
150 );
151 let mut toc_checksum = [0u8; 32];
152 toc_checksum.copy_from_slice(&bytes[TOC_CHECKSUM_POS..TOC_CHECKSUM_END]);
153
154 Ok(Header {
155 magic,
156 version,
157 footer_offset,
158 wal_offset,
159 wal_size,
160 wal_checkpoint_pos,
161 wal_sequence,
162 toc_checksum,
163 })
164 }
165}
166
167fn clear_legacy_lock_metadata(buf: &mut [u8; HEADER_SIZE]) -> bool {
168 let region = &mut buf[LEGACY_LOCK_REGION_START..LEGACY_LOCK_REGION_END];
169 if region.iter().any(|byte| *byte != 0) {
170 region.fill(0);
171 true
172 } else {
173 false
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use std::io::Cursor;
181
182 fn sample_header() -> Header {
183 Header {
184 magic: MAGIC,
185 version: EXPECTED_VERSION,
186 footer_offset: 1_048_576,
187 wal_offset: WAL_OFFSET,
188 wal_size: 4 * 1024 * 1024,
189 wal_checkpoint_pos: 0,
190 wal_sequence: 42,
191 toc_checksum: [0xAB; 32],
192 }
193 }
194
195 #[test]
196 fn roundtrip_encode_decode() {
197 let header = sample_header();
198 let encoded = HeaderCodec::encode(&header).expect("encode header");
199 let decoded = HeaderCodec::decode(&encoded).expect("decode header");
200 assert_eq!(decoded.magic, MAGIC);
201 assert_eq!(decoded.version, EXPECTED_VERSION);
202 assert_eq!(decoded.footer_offset, header.footer_offset);
203 assert_eq!(decoded.wal_offset, WAL_OFFSET);
204 assert_eq!(decoded.toc_checksum, header.toc_checksum);
205 }
206
207 #[test]
208 fn read_write_from_cursor() {
209 let header = sample_header();
210 let mut cursor = Cursor::new(vec![0u8; HEADER_SIZE]);
211 HeaderCodec::write(&mut cursor, &header).expect("write header");
212 cursor.set_position(0);
213 let decoded = HeaderCodec::read(&mut cursor).expect("read header");
214 assert_eq!(decoded.wal_size, header.wal_size);
215 assert_eq!(decoded.wal_sequence, header.wal_sequence);
216 }
217
218 #[test]
219 fn clears_legacy_lock_metadata() {
220 let header = sample_header();
221 let mut encoded = HeaderCodec::encode(&header).expect("encode header");
222 encoded[LEGACY_LOCK_REGION_START..LEGACY_LOCK_REGION_END].fill(0xAA);
223 let mut cursor = Cursor::new(encoded.to_vec());
224 HeaderCodec::read(&mut cursor).expect("read header with legacy metadata");
225 let sanitized = cursor.into_inner();
226 assert!(
227 sanitized[LEGACY_LOCK_REGION_START..LEGACY_LOCK_REGION_END]
228 .iter()
229 .all(|byte| *byte == 0)
230 );
231 }
232
233 #[test]
234 fn reject_invalid_magic() {
235 let mut header = sample_header();
236 header.magic = *b"BAD!";
237 let err = HeaderCodec::encode(&header).expect_err("should fail");
238 matches!(err, MemvidError::InvalidHeader { .. });
239 }
240
241 #[test]
242 fn reject_short_wal_size() {
243 let mut header = sample_header();
244 header.wal_size = 0;
245 let err = HeaderCodec::encode(&header).expect_err("should fail");
246 matches!(err, MemvidError::InvalidHeader { .. });
247 }
248
249 #[test]
250 fn reject_decoding_with_bad_version() {
251 let header = sample_header();
252 let mut encoded = HeaderCodec::encode(&header).expect("encode header");
253 encoded[VERSION_OFFSET] = 0xFF;
254 let err = HeaderCodec::decode(&encoded).expect_err("decode should fail");
255 matches!(err, MemvidError::InvalidHeader { .. });
256 }
257}