1use crate::error::{KernelError, KernelResult, WalErrorKind};
24use crate::kernel_api::PageId;
25use crate::transaction::TransactionId;
26use bytes::{BufMut, Bytes, BytesMut};
27use parking_lot::{Mutex, RwLock};
28use std::collections::HashMap;
29use std::fs::{File, OpenOptions};
30use std::io::{Read, Seek, SeekFrom, Write};
31use std::path::{Path, PathBuf};
32use std::sync::atomic::{AtomicU64, Ordering};
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
36pub struct LogSequenceNumber(pub u64);
37
38impl LogSequenceNumber {
39 pub const INVALID: Self = Self(u64::MAX);
41
42 pub fn new(value: u64) -> Self {
44 Self(value)
45 }
46
47 pub fn value(&self) -> u64 {
49 self.0
50 }
51
52 pub fn is_valid(&self) -> bool {
54 self.0 != u64::MAX
55 }
56}
57
58impl std::fmt::Display for LogSequenceNumber {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 write!(f, "LSN({})", self.0)
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74#[repr(u8)]
75pub enum WalRecordType {
76 Begin = 1,
78 Commit = 2,
80 Abort = 3,
82 Update = 4,
84 Insert = 5,
86 Delete = 6,
88 Clr = 7,
90 CheckpointBegin = 8,
92 CheckpointEnd = 9,
94 AllocPage = 10,
96 FreePage = 11,
98}
99
100impl WalRecordType {
101 pub fn to_canonical(self) -> Option<sochdb_core::txn::WalRecordType> {
104 use sochdb_core::txn::WalRecordType as C;
105 match self {
106 Self::Begin => Some(C::TxnBegin),
107 Self::Commit => Some(C::TxnCommit),
108 Self::Abort => Some(C::TxnAbort),
109 Self::Update => Some(C::PageUpdate),
110 Self::Insert => Some(C::Data),
111 Self::Delete => Some(C::Delete),
112 Self::Clr => Some(C::CompensationLogRecord),
113 Self::CheckpointBegin => Some(C::Checkpoint),
114 Self::CheckpointEnd => Some(C::CheckpointEnd),
115 Self::AllocPage | Self::FreePage => None,
116 }
117 }
118
119 pub fn from_canonical(rt: sochdb_core::txn::WalRecordType) -> Option<Self> {
121 use sochdb_core::txn::WalRecordType as C;
122 match rt {
123 C::TxnBegin => Some(Self::Begin),
124 C::TxnCommit => Some(Self::Commit),
125 C::TxnAbort => Some(Self::Abort),
126 C::PageUpdate => Some(Self::Update),
127 C::Data => Some(Self::Insert),
128 C::Delete => Some(Self::Delete),
129 C::CompensationLogRecord => Some(Self::Clr),
130 C::Checkpoint => Some(Self::CheckpointBegin),
131 C::CheckpointEnd => Some(Self::CheckpointEnd),
132 _ => None,
133 }
134 }
135}
136
137impl TryFrom<u8> for WalRecordType {
138 type Error = KernelError;
139
140 fn try_from(value: u8) -> Result<Self, Self::Error> {
141 match value {
142 1 => Ok(Self::Begin),
143 2 => Ok(Self::Commit),
144 3 => Ok(Self::Abort),
145 4 => Ok(Self::Update),
146 5 => Ok(Self::Insert),
147 6 => Ok(Self::Delete),
148 7 => Ok(Self::Clr),
149 8 => Ok(Self::CheckpointBegin),
150 9 => Ok(Self::CheckpointEnd),
151 10 => Ok(Self::AllocPage),
152 11 => Ok(Self::FreePage),
153 _ => Err(KernelError::Wal {
154 kind: WalErrorKind::Corrupted,
155 }),
156 }
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct WalRecord {
163 pub lsn: LogSequenceNumber,
165 pub prev_lsn: LogSequenceNumber,
167 pub txn_id: TransactionId,
169 pub record_type: WalRecordType,
171 pub page_id: Option<PageId>,
173 pub redo_data: Bytes,
175 pub undo_data: Bytes,
177 pub checksum: u32,
179}
180
181impl WalRecord {
182 const HEADER_SIZE: usize = 45;
184
185 pub fn new(
187 lsn: LogSequenceNumber,
188 prev_lsn: LogSequenceNumber,
189 txn_id: TransactionId,
190 record_type: WalRecordType,
191 page_id: Option<PageId>,
192 redo_data: Bytes,
193 undo_data: Bytes,
194 ) -> Self {
195 let mut record = Self {
196 lsn,
197 prev_lsn,
198 txn_id,
199 record_type,
200 page_id,
201 redo_data,
202 undo_data,
203 checksum: 0,
204 };
205 record.checksum = record.compute_checksum();
206 record
207 }
208
209 pub fn serialize(&self) -> Bytes {
211 let mut buf = BytesMut::with_capacity(
212 Self::HEADER_SIZE + self.redo_data.len() + self.undo_data.len(),
213 );
214
215 buf.put_u64_le(self.lsn.0);
216 buf.put_u64_le(self.prev_lsn.0);
217 buf.put_u64_le(self.txn_id);
218 buf.put_u8(self.record_type as u8);
219 buf.put_u64_le(self.page_id.unwrap_or(0));
220 buf.put_u32_le(self.redo_data.len() as u32);
221 buf.put_u32_le(self.undo_data.len() as u32);
222 buf.put_slice(&self.redo_data);
223 buf.put_slice(&self.undo_data);
224 buf.put_u32_le(self.checksum);
225
226 buf.freeze()
227 }
228
229 pub fn deserialize(data: &[u8]) -> KernelResult<Self> {
231 if data.len() < Self::HEADER_SIZE {
232 return Err(KernelError::Wal {
233 kind: WalErrorKind::Corrupted,
234 });
235 }
236
237 let lsn = LogSequenceNumber(u64::from_le_bytes(data[0..8].try_into().unwrap()));
238 let prev_lsn = LogSequenceNumber(u64::from_le_bytes(data[8..16].try_into().unwrap()));
239 let txn_id = u64::from_le_bytes(data[16..24].try_into().unwrap());
240 let record_type = WalRecordType::try_from(data[24])?;
241 let page_id_raw = u64::from_le_bytes(data[25..33].try_into().unwrap());
242 let page_id = if page_id_raw == 0 {
243 None
244 } else {
245 Some(page_id_raw)
246 };
247 let redo_len = u32::from_le_bytes(data[33..37].try_into().unwrap()) as usize;
248 let undo_len = u32::from_le_bytes(data[37..41].try_into().unwrap()) as usize;
249
250 let expected_len = Self::HEADER_SIZE + redo_len + undo_len;
251 if data.len() < expected_len {
252 return Err(KernelError::Wal {
253 kind: WalErrorKind::Corrupted,
254 });
255 }
256
257 let redo_start = 41;
258 let redo_data = Bytes::copy_from_slice(&data[redo_start..redo_start + redo_len]);
259 let undo_start = redo_start + redo_len;
260 let undo_data = Bytes::copy_from_slice(&data[undo_start..undo_start + undo_len]);
261 let checksum_start = undo_start + undo_len;
262 let checksum =
263 u32::from_le_bytes(data[checksum_start..checksum_start + 4].try_into().unwrap());
264
265 let record = Self {
266 lsn,
267 prev_lsn,
268 txn_id,
269 record_type,
270 page_id,
271 redo_data,
272 undo_data,
273 checksum,
274 };
275
276 let computed = record.compute_checksum();
278 if computed != checksum {
279 return Err(KernelError::Wal {
280 kind: WalErrorKind::ChecksumMismatch {
281 expected: checksum,
282 actual: computed,
283 },
284 });
285 }
286
287 Ok(record)
288 }
289
290 fn compute_checksum(&self) -> u32 {
292 let mut hasher = crc32fast::Hasher::new();
293 hasher.update(&self.lsn.0.to_le_bytes());
294 hasher.update(&self.prev_lsn.0.to_le_bytes());
295 hasher.update(&self.txn_id.to_le_bytes());
296 hasher.update(&[self.record_type as u8]);
297 hasher.update(&self.page_id.unwrap_or(0).to_le_bytes());
298 hasher.update(&self.redo_data);
299 hasher.update(&self.undo_data);
300 hasher.finalize()
301 }
302
303 pub fn size(&self) -> usize {
305 Self::HEADER_SIZE + self.redo_data.len() + self.undo_data.len()
306 }
307}
308
309pub struct WalManager {
313 path: PathBuf,
315 file: Mutex<File>,
317 next_lsn: AtomicU64,
319 durable_lsn: AtomicU64,
321 txn_last_lsn: RwLock<HashMap<TransactionId, LogSequenceNumber>>,
323 checkpoint_lsn: AtomicU64,
325 write_buffer: Mutex<BytesMut>,
327 buffer_threshold: usize,
329}
330
331impl WalManager {
332 const DEFAULT_BUFFER_THRESHOLD: usize = 64 * 1024;
334
335 pub fn open(path: impl AsRef<Path>) -> KernelResult<Self> {
337 let path = path.as_ref().to_path_buf();
338
339 let file = OpenOptions::new()
340 .read(true)
341 .write(true)
342 .create(true)
343 .truncate(false)
344 .open(&path)?;
345
346 let file_len = file.metadata()?.len();
347 let next_lsn = file_len;
349
350 Ok(Self {
351 path,
352 file: Mutex::new(file),
353 next_lsn: AtomicU64::new(next_lsn),
354 durable_lsn: AtomicU64::new(if file_len > 0 { file_len } else { 0 }),
355 txn_last_lsn: RwLock::new(HashMap::new()),
356 checkpoint_lsn: AtomicU64::new(0),
357 write_buffer: Mutex::new(BytesMut::with_capacity(Self::DEFAULT_BUFFER_THRESHOLD)),
358 buffer_threshold: Self::DEFAULT_BUFFER_THRESHOLD,
359 })
360 }
361
362 pub fn append(&self, record: &mut WalRecord) -> KernelResult<LogSequenceNumber> {
366 let lsn = LogSequenceNumber(
368 self.next_lsn
369 .fetch_add(record.size() as u64, Ordering::SeqCst),
370 );
371 record.lsn = lsn;
372
373 if let Some(&prev) = self.txn_last_lsn.read().get(&record.txn_id) {
375 record.prev_lsn = prev;
376 }
377
378 record.checksum = record.compute_checksum();
380
381 let data = record.serialize();
383
384 let mut buffer = self.write_buffer.lock();
386 buffer.extend_from_slice(&data);
387
388 self.txn_last_lsn.write().insert(record.txn_id, lsn);
390
391 if buffer.len() >= self.buffer_threshold {
393 drop(buffer);
394 self.flush()?;
395 }
396
397 Ok(lsn)
398 }
399
400 pub fn flush(&self) -> KernelResult<()> {
402 let mut buffer = self.write_buffer.lock();
403 if buffer.is_empty() {
404 return Ok(());
405 }
406
407 let data = buffer.split().freeze();
408 let mut file = self.file.lock();
409
410 file.seek(SeekFrom::End(0))?;
412 file.write_all(&data)?;
413
414 Ok(())
415 }
416
417 pub fn sync(&self) -> KernelResult<LogSequenceNumber> {
419 self.flush()?;
421
422 let file = self.file.lock();
424 file.sync_all()?;
425
426 let current_lsn = self.next_lsn.load(Ordering::SeqCst);
428 self.durable_lsn.store(current_lsn, Ordering::SeqCst);
429
430 Ok(LogSequenceNumber(current_lsn))
431 }
432
433 pub fn durable_lsn(&self) -> LogSequenceNumber {
435 LogSequenceNumber(self.durable_lsn.load(Ordering::SeqCst))
436 }
437
438 pub fn next_lsn(&self) -> LogSequenceNumber {
440 LogSequenceNumber(self.next_lsn.load(Ordering::SeqCst))
441 }
442
443 pub fn log_begin(&self, txn_id: TransactionId) -> KernelResult<LogSequenceNumber> {
445 let mut record = WalRecord::new(
446 LogSequenceNumber::INVALID,
447 LogSequenceNumber::INVALID,
448 txn_id,
449 WalRecordType::Begin,
450 None,
451 Bytes::new(),
452 Bytes::new(),
453 );
454 self.append(&mut record)
455 }
456
457 pub fn log_commit(&self, txn_id: TransactionId) -> KernelResult<LogSequenceNumber> {
459 let prev_lsn = self
460 .txn_last_lsn
461 .read()
462 .get(&txn_id)
463 .copied()
464 .unwrap_or(LogSequenceNumber::INVALID);
465 let mut record = WalRecord::new(
466 LogSequenceNumber::INVALID,
467 prev_lsn,
468 txn_id,
469 WalRecordType::Commit,
470 None,
471 Bytes::new(),
472 Bytes::new(),
473 );
474 let lsn = self.append(&mut record)?;
475
476 self.sync()?;
478
479 self.txn_last_lsn.write().remove(&txn_id);
481
482 Ok(lsn)
483 }
484
485 pub fn log_abort(&self, txn_id: TransactionId) -> KernelResult<LogSequenceNumber> {
487 let prev_lsn = self
488 .txn_last_lsn
489 .read()
490 .get(&txn_id)
491 .copied()
492 .unwrap_or(LogSequenceNumber::INVALID);
493 let mut record = WalRecord::new(
494 LogSequenceNumber::INVALID,
495 prev_lsn,
496 txn_id,
497 WalRecordType::Abort,
498 None,
499 Bytes::new(),
500 Bytes::new(),
501 );
502 let lsn = self.append(&mut record)?;
503
504 self.txn_last_lsn.write().remove(&txn_id);
506
507 Ok(lsn)
508 }
509
510 pub fn log_update(
512 &self,
513 txn_id: TransactionId,
514 page_id: PageId,
515 redo_data: Bytes,
516 undo_data: Bytes,
517 ) -> KernelResult<LogSequenceNumber> {
518 let prev_lsn = self
519 .txn_last_lsn
520 .read()
521 .get(&txn_id)
522 .copied()
523 .unwrap_or(LogSequenceNumber::INVALID);
524 let mut record = WalRecord::new(
525 LogSequenceNumber::INVALID,
526 prev_lsn,
527 txn_id,
528 WalRecordType::Update,
529 Some(page_id),
530 redo_data,
531 undo_data,
532 );
533 self.append(&mut record)
534 }
535
536 pub fn log_checkpoint_begin(&self) -> KernelResult<LogSequenceNumber> {
538 let mut record = WalRecord::new(
539 LogSequenceNumber::INVALID,
540 LogSequenceNumber::INVALID,
541 0, WalRecordType::CheckpointBegin,
543 None,
544 Bytes::new(),
545 Bytes::new(),
546 );
547 self.append(&mut record)
548 }
549
550 pub fn log_checkpoint_end(
552 &self,
553 active_txns: &[TransactionId],
554 ) -> KernelResult<LogSequenceNumber> {
555 let mut redo_data = BytesMut::with_capacity(active_txns.len() * 8);
557 for &txn_id in active_txns {
558 redo_data.put_u64_le(txn_id);
559 }
560
561 let mut record = WalRecord::new(
562 LogSequenceNumber::INVALID,
563 LogSequenceNumber::INVALID,
564 0, WalRecordType::CheckpointEnd,
566 None,
567 redo_data.freeze(),
568 Bytes::new(),
569 );
570 let lsn = self.append(&mut record)?;
571
572 self.sync()?;
574
575 self.checkpoint_lsn.store(lsn.0, Ordering::SeqCst);
577
578 Ok(lsn)
579 }
580
581 pub fn checkpoint_lsn(&self) -> Option<LogSequenceNumber> {
583 let lsn = self.checkpoint_lsn.load(Ordering::SeqCst);
584 if lsn == 0 {
585 None
586 } else {
587 Some(LogSequenceNumber(lsn))
588 }
589 }
590
591 pub fn read_from(&self, start_lsn: LogSequenceNumber) -> KernelResult<Vec<WalRecord>> {
593 self.flush()?;
595
596 let mut file = self.file.lock();
597 let file_len = file.metadata()?.len();
598
599 if start_lsn.0 >= file_len {
600 return Ok(Vec::new());
601 }
602
603 file.seek(SeekFrom::Start(start_lsn.0))?;
604
605 let mut buffer = vec![0u8; (file_len - start_lsn.0) as usize];
606 file.read_exact(&mut buffer)?;
607
608 let mut records = Vec::new();
609 let mut offset = 0;
610
611 while offset < buffer.len() {
612 match WalRecord::deserialize(&buffer[offset..]) {
613 Ok(record) => {
614 let size = record.size();
615 records.push(record);
616 offset += size;
617 }
618 Err(_) => {
619 break;
621 }
622 }
623 }
624
625 Ok(records)
626 }
627
628 pub fn path(&self) -> &Path {
630 &self.path
631 }
632
633 pub fn truncate_before(&self, _lsn: LogSequenceNumber) -> KernelResult<()> {
635 Ok(())
638 }
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644 use tempfile::tempdir;
645
646 #[test]
647 fn test_wal_record_serialize_deserialize() {
648 let record = WalRecord::new(
649 LogSequenceNumber(100),
650 LogSequenceNumber(50),
651 1,
652 WalRecordType::Update,
653 Some(42),
654 Bytes::from_static(b"redo data"),
655 Bytes::from_static(b"undo data"),
656 );
657
658 let serialized = record.serialize();
659 let deserialized = WalRecord::deserialize(&serialized).unwrap();
660
661 assert_eq!(record.lsn, deserialized.lsn);
662 assert_eq!(record.prev_lsn, deserialized.prev_lsn);
663 assert_eq!(record.txn_id, deserialized.txn_id);
664 assert_eq!(record.record_type, deserialized.record_type);
665 assert_eq!(record.page_id, deserialized.page_id);
666 assert_eq!(record.redo_data, deserialized.redo_data);
667 assert_eq!(record.undo_data, deserialized.undo_data);
668 }
669
670 #[test]
671 fn test_wal_manager_append_sync() {
672 let dir = tempdir().unwrap();
673 let wal_path = dir.path().join("test.wal");
674
675 let wal = WalManager::open(&wal_path).unwrap();
676
677 let lsn1 = wal.log_begin(1).unwrap();
679 assert!(lsn1.is_valid());
680
681 let lsn2 = wal
683 .log_update(
684 1,
685 100,
686 Bytes::from_static(b"new value"),
687 Bytes::from_static(b"old value"),
688 )
689 .unwrap();
690 assert!(lsn2 > lsn1);
691
692 let durable = wal.sync().unwrap();
694 assert!(durable >= lsn2);
695 }
696
697 #[test]
698 fn test_wal_recovery() {
699 let dir = tempdir().unwrap();
700 let wal_path = dir.path().join("test.wal");
701
702 let first_lsn = {
704 let wal = WalManager::open(&wal_path).unwrap();
705 let lsn = wal.log_begin(1).unwrap();
706 wal.log_update(1, 100, Bytes::from_static(b"data"), Bytes::new())
707 .unwrap();
708 wal.log_commit(1).unwrap();
709 lsn
710 };
711
712 {
714 let wal = WalManager::open(&wal_path).unwrap();
715 let records = wal.read_from(first_lsn).unwrap();
716
717 assert!(
718 records.len() >= 3,
719 "Expected at least 3 records, got {}",
720 records.len()
721 );
722 assert_eq!(records[0].record_type, WalRecordType::Begin);
723 assert_eq!(records[1].record_type, WalRecordType::Update);
724 assert_eq!(records[2].record_type, WalRecordType::Commit);
725 }
726 }
727}