Skip to main content

rivven_core/
wal.rs

1//! Group Commit Write-Ahead Log (WAL)
2//!
3//! High-performance WAL implementation with group commit optimization:
4//! - **Group Commit**: Batches multiple writes into single fsync (10-100x throughput)
5//! - **Pipelined Writes**: Overlaps I/O with compute
6//! - **Pre-allocated Files**: Reduces filesystem overhead
7//! - **CRC32 Checksums**: Data integrity verification
8//! - **Asynchronous Sync**: Non-blocking durability
9//!
10//! Based on techniques from:
11//! - MySQL InnoDB group commit
12//! - PostgreSQL WAL
13//! - RocksDB write batching
14
15use bytes::{BufMut, Bytes, BytesMut};
16use crc32fast::Hasher;
17use std::collections::VecDeque;
18use std::fs::{File, OpenOptions};
19use std::io::{self, BufWriter, Write};
20use std::path::{Path, PathBuf};
21use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24use tokio::sync::{mpsc, oneshot, Mutex, Notify};
25
26/// WAL record header size: magic(4) + crc(4) + len(4) + type(1) + flags(1) = 14 bytes
27const RECORD_HEADER_SIZE: usize = 14;
28
29/// Magic number for WAL records
30const WAL_MAGIC: u32 = 0x57414C52; // "WALR"
31
32/// Default group commit window (microseconds)
33const DEFAULT_GROUP_COMMIT_WINDOW_US: u64 = 200;
34
35/// Default max batch size (bytes)
36const DEFAULT_MAX_BATCH_SIZE: usize = 4 * 1024 * 1024; // 4 MB
37
38/// Default max pending writes before forcing flush
39const DEFAULT_MAX_PENDING_WRITES: usize = 1000;
40
41/// WAL record types
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43#[repr(u8)]
44pub enum RecordType {
45    /// Full record (single record contains complete data)
46    Full = 0,
47    /// First fragment of a large record
48    First = 1,
49    /// Middle fragment of a large record
50    Middle = 2,
51    /// Last fragment of a large record
52    Last = 3,
53    /// Checkpoint marker
54    Checkpoint = 4,
55    /// Transaction begin
56    TxnBegin = 5,
57    /// Transaction commit
58    TxnCommit = 6,
59    /// Transaction abort
60    TxnAbort = 7,
61}
62
63impl TryFrom<u8> for RecordType {
64    type Error = io::Error;
65
66    fn try_from(value: u8) -> Result<Self, Self::Error> {
67        match value {
68            0 => Ok(RecordType::Full),
69            1 => Ok(RecordType::First),
70            2 => Ok(RecordType::Middle),
71            3 => Ok(RecordType::Last),
72            4 => Ok(RecordType::Checkpoint),
73            5 => Ok(RecordType::TxnBegin),
74            6 => Ok(RecordType::TxnCommit),
75            7 => Ok(RecordType::TxnAbort),
76            _ => Err(io::Error::new(
77                io::ErrorKind::InvalidData,
78                "Invalid record type",
79            )),
80        }
81    }
82}
83
84/// WAL record flags
85#[derive(Debug, Clone, Copy)]
86pub struct RecordFlags(u8);
87
88impl RecordFlags {
89    pub const NONE: Self = Self(0);
90    pub const COMPRESSED: Self = Self(1 << 0);
91    pub const ENCRYPTED: Self = Self(1 << 1);
92    pub const HAS_CHECKSUM: Self = Self(1 << 2);
93
94    pub fn is_compressed(&self) -> bool {
95        self.0 & Self::COMPRESSED.0 != 0
96    }
97
98    pub fn is_encrypted(&self) -> bool {
99        self.0 & Self::ENCRYPTED.0 != 0
100    }
101
102    pub fn has_checksum(&self) -> bool {
103        self.0 & Self::HAS_CHECKSUM.0 != 0
104    }
105}
106
107/// A WAL record
108#[derive(Debug, Clone)]
109pub struct WalRecord {
110    /// Log sequence number
111    pub lsn: u64,
112    /// Record type
113    pub record_type: RecordType,
114    /// Flags
115    pub flags: RecordFlags,
116    /// Record data
117    pub data: Bytes,
118}
119
120impl WalRecord {
121    /// Create a new full record
122    pub fn new(lsn: u64, data: Bytes) -> Self {
123        Self {
124            lsn,
125            record_type: RecordType::Full,
126            flags: RecordFlags::HAS_CHECKSUM,
127            data,
128        }
129    }
130
131    /// Serialize to bytes
132    pub fn to_bytes(&self) -> Bytes {
133        let mut buf = BytesMut::with_capacity(RECORD_HEADER_SIZE + self.data.len());
134
135        // Calculate CRC of data
136        let mut hasher = Hasher::new();
137        hasher.update(&self.data);
138        let crc = hasher.finalize();
139
140        // Write header
141        buf.put_u32(WAL_MAGIC);
142        buf.put_u32(crc);
143        buf.put_u32(self.data.len() as u32);
144        buf.put_u8(self.record_type as u8);
145        buf.put_u8(self.flags.0);
146
147        // Write data
148        buf.extend_from_slice(&self.data);
149
150        buf.freeze()
151    }
152
153    /// Parse from bytes
154    pub fn from_bytes(data: &[u8], lsn: u64) -> io::Result<Self> {
155        if data.len() < RECORD_HEADER_SIZE {
156            return Err(io::Error::new(
157                io::ErrorKind::InvalidData,
158                "Record too short",
159            ));
160        }
161
162        // Read header
163        let magic = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
164        if magic != WAL_MAGIC {
165            return Err(io::Error::new(
166                io::ErrorKind::InvalidData,
167                "Invalid magic number",
168            ));
169        }
170
171        let stored_crc = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
172        let data_len = u32::from_be_bytes([data[8], data[9], data[10], data[11]]) as usize;
173        let record_type = RecordType::try_from(data[12])?;
174        let flags = RecordFlags(data[13]);
175
176        if data.len() < RECORD_HEADER_SIZE + data_len {
177            return Err(io::Error::new(
178                io::ErrorKind::InvalidData,
179                "Incomplete record",
180            ));
181        }
182
183        let record_data =
184            Bytes::copy_from_slice(&data[RECORD_HEADER_SIZE..RECORD_HEADER_SIZE + data_len]);
185
186        // Verify CRC
187        let mut hasher = Hasher::new();
188        hasher.update(&record_data);
189        let computed_crc = hasher.finalize();
190
191        if computed_crc != stored_crc {
192            return Err(io::Error::new(io::ErrorKind::InvalidData, "CRC mismatch"));
193        }
194
195        Ok(Self {
196            lsn,
197            record_type,
198            flags,
199            data: record_data,
200        })
201    }
202
203    /// Get total serialized size
204    pub fn serialized_size(&self) -> usize {
205        RECORD_HEADER_SIZE + self.data.len()
206    }
207}
208
209/// Configuration for group commit WAL
210#[derive(Debug, Clone)]
211pub struct WalConfig {
212    /// Directory for WAL files
213    pub dir: PathBuf,
214    /// Group commit window (how long to wait for more writes)
215    pub group_commit_window: Duration,
216    /// Maximum batch size before forcing flush
217    pub max_batch_size: usize,
218    /// Maximum pending writes before forcing flush
219    pub max_pending_writes: usize,
220    /// Pre-allocate WAL files to this size
221    pub preallocate_size: u64,
222    /// Enable direct I/O (bypass OS cache)
223    pub direct_io: bool,
224    /// Sync mode
225    pub sync_mode: SyncMode,
226    /// Maximum WAL file size before rotation
227    pub max_file_size: u64,
228    /// Optional encryption for data at rest
229    #[cfg(feature = "encryption")]
230    pub encryptor: Option<std::sync::Arc<dyn crate::encryption::Encryptor>>,
231}
232
233impl Default for WalConfig {
234    fn default() -> Self {
235        Self {
236            dir: PathBuf::from("./wal"),
237            group_commit_window: Duration::from_micros(DEFAULT_GROUP_COMMIT_WINDOW_US),
238            max_batch_size: DEFAULT_MAX_BATCH_SIZE,
239            max_pending_writes: DEFAULT_MAX_PENDING_WRITES,
240            preallocate_size: 64 * 1024 * 1024, // 64 MB
241            direct_io: false,                   // Requires O_DIRECT support
242            sync_mode: SyncMode::Fsync,
243            max_file_size: 1024 * 1024 * 1024, // 1 GB
244            #[cfg(feature = "encryption")]
245            encryptor: None,
246        }
247    }
248}
249
250impl WalConfig {
251    /// High-throughput configuration (more batching, less frequent sync)
252    pub fn high_throughput() -> Self {
253        Self {
254            group_commit_window: Duration::from_micros(1000), // 1ms
255            max_batch_size: 16 * 1024 * 1024,                 // 16 MB
256            max_pending_writes: 5000,
257            ..Default::default()
258        }
259    }
260
261    /// Low-latency configuration (less batching, more frequent sync)
262    pub fn low_latency() -> Self {
263        Self {
264            group_commit_window: Duration::from_micros(50), // 50us
265            max_batch_size: 512 * 1024,                     // 512 KB
266            max_pending_writes: 100,
267            ..Default::default()
268        }
269    }
270
271    /// Durability-focused configuration
272    pub fn durable() -> Self {
273        Self {
274            sync_mode: SyncMode::FsyncData,
275            ..Default::default()
276        }
277    }
278}
279
280/// Sync mode for WAL writes
281#[derive(Debug, Clone, Copy, PartialEq, Eq)]
282pub enum SyncMode {
283    /// No sync (fastest, least durable)
284    None,
285    /// fdatasync (syncs data but not metadata)
286    FsyncData,
287    /// Full fsync (syncs data and metadata)
288    Fsync,
289    /// O_DSYNC flag (sync on each write)
290    Dsync,
291}
292
293/// Write request for the WAL
294struct WriteRequest {
295    /// Data to write
296    data: Bytes,
297    /// Record type
298    record_type: RecordType,
299    /// Channel to send completion notification (carries Result so disk errors propagate)
300    completion: oneshot::Sender<Result<WriteResult, String>>,
301}
302
303/// Result of a write operation
304#[derive(Debug, Clone)]
305pub struct WriteResult {
306    /// Assigned LSN
307    pub lsn: u64,
308    /// Size written
309    pub size: usize,
310    /// Whether this write was part of a group commit
311    pub group_commit: bool,
312    /// Number of writes in the group
313    pub group_size: usize,
314    /// Time spent waiting for group commit
315    pub wait_time: Duration,
316}
317
318/// Group commit WAL writer
319pub struct GroupCommitWal {
320    config: WalConfig,
321    /// Current file writer
322    writer: Mutex<WalWriter>,
323    /// Current LSN
324    current_lsn: AtomicU64,
325    /// Write request sender
326    write_tx: mpsc::Sender<WriteRequest>,
327    /// Shutdown flag
328    shutdown: AtomicBool,
329    /// Notify when new writes arrive
330    write_notify: Arc<Notify>,
331    /// Statistics
332    stats: Arc<WalStats>,
333}
334
335impl GroupCommitWal {
336    /// Create a new group commit WAL
337    pub async fn new(config: WalConfig) -> io::Result<Arc<Self>> {
338        std::fs::create_dir_all(&config.dir)?;
339
340        // Find the latest WAL file and LSN
341        let (current_file, current_lsn) = Self::recover_state(&config).await?;
342
343        let writer = WalWriter::new(current_file, config.clone())?;
344        let (write_tx, write_rx) = mpsc::channel(config.max_pending_writes);
345
346        let wal = Arc::new(Self {
347            config,
348            writer: Mutex::new(writer),
349            current_lsn: AtomicU64::new(current_lsn),
350            write_tx,
351            shutdown: AtomicBool::new(false),
352            write_notify: Arc::new(Notify::new()),
353            stats: Arc::new(WalStats::new()),
354        });
355
356        // Start background group commit worker
357        wal.clone().start_group_commit_worker(write_rx);
358
359        Ok(wal)
360    }
361
362    /// Recover state from existing WAL files
363    async fn recover_state(config: &WalConfig) -> io::Result<(PathBuf, u64)> {
364        let mut max_lsn = 0u64;
365        let mut latest_file = None;
366
367        if let Ok(entries) = std::fs::read_dir(&config.dir) {
368            for entry in entries.flatten() {
369                let path = entry.path();
370                if path.extension().is_some_and(|e| e == "wal") {
371                    if let Some(name) = path.file_stem() {
372                        if let Ok(lsn) = name.to_string_lossy().parse::<u64>() {
373                            if lsn >= max_lsn {
374                                max_lsn = lsn;
375                                latest_file = Some(path);
376                            }
377                        }
378                    }
379                }
380            }
381        }
382
383        // If we found a file, scan it to find the true max LSN
384        if let Some(ref file) = latest_file {
385            if let Ok(recovered_lsn) = Self::scan_wal_file(file).await {
386                max_lsn = recovered_lsn;
387            }
388        }
389
390        // Create new file if none exists
391        let file = latest_file.unwrap_or_else(|| config.dir.join(format!("{:020}.wal", 0)));
392
393        Ok((file, max_lsn))
394    }
395
396    /// Scan a WAL file to find the highest LSN
397    async fn scan_wal_file(path: &Path) -> io::Result<u64> {
398        let data = tokio::fs::read(path).await?;
399        let mut offset = 0;
400        let mut max_lsn = 0u64;
401
402        while offset + RECORD_HEADER_SIZE <= data.len() {
403            // Check magic
404            let magic = u32::from_be_bytes([
405                data[offset],
406                data[offset + 1],
407                data[offset + 2],
408                data[offset + 3],
409            ]);
410
411            if magic != WAL_MAGIC {
412                break;
413            }
414
415            let data_len = u32::from_be_bytes([
416                data[offset + 8],
417                data[offset + 9],
418                data[offset + 10],
419                data[offset + 11],
420            ]) as usize;
421
422            let record_size = RECORD_HEADER_SIZE + data_len;
423            if offset + record_size > data.len() {
424                break;
425            }
426
427            max_lsn += 1;
428            offset += record_size;
429        }
430
431        Ok(max_lsn)
432    }
433
434    /// Write a record to the WAL (async, batched)
435    pub async fn write(&self, data: Bytes) -> io::Result<WriteResult> {
436        self.write_with_type(data, RecordType::Full).await
437    }
438
439    /// Write a record with specific type
440    pub async fn write_with_type(
441        &self,
442        data: Bytes,
443        record_type: RecordType,
444    ) -> io::Result<WriteResult> {
445        if self.shutdown.load(Ordering::Acquire) {
446            return Err(io::Error::new(
447                io::ErrorKind::BrokenPipe,
448                "WAL is shut down",
449            ));
450        }
451
452        let (tx, rx) = oneshot::channel();
453
454        let request = WriteRequest {
455            data,
456            record_type,
457            completion: tx,
458        };
459
460        self.write_tx
461            .send(request)
462            .await
463            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "WAL write channel closed"))?;
464
465        self.write_notify.notify_one();
466
467        rx.await
468            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "WAL write cancelled"))?
469            .map_err(io::Error::other)
470    }
471
472    /// Write a batch of records atomically
473    pub async fn write_batch(&self, records: Vec<Bytes>) -> io::Result<Vec<WriteResult>> {
474        let mut results = Vec::with_capacity(records.len());
475        let mut receivers = Vec::with_capacity(records.len());
476
477        for data in records {
478            let (tx, rx) = oneshot::channel();
479
480            let request = WriteRequest {
481                data,
482                record_type: RecordType::Full,
483                completion: tx,
484            };
485
486            self.write_tx.send(request).await.map_err(|_| {
487                io::Error::new(io::ErrorKind::BrokenPipe, "WAL write channel closed")
488            })?;
489
490            receivers.push(rx);
491        }
492
493        self.write_notify.notify_one();
494
495        for rx in receivers {
496            let result = rx
497                .await
498                .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "WAL write cancelled"))?
499                .map_err(io::Error::other)?;
500            results.push(result);
501        }
502
503        Ok(results)
504    }
505
506    /// Sync the WAL to disk (force flush)
507    pub async fn sync(&self) -> io::Result<()> {
508        let mut writer = self.writer.lock().await;
509        writer.sync()
510    }
511
512    /// Get current LSN
513    pub fn current_lsn(&self) -> u64 {
514        self.current_lsn.load(Ordering::Acquire)
515    }
516
517    /// Get WAL statistics
518    pub fn stats(&self) -> WalStatsSnapshot {
519        WalStatsSnapshot {
520            writes_total: self.stats.writes_total.load(Ordering::Relaxed),
521            bytes_written: self.stats.bytes_written.load(Ordering::Relaxed),
522            syncs_total: self.stats.syncs_total.load(Ordering::Relaxed),
523            group_commits: self.stats.group_commits.load(Ordering::Relaxed),
524            avg_group_size: if self.stats.group_commits.load(Ordering::Relaxed) > 0 {
525                self.stats.writes_total.load(Ordering::Relaxed) as f64
526                    / self.stats.group_commits.load(Ordering::Relaxed) as f64
527            } else {
528                0.0
529            },
530            current_lsn: self.current_lsn.load(Ordering::Relaxed),
531        }
532    }
533
534    /// Shutdown the WAL
535    pub async fn shutdown(&self) -> io::Result<()> {
536        self.shutdown.store(true, Ordering::Release);
537        self.write_notify.notify_waiters();
538
539        // Final sync
540        let mut writer = self.writer.lock().await;
541        writer.sync()
542    }
543
544    /// Start the background group commit worker
545    fn start_group_commit_worker(self: Arc<Self>, mut rx: mpsc::Receiver<WriteRequest>) {
546        let wal = self.clone();
547
548        tokio::spawn(async move {
549            let mut pending: VecDeque<WriteRequest> = VecDeque::new();
550            let mut batch_buffer = BytesMut::with_capacity(wal.config.max_batch_size);
551            let mut group_start: Option<Instant> = None;
552
553            loop {
554                // Check shutdown early
555                if wal.shutdown.load(Ordering::Acquire) {
556                    // Drain any remaining messages from channel
557                    while let Ok(request) = rx.try_recv() {
558                        pending.push_back(request);
559                    }
560                    // Flush remaining and exit
561                    if !pending.is_empty() {
562                        wal.flush_batch(&mut pending, &mut batch_buffer, group_start.take())
563                            .await;
564                    }
565                    break;
566                }
567
568                // Wait for writes or timeout
569                let timeout = if pending.is_empty() {
570                    Duration::from_secs(60) // Long timeout when idle
571                } else {
572                    wal.config.group_commit_window
573                };
574
575                tokio::select! {
576                    biased;
577
578                    // Receive new write request (higher priority)
579                    Some(request) = rx.recv() => {
580                        if group_start.is_none() {
581                            group_start = Some(Instant::now());
582                        }
583                        pending.push_back(request);
584
585                        // Check if we should flush immediately
586                        let should_flush =
587                            pending.len() >= wal.config.max_pending_writes ||
588                            batch_buffer.len() >= wal.config.max_batch_size;
589
590                        if should_flush {
591                            wal.flush_batch(&mut pending, &mut batch_buffer, group_start.take()).await;
592                        }
593                    }
594
595                    // Wait for notification
596                    _ = wal.write_notify.notified() => {
597                        // Continue loop to check shutdown flag
598                    }
599
600                    // Timeout - flush whatever we have
601                    _ = tokio::time::sleep(timeout) => {
602                        if !pending.is_empty() {
603                            wal.flush_batch(&mut pending, &mut batch_buffer, group_start.take()).await;
604                        }
605                    }
606                }
607            }
608        });
609    }
610
611    /// Flush a batch of pending writes
612    async fn flush_batch(
613        &self,
614        pending: &mut VecDeque<WriteRequest>,
615        batch_buffer: &mut BytesMut,
616        group_start: Option<Instant>,
617    ) {
618        if pending.is_empty() {
619            return;
620        }
621
622        let wait_time = group_start.map(|s| s.elapsed()).unwrap_or(Duration::ZERO);
623        let group_size = pending.len();
624
625        // Build batch
626        batch_buffer.clear();
627        let mut lsns = Vec::with_capacity(group_size);
628        let mut sizes = Vec::with_capacity(group_size);
629
630        for request in pending.iter() {
631            let lsn = self.current_lsn.fetch_add(1, Ordering::AcqRel) + 1;
632            lsns.push(lsn);
633
634            // Optionally encrypt the data
635            #[cfg(feature = "encryption")]
636            let (data, is_encrypted) = if let Some(ref encryptor) = self.config.encryptor {
637                if encryptor.is_enabled() {
638                    match encryptor.encrypt(&request.data, lsn) {
639                        Ok(encrypted) => (Bytes::from(encrypted), true),
640                        Err(e) => {
641                            tracing::error!("Encryption failed for LSN {}: {:?}", lsn, e);
642                            (request.data.clone(), false)
643                        }
644                    }
645                } else {
646                    (request.data.clone(), false)
647                }
648            } else {
649                (request.data.clone(), false)
650            };
651
652            #[cfg(not(feature = "encryption"))]
653            let (data, is_encrypted) = (request.data.clone(), false);
654
655            let flags = if is_encrypted {
656                RecordFlags(RecordFlags::HAS_CHECKSUM.0 | RecordFlags::ENCRYPTED.0)
657            } else {
658                RecordFlags::HAS_CHECKSUM
659            };
660
661            let record = WalRecord {
662                lsn,
663                record_type: request.record_type,
664                flags,
665                data,
666            };
667
668            let record_bytes = record.to_bytes();
669            sizes.push(record_bytes.len());
670            batch_buffer.extend_from_slice(&record_bytes);
671        }
672
673        // Write batch to disk and rotate if file exceeds max_file_size
674        let write_result = {
675            let mut writer = self.writer.lock().await;
676            let result = writer.write_batch(batch_buffer);
677            if result.is_ok() {
678                // Check rotation after successful write
679                let next_lsn = self.current_lsn.load(Ordering::Acquire) + 1;
680                if let Err(e) = writer.rotate_if_needed(next_lsn) {
681                    tracing::error!("WAL rotation failed: {e}");
682                    // Rotation failure is non-fatal — writes continue to current file
683                }
684            }
685            result
686        };
687
688        // Update stats
689        self.stats
690            .writes_total
691            .fetch_add(group_size as u64, Ordering::Relaxed);
692        self.stats
693            .bytes_written
694            .fetch_add(batch_buffer.len() as u64, Ordering::Relaxed);
695        self.stats.group_commits.fetch_add(1, Ordering::Relaxed);
696        self.stats.syncs_total.fetch_add(1, Ordering::Relaxed);
697
698        // Send results to waiters
699        let group_commit = group_size > 1;
700
701        for (i, request) in pending.drain(..).enumerate() {
702            let result = match &write_result {
703                Ok(()) => Ok(WriteResult {
704                    lsn: lsns[i],
705                    size: sizes[i],
706                    group_commit,
707                    group_size,
708                    wait_time,
709                }),
710                Err(e) => Err(format!("WAL write failed: {e}")),
711            };
712
713            let _ = request.completion.send(result);
714        }
715    }
716}
717
718/// Low-level WAL file writer
719struct WalWriter {
720    file: BufWriter<File>,
721    path: PathBuf,
722    position: u64,
723    config: WalConfig,
724}
725
726impl WalWriter {
727    fn new(path: PathBuf, config: WalConfig) -> io::Result<Self> {
728        use std::io::{Seek, SeekFrom};
729
730        // Check if file exists and has valid data
731        let existing_len = std::fs::metadata(&path).map(|m| m.len()).unwrap_or(0);
732
733        let file = OpenOptions::new()
734            .create(true)
735            .read(true)
736            .write(true)
737            .truncate(false) // Preserve existing WAL data
738            .open(&path)?;
739
740        // Determine actual data length (scan for end of valid records)
741        let actual_position = if existing_len > 0 {
742            Self::find_actual_end(&file, existing_len)?
743        } else {
744            0
745        };
746
747        // Pre-allocate if this is a new file
748        if actual_position == 0 && config.preallocate_size > 0 {
749            file.set_len(config.preallocate_size)?;
750        }
751
752        // Seek to the actual write position
753        let mut writer = BufWriter::with_capacity(config.max_batch_size, file);
754        writer.seek(SeekFrom::Start(actual_position))?;
755
756        Ok(Self {
757            file: writer,
758            path,
759            position: actual_position,
760            config,
761        })
762    }
763
764    /// Find the actual end of valid data in the file
765    fn find_actual_end(file: &File, file_len: u64) -> io::Result<u64> {
766        use std::io::Read;
767
768        let mut position = 0u64;
769        let mut file = file.try_clone()?;
770
771        // Read in chunks to find valid record boundaries
772        while position + RECORD_HEADER_SIZE as u64 <= file_len {
773            let mut header = [0u8; RECORD_HEADER_SIZE];
774
775            use std::io::{Seek, SeekFrom};
776            file.seek(SeekFrom::Start(position))?;
777
778            if file.read_exact(&mut header).is_err() {
779                break;
780            }
781
782            // Check magic
783            let magic = u32::from_be_bytes([header[0], header[1], header[2], header[3]]);
784            if magic != WAL_MAGIC {
785                break;
786            }
787
788            let data_len =
789                u32::from_be_bytes([header[8], header[9], header[10], header[11]]) as u64;
790            let record_size = RECORD_HEADER_SIZE as u64 + data_len;
791
792            if position + record_size > file_len {
793                break;
794            }
795
796            position += record_size;
797        }
798
799        Ok(position)
800    }
801
802    fn write_batch(&mut self, data: &[u8]) -> io::Result<()> {
803        self.file.write_all(data)?;
804        self.file.flush()?;
805
806        // Sync based on mode
807        match self.config.sync_mode {
808            SyncMode::None => {}
809            SyncMode::FsyncData => {
810                self.file.get_ref().sync_data()?;
811            }
812            SyncMode::Fsync | SyncMode::Dsync => {
813                self.file.get_ref().sync_all()?;
814            }
815        }
816
817        self.position += data.len() as u64;
818        Ok(())
819    }
820
821    /// Rotate the WAL file if the current file exceeds max_file_size.
822    ///
823    /// Flushes and syncs the current file, then creates a new WAL segment
824    /// named after the given LSN. Returns `true` if rotation occurred.
825    fn rotate_if_needed(&mut self, next_lsn: u64) -> io::Result<bool> {
826        if self.config.max_file_size == 0 || self.position < self.config.max_file_size {
827            return Ok(false);
828        }
829
830        // Sync and close current file
831        self.file.flush()?;
832        self.file.get_ref().sync_all()?;
833
834        // Truncate preallocated space to actual data length
835        if self.position < self.file.get_ref().metadata()?.len() {
836            self.file.get_ref().set_len(self.position)?;
837        }
838
839        // Create new WAL file named after the next LSN
840        let new_path = self.config.dir.join(format!("{:020}.wal", next_lsn));
841
842        tracing::info!(
843            old_file = %self.path.display(),
844            new_file = %new_path.display(),
845            old_size = self.position,
846            max_size = self.config.max_file_size,
847            "Rotating WAL file"
848        );
849
850        let file = OpenOptions::new()
851            .create(true)
852            .read(true)
853            .write(true)
854            .truncate(false)
855            .open(&new_path)?;
856
857        // Pre-allocate new file
858        if self.config.preallocate_size > 0 {
859            file.set_len(self.config.preallocate_size)?;
860        }
861
862        self.file = BufWriter::with_capacity(self.config.max_batch_size, file);
863        self.path = new_path;
864        self.position = 0;
865
866        Ok(true)
867    }
868
869    fn sync(&mut self) -> io::Result<()> {
870        self.file.flush()?;
871        self.file.get_ref().sync_all()
872    }
873
874    /// Get the path of this WAL file
875    #[allow(dead_code)]
876    fn path(&self) -> &std::path::Path {
877        &self.path
878    }
879}
880
881/// WAL statistics
882struct WalStats {
883    writes_total: AtomicU64,
884    bytes_written: AtomicU64,
885    syncs_total: AtomicU64,
886    group_commits: AtomicU64,
887}
888
889impl WalStats {
890    fn new() -> Self {
891        Self {
892            writes_total: AtomicU64::new(0),
893            bytes_written: AtomicU64::new(0),
894            syncs_total: AtomicU64::new(0),
895            group_commits: AtomicU64::new(0),
896        }
897    }
898}
899
900#[derive(Debug, Clone)]
901pub struct WalStatsSnapshot {
902    pub writes_total: u64,
903    pub bytes_written: u64,
904    pub syncs_total: u64,
905    pub group_commits: u64,
906    pub avg_group_size: f64,
907    pub current_lsn: u64,
908}
909
910/// WAL reader for recovery and replication
911pub struct WalReader {
912    path: PathBuf,
913    position: u64,
914    #[cfg(feature = "encryption")]
915    encryptor: Option<std::sync::Arc<dyn crate::encryption::Encryptor>>,
916}
917
918impl WalReader {
919    /// Open a WAL file for reading
920    pub fn open(path: PathBuf) -> io::Result<Self> {
921        Ok(Self {
922            path,
923            position: 0,
924            #[cfg(feature = "encryption")]
925            encryptor: None,
926        })
927    }
928
929    /// Open a WAL file for reading with optional encryption
930    #[cfg(feature = "encryption")]
931    pub fn open_with_encryption(
932        path: PathBuf,
933        encryptor: Option<std::sync::Arc<dyn crate::encryption::Encryptor>>,
934    ) -> io::Result<Self> {
935        Ok(Self {
936            path,
937            position: 0,
938            encryptor,
939        })
940    }
941
942    /// Decrypt record data if needed
943    #[cfg(feature = "encryption")]
944    fn decrypt_record_data(&self, record: &mut WalRecord) -> io::Result<()> {
945        if record.flags.is_encrypted() {
946            if let Some(ref encryptor) = self.encryptor {
947                let decrypted = encryptor
948                    .decrypt(&record.data, record.lsn)
949                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
950                record.data = Bytes::from(decrypted);
951            } else {
952                return Err(io::Error::new(
953                    io::ErrorKind::InvalidData,
954                    "Record is encrypted but no encryptor provided",
955                ));
956            }
957        }
958        Ok(())
959    }
960
961    /// Read all records from current position
962    pub async fn read_all(&mut self) -> io::Result<Vec<WalRecord>> {
963        let data = tokio::fs::read(&self.path).await?;
964        let mut records = Vec::new();
965        let mut lsn = 0u64;
966
967        while self.position + RECORD_HEADER_SIZE as u64 <= data.len() as u64 {
968            let offset = self.position as usize;
969
970            // Check magic
971            let magic = u32::from_be_bytes([
972                data[offset],
973                data[offset + 1],
974                data[offset + 2],
975                data[offset + 3],
976            ]);
977
978            if magic != WAL_MAGIC {
979                break;
980            }
981
982            let data_len = u32::from_be_bytes([
983                data[offset + 8],
984                data[offset + 9],
985                data[offset + 10],
986                data[offset + 11],
987            ]) as usize;
988
989            let record_size = RECORD_HEADER_SIZE + data_len;
990
991            if offset + record_size > data.len() {
992                break;
993            }
994
995            lsn += 1;
996
997            match WalRecord::from_bytes(&data[offset..offset + record_size], lsn) {
998                #[cfg(feature = "encryption")]
999                Ok(mut record) => {
1000                    // Decrypt if needed
1001                    self.decrypt_record_data(&mut record)?;
1002
1003                    records.push(record);
1004                    self.position += record_size as u64;
1005                }
1006                #[cfg(not(feature = "encryption"))]
1007                Ok(record) => {
1008                    records.push(record);
1009                    self.position += record_size as u64;
1010                }
1011                Err(_) => break,
1012            }
1013        }
1014
1015        Ok(records)
1016    }
1017
1018    /// Seek to a specific LSN
1019    pub async fn seek_to_lsn(&mut self, target_lsn: u64) -> io::Result<()> {
1020        let data = tokio::fs::read(&self.path).await?;
1021        let mut position = 0usize;
1022        let mut current_lsn = 0u64;
1023
1024        while position + RECORD_HEADER_SIZE <= data.len() {
1025            let magic = u32::from_be_bytes([
1026                data[position],
1027                data[position + 1],
1028                data[position + 2],
1029                data[position + 3],
1030            ]);
1031
1032            if magic != WAL_MAGIC {
1033                break;
1034            }
1035
1036            let data_len = u32::from_be_bytes([
1037                data[position + 8],
1038                data[position + 9],
1039                data[position + 10],
1040                data[position + 11],
1041            ]) as usize;
1042
1043            let record_size = RECORD_HEADER_SIZE + data_len;
1044            current_lsn += 1;
1045
1046            if current_lsn >= target_lsn {
1047                self.position = position as u64;
1048                return Ok(());
1049            }
1050
1051            position += record_size;
1052        }
1053
1054        Err(io::Error::new(
1055            io::ErrorKind::NotFound,
1056            format!("LSN {} not found", target_lsn),
1057        ))
1058    }
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063    use super::*;
1064    use tempfile::TempDir;
1065
1066    #[test]
1067    fn test_wal_record_serialization() {
1068        let data = Bytes::from("test data");
1069        let record = WalRecord::new(1, data.clone());
1070
1071        let serialized = record.to_bytes();
1072        assert!(serialized.len() >= RECORD_HEADER_SIZE + data.len());
1073
1074        let parsed = WalRecord::from_bytes(&serialized, 1).unwrap();
1075        assert_eq!(parsed.lsn, 1);
1076        assert_eq!(parsed.data, data);
1077    }
1078
1079    #[test]
1080    fn test_wal_record_crc() {
1081        let data = Bytes::from("test data");
1082        let record = WalRecord::new(1, data);
1083        let mut serialized = record.to_bytes().to_vec();
1084
1085        // Corrupt the data
1086        serialized[RECORD_HEADER_SIZE] ^= 0xFF;
1087
1088        // Should fail CRC check
1089        assert!(WalRecord::from_bytes(&serialized, 1).is_err());
1090    }
1091
1092    #[tokio::test]
1093    async fn test_group_commit_wal_single_write() {
1094        let temp_dir = TempDir::new().unwrap();
1095        let config = WalConfig {
1096            dir: temp_dir.path().to_path_buf(),
1097            group_commit_window: Duration::from_micros(100),
1098            ..Default::default()
1099        };
1100
1101        let wal = GroupCommitWal::new(config).await.unwrap();
1102
1103        let result = wal.write(Bytes::from("test data")).await.unwrap();
1104        assert_eq!(result.lsn, 1);
1105        assert!(result.size > 0);
1106
1107        let stats = wal.stats();
1108        assert_eq!(stats.writes_total, 1);
1109
1110        wal.shutdown().await.unwrap();
1111    }
1112
1113    #[tokio::test]
1114    async fn test_group_commit_wal_batch() {
1115        let temp_dir = TempDir::new().unwrap();
1116        let config = WalConfig {
1117            dir: temp_dir.path().to_path_buf(),
1118            group_commit_window: Duration::from_millis(10),
1119            ..Default::default()
1120        };
1121
1122        let wal = GroupCommitWal::new(config).await.unwrap();
1123
1124        let records: Vec<Bytes> = (0..10)
1125            .map(|i| Bytes::from(format!("record {}", i)))
1126            .collect();
1127
1128        let results = wal.write_batch(records).await.unwrap();
1129
1130        assert_eq!(results.len(), 10);
1131        for (i, result) in results.iter().enumerate() {
1132            assert_eq!(result.lsn, (i + 1) as u64);
1133        }
1134
1135        let stats = wal.stats();
1136        assert_eq!(stats.writes_total, 10);
1137
1138        wal.shutdown().await.unwrap();
1139    }
1140
1141    #[tokio::test]
1142    async fn test_group_commit_batching() {
1143        let temp_dir = TempDir::new().unwrap();
1144        let config = WalConfig {
1145            dir: temp_dir.path().to_path_buf(),
1146            group_commit_window: Duration::from_millis(50),
1147            max_pending_writes: 100,
1148            ..Default::default()
1149        };
1150
1151        let wal = Arc::new(GroupCommitWal::new(config).await.unwrap());
1152
1153        // Spawn multiple writers concurrently
1154        let mut handles = vec![];
1155        for i in 0..20 {
1156            let wal_clone = wal.clone();
1157            handles.push(tokio::spawn(async move {
1158                wal_clone
1159                    .write(Bytes::from(format!("concurrent write {}", i)))
1160                    .await
1161            }));
1162        }
1163
1164        // Wait for all writes
1165        for handle in handles {
1166            let result = handle.await.unwrap().unwrap();
1167            assert!(result.lsn > 0);
1168        }
1169
1170        let stats = wal.stats();
1171        assert_eq!(stats.writes_total, 20);
1172        // With batching, we should have fewer syncs than writes
1173        assert!(stats.group_commits <= stats.writes_total);
1174
1175        wal.shutdown().await.unwrap();
1176    }
1177
1178    #[tokio::test]
1179    async fn test_wal_reader() {
1180        let temp_dir = TempDir::new().unwrap();
1181        let config = WalConfig {
1182            dir: temp_dir.path().to_path_buf(),
1183            group_commit_window: Duration::from_micros(100),
1184            sync_mode: SyncMode::Fsync,
1185            max_pending_writes: 10,
1186            ..Default::default()
1187        };
1188
1189        let wal = GroupCommitWal::new(config.clone()).await.unwrap();
1190
1191        // Write some records and wait for each to complete
1192        for i in 0..5 {
1193            let result = wal
1194                .write(Bytes::from(format!("record {}", i)))
1195                .await
1196                .unwrap();
1197            assert!(result.lsn > 0, "Expected valid LSN for record {}", i);
1198        }
1199
1200        // Force a sync to ensure data is on disk
1201        wal.sync().await.unwrap();
1202
1203        // Give a small delay to ensure background worker has flushed
1204        tokio::time::sleep(Duration::from_millis(100)).await;
1205
1206        wal.shutdown().await.unwrap();
1207
1208        // Find the WAL file
1209        let entries: Vec<_> = std::fs::read_dir(&config.dir)
1210            .unwrap()
1211            .filter_map(|e| e.ok())
1212            .filter(|e| e.path().extension().is_some_and(|ext| ext == "wal"))
1213            .collect();
1214
1215        assert!(!entries.is_empty(), "No WAL files found");
1216
1217        let wal_file = entries[0].path();
1218        let file_size = std::fs::metadata(&wal_file).unwrap().len();
1219        assert!(file_size > 0, "WAL file is empty");
1220
1221        let mut reader = WalReader::open(wal_file).unwrap();
1222        let records = reader.read_all().await.unwrap();
1223
1224        assert_eq!(
1225            records.len(),
1226            5,
1227            "Expected 5 records, got {} (file size: {})",
1228            records.len(),
1229            file_size
1230        );
1231        for (i, record) in records.iter().enumerate() {
1232            let expected = format!("record {}", i);
1233            assert_eq!(record.data, Bytes::from(expected));
1234        }
1235    }
1236
1237    #[test]
1238    fn test_record_flags() {
1239        let flags = RecordFlags::COMPRESSED;
1240        assert!(flags.is_compressed());
1241        assert!(!flags.is_encrypted());
1242
1243        let flags = RecordFlags(RecordFlags::COMPRESSED.0 | RecordFlags::ENCRYPTED.0);
1244        assert!(flags.is_compressed());
1245        assert!(flags.is_encrypted());
1246    }
1247
1248    #[tokio::test]
1249    async fn test_wal_rotation() {
1250        let temp_dir = TempDir::new().unwrap();
1251        let config = WalConfig {
1252            dir: temp_dir.path().to_path_buf(),
1253            group_commit_window: Duration::from_micros(50),
1254            max_file_size: 200, // Very small to trigger rotation quickly
1255            preallocate_size: 0,
1256            ..Default::default()
1257        };
1258
1259        let wal = GroupCommitWal::new(config).await.unwrap();
1260
1261        // Write enough data to trigger rotation
1262        for i in 0..10 {
1263            let data = format!("rotation-record-{:04}", i);
1264            let result = wal.write(Bytes::from(data)).await.unwrap();
1265            assert!(result.lsn > 0);
1266        }
1267
1268        wal.sync().await.unwrap();
1269        tokio::time::sleep(Duration::from_millis(100)).await;
1270        wal.shutdown().await.unwrap();
1271
1272        // Check that multiple WAL files were created (rotation occurred)
1273        let wal_files: Vec<_> = std::fs::read_dir(temp_dir.path())
1274            .unwrap()
1275            .filter_map(|e| e.ok())
1276            .filter(|e| e.path().extension().is_some_and(|ext| ext == "wal"))
1277            .collect();
1278
1279        assert!(
1280            wal_files.len() > 1,
1281            "Expected multiple WAL files after rotation, got {}",
1282            wal_files.len()
1283        );
1284    }
1285
1286    #[test]
1287    fn test_wal_writer_rotate_if_needed() {
1288        let temp_dir = TempDir::new().unwrap();
1289        let config = WalConfig {
1290            dir: temp_dir.path().to_path_buf(),
1291            max_file_size: 100,
1292            preallocate_size: 0,
1293            ..Default::default()
1294        };
1295
1296        let path = temp_dir.path().join("00000000000000000000.wal");
1297        let mut writer = WalWriter::new(path.clone(), config).unwrap();
1298
1299        // Write some data to push past max_file_size
1300        writer.write_batch(&[0u8; 150]).unwrap();
1301        assert_eq!(writer.position, 150);
1302
1303        // Should rotate
1304        let rotated = writer.rotate_if_needed(42).unwrap();
1305        assert!(rotated, "Expected rotation to occur");
1306        assert_eq!(writer.position, 0);
1307        assert_ne!(writer.path, path);
1308        assert!(writer
1309            .path
1310            .to_str()
1311            .unwrap()
1312            .contains("00000000000000000042"));
1313
1314        // Should not rotate again immediately
1315        let rotated = writer.rotate_if_needed(43).unwrap();
1316        assert!(!rotated, "Expected no rotation when under max_file_size");
1317    }
1318}