1use bytes::{BufMut, Bytes, BytesMut};
16use crc32fast::Hasher;
17use std::collections::VecDeque;
18use std::fs::{File, OpenOptions};
19use std::io::{self, BufWriter, Write};
20use std::path::{Path, PathBuf};
21use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24use tokio::sync::{mpsc, oneshot, Mutex, Notify};
25
26const RECORD_HEADER_SIZE: usize = 14;
28
29const WAL_MAGIC: u32 = 0x57414C52; const DEFAULT_GROUP_COMMIT_WINDOW_US: u64 = 200;
34
35const DEFAULT_MAX_BATCH_SIZE: usize = 4 * 1024 * 1024; const DEFAULT_MAX_PENDING_WRITES: usize = 1000;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43#[repr(u8)]
44pub enum RecordType {
45 Full = 0,
47 First = 1,
49 Middle = 2,
51 Last = 3,
53 Checkpoint = 4,
55 TxnBegin = 5,
57 TxnCommit = 6,
59 TxnAbort = 7,
61}
62
63impl TryFrom<u8> for RecordType {
64 type Error = io::Error;
65
66 fn try_from(value: u8) -> Result<Self, Self::Error> {
67 match value {
68 0 => Ok(RecordType::Full),
69 1 => Ok(RecordType::First),
70 2 => Ok(RecordType::Middle),
71 3 => Ok(RecordType::Last),
72 4 => Ok(RecordType::Checkpoint),
73 5 => Ok(RecordType::TxnBegin),
74 6 => Ok(RecordType::TxnCommit),
75 7 => Ok(RecordType::TxnAbort),
76 _ => Err(io::Error::new(
77 io::ErrorKind::InvalidData,
78 "Invalid record type",
79 )),
80 }
81 }
82}
83
84#[derive(Debug, Clone, Copy)]
86pub struct RecordFlags(u8);
87
88impl RecordFlags {
89 pub const NONE: Self = Self(0);
90 pub const COMPRESSED: Self = Self(1 << 0);
91 pub const ENCRYPTED: Self = Self(1 << 1);
92 pub const HAS_CHECKSUM: Self = Self(1 << 2);
93
94 pub fn is_compressed(&self) -> bool {
95 self.0 & Self::COMPRESSED.0 != 0
96 }
97
98 pub fn is_encrypted(&self) -> bool {
99 self.0 & Self::ENCRYPTED.0 != 0
100 }
101
102 pub fn has_checksum(&self) -> bool {
103 self.0 & Self::HAS_CHECKSUM.0 != 0
104 }
105}
106
107#[derive(Debug, Clone)]
109pub struct WalRecord {
110 pub lsn: u64,
112 pub record_type: RecordType,
114 pub flags: RecordFlags,
116 pub data: Bytes,
118}
119
120impl WalRecord {
121 pub fn new(lsn: u64, data: Bytes) -> Self {
123 Self {
124 lsn,
125 record_type: RecordType::Full,
126 flags: RecordFlags::HAS_CHECKSUM,
127 data,
128 }
129 }
130
131 pub fn to_bytes(&self) -> Bytes {
133 let mut buf = BytesMut::with_capacity(RECORD_HEADER_SIZE + self.data.len());
134
135 let mut hasher = Hasher::new();
137 hasher.update(&self.data);
138 let crc = hasher.finalize();
139
140 buf.put_u32(WAL_MAGIC);
142 buf.put_u32(crc);
143 buf.put_u32(self.data.len() as u32);
144 buf.put_u8(self.record_type as u8);
145 buf.put_u8(self.flags.0);
146
147 buf.extend_from_slice(&self.data);
149
150 buf.freeze()
151 }
152
153 pub fn from_bytes(data: &[u8], lsn: u64) -> io::Result<Self> {
155 if data.len() < RECORD_HEADER_SIZE {
156 return Err(io::Error::new(
157 io::ErrorKind::InvalidData,
158 "Record too short",
159 ));
160 }
161
162 let magic = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
164 if magic != WAL_MAGIC {
165 return Err(io::Error::new(
166 io::ErrorKind::InvalidData,
167 "Invalid magic number",
168 ));
169 }
170
171 let stored_crc = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
172 let data_len = u32::from_be_bytes([data[8], data[9], data[10], data[11]]) as usize;
173 let record_type = RecordType::try_from(data[12])?;
174 let flags = RecordFlags(data[13]);
175
176 if data.len() < RECORD_HEADER_SIZE + data_len {
177 return Err(io::Error::new(
178 io::ErrorKind::InvalidData,
179 "Incomplete record",
180 ));
181 }
182
183 let record_data =
184 Bytes::copy_from_slice(&data[RECORD_HEADER_SIZE..RECORD_HEADER_SIZE + data_len]);
185
186 let mut hasher = Hasher::new();
188 hasher.update(&record_data);
189 let computed_crc = hasher.finalize();
190
191 if computed_crc != stored_crc {
192 return Err(io::Error::new(io::ErrorKind::InvalidData, "CRC mismatch"));
193 }
194
195 Ok(Self {
196 lsn,
197 record_type,
198 flags,
199 data: record_data,
200 })
201 }
202
203 pub fn serialized_size(&self) -> usize {
205 RECORD_HEADER_SIZE + self.data.len()
206 }
207}
208
209#[derive(Debug, Clone)]
211pub struct WalConfig {
212 pub dir: PathBuf,
214 pub group_commit_window: Duration,
216 pub max_batch_size: usize,
218 pub max_pending_writes: usize,
220 pub preallocate_size: u64,
222 pub direct_io: bool,
224 pub sync_mode: SyncMode,
226 pub max_file_size: u64,
228 #[cfg(feature = "encryption")]
230 pub encryptor: Option<std::sync::Arc<dyn crate::encryption::Encryptor>>,
231}
232
233impl Default for WalConfig {
234 fn default() -> Self {
235 Self {
236 dir: PathBuf::from("./wal"),
237 group_commit_window: Duration::from_micros(DEFAULT_GROUP_COMMIT_WINDOW_US),
238 max_batch_size: DEFAULT_MAX_BATCH_SIZE,
239 max_pending_writes: DEFAULT_MAX_PENDING_WRITES,
240 preallocate_size: 64 * 1024 * 1024, direct_io: false, sync_mode: SyncMode::Fsync,
243 max_file_size: 1024 * 1024 * 1024, #[cfg(feature = "encryption")]
245 encryptor: None,
246 }
247 }
248}
249
250impl WalConfig {
251 pub fn high_throughput() -> Self {
253 Self {
254 group_commit_window: Duration::from_micros(1000), max_batch_size: 16 * 1024 * 1024, max_pending_writes: 5000,
257 ..Default::default()
258 }
259 }
260
261 pub fn low_latency() -> Self {
263 Self {
264 group_commit_window: Duration::from_micros(50), max_batch_size: 512 * 1024, max_pending_writes: 100,
267 ..Default::default()
268 }
269 }
270
271 pub fn durable() -> Self {
273 Self {
274 sync_mode: SyncMode::FsyncData,
275 ..Default::default()
276 }
277 }
278}
279
280#[derive(Debug, Clone, Copy, PartialEq, Eq)]
282pub enum SyncMode {
283 None,
285 FsyncData,
287 Fsync,
289 Dsync,
291}
292
293struct WriteRequest {
295 data: Bytes,
297 record_type: RecordType,
299 completion: oneshot::Sender<Result<WriteResult, String>>,
301}
302
303#[derive(Debug, Clone)]
305pub struct WriteResult {
306 pub lsn: u64,
308 pub size: usize,
310 pub group_commit: bool,
312 pub group_size: usize,
314 pub wait_time: Duration,
316}
317
318pub struct GroupCommitWal {
320 config: WalConfig,
321 writer: Mutex<WalWriter>,
323 current_lsn: AtomicU64,
325 write_tx: mpsc::Sender<WriteRequest>,
327 shutdown: AtomicBool,
329 write_notify: Arc<Notify>,
331 stats: Arc<WalStats>,
333}
334
335impl GroupCommitWal {
336 pub async fn new(config: WalConfig) -> io::Result<Arc<Self>> {
338 std::fs::create_dir_all(&config.dir)?;
339
340 let (current_file, current_lsn) = Self::recover_state(&config).await?;
342
343 let writer = WalWriter::new(current_file, config.clone())?;
344 let (write_tx, write_rx) = mpsc::channel(config.max_pending_writes);
345
346 let wal = Arc::new(Self {
347 config,
348 writer: Mutex::new(writer),
349 current_lsn: AtomicU64::new(current_lsn),
350 write_tx,
351 shutdown: AtomicBool::new(false),
352 write_notify: Arc::new(Notify::new()),
353 stats: Arc::new(WalStats::new()),
354 });
355
356 wal.clone().start_group_commit_worker(write_rx);
358
359 Ok(wal)
360 }
361
362 async fn recover_state(config: &WalConfig) -> io::Result<(PathBuf, u64)> {
364 let mut max_lsn = 0u64;
365 let mut latest_file = None;
366
367 if let Ok(entries) = std::fs::read_dir(&config.dir) {
368 for entry in entries.flatten() {
369 let path = entry.path();
370 if path.extension().is_some_and(|e| e == "wal") {
371 if let Some(name) = path.file_stem() {
372 if let Ok(lsn) = name.to_string_lossy().parse::<u64>() {
373 if lsn >= max_lsn {
374 max_lsn = lsn;
375 latest_file = Some(path);
376 }
377 }
378 }
379 }
380 }
381 }
382
383 if let Some(ref file) = latest_file {
385 if let Ok(recovered_lsn) = Self::scan_wal_file(file).await {
386 max_lsn = recovered_lsn;
387 }
388 }
389
390 let file = latest_file.unwrap_or_else(|| config.dir.join(format!("{:020}.wal", 0)));
392
393 Ok((file, max_lsn))
394 }
395
396 async fn scan_wal_file(path: &Path) -> io::Result<u64> {
398 let data = tokio::fs::read(path).await?;
399 let mut offset = 0;
400 let mut max_lsn = 0u64;
401
402 while offset + RECORD_HEADER_SIZE <= data.len() {
403 let magic = u32::from_be_bytes([
405 data[offset],
406 data[offset + 1],
407 data[offset + 2],
408 data[offset + 3],
409 ]);
410
411 if magic != WAL_MAGIC {
412 break;
413 }
414
415 let data_len = u32::from_be_bytes([
416 data[offset + 8],
417 data[offset + 9],
418 data[offset + 10],
419 data[offset + 11],
420 ]) as usize;
421
422 let record_size = RECORD_HEADER_SIZE + data_len;
423 if offset + record_size > data.len() {
424 break;
425 }
426
427 max_lsn += 1;
428 offset += record_size;
429 }
430
431 Ok(max_lsn)
432 }
433
434 pub async fn write(&self, data: Bytes) -> io::Result<WriteResult> {
436 self.write_with_type(data, RecordType::Full).await
437 }
438
439 pub async fn write_with_type(
441 &self,
442 data: Bytes,
443 record_type: RecordType,
444 ) -> io::Result<WriteResult> {
445 if self.shutdown.load(Ordering::Acquire) {
446 return Err(io::Error::new(
447 io::ErrorKind::BrokenPipe,
448 "WAL is shut down",
449 ));
450 }
451
452 let (tx, rx) = oneshot::channel();
453
454 let request = WriteRequest {
455 data,
456 record_type,
457 completion: tx,
458 };
459
460 self.write_tx
461 .send(request)
462 .await
463 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "WAL write channel closed"))?;
464
465 self.write_notify.notify_one();
466
467 rx.await
468 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "WAL write cancelled"))?
469 .map_err(io::Error::other)
470 }
471
472 pub async fn write_batch(&self, records: Vec<Bytes>) -> io::Result<Vec<WriteResult>> {
474 let mut results = Vec::with_capacity(records.len());
475 let mut receivers = Vec::with_capacity(records.len());
476
477 for data in records {
478 let (tx, rx) = oneshot::channel();
479
480 let request = WriteRequest {
481 data,
482 record_type: RecordType::Full,
483 completion: tx,
484 };
485
486 self.write_tx.send(request).await.map_err(|_| {
487 io::Error::new(io::ErrorKind::BrokenPipe, "WAL write channel closed")
488 })?;
489
490 receivers.push(rx);
491 }
492
493 self.write_notify.notify_one();
494
495 for rx in receivers {
496 let result = rx
497 .await
498 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "WAL write cancelled"))?
499 .map_err(io::Error::other)?;
500 results.push(result);
501 }
502
503 Ok(results)
504 }
505
506 pub async fn sync(&self) -> io::Result<()> {
508 let mut writer = self.writer.lock().await;
509 writer.sync()
510 }
511
512 pub fn current_lsn(&self) -> u64 {
514 self.current_lsn.load(Ordering::Acquire)
515 }
516
517 pub fn stats(&self) -> WalStatsSnapshot {
519 WalStatsSnapshot {
520 writes_total: self.stats.writes_total.load(Ordering::Relaxed),
521 bytes_written: self.stats.bytes_written.load(Ordering::Relaxed),
522 syncs_total: self.stats.syncs_total.load(Ordering::Relaxed),
523 group_commits: self.stats.group_commits.load(Ordering::Relaxed),
524 avg_group_size: if self.stats.group_commits.load(Ordering::Relaxed) > 0 {
525 self.stats.writes_total.load(Ordering::Relaxed) as f64
526 / self.stats.group_commits.load(Ordering::Relaxed) as f64
527 } else {
528 0.0
529 },
530 current_lsn: self.current_lsn.load(Ordering::Relaxed),
531 }
532 }
533
534 pub async fn shutdown(&self) -> io::Result<()> {
536 self.shutdown.store(true, Ordering::Release);
537 self.write_notify.notify_waiters();
538
539 let mut writer = self.writer.lock().await;
541 writer.sync()
542 }
543
544 fn start_group_commit_worker(self: Arc<Self>, mut rx: mpsc::Receiver<WriteRequest>) {
546 let wal = self.clone();
547
548 tokio::spawn(async move {
549 let mut pending: VecDeque<WriteRequest> = VecDeque::new();
550 let mut batch_buffer = BytesMut::with_capacity(wal.config.max_batch_size);
551 let mut group_start: Option<Instant> = None;
552
553 loop {
554 if wal.shutdown.load(Ordering::Acquire) {
556 while let Ok(request) = rx.try_recv() {
558 pending.push_back(request);
559 }
560 if !pending.is_empty() {
562 wal.flush_batch(&mut pending, &mut batch_buffer, group_start.take())
563 .await;
564 }
565 break;
566 }
567
568 let timeout = if pending.is_empty() {
570 Duration::from_secs(60) } else {
572 wal.config.group_commit_window
573 };
574
575 tokio::select! {
576 biased;
577
578 Some(request) = rx.recv() => {
580 if group_start.is_none() {
581 group_start = Some(Instant::now());
582 }
583 pending.push_back(request);
584
585 let should_flush =
587 pending.len() >= wal.config.max_pending_writes ||
588 batch_buffer.len() >= wal.config.max_batch_size;
589
590 if should_flush {
591 wal.flush_batch(&mut pending, &mut batch_buffer, group_start.take()).await;
592 }
593 }
594
595 _ = wal.write_notify.notified() => {
597 }
599
600 _ = tokio::time::sleep(timeout) => {
602 if !pending.is_empty() {
603 wal.flush_batch(&mut pending, &mut batch_buffer, group_start.take()).await;
604 }
605 }
606 }
607 }
608 });
609 }
610
611 async fn flush_batch(
613 &self,
614 pending: &mut VecDeque<WriteRequest>,
615 batch_buffer: &mut BytesMut,
616 group_start: Option<Instant>,
617 ) {
618 if pending.is_empty() {
619 return;
620 }
621
622 let wait_time = group_start.map(|s| s.elapsed()).unwrap_or(Duration::ZERO);
623 let group_size = pending.len();
624
625 batch_buffer.clear();
627 let mut lsns = Vec::with_capacity(group_size);
628 let mut sizes = Vec::with_capacity(group_size);
629
630 for request in pending.iter() {
631 let lsn = self.current_lsn.fetch_add(1, Ordering::AcqRel) + 1;
632 lsns.push(lsn);
633
634 #[cfg(feature = "encryption")]
636 let (data, is_encrypted) = if let Some(ref encryptor) = self.config.encryptor {
637 if encryptor.is_enabled() {
638 match encryptor.encrypt(&request.data, lsn) {
639 Ok(encrypted) => (Bytes::from(encrypted), true),
640 Err(e) => {
641 tracing::error!("Encryption failed for LSN {}: {:?}", lsn, e);
642 (request.data.clone(), false)
643 }
644 }
645 } else {
646 (request.data.clone(), false)
647 }
648 } else {
649 (request.data.clone(), false)
650 };
651
652 #[cfg(not(feature = "encryption"))]
653 let (data, is_encrypted) = (request.data.clone(), false);
654
655 let flags = if is_encrypted {
656 RecordFlags(RecordFlags::HAS_CHECKSUM.0 | RecordFlags::ENCRYPTED.0)
657 } else {
658 RecordFlags::HAS_CHECKSUM
659 };
660
661 let record = WalRecord {
662 lsn,
663 record_type: request.record_type,
664 flags,
665 data,
666 };
667
668 let record_bytes = record.to_bytes();
669 sizes.push(record_bytes.len());
670 batch_buffer.extend_from_slice(&record_bytes);
671 }
672
673 let write_result = {
675 let mut writer = self.writer.lock().await;
676 let result = writer.write_batch(batch_buffer);
677 if result.is_ok() {
678 let next_lsn = self.current_lsn.load(Ordering::Acquire) + 1;
680 if let Err(e) = writer.rotate_if_needed(next_lsn) {
681 tracing::error!("WAL rotation failed: {e}");
682 }
684 }
685 result
686 };
687
688 self.stats
690 .writes_total
691 .fetch_add(group_size as u64, Ordering::Relaxed);
692 self.stats
693 .bytes_written
694 .fetch_add(batch_buffer.len() as u64, Ordering::Relaxed);
695 self.stats.group_commits.fetch_add(1, Ordering::Relaxed);
696 self.stats.syncs_total.fetch_add(1, Ordering::Relaxed);
697
698 let group_commit = group_size > 1;
700
701 for (i, request) in pending.drain(..).enumerate() {
702 let result = match &write_result {
703 Ok(()) => Ok(WriteResult {
704 lsn: lsns[i],
705 size: sizes[i],
706 group_commit,
707 group_size,
708 wait_time,
709 }),
710 Err(e) => Err(format!("WAL write failed: {e}")),
711 };
712
713 let _ = request.completion.send(result);
714 }
715 }
716}
717
718struct WalWriter {
720 file: BufWriter<File>,
721 path: PathBuf,
722 position: u64,
723 config: WalConfig,
724}
725
726impl WalWriter {
727 fn new(path: PathBuf, config: WalConfig) -> io::Result<Self> {
728 use std::io::{Seek, SeekFrom};
729
730 let existing_len = std::fs::metadata(&path).map(|m| m.len()).unwrap_or(0);
732
733 let file = OpenOptions::new()
734 .create(true)
735 .read(true)
736 .write(true)
737 .truncate(false) .open(&path)?;
739
740 let actual_position = if existing_len > 0 {
742 Self::find_actual_end(&file, existing_len)?
743 } else {
744 0
745 };
746
747 if actual_position == 0 && config.preallocate_size > 0 {
749 file.set_len(config.preallocate_size)?;
750 }
751
752 let mut writer = BufWriter::with_capacity(config.max_batch_size, file);
754 writer.seek(SeekFrom::Start(actual_position))?;
755
756 Ok(Self {
757 file: writer,
758 path,
759 position: actual_position,
760 config,
761 })
762 }
763
764 fn find_actual_end(file: &File, file_len: u64) -> io::Result<u64> {
766 use std::io::Read;
767
768 let mut position = 0u64;
769 let mut file = file.try_clone()?;
770
771 while position + RECORD_HEADER_SIZE as u64 <= file_len {
773 let mut header = [0u8; RECORD_HEADER_SIZE];
774
775 use std::io::{Seek, SeekFrom};
776 file.seek(SeekFrom::Start(position))?;
777
778 if file.read_exact(&mut header).is_err() {
779 break;
780 }
781
782 let magic = u32::from_be_bytes([header[0], header[1], header[2], header[3]]);
784 if magic != WAL_MAGIC {
785 break;
786 }
787
788 let data_len =
789 u32::from_be_bytes([header[8], header[9], header[10], header[11]]) as u64;
790 let record_size = RECORD_HEADER_SIZE as u64 + data_len;
791
792 if position + record_size > file_len {
793 break;
794 }
795
796 position += record_size;
797 }
798
799 Ok(position)
800 }
801
802 fn write_batch(&mut self, data: &[u8]) -> io::Result<()> {
803 self.file.write_all(data)?;
804 self.file.flush()?;
805
806 match self.config.sync_mode {
808 SyncMode::None => {}
809 SyncMode::FsyncData => {
810 self.file.get_ref().sync_data()?;
811 }
812 SyncMode::Fsync | SyncMode::Dsync => {
813 self.file.get_ref().sync_all()?;
814 }
815 }
816
817 self.position += data.len() as u64;
818 Ok(())
819 }
820
821 fn rotate_if_needed(&mut self, next_lsn: u64) -> io::Result<bool> {
826 if self.config.max_file_size == 0 || self.position < self.config.max_file_size {
827 return Ok(false);
828 }
829
830 self.file.flush()?;
832 self.file.get_ref().sync_all()?;
833
834 if self.position < self.file.get_ref().metadata()?.len() {
836 self.file.get_ref().set_len(self.position)?;
837 }
838
839 let new_path = self.config.dir.join(format!("{:020}.wal", next_lsn));
841
842 tracing::info!(
843 old_file = %self.path.display(),
844 new_file = %new_path.display(),
845 old_size = self.position,
846 max_size = self.config.max_file_size,
847 "Rotating WAL file"
848 );
849
850 let file = OpenOptions::new()
851 .create(true)
852 .read(true)
853 .write(true)
854 .truncate(false)
855 .open(&new_path)?;
856
857 if self.config.preallocate_size > 0 {
859 file.set_len(self.config.preallocate_size)?;
860 }
861
862 self.file = BufWriter::with_capacity(self.config.max_batch_size, file);
863 self.path = new_path;
864 self.position = 0;
865
866 Ok(true)
867 }
868
869 fn sync(&mut self) -> io::Result<()> {
870 self.file.flush()?;
871 self.file.get_ref().sync_all()
872 }
873
874 #[allow(dead_code)]
876 fn path(&self) -> &std::path::Path {
877 &self.path
878 }
879}
880
881struct WalStats {
883 writes_total: AtomicU64,
884 bytes_written: AtomicU64,
885 syncs_total: AtomicU64,
886 group_commits: AtomicU64,
887}
888
889impl WalStats {
890 fn new() -> Self {
891 Self {
892 writes_total: AtomicU64::new(0),
893 bytes_written: AtomicU64::new(0),
894 syncs_total: AtomicU64::new(0),
895 group_commits: AtomicU64::new(0),
896 }
897 }
898}
899
900#[derive(Debug, Clone)]
901pub struct WalStatsSnapshot {
902 pub writes_total: u64,
903 pub bytes_written: u64,
904 pub syncs_total: u64,
905 pub group_commits: u64,
906 pub avg_group_size: f64,
907 pub current_lsn: u64,
908}
909
910pub struct WalReader {
912 path: PathBuf,
913 position: u64,
914 #[cfg(feature = "encryption")]
915 encryptor: Option<std::sync::Arc<dyn crate::encryption::Encryptor>>,
916}
917
918impl WalReader {
919 pub fn open(path: PathBuf) -> io::Result<Self> {
921 Ok(Self {
922 path,
923 position: 0,
924 #[cfg(feature = "encryption")]
925 encryptor: None,
926 })
927 }
928
929 #[cfg(feature = "encryption")]
931 pub fn open_with_encryption(
932 path: PathBuf,
933 encryptor: Option<std::sync::Arc<dyn crate::encryption::Encryptor>>,
934 ) -> io::Result<Self> {
935 Ok(Self {
936 path,
937 position: 0,
938 encryptor,
939 })
940 }
941
942 #[cfg(feature = "encryption")]
944 fn decrypt_record_data(&self, record: &mut WalRecord) -> io::Result<()> {
945 if record.flags.is_encrypted() {
946 if let Some(ref encryptor) = self.encryptor {
947 let decrypted = encryptor
948 .decrypt(&record.data, record.lsn)
949 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
950 record.data = Bytes::from(decrypted);
951 } else {
952 return Err(io::Error::new(
953 io::ErrorKind::InvalidData,
954 "Record is encrypted but no encryptor provided",
955 ));
956 }
957 }
958 Ok(())
959 }
960
961 pub async fn read_all(&mut self) -> io::Result<Vec<WalRecord>> {
963 let data = tokio::fs::read(&self.path).await?;
964 let mut records = Vec::new();
965 let mut lsn = 0u64;
966
967 while self.position + RECORD_HEADER_SIZE as u64 <= data.len() as u64 {
968 let offset = self.position as usize;
969
970 let magic = u32::from_be_bytes([
972 data[offset],
973 data[offset + 1],
974 data[offset + 2],
975 data[offset + 3],
976 ]);
977
978 if magic != WAL_MAGIC {
979 break;
980 }
981
982 let data_len = u32::from_be_bytes([
983 data[offset + 8],
984 data[offset + 9],
985 data[offset + 10],
986 data[offset + 11],
987 ]) as usize;
988
989 let record_size = RECORD_HEADER_SIZE + data_len;
990
991 if offset + record_size > data.len() {
992 break;
993 }
994
995 lsn += 1;
996
997 match WalRecord::from_bytes(&data[offset..offset + record_size], lsn) {
998 #[cfg(feature = "encryption")]
999 Ok(mut record) => {
1000 self.decrypt_record_data(&mut record)?;
1002
1003 records.push(record);
1004 self.position += record_size as u64;
1005 }
1006 #[cfg(not(feature = "encryption"))]
1007 Ok(record) => {
1008 records.push(record);
1009 self.position += record_size as u64;
1010 }
1011 Err(_) => break,
1012 }
1013 }
1014
1015 Ok(records)
1016 }
1017
1018 pub async fn seek_to_lsn(&mut self, target_lsn: u64) -> io::Result<()> {
1020 let data = tokio::fs::read(&self.path).await?;
1021 let mut position = 0usize;
1022 let mut current_lsn = 0u64;
1023
1024 while position + RECORD_HEADER_SIZE <= data.len() {
1025 let magic = u32::from_be_bytes([
1026 data[position],
1027 data[position + 1],
1028 data[position + 2],
1029 data[position + 3],
1030 ]);
1031
1032 if magic != WAL_MAGIC {
1033 break;
1034 }
1035
1036 let data_len = u32::from_be_bytes([
1037 data[position + 8],
1038 data[position + 9],
1039 data[position + 10],
1040 data[position + 11],
1041 ]) as usize;
1042
1043 let record_size = RECORD_HEADER_SIZE + data_len;
1044 current_lsn += 1;
1045
1046 if current_lsn >= target_lsn {
1047 self.position = position as u64;
1048 return Ok(());
1049 }
1050
1051 position += record_size;
1052 }
1053
1054 Err(io::Error::new(
1055 io::ErrorKind::NotFound,
1056 format!("LSN {} not found", target_lsn),
1057 ))
1058 }
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063 use super::*;
1064 use tempfile::TempDir;
1065
1066 #[test]
1067 fn test_wal_record_serialization() {
1068 let data = Bytes::from("test data");
1069 let record = WalRecord::new(1, data.clone());
1070
1071 let serialized = record.to_bytes();
1072 assert!(serialized.len() >= RECORD_HEADER_SIZE + data.len());
1073
1074 let parsed = WalRecord::from_bytes(&serialized, 1).unwrap();
1075 assert_eq!(parsed.lsn, 1);
1076 assert_eq!(parsed.data, data);
1077 }
1078
1079 #[test]
1080 fn test_wal_record_crc() {
1081 let data = Bytes::from("test data");
1082 let record = WalRecord::new(1, data);
1083 let mut serialized = record.to_bytes().to_vec();
1084
1085 serialized[RECORD_HEADER_SIZE] ^= 0xFF;
1087
1088 assert!(WalRecord::from_bytes(&serialized, 1).is_err());
1090 }
1091
1092 #[tokio::test]
1093 async fn test_group_commit_wal_single_write() {
1094 let temp_dir = TempDir::new().unwrap();
1095 let config = WalConfig {
1096 dir: temp_dir.path().to_path_buf(),
1097 group_commit_window: Duration::from_micros(100),
1098 ..Default::default()
1099 };
1100
1101 let wal = GroupCommitWal::new(config).await.unwrap();
1102
1103 let result = wal.write(Bytes::from("test data")).await.unwrap();
1104 assert_eq!(result.lsn, 1);
1105 assert!(result.size > 0);
1106
1107 let stats = wal.stats();
1108 assert_eq!(stats.writes_total, 1);
1109
1110 wal.shutdown().await.unwrap();
1111 }
1112
1113 #[tokio::test]
1114 async fn test_group_commit_wal_batch() {
1115 let temp_dir = TempDir::new().unwrap();
1116 let config = WalConfig {
1117 dir: temp_dir.path().to_path_buf(),
1118 group_commit_window: Duration::from_millis(10),
1119 ..Default::default()
1120 };
1121
1122 let wal = GroupCommitWal::new(config).await.unwrap();
1123
1124 let records: Vec<Bytes> = (0..10)
1125 .map(|i| Bytes::from(format!("record {}", i)))
1126 .collect();
1127
1128 let results = wal.write_batch(records).await.unwrap();
1129
1130 assert_eq!(results.len(), 10);
1131 for (i, result) in results.iter().enumerate() {
1132 assert_eq!(result.lsn, (i + 1) as u64);
1133 }
1134
1135 let stats = wal.stats();
1136 assert_eq!(stats.writes_total, 10);
1137
1138 wal.shutdown().await.unwrap();
1139 }
1140
1141 #[tokio::test]
1142 async fn test_group_commit_batching() {
1143 let temp_dir = TempDir::new().unwrap();
1144 let config = WalConfig {
1145 dir: temp_dir.path().to_path_buf(),
1146 group_commit_window: Duration::from_millis(50),
1147 max_pending_writes: 100,
1148 ..Default::default()
1149 };
1150
1151 let wal = Arc::new(GroupCommitWal::new(config).await.unwrap());
1152
1153 let mut handles = vec![];
1155 for i in 0..20 {
1156 let wal_clone = wal.clone();
1157 handles.push(tokio::spawn(async move {
1158 wal_clone
1159 .write(Bytes::from(format!("concurrent write {}", i)))
1160 .await
1161 }));
1162 }
1163
1164 for handle in handles {
1166 let result = handle.await.unwrap().unwrap();
1167 assert!(result.lsn > 0);
1168 }
1169
1170 let stats = wal.stats();
1171 assert_eq!(stats.writes_total, 20);
1172 assert!(stats.group_commits <= stats.writes_total);
1174
1175 wal.shutdown().await.unwrap();
1176 }
1177
1178 #[tokio::test]
1179 async fn test_wal_reader() {
1180 let temp_dir = TempDir::new().unwrap();
1181 let config = WalConfig {
1182 dir: temp_dir.path().to_path_buf(),
1183 group_commit_window: Duration::from_micros(100),
1184 sync_mode: SyncMode::Fsync,
1185 max_pending_writes: 10,
1186 ..Default::default()
1187 };
1188
1189 let wal = GroupCommitWal::new(config.clone()).await.unwrap();
1190
1191 for i in 0..5 {
1193 let result = wal
1194 .write(Bytes::from(format!("record {}", i)))
1195 .await
1196 .unwrap();
1197 assert!(result.lsn > 0, "Expected valid LSN for record {}", i);
1198 }
1199
1200 wal.sync().await.unwrap();
1202
1203 tokio::time::sleep(Duration::from_millis(100)).await;
1205
1206 wal.shutdown().await.unwrap();
1207
1208 let entries: Vec<_> = std::fs::read_dir(&config.dir)
1210 .unwrap()
1211 .filter_map(|e| e.ok())
1212 .filter(|e| e.path().extension().is_some_and(|ext| ext == "wal"))
1213 .collect();
1214
1215 assert!(!entries.is_empty(), "No WAL files found");
1216
1217 let wal_file = entries[0].path();
1218 let file_size = std::fs::metadata(&wal_file).unwrap().len();
1219 assert!(file_size > 0, "WAL file is empty");
1220
1221 let mut reader = WalReader::open(wal_file).unwrap();
1222 let records = reader.read_all().await.unwrap();
1223
1224 assert_eq!(
1225 records.len(),
1226 5,
1227 "Expected 5 records, got {} (file size: {})",
1228 records.len(),
1229 file_size
1230 );
1231 for (i, record) in records.iter().enumerate() {
1232 let expected = format!("record {}", i);
1233 assert_eq!(record.data, Bytes::from(expected));
1234 }
1235 }
1236
1237 #[test]
1238 fn test_record_flags() {
1239 let flags = RecordFlags::COMPRESSED;
1240 assert!(flags.is_compressed());
1241 assert!(!flags.is_encrypted());
1242
1243 let flags = RecordFlags(RecordFlags::COMPRESSED.0 | RecordFlags::ENCRYPTED.0);
1244 assert!(flags.is_compressed());
1245 assert!(flags.is_encrypted());
1246 }
1247
1248 #[tokio::test]
1249 async fn test_wal_rotation() {
1250 let temp_dir = TempDir::new().unwrap();
1251 let config = WalConfig {
1252 dir: temp_dir.path().to_path_buf(),
1253 group_commit_window: Duration::from_micros(50),
1254 max_file_size: 200, preallocate_size: 0,
1256 ..Default::default()
1257 };
1258
1259 let wal = GroupCommitWal::new(config).await.unwrap();
1260
1261 for i in 0..10 {
1263 let data = format!("rotation-record-{:04}", i);
1264 let result = wal.write(Bytes::from(data)).await.unwrap();
1265 assert!(result.lsn > 0);
1266 }
1267
1268 wal.sync().await.unwrap();
1269 tokio::time::sleep(Duration::from_millis(100)).await;
1270 wal.shutdown().await.unwrap();
1271
1272 let wal_files: Vec<_> = std::fs::read_dir(temp_dir.path())
1274 .unwrap()
1275 .filter_map(|e| e.ok())
1276 .filter(|e| e.path().extension().is_some_and(|ext| ext == "wal"))
1277 .collect();
1278
1279 assert!(
1280 wal_files.len() > 1,
1281 "Expected multiple WAL files after rotation, got {}",
1282 wal_files.len()
1283 );
1284 }
1285
1286 #[test]
1287 fn test_wal_writer_rotate_if_needed() {
1288 let temp_dir = TempDir::new().unwrap();
1289 let config = WalConfig {
1290 dir: temp_dir.path().to_path_buf(),
1291 max_file_size: 100,
1292 preallocate_size: 0,
1293 ..Default::default()
1294 };
1295
1296 let path = temp_dir.path().join("00000000000000000000.wal");
1297 let mut writer = WalWriter::new(path.clone(), config).unwrap();
1298
1299 writer.write_batch(&[0u8; 150]).unwrap();
1301 assert_eq!(writer.position, 150);
1302
1303 let rotated = writer.rotate_if_needed(42).unwrap();
1305 assert!(rotated, "Expected rotation to occur");
1306 assert_eq!(writer.position, 0);
1307 assert_ne!(writer.path, path);
1308 assert!(writer
1309 .path
1310 .to_str()
1311 .unwrap()
1312 .contains("00000000000000000042"));
1313
1314 let rotated = writer.rotate_if_needed(43).unwrap();
1316 assert!(!rotated, "Expected no rotation when under max_file_size");
1317 }
1318}