Skip to main content

aegis_storage/
wal.rs

1//! Aegis WAL - Write-Ahead Logging
2//!
3//! Write-ahead log for durability and crash recovery. All modifications are
4//! logged before being applied to data pages, ensuring ACID durability
5//! guarantees even in the face of system failures.
6//!
7//! Key Features:
8//! - Sequential write optimization for high throughput
9//! - Log sequence numbers (LSN) for ordering and recovery
10//! - Checkpoint support for recovery time optimization
11//! - Segment-based log file management
12//!
13//! @version 0.1.0
14//! @author AutomataNexus Development Team
15
16use aegis_common::{AegisError, Lsn, PageId, Result, TransactionId};
17use bytes::{Buf, BufMut, Bytes, BytesMut};
18use parking_lot::Mutex;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet, VecDeque};
21use std::fs::{File, OpenOptions};
22use std::io::{BufReader, BufWriter, Read, Write};
23use std::path::PathBuf;
24use std::sync::atomic::{AtomicU64, Ordering};
25
26// =============================================================================
27// Constants
28// =============================================================================
29
30pub const WAL_SEGMENT_SIZE: u64 = 64 * 1024 * 1024; // 64 MB
31/// Size of the fixed record header before variable data:
32/// lsn(8) + prev_lsn(8) + tx_id(8) + type(1) + has_page(1) + padding(2) + page_id(8) + data_len(4) = 40
33pub const WAL_RECORD_HEADER_SIZE: usize = 40;
34
35// =============================================================================
36// Log Record Types
37// =============================================================================
38
39/// Type of WAL record.
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
41#[repr(u8)]
42pub enum LogRecordType {
43    Begin = 1,
44    Commit = 2,
45    Abort = 3,
46    Insert = 4,
47    Update = 5,
48    Delete = 6,
49    Checkpoint = 7,
50    CompensationLogRecord = 8,
51}
52
53impl From<u8> for LogRecordType {
54    fn from(value: u8) -> Self {
55        match value {
56            1 => LogRecordType::Begin,
57            2 => LogRecordType::Commit,
58            3 => LogRecordType::Abort,
59            4 => LogRecordType::Insert,
60            5 => LogRecordType::Update,
61            6 => LogRecordType::Delete,
62            7 => LogRecordType::Checkpoint,
63            8 => LogRecordType::CompensationLogRecord,
64            _ => LogRecordType::Begin,
65        }
66    }
67}
68
69// =============================================================================
70// Log Record
71// =============================================================================
72
73/// A single record in the write-ahead log.
74#[derive(Debug, Clone)]
75pub struct LogRecord {
76    pub lsn: Lsn,
77    pub prev_lsn: Option<Lsn>,
78    pub tx_id: TransactionId,
79    pub record_type: LogRecordType,
80    pub page_id: Option<PageId>,
81    pub data: Bytes,
82}
83
84impl LogRecord {
85    /// Create a new transaction begin record.
86    pub fn begin(lsn: Lsn, tx_id: TransactionId) -> Self {
87        Self {
88            lsn,
89            prev_lsn: None,
90            tx_id,
91            record_type: LogRecordType::Begin,
92            page_id: None,
93            data: Bytes::new(),
94        }
95    }
96
97    /// Create a new transaction commit record.
98    pub fn commit(lsn: Lsn, prev_lsn: Lsn, tx_id: TransactionId) -> Self {
99        Self {
100            lsn,
101            prev_lsn: Some(prev_lsn),
102            tx_id,
103            record_type: LogRecordType::Commit,
104            page_id: None,
105            data: Bytes::new(),
106        }
107    }
108
109    /// Create a new transaction abort record.
110    pub fn abort(lsn: Lsn, prev_lsn: Lsn, tx_id: TransactionId) -> Self {
111        Self {
112            lsn,
113            prev_lsn: Some(prev_lsn),
114            tx_id,
115            record_type: LogRecordType::Abort,
116            page_id: None,
117            data: Bytes::new(),
118        }
119    }
120
121    /// Create a data modification record.
122    pub fn data_record(
123        lsn: Lsn,
124        prev_lsn: Option<Lsn>,
125        tx_id: TransactionId,
126        record_type: LogRecordType,
127        page_id: PageId,
128        data: Bytes,
129    ) -> Self {
130        Self {
131            lsn,
132            prev_lsn,
133            tx_id,
134            record_type,
135            page_id: Some(page_id),
136            data,
137        }
138    }
139
140    /// Serialize the record to bytes.
141    pub fn to_bytes(&self) -> Bytes {
142        // header(40) + data + checksum(4)
143        let mut buf = BytesMut::with_capacity(WAL_RECORD_HEADER_SIZE + self.data.len() + 4);
144
145        buf.put_u64_le(self.lsn.0);
146        buf.put_u64_le(self.prev_lsn.map_or(0, |l| l.0));
147        buf.put_u64_le(self.tx_id.0);
148        buf.put_u8(self.record_type as u8);
149        buf.put_u8(self.page_id.is_some() as u8);
150        buf.put_u16_le(0); // padding
151        buf.put_u64_le(self.page_id.map_or(0, |p| p.0));
152        buf.put_u32_le(self.data.len() as u32);
153        buf.put(self.data.clone());
154
155        let checksum = crc32fast::hash(&buf);
156        buf.put_u32_le(checksum);
157
158        buf.freeze()
159    }
160
161    /// Deserialize a record from bytes.
162    pub fn from_bytes(data: &[u8]) -> Result<Self> {
163        // Minimum size: header(40) + checksum(4)
164        if data.len() < WAL_RECORD_HEADER_SIZE + 4 {
165            return Err(AegisError::Corruption("Log record too small".to_string()));
166        }
167
168        let mut buf = data;
169        let lsn = Lsn(buf.get_u64_le());
170        let prev_lsn_raw = buf.get_u64_le();
171        let prev_lsn = if prev_lsn_raw == 0 {
172            None
173        } else {
174            Some(Lsn(prev_lsn_raw))
175        };
176        let tx_id = TransactionId(buf.get_u64_le());
177        let record_type = LogRecordType::from(buf.get_u8());
178        let has_page_id = buf.get_u8() != 0;
179        let _padding = buf.get_u16_le();
180        let page_id_raw = buf.get_u64_le();
181        let page_id = if has_page_id {
182            Some(PageId(page_id_raw))
183        } else {
184            None
185        };
186        let data_len = buf.get_u32_le() as usize;
187
188        if buf.remaining() < data_len + 4 {
189            return Err(AegisError::Corruption(
190                "Log record data truncated".to_string(),
191            ));
192        }
193
194        let record_data = Bytes::copy_from_slice(&buf[..data_len]);
195        buf.advance(data_len);
196
197        let stored_checksum = buf.get_u32_le();
198        let computed_checksum = crc32fast::hash(&data[..data.len() - 4]);
199
200        if stored_checksum != computed_checksum {
201            return Err(AegisError::Corruption(
202                "Log record checksum mismatch".to_string(),
203            ));
204        }
205
206        Ok(Self {
207            lsn,
208            prev_lsn,
209            tx_id,
210            record_type,
211            page_id,
212            data: record_data,
213        })
214    }
215}
216
217// =============================================================================
218// Write-Ahead Log
219// =============================================================================
220
221/// Write-ahead log for durability with segment rotation and crash recovery.
222pub struct WriteAheadLog {
223    wal_dir: PathBuf,
224    current_lsn: AtomicU64,
225    flushed_lsn: AtomicU64,
226    /// LSN of the last checkpoint
227    checkpoint_lsn: AtomicU64,
228    buffer: Mutex<WalBuffer>,
229    sync_on_commit: bool,
230}
231
232struct WalBuffer {
233    records: VecDeque<LogRecord>,
234    size: usize,
235    writer: Option<BufWriter<File>>,
236    segment_offset: u64,
237    /// Current segment number
238    current_segment: u64,
239}
240
241/// Result of WAL recovery.
242#[derive(Debug)]
243pub struct RecoveryResult {
244    /// Records that need to be redone (committed transactions)
245    pub redo_records: Vec<LogRecord>,
246    /// Transaction IDs that were in-progress (need rollback)
247    pub incomplete_transactions: HashSet<TransactionId>,
248    /// The highest LSN found during recovery
249    pub max_lsn: Lsn,
250    /// Number of records processed
251    pub records_processed: usize,
252    /// Number of segments scanned
253    pub segments_scanned: usize,
254}
255
256impl Default for RecoveryResult {
257    fn default() -> Self {
258        Self {
259            redo_records: Vec::new(),
260            incomplete_transactions: HashSet::new(),
261            max_lsn: Lsn(0),
262            records_processed: 0,
263            segments_scanned: 0,
264        }
265    }
266}
267
268/// Checkpoint data stored in checkpoint records.
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct CheckpointData {
271    /// Active transaction IDs at checkpoint time
272    pub active_transactions: Vec<TransactionId>,
273    /// Dirty page IDs that need to be flushed
274    pub dirty_pages: Vec<PageId>,
275}
276
277impl WriteAheadLog {
278    /// Create a new WAL in the specified directory.
279    pub fn new(wal_dir: PathBuf, sync_on_commit: bool) -> Result<Self> {
280        std::fs::create_dir_all(&wal_dir)?;
281
282        // Find the highest segment number
283        let current_segment = Self::find_latest_segment(&wal_dir)?;
284
285        let wal = Self {
286            wal_dir,
287            current_lsn: AtomicU64::new(1),
288            flushed_lsn: AtomicU64::new(0),
289            checkpoint_lsn: AtomicU64::new(0),
290            buffer: Mutex::new(WalBuffer {
291                records: VecDeque::new(),
292                size: 0,
293                writer: None,
294                segment_offset: 0,
295                current_segment,
296            }),
297            sync_on_commit,
298        };
299
300        wal.open_segment(current_segment)?;
301        Ok(wal)
302    }
303
304    /// Create a WAL and perform recovery from existing log files.
305    pub fn open_and_recover(
306        wal_dir: PathBuf,
307        sync_on_commit: bool,
308    ) -> Result<(Self, RecoveryResult)> {
309        std::fs::create_dir_all(&wal_dir)?;
310
311        // Perform recovery first
312        let recovery = Self::recover_from_directory(&wal_dir)?;
313
314        // Create WAL starting from recovered LSN
315        let current_segment = Self::find_latest_segment(&wal_dir)?;
316        let next_lsn = recovery.max_lsn.0.saturating_add(1).max(1);
317
318        let wal = Self {
319            wal_dir,
320            current_lsn: AtomicU64::new(next_lsn),
321            flushed_lsn: AtomicU64::new(recovery.max_lsn.0),
322            checkpoint_lsn: AtomicU64::new(0),
323            buffer: Mutex::new(WalBuffer {
324                records: VecDeque::new(),
325                size: 0,
326                writer: None,
327                segment_offset: 0,
328                current_segment,
329            }),
330            sync_on_commit,
331        };
332
333        wal.open_segment(current_segment)?;
334        Ok((wal, recovery))
335    }
336
337    /// Find the latest segment number in the WAL directory.
338    fn find_latest_segment(wal_dir: &PathBuf) -> Result<u64> {
339        let mut max_segment = 0u64;
340
341        if let Ok(entries) = std::fs::read_dir(wal_dir) {
342            for entry in entries.flatten() {
343                let path = entry.path();
344                if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
345                    if let Some(num_str) = name
346                        .strip_prefix("wal_")
347                        .and_then(|s| s.strip_suffix(".log"))
348                    {
349                        if let Ok(num) = num_str.parse::<u64>() {
350                            max_segment = max_segment.max(num);
351                        }
352                    }
353                }
354            }
355        }
356
357        Ok(max_segment)
358    }
359
360    /// Allocate the next LSN.
361    pub fn next_lsn(&self) -> Lsn {
362        Lsn(self.current_lsn.fetch_add(1, Ordering::SeqCst))
363    }
364
365    /// Get the current LSN.
366    pub fn current_lsn(&self) -> Lsn {
367        Lsn(self.current_lsn.load(Ordering::SeqCst))
368    }
369
370    /// Get the flushed LSN.
371    pub fn flushed_lsn(&self) -> Lsn {
372        Lsn(self.flushed_lsn.load(Ordering::SeqCst))
373    }
374
375    /// Get the checkpoint LSN.
376    pub fn checkpoint_lsn(&self) -> Lsn {
377        Lsn(self.checkpoint_lsn.load(Ordering::SeqCst))
378    }
379
380    /// Append a log record, rotating segment if needed.
381    pub fn append(&self, record: LogRecord) -> Result<Lsn> {
382        let lsn = record.lsn;
383        let data = record.to_bytes();
384        let data_len = data.len() as u64;
385
386        let mut buffer = self.buffer.lock();
387
388        // Check if we need to rotate to a new segment
389        if buffer.segment_offset + data_len > WAL_SEGMENT_SIZE {
390            drop(buffer);
391            self.rotate_segment()?;
392            buffer = self.buffer.lock();
393        }
394
395        buffer.records.push_back(record);
396        buffer.size += data.len();
397
398        if let Some(ref mut writer) = buffer.writer {
399            writer.write_all(&data)?;
400            buffer.segment_offset += data_len;
401        }
402
403        Ok(lsn)
404    }
405
406    /// Rotate to a new WAL segment.
407    fn rotate_segment(&self) -> Result<()> {
408        let mut buffer = self.buffer.lock();
409
410        // Flush current segment
411        if let Some(ref mut writer) = buffer.writer {
412            writer.flush()?;
413            if self.sync_on_commit {
414                writer.get_ref().sync_all()?;
415            }
416        }
417
418        // Rename current segment to numbered segment
419        let old_path = self.wal_dir.join("wal_current.log");
420        let new_segment = buffer.current_segment + 1;
421        let new_path = self
422            .wal_dir
423            .join(format!("wal_{:016}.log", buffer.current_segment));
424
425        if old_path.exists() {
426            std::fs::rename(&old_path, &new_path)?;
427        }
428
429        // Open new segment
430        buffer.current_segment = new_segment;
431        let file = OpenOptions::new()
432            .create(true)
433            .write(true)
434            .truncate(true)
435            .open(&old_path)?;
436
437        buffer.writer = Some(BufWriter::new(file));
438        buffer.segment_offset = 0;
439
440        tracing::info!("Rotated WAL to segment {}", new_segment);
441        Ok(())
442    }
443
444    /// Flush all buffered records to disk.
445    pub fn flush(&self) -> Result<Lsn> {
446        let mut buffer = self.buffer.lock();
447
448        if let Some(ref mut writer) = buffer.writer {
449            writer.flush()?;
450
451            if self.sync_on_commit {
452                writer.get_ref().sync_all()?;
453            }
454        }
455
456        let flushed = self.current_lsn.load(Ordering::SeqCst) - 1;
457        self.flushed_lsn.store(flushed, Ordering::SeqCst);
458        buffer.records.clear();
459        buffer.size = 0;
460
461        Ok(Lsn(flushed))
462    }
463
464    /// Log a transaction begin.
465    pub fn log_begin(&self, tx_id: TransactionId) -> Result<Lsn> {
466        let lsn = self.next_lsn();
467        let record = LogRecord::begin(lsn, tx_id);
468        self.append(record)
469    }
470
471    /// Log a transaction commit.
472    pub fn log_commit(&self, tx_id: TransactionId, prev_lsn: Lsn) -> Result<Lsn> {
473        let lsn = self.next_lsn();
474        let record = LogRecord::commit(lsn, prev_lsn, tx_id);
475        self.append(record)?;
476
477        if self.sync_on_commit {
478            self.flush()?;
479        }
480
481        Ok(lsn)
482    }
483
484    /// Log a transaction abort.
485    pub fn log_abort(&self, tx_id: TransactionId, prev_lsn: Lsn) -> Result<Lsn> {
486        let lsn = self.next_lsn();
487        let record = LogRecord::abort(lsn, prev_lsn, tx_id);
488        self.append(record)
489    }
490
491    /// Log an insert operation.
492    pub fn log_insert(
493        &self,
494        tx_id: TransactionId,
495        prev_lsn: Option<Lsn>,
496        page_id: PageId,
497        data: Bytes,
498    ) -> Result<Lsn> {
499        let lsn = self.next_lsn();
500        let record =
501            LogRecord::data_record(lsn, prev_lsn, tx_id, LogRecordType::Insert, page_id, data);
502        self.append(record)
503    }
504
505    /// Log an update operation.
506    pub fn log_update(
507        &self,
508        tx_id: TransactionId,
509        prev_lsn: Option<Lsn>,
510        page_id: PageId,
511        data: Bytes,
512    ) -> Result<Lsn> {
513        let lsn = self.next_lsn();
514        let record =
515            LogRecord::data_record(lsn, prev_lsn, tx_id, LogRecordType::Update, page_id, data);
516        self.append(record)
517    }
518
519    /// Log a delete operation.
520    pub fn log_delete(
521        &self,
522        tx_id: TransactionId,
523        prev_lsn: Option<Lsn>,
524        page_id: PageId,
525        data: Bytes,
526    ) -> Result<Lsn> {
527        let lsn = self.next_lsn();
528        let record =
529            LogRecord::data_record(lsn, prev_lsn, tx_id, LogRecordType::Delete, page_id, data);
530        self.append(record)
531    }
532
533    /// Log a checkpoint with active transaction and dirty page information.
534    pub fn log_checkpoint(
535        &self,
536        active_transactions: Vec<TransactionId>,
537        dirty_pages: Vec<PageId>,
538    ) -> Result<Lsn> {
539        let lsn = self.next_lsn();
540        let checkpoint_data = CheckpointData {
541            active_transactions,
542            dirty_pages,
543        };
544        let data = serde_json::to_vec(&checkpoint_data)
545            .map_err(|e| AegisError::Internal(format!("Failed to serialize checkpoint: {}", e)))?;
546
547        let record = LogRecord {
548            lsn,
549            prev_lsn: None,
550            tx_id: TransactionId(0),
551            record_type: LogRecordType::Checkpoint,
552            page_id: None,
553            data: Bytes::from(data),
554        };
555
556        self.append(record)?;
557        self.flush()?;
558
559        // Update checkpoint LSN
560        self.checkpoint_lsn.store(lsn.0, Ordering::SeqCst);
561
562        tracing::info!("Checkpoint created at LSN {}", lsn.0);
563        Ok(lsn)
564    }
565
566    /// Truncate WAL segments older than the checkpoint LSN.
567    pub fn truncate_before_checkpoint(&self) -> Result<usize> {
568        let checkpoint = self.checkpoint_lsn.load(Ordering::SeqCst);
569        if checkpoint == 0 {
570            return Ok(0);
571        }
572
573        let mut removed = 0;
574        let buffer = self.buffer.lock();
575        let current_segment = buffer.current_segment;
576        drop(buffer);
577
578        // Find and remove old segments
579        if let Ok(entries) = std::fs::read_dir(&self.wal_dir) {
580            for entry in entries.flatten() {
581                let path = entry.path();
582                if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
583                    if let Some(num_str) = name
584                        .strip_prefix("wal_")
585                        .and_then(|s| s.strip_suffix(".log"))
586                    {
587                        if let Ok(num) = num_str.parse::<u64>() {
588                            // Keep at least the current segment and one before
589                            if num + 2 < current_segment && std::fs::remove_file(&path).is_ok() {
590                                removed += 1;
591                                tracing::debug!("Removed old WAL segment: {}", name);
592                            }
593                        }
594                    }
595                }
596            }
597        }
598
599        Ok(removed)
600    }
601
602    /// Recover from WAL directory, returning records needed for redo/undo.
603    fn recover_from_directory(wal_dir: &PathBuf) -> Result<RecoveryResult> {
604        let mut result = RecoveryResult::default();
605        let mut tx_status: HashMap<TransactionId, bool> = HashMap::new(); // true = committed
606        let mut tx_records: HashMap<TransactionId, Vec<LogRecord>> = HashMap::new();
607
608        // Collect all segment files
609        let mut segments: Vec<PathBuf> = Vec::new();
610        if let Ok(entries) = std::fs::read_dir(wal_dir) {
611            for entry in entries.flatten() {
612                let path = entry.path();
613                if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
614                    if name.starts_with("wal_") && name.ends_with(".log") {
615                        segments.push(path);
616                    }
617                }
618            }
619        }
620
621        // Sort segments by number
622        segments.sort();
623
624        // Process each segment
625        for segment_path in &segments {
626            result.segments_scanned += 1;
627            let records = Self::read_segment(segment_path)?;
628
629            for record in records {
630                result.records_processed += 1;
631                result.max_lsn = result.max_lsn.max(record.lsn);
632
633                match record.record_type {
634                    LogRecordType::Begin => {
635                        tx_status.insert(record.tx_id, false);
636                        tx_records.insert(record.tx_id, Vec::new());
637                    }
638                    LogRecordType::Commit => {
639                        tx_status.insert(record.tx_id, true);
640                    }
641                    LogRecordType::Abort => {
642                        // Mark as complete but not committed - remove from tracking
643                        tx_status.remove(&record.tx_id);
644                        tx_records.remove(&record.tx_id);
645                    }
646                    LogRecordType::Insert | LogRecordType::Update | LogRecordType::Delete => {
647                        if let Some(records) = tx_records.get_mut(&record.tx_id) {
648                            records.push(record.clone());
649                        }
650                    }
651                    LogRecordType::Checkpoint => {
652                        // Parse checkpoint data to find active transactions
653                        if let Ok(checkpoint) =
654                            serde_json::from_slice::<CheckpointData>(&record.data)
655                        {
656                            for tx_id in checkpoint.active_transactions {
657                                tx_status.entry(tx_id).or_insert(false);
658                            }
659                        }
660                    }
661                    LogRecordType::CompensationLogRecord => {
662                        // CLRs are used during undo - skip for now
663                    }
664                }
665            }
666        }
667
668        // Collect redo records for committed transactions
669        for (tx_id, committed) in &tx_status {
670            if *committed {
671                if let Some(records) = tx_records.remove(tx_id) {
672                    result.redo_records.extend(records);
673                }
674            } else {
675                result.incomplete_transactions.insert(*tx_id);
676            }
677        }
678
679        // Sort redo records by LSN
680        result.redo_records.sort_by_key(|r| r.lsn);
681
682        tracing::info!(
683            "WAL recovery: {} records processed, {} redo records, {} incomplete transactions",
684            result.records_processed,
685            result.redo_records.len(),
686            result.incomplete_transactions.len()
687        );
688
689        Ok(result)
690    }
691
692    /// Read all records from a segment file.
693    fn read_segment(path: &PathBuf) -> Result<Vec<LogRecord>> {
694        let mut file = BufReader::new(File::open(path)?);
695        let mut records = Vec::new();
696        let mut buffer = Vec::new();
697
698        // Read the entire file
699        file.read_to_end(&mut buffer)?;
700
701        let mut offset = 0;
702        while offset < buffer.len() {
703            // Need at least header + checksum
704            if buffer.len() - offset < WAL_RECORD_HEADER_SIZE + 4 {
705                break;
706            }
707
708            // Read data length from header (at offset 36: after lsn, prev_lsn, tx_id, type, has_page, padding, page_id)
709            let data_len_offset = offset + 36;
710            if data_len_offset + 4 > buffer.len() {
711                break;
712            }
713
714            let data_len = u32::from_le_bytes([
715                buffer[data_len_offset],
716                buffer[data_len_offset + 1],
717                buffer[data_len_offset + 2],
718                buffer[data_len_offset + 3],
719            ]) as usize;
720
721            let total_record_len = WAL_RECORD_HEADER_SIZE + data_len + 4; // header + data + checksum
722
723            if offset + total_record_len > buffer.len() {
724                break;
725            }
726
727            match LogRecord::from_bytes(&buffer[offset..offset + total_record_len]) {
728                Ok(record) => {
729                    records.push(record);
730                    offset += total_record_len;
731                }
732                Err(e) => {
733                    tracing::warn!("Failed to parse WAL record at offset {}: {}", offset, e);
734                    break;
735                }
736            }
737        }
738
739        Ok(records)
740    }
741
742    fn open_segment(&self, segment_num: u64) -> Result<()> {
743        let segment_path = self.wal_dir.join("wal_current.log");
744        let file = OpenOptions::new()
745            .create(true)
746            .append(true)
747            .open(&segment_path)?;
748
749        // Get current file size
750        let metadata = file.metadata()?;
751        let offset = metadata.len();
752
753        let mut buffer = self.buffer.lock();
754        buffer.writer = Some(BufWriter::new(file));
755        buffer.segment_offset = offset;
756        buffer.current_segment = segment_num;
757
758        Ok(())
759    }
760}
761
762// =============================================================================
763// Tests
764// =============================================================================
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769
770    #[test]
771    fn test_log_record_roundtrip() {
772        let record = LogRecord::begin(Lsn(1), TransactionId(100));
773        let bytes = record.to_bytes();
774        let restored = LogRecord::from_bytes(&bytes).expect("failed to deserialize log record");
775
776        assert_eq!(restored.lsn, Lsn(1));
777        assert_eq!(restored.tx_id, TransactionId(100));
778        assert_eq!(restored.record_type, LogRecordType::Begin);
779    }
780
781    #[test]
782    fn test_log_record_with_data() {
783        let data = Bytes::from("test data");
784        let record = LogRecord::data_record(
785            Lsn(5),
786            Some(Lsn(4)),
787            TransactionId(100),
788            LogRecordType::Insert,
789            PageId(42),
790            data.clone(),
791        );
792
793        let bytes = record.to_bytes();
794        let restored =
795            LogRecord::from_bytes(&bytes).expect("failed to deserialize log record with data");
796
797        assert_eq!(restored.lsn, Lsn(5));
798        assert_eq!(restored.prev_lsn, Some(Lsn(4)));
799        assert_eq!(restored.page_id, Some(PageId(42)));
800        assert_eq!(restored.data, data);
801    }
802
803    #[test]
804    fn test_wal_operations() {
805        let temp_dir = tempfile::tempdir().expect("failed to create temp directory");
806        let wal =
807            WriteAheadLog::new(temp_dir.path().to_path_buf(), false).expect("failed to create WAL");
808
809        let tx_id = TransactionId(1);
810        let begin_lsn = wal.log_begin(tx_id).expect("failed to log begin");
811        assert_eq!(begin_lsn, Lsn(1));
812
813        let insert_lsn = wal
814            .log_insert(tx_id, Some(begin_lsn), PageId(1), Bytes::from("data"))
815            .expect("failed to log insert");
816        assert_eq!(insert_lsn, Lsn(2));
817
818        let commit_lsn = wal
819            .log_commit(tx_id, insert_lsn)
820            .expect("failed to log commit");
821        assert_eq!(commit_lsn, Lsn(3));
822    }
823
824    #[test]
825    fn test_wal_recovery_committed_transaction() {
826        let temp_dir = tempfile::tempdir().expect("failed to create temp directory");
827        let wal_dir = temp_dir.path().to_path_buf();
828
829        // Create WAL and write a committed transaction
830        {
831            let wal = WriteAheadLog::new(wal_dir.clone(), true).expect("failed to create WAL");
832            let tx_id = TransactionId(1);
833
834            let begin_lsn = wal.log_begin(tx_id).expect("failed to log begin");
835            let insert_lsn = wal
836                .log_insert(tx_id, Some(begin_lsn), PageId(1), Bytes::from("test data"))
837                .expect("failed to log insert");
838            wal.log_commit(tx_id, insert_lsn)
839                .expect("failed to log commit");
840        }
841
842        // Recover and verify
843        let (wal, recovery) =
844            WriteAheadLog::open_and_recover(wal_dir, true).expect("failed to recover WAL");
845        assert_eq!(recovery.records_processed, 3);
846        assert_eq!(recovery.redo_records.len(), 1); // One insert record
847        assert!(recovery.incomplete_transactions.is_empty());
848        assert_eq!(recovery.max_lsn, Lsn(3));
849
850        // Verify we can continue writing
851        let next_lsn = wal.next_lsn();
852        assert_eq!(next_lsn, Lsn(4));
853    }
854
855    #[test]
856    fn test_wal_recovery_incomplete_transaction() {
857        let temp_dir = tempfile::tempdir().expect("failed to create temp directory");
858        let wal_dir = temp_dir.path().to_path_buf();
859
860        // Create WAL with an incomplete transaction
861        {
862            let wal = WriteAheadLog::new(wal_dir.clone(), true).expect("failed to create WAL");
863            let tx_id = TransactionId(1);
864
865            wal.log_begin(tx_id).expect("failed to log begin");
866            wal.log_insert(tx_id, None, PageId(1), Bytes::from("uncommitted"))
867                .expect("failed to log insert");
868            wal.flush().expect("failed to flush WAL");
869            // No commit - transaction is incomplete
870        }
871
872        // Recover and verify
873        let (_wal, recovery) =
874            WriteAheadLog::open_and_recover(wal_dir, true).expect("failed to recover WAL");
875        assert_eq!(recovery.records_processed, 2);
876        assert!(recovery.redo_records.is_empty()); // No redo for uncommitted
877        assert!(recovery.incomplete_transactions.contains(&TransactionId(1)));
878    }
879
880    #[test]
881    fn test_wal_checkpoint() {
882        let temp_dir = tempfile::tempdir().expect("failed to create temp directory");
883        let wal =
884            WriteAheadLog::new(temp_dir.path().to_path_buf(), true).expect("failed to create WAL");
885
886        // Write some transactions
887        let tx1 = TransactionId(1);
888        let begin1 = wal.log_begin(tx1).expect("failed to log begin");
889        wal.log_insert(tx1, Some(begin1), PageId(1), Bytes::from("data1"))
890            .expect("failed to log insert");
891
892        // Create checkpoint
893        let checkpoint_lsn = wal
894            .log_checkpoint(vec![tx1], vec![PageId(1)])
895            .expect("failed to log checkpoint");
896
897        assert!(checkpoint_lsn.0 > 0);
898        assert_eq!(wal.checkpoint_lsn(), checkpoint_lsn);
899    }
900
901    #[test]
902    fn test_wal_recovery_with_checkpoint() {
903        let temp_dir = tempfile::tempdir().expect("failed to create temp directory");
904        let wal_dir = temp_dir.path().to_path_buf();
905
906        // Create WAL with checkpoint
907        {
908            let wal = WriteAheadLog::new(wal_dir.clone(), true).expect("failed to create WAL");
909
910            // Transaction 1 - committed before checkpoint
911            let tx1 = TransactionId(1);
912            let begin1 = wal.log_begin(tx1).expect("failed to log begin for tx1");
913            let insert1 = wal
914                .log_insert(tx1, Some(begin1), PageId(1), Bytes::from("data1"))
915                .expect("failed to log insert for tx1");
916            wal.log_commit(tx1, insert1)
917                .expect("failed to log commit for tx1");
918
919            // Checkpoint
920            wal.log_checkpoint(vec![], vec![])
921                .expect("failed to log checkpoint");
922
923            // Transaction 2 - committed after checkpoint
924            let tx2 = TransactionId(2);
925            let begin2 = wal.log_begin(tx2).expect("failed to log begin for tx2");
926            let insert2 = wal
927                .log_insert(tx2, Some(begin2), PageId(2), Bytes::from("data2"))
928                .expect("failed to log insert for tx2");
929            wal.log_commit(tx2, insert2)
930                .expect("failed to log commit for tx2");
931        }
932
933        // Recover
934        let (_wal, recovery) =
935            WriteAheadLog::open_and_recover(wal_dir, true).expect("failed to recover WAL");
936
937        // Both transactions should have redo records
938        assert_eq!(recovery.redo_records.len(), 2);
939        assert!(recovery.incomplete_transactions.is_empty());
940    }
941
942    #[test]
943    fn test_wal_multiple_transactions() {
944        let temp_dir = tempfile::tempdir().expect("failed to create temp directory");
945        let wal_dir = temp_dir.path().to_path_buf();
946
947        {
948            let wal = WriteAheadLog::new(wal_dir.clone(), true).expect("failed to create WAL");
949
950            // Transaction 1 - committed
951            let tx1 = TransactionId(1);
952            let begin1 = wal.log_begin(tx1).expect("failed to log begin for tx1");
953            let insert1 = wal
954                .log_insert(tx1, Some(begin1), PageId(1), Bytes::from("tx1"))
955                .expect("failed to log insert for tx1");
956            wal.log_commit(tx1, insert1)
957                .expect("failed to log commit for tx1");
958
959            // Transaction 2 - aborted
960            let tx2 = TransactionId(2);
961            let begin2 = wal.log_begin(tx2).expect("failed to log begin for tx2");
962            let insert2 = wal
963                .log_insert(tx2, Some(begin2), PageId(2), Bytes::from("tx2"))
964                .expect("failed to log insert for tx2");
965            wal.log_abort(tx2, insert2)
966                .expect("failed to log abort for tx2");
967
968            // Transaction 3 - committed
969            let tx3 = TransactionId(3);
970            let begin3 = wal.log_begin(tx3).expect("failed to log begin for tx3");
971            let insert3 = wal
972                .log_insert(tx3, Some(begin3), PageId(3), Bytes::from("tx3"))
973                .expect("failed to log insert for tx3");
974            wal.log_commit(tx3, insert3)
975                .expect("failed to log commit for tx3");
976
977            wal.flush().expect("failed to flush WAL");
978        }
979
980        let (_wal, recovery) =
981            WriteAheadLog::open_and_recover(wal_dir, true).expect("failed to recover WAL");
982
983        // Only tx1 and tx3 should be in redo (tx2 was aborted)
984        assert_eq!(recovery.redo_records.len(), 2);
985        assert!(recovery.incomplete_transactions.is_empty());
986
987        // Verify the redo records are from correct transactions
988        let tx_ids: std::collections::HashSet<_> =
989            recovery.redo_records.iter().map(|r| r.tx_id).collect();
990        assert!(tx_ids.contains(&TransactionId(1)));
991        assert!(tx_ids.contains(&TransactionId(3)));
992        assert!(!tx_ids.contains(&TransactionId(2)));
993    }
994}