1use crate::{WAL_FILE_VERSION, WAL_FILE_VERSION_V2};
12use std::io::{self, Read};
13
14pub const MAIN_WAL_DEFAULT_COMPRESS_THRESHOLD: usize = 256;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[repr(u8)]
18pub enum MainWalRecordType {
19 Begin = 1,
20 Commit = 2,
21 Rollback = 3,
22 PageWrite = 4,
23 Checkpoint = 5,
24 PageWriteCompressed = 6,
25 TxCommitBatch = 7,
26 FullPageImage = 8,
27 VectorInsert = 9,
28}
29
30impl MainWalRecordType {
31 pub fn from_u8(value: u8) -> Option<Self> {
32 match value {
33 1 => Some(Self::Begin),
34 2 => Some(Self::Commit),
35 3 => Some(Self::Rollback),
36 4 => Some(Self::PageWrite),
37 5 => Some(Self::Checkpoint),
38 6 => Some(Self::PageWriteCompressed),
39 7 => Some(Self::TxCommitBatch),
40 8 => Some(Self::FullPageImage),
41 9 => Some(Self::VectorInsert),
42 _ => None,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48#[repr(u8)]
49pub enum MainWalCompression {
50 None = 0,
51 Zstd = 1,
52}
53
54impl MainWalCompression {
55 fn from_u8(value: u8) -> Option<Self> {
56 match value {
57 0 => Some(Self::None),
58 1 => Some(Self::Zstd),
59 _ => None,
60 }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq)]
65pub enum MainWalRecordFrame {
66 Begin {
67 tx_id: u64,
68 },
69 Commit {
70 tx_id: u64,
71 },
72 Rollback {
73 tx_id: u64,
74 },
75 PageWrite {
76 tx_id: u64,
77 page_id: u32,
78 data: Vec<u8>,
79 },
80 TxCommitBatch {
81 tx_id: u64,
82 actions: Vec<Vec<u8>>,
83 },
84 FullPageImage {
85 tx_id: u64,
86 page_id: u32,
87 ckpt_epoch: u64,
88 data: Vec<u8>,
89 },
90 VectorInsert {
91 collection: String,
92 entity_id: u64,
93 vector: Vec<f32>,
94 },
95 Checkpoint {
96 lsn: u64,
97 },
98}
99
100pub fn encode_main_wal_record_frame(frame: &MainWalRecordFrame, term: u64) -> io::Result<Vec<u8>> {
101 let mut out = Vec::new();
102 encode_main_wal_record_frame_into(frame, term, &mut out)?;
103 Ok(out)
104}
105
106pub fn encode_main_wal_record_frame_into(
107 frame: &MainWalRecordFrame,
108 term: u64,
109 out: &mut Vec<u8>,
110) -> io::Result<()> {
111 let start = out.len();
112 match frame {
113 MainWalRecordFrame::Begin { tx_id } => {
114 write_type_and_term(out, MainWalRecordType::Begin, term);
115 out.extend_from_slice(&tx_id.to_le_bytes());
116 }
117 MainWalRecordFrame::Commit { tx_id } => {
118 write_type_and_term(out, MainWalRecordType::Commit, term);
119 out.extend_from_slice(&tx_id.to_le_bytes());
120 }
121 MainWalRecordFrame::Rollback { tx_id } => {
122 write_type_and_term(out, MainWalRecordType::Rollback, term);
123 out.extend_from_slice(&tx_id.to_le_bytes());
124 }
125 MainWalRecordFrame::PageWrite {
126 tx_id,
127 page_id,
128 data,
129 } => {
130 if data.len() >= MAIN_WAL_DEFAULT_COMPRESS_THRESHOLD {
131 if let Ok(compressed) = zstd::bulk::compress(data.as_slice(), 3) {
132 if compressed.len() < data.len() {
133 write_type_and_term(out, MainWalRecordType::PageWriteCompressed, term);
134 out.extend_from_slice(&tx_id.to_le_bytes());
135 out.extend_from_slice(&page_id.to_le_bytes());
136 out.push(MainWalCompression::Zstd as u8);
137 write_u32_len(out, data.len(), "main wal original page length")?;
138 write_u32_len(out, compressed.len(), "main wal compressed page length")?;
139 out.extend_from_slice(&compressed);
140 append_crc(out, start);
141 return Ok(());
142 }
143 }
144 }
145
146 write_type_and_term(out, MainWalRecordType::PageWrite, term);
147 out.extend_from_slice(&tx_id.to_le_bytes());
148 out.extend_from_slice(&page_id.to_le_bytes());
149 write_u32_len(out, data.len(), "main wal page length")?;
150 out.extend_from_slice(data);
151 }
152 MainWalRecordFrame::TxCommitBatch { tx_id, actions } => {
153 write_type_and_term(out, MainWalRecordType::TxCommitBatch, term);
154 out.extend_from_slice(&tx_id.to_le_bytes());
155 write_u32_len(out, actions.len(), "main wal action count")?;
156 for action in actions {
157 write_u32_len(out, action.len(), "main wal action length")?;
158 out.extend_from_slice(action);
159 }
160 }
161 MainWalRecordFrame::FullPageImage {
162 tx_id,
163 page_id,
164 ckpt_epoch,
165 data,
166 } => {
167 write_type_and_term(out, MainWalRecordType::FullPageImage, term);
168 out.extend_from_slice(&tx_id.to_le_bytes());
169 out.extend_from_slice(&page_id.to_le_bytes());
170 out.extend_from_slice(&ckpt_epoch.to_le_bytes());
171 write_u32_len(out, data.len(), "main wal full-page image length")?;
172 out.extend_from_slice(data);
173 }
174 MainWalRecordFrame::VectorInsert {
175 collection,
176 entity_id,
177 vector,
178 } => {
179 write_type_and_term(out, MainWalRecordType::VectorInsert, term);
180 write_u32_len(out, collection.len(), "main wal collection name length")?;
181 out.extend_from_slice(collection.as_bytes());
182 out.extend_from_slice(&entity_id.to_le_bytes());
183 write_u32_len(out, vector.len(), "main wal vector length")?;
184 for value in vector {
185 out.extend_from_slice(&value.to_le_bytes());
186 }
187 }
188 MainWalRecordFrame::Checkpoint { lsn } => {
189 write_type_and_term(out, MainWalRecordType::Checkpoint, term);
190 out.extend_from_slice(&lsn.to_le_bytes());
191 }
192 }
193
194 append_crc(out, start);
195 Ok(())
196}
197
198pub fn decode_main_wal_record_frame<R: Read>(
199 reader: &mut R,
200 format_version: u8,
201 default_term: u64,
202) -> io::Result<Option<(u64, MainWalRecordFrame)>> {
203 let mut checksum_bytes = Vec::new();
204 let mut type_buf = [0u8; 1];
205 match reader.read_exact(&mut type_buf) {
206 Ok(()) => checksum_bytes.extend_from_slice(&type_buf),
207 Err(err) if err.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
208 Err(err) => return Err(err),
209 }
210
211 let record_type = MainWalRecordType::from_u8(type_buf[0])
212 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid record type"))?;
213
214 let term = match format_version {
215 WAL_FILE_VERSION => read_u64_tracked(reader, &mut checksum_bytes)?,
216 WAL_FILE_VERSION_V2 => default_term,
217 _ => {
218 return Err(io::Error::new(
219 io::ErrorKind::InvalidData,
220 format!("Unsupported WAL version: {format_version}"),
221 ));
222 }
223 };
224
225 let frame = match record_type {
226 MainWalRecordType::Begin => MainWalRecordFrame::Begin {
227 tx_id: read_u64_tracked(reader, &mut checksum_bytes)?,
228 },
229 MainWalRecordType::Commit => MainWalRecordFrame::Commit {
230 tx_id: read_u64_tracked(reader, &mut checksum_bytes)?,
231 },
232 MainWalRecordType::Rollback => MainWalRecordFrame::Rollback {
233 tx_id: read_u64_tracked(reader, &mut checksum_bytes)?,
234 },
235 MainWalRecordType::PageWrite => {
236 let tx_id = read_u64_tracked(reader, &mut checksum_bytes)?;
237 let page_id = read_u32_tracked(reader, &mut checksum_bytes)?;
238 let data = read_bytes_tracked(reader, &mut checksum_bytes)?;
239 MainWalRecordFrame::PageWrite {
240 tx_id,
241 page_id,
242 data,
243 }
244 }
245 MainWalRecordType::PageWriteCompressed => {
246 let tx_id = read_u64_tracked(reader, &mut checksum_bytes)?;
247 let page_id = read_u32_tracked(reader, &mut checksum_bytes)?;
248 let compression = read_compression_tracked(reader, &mut checksum_bytes)?;
249 let original_len = read_u32_tracked(reader, &mut checksum_bytes)? as usize;
250 let compressed = read_bytes_tracked(reader, &mut checksum_bytes)?;
251 let data = match compression {
252 MainWalCompression::Zstd => {
253 let mut out = vec![0u8; original_len];
254 zstd::bulk::decompress_to_buffer(&compressed, &mut out).map_err(|err| {
255 io::Error::new(
256 io::ErrorKind::InvalidData,
257 format!("WAL zstd decompress failed: {err}"),
258 )
259 })?;
260 out
261 }
262 MainWalCompression::None => compressed,
263 };
264 MainWalRecordFrame::PageWrite {
265 tx_id,
266 page_id,
267 data,
268 }
269 }
270 MainWalRecordType::TxCommitBatch => {
271 let tx_id = read_u64_tracked(reader, &mut checksum_bytes)?;
272 let count = read_u32_tracked(reader, &mut checksum_bytes)? as usize;
273 let mut actions = Vec::with_capacity(count);
274 for _ in 0..count {
275 actions.push(read_bytes_tracked(reader, &mut checksum_bytes)?);
276 }
277 MainWalRecordFrame::TxCommitBatch { tx_id, actions }
278 }
279 MainWalRecordType::FullPageImage => {
280 let tx_id = read_u64_tracked(reader, &mut checksum_bytes)?;
281 let page_id = read_u32_tracked(reader, &mut checksum_bytes)?;
282 let ckpt_epoch = read_u64_tracked(reader, &mut checksum_bytes)?;
283 let data = read_bytes_tracked(reader, &mut checksum_bytes)?;
284 MainWalRecordFrame::FullPageImage {
285 tx_id,
286 page_id,
287 ckpt_epoch,
288 data,
289 }
290 }
291 MainWalRecordType::VectorInsert => {
292 let collection = String::from_utf8(read_bytes_tracked(reader, &mut checksum_bytes)?)
293 .map_err(|err| {
294 io::Error::new(
295 io::ErrorKind::InvalidData,
296 format!("invalid collection utf8: {err}"),
297 )
298 })?;
299 let entity_id = read_u64_tracked(reader, &mut checksum_bytes)?;
300 let count = read_u32_tracked(reader, &mut checksum_bytes)? as usize;
301 let mut vector = Vec::with_capacity(count);
302 for _ in 0..count {
303 vector.push(f32::from_le_bytes(read_array_tracked(
304 reader,
305 &mut checksum_bytes,
306 )?));
307 }
308 MainWalRecordFrame::VectorInsert {
309 collection,
310 entity_id,
311 vector,
312 }
313 }
314 MainWalRecordType::Checkpoint => MainWalRecordFrame::Checkpoint {
315 lsn: read_u64_tracked(reader, &mut checksum_bytes)?,
316 },
317 };
318
319 let stored_crc = read_u32_untracked(reader)?;
320 if crc32(&checksum_bytes) != stored_crc {
321 return Err(io::Error::new(
322 io::ErrorKind::InvalidData,
323 "WAL record checksum mismatch",
324 ));
325 }
326
327 Ok(Some((term, frame)))
328}
329
330fn write_type_and_term(out: &mut Vec<u8>, record_type: MainWalRecordType, term: u64) {
331 out.push(record_type as u8);
332 out.extend_from_slice(&term.to_le_bytes());
333}
334
335fn write_u32_len(out: &mut Vec<u8>, len: usize, label: &'static str) -> io::Result<()> {
336 let len = u32::try_from(len).map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, label))?;
337 out.extend_from_slice(&len.to_le_bytes());
338 Ok(())
339}
340
341fn append_crc(out: &mut Vec<u8>, start: usize) {
342 let checksum = crc32(&out[start..]);
343 out.extend_from_slice(&checksum.to_le_bytes());
344}
345
346fn crc32(bytes: &[u8]) -> u32 {
347 let mut hasher = crc32fast::Hasher::new();
348 hasher.update(bytes);
349 hasher.finalize()
350}
351
352fn read_compression_tracked<R: Read>(
353 reader: &mut R,
354 checksum_bytes: &mut Vec<u8>,
355) -> io::Result<MainWalCompression> {
356 let value = read_array_tracked::<_, 1>(reader, checksum_bytes)?[0];
357 MainWalCompression::from_u8(value).ok_or_else(|| {
358 io::Error::new(
359 io::ErrorKind::InvalidData,
360 format!("Unknown WAL compression algorithm: {value}"),
361 )
362 })
363}
364
365fn read_bytes_tracked<R: Read>(
366 reader: &mut R,
367 checksum_bytes: &mut Vec<u8>,
368) -> io::Result<Vec<u8>> {
369 let len = read_u32_tracked(reader, checksum_bytes)? as usize;
370 let mut bytes = vec![0u8; len];
371 reader.read_exact(&mut bytes)?;
372 checksum_bytes.extend_from_slice(&bytes);
373 Ok(bytes)
374}
375
376fn read_u64_tracked<R: Read>(reader: &mut R, checksum_bytes: &mut Vec<u8>) -> io::Result<u64> {
377 Ok(u64::from_le_bytes(read_array_tracked(
378 reader,
379 checksum_bytes,
380 )?))
381}
382
383fn read_u32_tracked<R: Read>(reader: &mut R, checksum_bytes: &mut Vec<u8>) -> io::Result<u32> {
384 Ok(u32::from_le_bytes(read_array_tracked(
385 reader,
386 checksum_bytes,
387 )?))
388}
389
390fn read_array_tracked<R: Read, const N: usize>(
391 reader: &mut R,
392 checksum_bytes: &mut Vec<u8>,
393) -> io::Result<[u8; N]> {
394 let mut bytes = [0u8; N];
395 reader.read_exact(&mut bytes)?;
396 checksum_bytes.extend_from_slice(&bytes);
397 Ok(bytes)
398}
399
400fn read_u32_untracked<R: Read>(reader: &mut R) -> io::Result<u32> {
401 let mut bytes = [0u8; 4];
402 reader.read_exact(&mut bytes)?;
403 Ok(u32::from_le_bytes(bytes))
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use std::io::Cursor;
410
411 #[test]
412 fn main_wal_record_types_are_stable() {
413 assert_eq!(
414 MainWalRecordType::from_u8(1),
415 Some(MainWalRecordType::Begin)
416 );
417 assert_eq!(
418 MainWalRecordType::from_u8(9),
419 Some(MainWalRecordType::VectorInsert)
420 );
421 assert_eq!(MainWalRecordType::from_u8(10), None);
422 }
423
424 #[test]
425 fn main_wal_records_round_trip_current_format() {
426 let frames = vec![
427 MainWalRecordFrame::Begin { tx_id: 1 },
428 MainWalRecordFrame::Commit { tx_id: 2 },
429 MainWalRecordFrame::Rollback { tx_id: 3 },
430 MainWalRecordFrame::Checkpoint { lsn: 4 },
431 MainWalRecordFrame::PageWrite {
432 tx_id: 5,
433 page_id: 6,
434 data: vec![1, 2, 3],
435 },
436 MainWalRecordFrame::TxCommitBatch {
437 tx_id: 7,
438 actions: vec![b"insert".to_vec(), b"update".to_vec()],
439 },
440 MainWalRecordFrame::FullPageImage {
441 tx_id: 8,
442 page_id: 9,
443 ckpt_epoch: 10,
444 data: vec![0xAA; 128],
445 },
446 MainWalRecordFrame::VectorInsert {
447 collection: "vectors".into(),
448 entity_id: 11,
449 vector: vec![1.0, -0.5, 0.25],
450 },
451 ];
452
453 for frame in frames {
454 let encoded = encode_main_wal_record_frame(&frame, 42).unwrap();
455 let mut cursor = Cursor::new(encoded);
456 let (term, decoded) = decode_main_wal_record_frame(&mut cursor, WAL_FILE_VERSION, 0)
457 .unwrap()
458 .unwrap();
459 assert_eq!(term, 42);
460 assert_eq!(decoded, frame);
461 }
462 }
463
464 #[test]
465 fn main_wal_record_accepts_legacy_v2_without_term() {
466 let mut encoded = Vec::new();
467 encoded.push(MainWalRecordType::Begin as u8);
468 encoded.extend_from_slice(&42u64.to_le_bytes());
469 let checksum = crc32(&encoded);
470 encoded.extend_from_slice(&checksum.to_le_bytes());
471
472 let mut cursor = Cursor::new(encoded);
473 let (term, frame) = decode_main_wal_record_frame(&mut cursor, WAL_FILE_VERSION_V2, 99)
474 .unwrap()
475 .unwrap();
476 assert_eq!(term, 99);
477 assert_eq!(frame, MainWalRecordFrame::Begin { tx_id: 42 });
478 }
479
480 #[test]
481 fn main_wal_record_detects_checksum_mismatch() {
482 let frame = MainWalRecordFrame::Begin { tx_id: 42 };
483 let mut encoded = encode_main_wal_record_frame(&frame, 1).unwrap();
484 let last = encoded.len() - 1;
485 encoded[last] ^= 0xFF;
486
487 let mut cursor = Cursor::new(encoded);
488 assert_eq!(
489 decode_main_wal_record_frame(&mut cursor, WAL_FILE_VERSION, 0)
490 .unwrap_err()
491 .to_string(),
492 "WAL record checksum mismatch"
493 );
494 }
495
496 #[test]
497 fn main_wal_record_compresses_and_decompresses_page_writes() {
498 let frame = MainWalRecordFrame::PageWrite {
499 tx_id: 7,
500 page_id: 3,
501 data: vec![0xAB; 1024],
502 };
503 let encoded = encode_main_wal_record_frame(&frame, 1).unwrap();
504 assert_eq!(encoded[0], MainWalRecordType::PageWriteCompressed as u8);
505
506 let mut cursor = Cursor::new(encoded);
507 let (_, decoded) = decode_main_wal_record_frame(&mut cursor, WAL_FILE_VERSION, 0)
508 .unwrap()
509 .unwrap();
510 assert_eq!(decoded, frame);
511 }
512}