Skip to main content

rivven_core/
wal.rs

1//! Group Commit Write-Ahead Log (WAL)
2//!
3//! High-performance WAL implementation with group commit optimization:
4//! - **Group Commit**: Batches multiple writes into single fsync (10-100x throughput)
5//! - **Pipelined Writes**: Overlaps I/O with compute
6//! - **Pre-allocated Files**: Reduces filesystem overhead
7//! - **CRC32 Checksums**: Data integrity verification
8//! - **Asynchronous Sync**: Non-blocking durability
9//!
10//! Based on techniques from:
11//! - MySQL InnoDB group commit
12//! - PostgreSQL WAL
13//! - RocksDB write batching
14
15use bytes::{BufMut, Bytes, BytesMut};
16use crc32fast::Hasher;
17use std::collections::VecDeque;
18use std::fs::{File, OpenOptions};
19use std::io::{self, BufWriter, Write};
20use std::path::{Path, PathBuf};
21use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24use tokio::sync::{mpsc, oneshot, Mutex, Notify};
25
26/// WAL record header size: magic(4) + crc(4) + len(4) + type(1) + flags(1) = 14 bytes
27const RECORD_HEADER_SIZE: usize = 14;
28
29/// Magic number for WAL records
30const WAL_MAGIC: u32 = 0x57414C52; // "WALR"
31
32/// Default group commit window (microseconds)
33const DEFAULT_GROUP_COMMIT_WINDOW_US: u64 = 200;
34
35/// Default max batch size (bytes)
36const DEFAULT_MAX_BATCH_SIZE: usize = 4 * 1024 * 1024; // 4 MB
37
38/// Default max pending writes before forcing flush
39const DEFAULT_MAX_PENDING_WRITES: usize = 1000;
40
41/// WAL record types
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43#[repr(u8)]
44pub enum RecordType {
45    /// Full record (single record contains complete data)
46    Full = 0,
47    /// First fragment of a large record
48    First = 1,
49    /// Middle fragment of a large record
50    Middle = 2,
51    /// Last fragment of a large record
52    Last = 3,
53    /// Checkpoint marker
54    Checkpoint = 4,
55    /// Transaction begin
56    TxnBegin = 5,
57    /// Transaction commit
58    TxnCommit = 6,
59    /// Transaction abort
60    TxnAbort = 7,
61}
62
63impl TryFrom<u8> for RecordType {
64    type Error = io::Error;
65
66    fn try_from(value: u8) -> Result<Self, Self::Error> {
67        match value {
68            0 => Ok(RecordType::Full),
69            1 => Ok(RecordType::First),
70            2 => Ok(RecordType::Middle),
71            3 => Ok(RecordType::Last),
72            4 => Ok(RecordType::Checkpoint),
73            5 => Ok(RecordType::TxnBegin),
74            6 => Ok(RecordType::TxnCommit),
75            7 => Ok(RecordType::TxnAbort),
76            _ => Err(io::Error::new(
77                io::ErrorKind::InvalidData,
78                "Invalid record type",
79            )),
80        }
81    }
82}
83
84/// WAL record flags
85#[derive(Debug, Clone, Copy)]
86pub struct RecordFlags(u8);
87
88impl RecordFlags {
89    pub const NONE: Self = Self(0);
90    pub const COMPRESSED: Self = Self(1 << 0);
91    pub const ENCRYPTED: Self = Self(1 << 1);
92    pub const HAS_CHECKSUM: Self = Self(1 << 2);
93
94    pub fn is_compressed(&self) -> bool {
95        self.0 & Self::COMPRESSED.0 != 0
96    }
97
98    pub fn is_encrypted(&self) -> bool {
99        self.0 & Self::ENCRYPTED.0 != 0
100    }
101
102    pub fn has_checksum(&self) -> bool {
103        self.0 & Self::HAS_CHECKSUM.0 != 0
104    }
105}
106
107/// A WAL record
108#[derive(Debug, Clone)]
109pub struct WalRecord {
110    /// Log sequence number
111    pub lsn: u64,
112    /// Record type
113    pub record_type: RecordType,
114    /// Flags
115    pub flags: RecordFlags,
116    /// Record data
117    pub data: Bytes,
118}
119
120impl WalRecord {
121    /// Create a new full record
122    pub fn new(lsn: u64, data: Bytes) -> Self {
123        Self {
124            lsn,
125            record_type: RecordType::Full,
126            flags: RecordFlags::HAS_CHECKSUM,
127            data,
128        }
129    }
130
131    /// Serialize to bytes
132    pub fn to_bytes(&self) -> Bytes {
133        let mut buf = BytesMut::with_capacity(RECORD_HEADER_SIZE + self.data.len());
134
135        // Calculate CRC of data
136        let mut hasher = Hasher::new();
137        hasher.update(&self.data);
138        let crc = hasher.finalize();
139
140        // Write header
141        buf.put_u32(WAL_MAGIC);
142        buf.put_u32(crc);
143        buf.put_u32(self.data.len() as u32);
144        buf.put_u8(self.record_type as u8);
145        buf.put_u8(self.flags.0);
146
147        // Write data
148        buf.extend_from_slice(&self.data);
149
150        buf.freeze()
151    }
152
153    /// Parse from bytes
154    pub fn from_bytes(data: &[u8], lsn: u64) -> io::Result<Self> {
155        if data.len() < RECORD_HEADER_SIZE {
156            return Err(io::Error::new(
157                io::ErrorKind::InvalidData,
158                "Record too short",
159            ));
160        }
161
162        // Read header
163        let magic = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
164        if magic != WAL_MAGIC {
165            return Err(io::Error::new(
166                io::ErrorKind::InvalidData,
167                "Invalid magic number",
168            ));
169        }
170
171        let stored_crc = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
172        let data_len = u32::from_be_bytes([data[8], data[9], data[10], data[11]]) as usize;
173        let record_type = RecordType::try_from(data[12])?;
174        let flags = RecordFlags(data[13]);
175
176        if data.len() < RECORD_HEADER_SIZE + data_len {
177            return Err(io::Error::new(
178                io::ErrorKind::InvalidData,
179                "Incomplete record",
180            ));
181        }
182
183        let record_data =
184            Bytes::copy_from_slice(&data[RECORD_HEADER_SIZE..RECORD_HEADER_SIZE + data_len]);
185
186        // Verify CRC
187        let mut hasher = Hasher::new();
188        hasher.update(&record_data);
189        let computed_crc = hasher.finalize();
190
191        if computed_crc != stored_crc {
192            return Err(io::Error::new(io::ErrorKind::InvalidData, "CRC mismatch"));
193        }
194
195        Ok(Self {
196            lsn,
197            record_type,
198            flags,
199            data: record_data,
200        })
201    }
202
203    /// Get total serialized size
204    pub fn serialized_size(&self) -> usize {
205        RECORD_HEADER_SIZE + self.data.len()
206    }
207}
208
209/// Configuration for group commit WAL
210#[derive(Debug, Clone)]
211pub struct WalConfig {
212    /// Directory for WAL files
213    pub dir: PathBuf,
214    /// Group commit window (how long to wait for more writes)
215    pub group_commit_window: Duration,
216    /// Maximum batch size before forcing flush
217    pub max_batch_size: usize,
218    /// Maximum pending writes before forcing flush
219    pub max_pending_writes: usize,
220    /// Pre-allocate WAL files to this size
221    pub preallocate_size: u64,
222    /// Enable direct I/O (bypass OS cache)
223    pub direct_io: bool,
224    /// Sync mode
225    pub sync_mode: SyncMode,
226    /// Maximum WAL file size before rotation
227    pub max_file_size: u64,
228    /// Optional encryption for data at rest
229    #[cfg(feature = "encryption")]
230    pub encryptor: Option<std::sync::Arc<dyn crate::encryption::Encryptor>>,
231}
232
233impl Default for WalConfig {
234    fn default() -> Self {
235        Self {
236            dir: PathBuf::from("./wal"),
237            group_commit_window: Duration::from_micros(DEFAULT_GROUP_COMMIT_WINDOW_US),
238            max_batch_size: DEFAULT_MAX_BATCH_SIZE,
239            max_pending_writes: DEFAULT_MAX_PENDING_WRITES,
240            preallocate_size: 64 * 1024 * 1024, // 64 MB
241            direct_io: false,                   // Requires O_DIRECT support
242            sync_mode: SyncMode::Fsync,
243            max_file_size: 1024 * 1024 * 1024, // 1 GB
244            #[cfg(feature = "encryption")]
245            encryptor: None,
246        }
247    }
248}
249
250impl WalConfig {
251    /// High-throughput configuration (more batching, less frequent sync)
252    pub fn high_throughput() -> Self {
253        Self {
254            group_commit_window: Duration::from_micros(1000), // 1ms
255            max_batch_size: 16 * 1024 * 1024,                 // 16 MB
256            max_pending_writes: 5000,
257            ..Default::default()
258        }
259    }
260
261    /// Low-latency configuration (less batching, more frequent sync)
262    pub fn low_latency() -> Self {
263        Self {
264            group_commit_window: Duration::from_micros(50), // 50us
265            max_batch_size: 512 * 1024,                     // 512 KB
266            max_pending_writes: 100,
267            ..Default::default()
268        }
269    }
270
271    /// Durability-focused configuration
272    pub fn durable() -> Self {
273        Self {
274            sync_mode: SyncMode::FsyncData,
275            ..Default::default()
276        }
277    }
278}
279
280/// Sync mode for WAL writes
281#[derive(Debug, Clone, Copy, PartialEq, Eq)]
282pub enum SyncMode {
283    /// No sync (fastest, least durable)
284    None,
285    /// fdatasync (syncs data but not metadata)
286    FsyncData,
287    /// Full fsync (syncs data and metadata)
288    Fsync,
289    /// O_DSYNC flag (sync on each write)
290    Dsync,
291}
292
293/// Write request for the WAL
294struct WriteRequest {
295    /// Data to write
296    data: Bytes,
297    /// Record type
298    record_type: RecordType,
299    /// Channel to send completion notification
300    completion: oneshot::Sender<WriteResult>,
301}
302
303/// Result of a write operation
304#[derive(Debug, Clone)]
305pub struct WriteResult {
306    /// Assigned LSN
307    pub lsn: u64,
308    /// Size written
309    pub size: usize,
310    /// Whether this write was part of a group commit
311    pub group_commit: bool,
312    /// Number of writes in the group
313    pub group_size: usize,
314    /// Time spent waiting for group commit
315    pub wait_time: Duration,
316}
317
318/// Group commit WAL writer
319pub struct GroupCommitWal {
320    config: WalConfig,
321    /// Current file writer
322    writer: Mutex<WalWriter>,
323    /// Current LSN
324    current_lsn: AtomicU64,
325    /// Write request sender
326    write_tx: mpsc::Sender<WriteRequest>,
327    /// Shutdown flag
328    shutdown: AtomicBool,
329    /// Notify when new writes arrive
330    write_notify: Arc<Notify>,
331    /// Statistics
332    stats: Arc<WalStats>,
333}
334
335impl GroupCommitWal {
336    /// Create a new group commit WAL
337    pub async fn new(config: WalConfig) -> io::Result<Arc<Self>> {
338        std::fs::create_dir_all(&config.dir)?;
339
340        // Find the latest WAL file and LSN
341        let (current_file, current_lsn) = Self::recover_state(&config).await?;
342
343        let writer = WalWriter::new(current_file, config.clone())?;
344        let (write_tx, write_rx) = mpsc::channel(config.max_pending_writes);
345
346        let wal = Arc::new(Self {
347            config,
348            writer: Mutex::new(writer),
349            current_lsn: AtomicU64::new(current_lsn),
350            write_tx,
351            shutdown: AtomicBool::new(false),
352            write_notify: Arc::new(Notify::new()),
353            stats: Arc::new(WalStats::new()),
354        });
355
356        // Start background group commit worker
357        wal.clone().start_group_commit_worker(write_rx);
358
359        Ok(wal)
360    }
361
362    /// Recover state from existing WAL files
363    async fn recover_state(config: &WalConfig) -> io::Result<(PathBuf, u64)> {
364        let mut max_lsn = 0u64;
365        let mut latest_file = None;
366
367        if let Ok(entries) = std::fs::read_dir(&config.dir) {
368            for entry in entries.flatten() {
369                let path = entry.path();
370                if path.extension().is_some_and(|e| e == "wal") {
371                    if let Some(name) = path.file_stem() {
372                        if let Ok(lsn) = name.to_string_lossy().parse::<u64>() {
373                            if lsn >= max_lsn {
374                                max_lsn = lsn;
375                                latest_file = Some(path);
376                            }
377                        }
378                    }
379                }
380            }
381        }
382
383        // If we found a file, scan it to find the true max LSN
384        if let Some(ref file) = latest_file {
385            if let Ok(recovered_lsn) = Self::scan_wal_file(file).await {
386                max_lsn = recovered_lsn;
387            }
388        }
389
390        // Create new file if none exists
391        let file = latest_file.unwrap_or_else(|| config.dir.join(format!("{:020}.wal", 0)));
392
393        Ok((file, max_lsn))
394    }
395
396    /// Scan a WAL file to find the highest LSN
397    async fn scan_wal_file(path: &Path) -> io::Result<u64> {
398        let data = tokio::fs::read(path).await?;
399        let mut offset = 0;
400        let mut max_lsn = 0u64;
401
402        while offset + RECORD_HEADER_SIZE <= data.len() {
403            // Check magic
404            let magic = u32::from_be_bytes([
405                data[offset],
406                data[offset + 1],
407                data[offset + 2],
408                data[offset + 3],
409            ]);
410
411            if magic != WAL_MAGIC {
412                break;
413            }
414
415            let data_len = u32::from_be_bytes([
416                data[offset + 8],
417                data[offset + 9],
418                data[offset + 10],
419                data[offset + 11],
420            ]) as usize;
421
422            let record_size = RECORD_HEADER_SIZE + data_len;
423            if offset + record_size > data.len() {
424                break;
425            }
426
427            max_lsn += 1;
428            offset += record_size;
429        }
430
431        Ok(max_lsn)
432    }
433
434    /// Write a record to the WAL (async, batched)
435    pub async fn write(&self, data: Bytes) -> io::Result<WriteResult> {
436        self.write_with_type(data, RecordType::Full).await
437    }
438
439    /// Write a record with specific type
440    pub async fn write_with_type(
441        &self,
442        data: Bytes,
443        record_type: RecordType,
444    ) -> io::Result<WriteResult> {
445        if self.shutdown.load(Ordering::Acquire) {
446            return Err(io::Error::new(
447                io::ErrorKind::BrokenPipe,
448                "WAL is shut down",
449            ));
450        }
451
452        let (tx, rx) = oneshot::channel();
453
454        let request = WriteRequest {
455            data,
456            record_type,
457            completion: tx,
458        };
459
460        self.write_tx
461            .send(request)
462            .await
463            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "WAL write channel closed"))?;
464
465        self.write_notify.notify_one();
466
467        rx.await
468            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "WAL write cancelled"))
469    }
470
471    /// Write a batch of records atomically
472    pub async fn write_batch(&self, records: Vec<Bytes>) -> io::Result<Vec<WriteResult>> {
473        let mut results = Vec::with_capacity(records.len());
474        let mut receivers = Vec::with_capacity(records.len());
475
476        for data in records {
477            let (tx, rx) = oneshot::channel();
478
479            let request = WriteRequest {
480                data,
481                record_type: RecordType::Full,
482                completion: tx,
483            };
484
485            self.write_tx.send(request).await.map_err(|_| {
486                io::Error::new(io::ErrorKind::BrokenPipe, "WAL write channel closed")
487            })?;
488
489            receivers.push(rx);
490        }
491
492        self.write_notify.notify_one();
493
494        for rx in receivers {
495            let result = rx
496                .await
497                .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "WAL write cancelled"))?;
498            results.push(result);
499        }
500
501        Ok(results)
502    }
503
504    /// Sync the WAL to disk (force flush)
505    pub async fn sync(&self) -> io::Result<()> {
506        let mut writer = self.writer.lock().await;
507        writer.sync()
508    }
509
510    /// Get current LSN
511    pub fn current_lsn(&self) -> u64 {
512        self.current_lsn.load(Ordering::Acquire)
513    }
514
515    /// Get WAL statistics
516    pub fn stats(&self) -> WalStatsSnapshot {
517        WalStatsSnapshot {
518            writes_total: self.stats.writes_total.load(Ordering::Relaxed),
519            bytes_written: self.stats.bytes_written.load(Ordering::Relaxed),
520            syncs_total: self.stats.syncs_total.load(Ordering::Relaxed),
521            group_commits: self.stats.group_commits.load(Ordering::Relaxed),
522            avg_group_size: if self.stats.group_commits.load(Ordering::Relaxed) > 0 {
523                self.stats.writes_total.load(Ordering::Relaxed) as f64
524                    / self.stats.group_commits.load(Ordering::Relaxed) as f64
525            } else {
526                0.0
527            },
528            current_lsn: self.current_lsn.load(Ordering::Relaxed),
529        }
530    }
531
532    /// Shutdown the WAL
533    pub async fn shutdown(&self) -> io::Result<()> {
534        self.shutdown.store(true, Ordering::Release);
535        self.write_notify.notify_waiters();
536
537        // Final sync
538        let mut writer = self.writer.lock().await;
539        writer.sync()
540    }
541
542    /// Start the background group commit worker
543    fn start_group_commit_worker(self: Arc<Self>, mut rx: mpsc::Receiver<WriteRequest>) {
544        let wal = self.clone();
545
546        tokio::spawn(async move {
547            let mut pending: VecDeque<WriteRequest> = VecDeque::new();
548            let mut batch_buffer = BytesMut::with_capacity(wal.config.max_batch_size);
549            let mut group_start: Option<Instant> = None;
550
551            loop {
552                // Check shutdown early
553                if wal.shutdown.load(Ordering::Acquire) {
554                    // Drain any remaining messages from channel
555                    while let Ok(request) = rx.try_recv() {
556                        pending.push_back(request);
557                    }
558                    // Flush remaining and exit
559                    if !pending.is_empty() {
560                        wal.flush_batch(&mut pending, &mut batch_buffer, group_start.take())
561                            .await;
562                    }
563                    break;
564                }
565
566                // Wait for writes or timeout
567                let timeout = if pending.is_empty() {
568                    Duration::from_secs(60) // Long timeout when idle
569                } else {
570                    wal.config.group_commit_window
571                };
572
573                tokio::select! {
574                    biased;
575
576                    // Receive new write request (higher priority)
577                    Some(request) = rx.recv() => {
578                        if group_start.is_none() {
579                            group_start = Some(Instant::now());
580                        }
581                        pending.push_back(request);
582
583                        // Check if we should flush immediately
584                        let should_flush =
585                            pending.len() >= wal.config.max_pending_writes ||
586                            batch_buffer.len() >= wal.config.max_batch_size;
587
588                        if should_flush {
589                            wal.flush_batch(&mut pending, &mut batch_buffer, group_start.take()).await;
590                        }
591                    }
592
593                    // Wait for notification
594                    _ = wal.write_notify.notified() => {
595                        // Continue loop to check shutdown flag
596                    }
597
598                    // Timeout - flush whatever we have
599                    _ = tokio::time::sleep(timeout) => {
600                        if !pending.is_empty() {
601                            wal.flush_batch(&mut pending, &mut batch_buffer, group_start.take()).await;
602                        }
603                    }
604                }
605            }
606        });
607    }
608
609    /// Flush a batch of pending writes
610    async fn flush_batch(
611        &self,
612        pending: &mut VecDeque<WriteRequest>,
613        batch_buffer: &mut BytesMut,
614        group_start: Option<Instant>,
615    ) {
616        if pending.is_empty() {
617            return;
618        }
619
620        let wait_time = group_start.map(|s| s.elapsed()).unwrap_or(Duration::ZERO);
621        let group_size = pending.len();
622
623        // Build batch
624        batch_buffer.clear();
625        let mut lsns = Vec::with_capacity(group_size);
626        let mut sizes = Vec::with_capacity(group_size);
627
628        for request in pending.iter() {
629            let lsn = self.current_lsn.fetch_add(1, Ordering::AcqRel) + 1;
630            lsns.push(lsn);
631
632            // Optionally encrypt the data
633            #[cfg(feature = "encryption")]
634            let (data, is_encrypted) = if let Some(ref encryptor) = self.config.encryptor {
635                if encryptor.is_enabled() {
636                    match encryptor.encrypt(&request.data, lsn) {
637                        Ok(encrypted) => (Bytes::from(encrypted), true),
638                        Err(e) => {
639                            tracing::error!("Encryption failed for LSN {}: {:?}", lsn, e);
640                            (request.data.clone(), false)
641                        }
642                    }
643                } else {
644                    (request.data.clone(), false)
645                }
646            } else {
647                (request.data.clone(), false)
648            };
649
650            #[cfg(not(feature = "encryption"))]
651            let (data, is_encrypted) = (request.data.clone(), false);
652
653            let flags = if is_encrypted {
654                RecordFlags(RecordFlags::HAS_CHECKSUM.0 | RecordFlags::ENCRYPTED.0)
655            } else {
656                RecordFlags::HAS_CHECKSUM
657            };
658
659            let record = WalRecord {
660                lsn,
661                record_type: request.record_type,
662                flags,
663                data,
664            };
665
666            let record_bytes = record.to_bytes();
667            sizes.push(record_bytes.len());
668            batch_buffer.extend_from_slice(&record_bytes);
669        }
670
671        // Write batch to disk
672        let write_result = {
673            let mut writer = self.writer.lock().await;
674            writer.write_batch(batch_buffer)
675        };
676
677        // Update stats
678        self.stats
679            .writes_total
680            .fetch_add(group_size as u64, Ordering::Relaxed);
681        self.stats
682            .bytes_written
683            .fetch_add(batch_buffer.len() as u64, Ordering::Relaxed);
684        self.stats.group_commits.fetch_add(1, Ordering::Relaxed);
685        self.stats.syncs_total.fetch_add(1, Ordering::Relaxed);
686
687        // Send results to waiters
688        let group_commit = group_size > 1;
689
690        for (i, request) in pending.drain(..).enumerate() {
691            let result = match &write_result {
692                Ok(()) => WriteResult {
693                    lsn: lsns[i],
694                    size: sizes[i],
695                    group_commit,
696                    group_size,
697                    wait_time,
698                },
699                Err(_) => {
700                    // On error, still send a result but it will indicate failure
701                    WriteResult {
702                        lsn: 0,
703                        size: 0,
704                        group_commit: false,
705                        group_size: 0,
706                        wait_time,
707                    }
708                }
709            };
710
711            let _ = request.completion.send(result);
712        }
713    }
714}
715
716/// Low-level WAL file writer
717struct WalWriter {
718    file: BufWriter<File>,
719    path: PathBuf,
720    position: u64,
721    config: WalConfig,
722}
723
724impl WalWriter {
725    fn new(path: PathBuf, config: WalConfig) -> io::Result<Self> {
726        use std::io::{Seek, SeekFrom};
727
728        // Check if file exists and has valid data
729        let existing_len = std::fs::metadata(&path).map(|m| m.len()).unwrap_or(0);
730
731        let file = OpenOptions::new()
732            .create(true)
733            .read(true)
734            .write(true)
735            .truncate(false) // Preserve existing WAL data
736            .open(&path)?;
737
738        // Determine actual data length (scan for end of valid records)
739        let actual_position = if existing_len > 0 {
740            Self::find_actual_end(&file, existing_len)?
741        } else {
742            0
743        };
744
745        // Pre-allocate if this is a new file
746        if actual_position == 0 && config.preallocate_size > 0 {
747            file.set_len(config.preallocate_size)?;
748        }
749
750        // Seek to the actual write position
751        let mut writer = BufWriter::with_capacity(config.max_batch_size, file);
752        writer.seek(SeekFrom::Start(actual_position))?;
753
754        Ok(Self {
755            file: writer,
756            path,
757            position: actual_position,
758            config,
759        })
760    }
761
762    /// Find the actual end of valid data in the file
763    fn find_actual_end(file: &File, file_len: u64) -> io::Result<u64> {
764        use std::io::Read;
765
766        let mut position = 0u64;
767        let mut file = file.try_clone()?;
768
769        // Read in chunks to find valid record boundaries
770        while position + RECORD_HEADER_SIZE as u64 <= file_len {
771            let mut header = [0u8; RECORD_HEADER_SIZE];
772
773            use std::io::{Seek, SeekFrom};
774            file.seek(SeekFrom::Start(position))?;
775
776            if file.read_exact(&mut header).is_err() {
777                break;
778            }
779
780            // Check magic
781            let magic = u32::from_be_bytes([header[0], header[1], header[2], header[3]]);
782            if magic != WAL_MAGIC {
783                break;
784            }
785
786            let data_len =
787                u32::from_be_bytes([header[8], header[9], header[10], header[11]]) as u64;
788            let record_size = RECORD_HEADER_SIZE as u64 + data_len;
789
790            if position + record_size > file_len {
791                break;
792            }
793
794            position += record_size;
795        }
796
797        Ok(position)
798    }
799
800    fn write_batch(&mut self, data: &[u8]) -> io::Result<()> {
801        self.file.write_all(data)?;
802        self.file.flush()?;
803
804        // Sync based on mode
805        match self.config.sync_mode {
806            SyncMode::None => {}
807            SyncMode::FsyncData => {
808                self.file.get_ref().sync_data()?;
809            }
810            SyncMode::Fsync | SyncMode::Dsync => {
811                self.file.get_ref().sync_all()?;
812            }
813        }
814
815        self.position += data.len() as u64;
816        Ok(())
817    }
818
819    fn sync(&mut self) -> io::Result<()> {
820        self.file.flush()?;
821        self.file.get_ref().sync_all()
822    }
823
824    /// Get the path of this WAL file
825    #[allow(dead_code)]
826    fn path(&self) -> &std::path::Path {
827        &self.path
828    }
829}
830
831/// WAL statistics
832struct WalStats {
833    writes_total: AtomicU64,
834    bytes_written: AtomicU64,
835    syncs_total: AtomicU64,
836    group_commits: AtomicU64,
837}
838
839impl WalStats {
840    fn new() -> Self {
841        Self {
842            writes_total: AtomicU64::new(0),
843            bytes_written: AtomicU64::new(0),
844            syncs_total: AtomicU64::new(0),
845            group_commits: AtomicU64::new(0),
846        }
847    }
848}
849
850#[derive(Debug, Clone)]
851pub struct WalStatsSnapshot {
852    pub writes_total: u64,
853    pub bytes_written: u64,
854    pub syncs_total: u64,
855    pub group_commits: u64,
856    pub avg_group_size: f64,
857    pub current_lsn: u64,
858}
859
860/// WAL reader for recovery and replication
861pub struct WalReader {
862    path: PathBuf,
863    position: u64,
864    #[cfg(feature = "encryption")]
865    encryptor: Option<std::sync::Arc<dyn crate::encryption::Encryptor>>,
866}
867
868impl WalReader {
869    /// Open a WAL file for reading
870    pub fn open(path: PathBuf) -> io::Result<Self> {
871        Ok(Self {
872            path,
873            position: 0,
874            #[cfg(feature = "encryption")]
875            encryptor: None,
876        })
877    }
878
879    /// Open a WAL file for reading with optional encryption
880    #[cfg(feature = "encryption")]
881    pub fn open_with_encryption(
882        path: PathBuf,
883        encryptor: Option<std::sync::Arc<dyn crate::encryption::Encryptor>>,
884    ) -> io::Result<Self> {
885        Ok(Self {
886            path,
887            position: 0,
888            encryptor,
889        })
890    }
891
892    /// Decrypt record data if needed
893    #[cfg(feature = "encryption")]
894    fn decrypt_record_data(&self, record: &mut WalRecord) -> io::Result<()> {
895        if record.flags.is_encrypted() {
896            if let Some(ref encryptor) = self.encryptor {
897                let decrypted = encryptor
898                    .decrypt(&record.data, record.lsn)
899                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
900                record.data = Bytes::from(decrypted);
901            } else {
902                return Err(io::Error::new(
903                    io::ErrorKind::InvalidData,
904                    "Record is encrypted but no encryptor provided",
905                ));
906            }
907        }
908        Ok(())
909    }
910
911    /// Read all records from current position
912    pub async fn read_all(&mut self) -> io::Result<Vec<WalRecord>> {
913        let data = tokio::fs::read(&self.path).await?;
914        let mut records = Vec::new();
915        let mut lsn = 0u64;
916
917        while self.position + RECORD_HEADER_SIZE as u64 <= data.len() as u64 {
918            let offset = self.position as usize;
919
920            // Check magic
921            let magic = u32::from_be_bytes([
922                data[offset],
923                data[offset + 1],
924                data[offset + 2],
925                data[offset + 3],
926            ]);
927
928            if magic != WAL_MAGIC {
929                break;
930            }
931
932            let data_len = u32::from_be_bytes([
933                data[offset + 8],
934                data[offset + 9],
935                data[offset + 10],
936                data[offset + 11],
937            ]) as usize;
938
939            let record_size = RECORD_HEADER_SIZE + data_len;
940
941            if offset + record_size > data.len() {
942                break;
943            }
944
945            lsn += 1;
946
947            match WalRecord::from_bytes(&data[offset..offset + record_size], lsn) {
948                #[cfg(feature = "encryption")]
949                Ok(mut record) => {
950                    // Decrypt if needed
951                    self.decrypt_record_data(&mut record)?;
952
953                    records.push(record);
954                    self.position += record_size as u64;
955                }
956                #[cfg(not(feature = "encryption"))]
957                Ok(record) => {
958                    records.push(record);
959                    self.position += record_size as u64;
960                }
961                Err(_) => break,
962            }
963        }
964
965        Ok(records)
966    }
967
968    /// Seek to a specific LSN
969    pub async fn seek_to_lsn(&mut self, target_lsn: u64) -> io::Result<()> {
970        let data = tokio::fs::read(&self.path).await?;
971        let mut position = 0usize;
972        let mut current_lsn = 0u64;
973
974        while position + RECORD_HEADER_SIZE <= data.len() {
975            let magic = u32::from_be_bytes([
976                data[position],
977                data[position + 1],
978                data[position + 2],
979                data[position + 3],
980            ]);
981
982            if magic != WAL_MAGIC {
983                break;
984            }
985
986            let data_len = u32::from_be_bytes([
987                data[position + 8],
988                data[position + 9],
989                data[position + 10],
990                data[position + 11],
991            ]) as usize;
992
993            let record_size = RECORD_HEADER_SIZE + data_len;
994            current_lsn += 1;
995
996            if current_lsn >= target_lsn {
997                self.position = position as u64;
998                return Ok(());
999            }
1000
1001            position += record_size;
1002        }
1003
1004        Err(io::Error::new(
1005            io::ErrorKind::NotFound,
1006            format!("LSN {} not found", target_lsn),
1007        ))
1008    }
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013    use super::*;
1014    use tempfile::TempDir;
1015
1016    #[test]
1017    fn test_wal_record_serialization() {
1018        let data = Bytes::from("test data");
1019        let record = WalRecord::new(1, data.clone());
1020
1021        let serialized = record.to_bytes();
1022        assert!(serialized.len() >= RECORD_HEADER_SIZE + data.len());
1023
1024        let parsed = WalRecord::from_bytes(&serialized, 1).unwrap();
1025        assert_eq!(parsed.lsn, 1);
1026        assert_eq!(parsed.data, data);
1027    }
1028
1029    #[test]
1030    fn test_wal_record_crc() {
1031        let data = Bytes::from("test data");
1032        let record = WalRecord::new(1, data);
1033        let mut serialized = record.to_bytes().to_vec();
1034
1035        // Corrupt the data
1036        serialized[RECORD_HEADER_SIZE] ^= 0xFF;
1037
1038        // Should fail CRC check
1039        assert!(WalRecord::from_bytes(&serialized, 1).is_err());
1040    }
1041
1042    #[tokio::test]
1043    async fn test_group_commit_wal_single_write() {
1044        let temp_dir = TempDir::new().unwrap();
1045        let config = WalConfig {
1046            dir: temp_dir.path().to_path_buf(),
1047            group_commit_window: Duration::from_micros(100),
1048            ..Default::default()
1049        };
1050
1051        let wal = GroupCommitWal::new(config).await.unwrap();
1052
1053        let result = wal.write(Bytes::from("test data")).await.unwrap();
1054        assert_eq!(result.lsn, 1);
1055        assert!(result.size > 0);
1056
1057        let stats = wal.stats();
1058        assert_eq!(stats.writes_total, 1);
1059
1060        wal.shutdown().await.unwrap();
1061    }
1062
1063    #[tokio::test]
1064    async fn test_group_commit_wal_batch() {
1065        let temp_dir = TempDir::new().unwrap();
1066        let config = WalConfig {
1067            dir: temp_dir.path().to_path_buf(),
1068            group_commit_window: Duration::from_millis(10),
1069            ..Default::default()
1070        };
1071
1072        let wal = GroupCommitWal::new(config).await.unwrap();
1073
1074        let records: Vec<Bytes> = (0..10)
1075            .map(|i| Bytes::from(format!("record {}", i)))
1076            .collect();
1077
1078        let results = wal.write_batch(records).await.unwrap();
1079
1080        assert_eq!(results.len(), 10);
1081        for (i, result) in results.iter().enumerate() {
1082            assert_eq!(result.lsn, (i + 1) as u64);
1083        }
1084
1085        let stats = wal.stats();
1086        assert_eq!(stats.writes_total, 10);
1087
1088        wal.shutdown().await.unwrap();
1089    }
1090
1091    #[tokio::test]
1092    async fn test_group_commit_batching() {
1093        let temp_dir = TempDir::new().unwrap();
1094        let config = WalConfig {
1095            dir: temp_dir.path().to_path_buf(),
1096            group_commit_window: Duration::from_millis(50),
1097            max_pending_writes: 100,
1098            ..Default::default()
1099        };
1100
1101        let wal = Arc::new(GroupCommitWal::new(config).await.unwrap());
1102
1103        // Spawn multiple writers concurrently
1104        let mut handles = vec![];
1105        for i in 0..20 {
1106            let wal_clone = wal.clone();
1107            handles.push(tokio::spawn(async move {
1108                wal_clone
1109                    .write(Bytes::from(format!("concurrent write {}", i)))
1110                    .await
1111            }));
1112        }
1113
1114        // Wait for all writes
1115        for handle in handles {
1116            let result = handle.await.unwrap().unwrap();
1117            assert!(result.lsn > 0);
1118        }
1119
1120        let stats = wal.stats();
1121        assert_eq!(stats.writes_total, 20);
1122        // With batching, we should have fewer syncs than writes
1123        assert!(stats.group_commits <= stats.writes_total);
1124
1125        wal.shutdown().await.unwrap();
1126    }
1127
1128    #[tokio::test]
1129    async fn test_wal_reader() {
1130        let temp_dir = TempDir::new().unwrap();
1131        let config = WalConfig {
1132            dir: temp_dir.path().to_path_buf(),
1133            group_commit_window: Duration::from_micros(100),
1134            sync_mode: SyncMode::Fsync,
1135            max_pending_writes: 10,
1136            ..Default::default()
1137        };
1138
1139        let wal = GroupCommitWal::new(config.clone()).await.unwrap();
1140
1141        // Write some records and wait for each to complete
1142        for i in 0..5 {
1143            let result = wal
1144                .write(Bytes::from(format!("record {}", i)))
1145                .await
1146                .unwrap();
1147            assert!(result.lsn > 0, "Expected valid LSN for record {}", i);
1148        }
1149
1150        // Force a sync to ensure data is on disk
1151        wal.sync().await.unwrap();
1152
1153        // Give a small delay to ensure background worker has flushed
1154        tokio::time::sleep(Duration::from_millis(100)).await;
1155
1156        wal.shutdown().await.unwrap();
1157
1158        // Find the WAL file
1159        let entries: Vec<_> = std::fs::read_dir(&config.dir)
1160            .unwrap()
1161            .filter_map(|e| e.ok())
1162            .filter(|e| e.path().extension().is_some_and(|ext| ext == "wal"))
1163            .collect();
1164
1165        assert!(!entries.is_empty(), "No WAL files found");
1166
1167        let wal_file = entries[0].path();
1168        let file_size = std::fs::metadata(&wal_file).unwrap().len();
1169        assert!(file_size > 0, "WAL file is empty");
1170
1171        let mut reader = WalReader::open(wal_file).unwrap();
1172        let records = reader.read_all().await.unwrap();
1173
1174        assert_eq!(
1175            records.len(),
1176            5,
1177            "Expected 5 records, got {} (file size: {})",
1178            records.len(),
1179            file_size
1180        );
1181        for (i, record) in records.iter().enumerate() {
1182            let expected = format!("record {}", i);
1183            assert_eq!(record.data, Bytes::from(expected));
1184        }
1185    }
1186
1187    #[test]
1188    fn test_record_flags() {
1189        let flags = RecordFlags::COMPRESSED;
1190        assert!(flags.is_compressed());
1191        assert!(!flags.is_encrypted());
1192
1193        let flags = RecordFlags(RecordFlags::COMPRESSED.0 | RecordFlags::ENCRYPTED.0);
1194        assert!(flags.is_compressed());
1195        assert!(flags.is_encrypted());
1196    }
1197}