1use bytes::{BufMut, Bytes, BytesMut};
16use crc32fast::Hasher;
17use std::collections::VecDeque;
18use std::fs::{File, OpenOptions};
19use std::io::{self, BufWriter, Write};
20use std::path::{Path, PathBuf};
21use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24use tokio::sync::{mpsc, oneshot, Mutex, Notify};
25
26const RECORD_HEADER_SIZE: usize = 14;
28
29const WAL_MAGIC: u32 = 0x57414C52; const DEFAULT_GROUP_COMMIT_WINDOW_US: u64 = 200;
34
35const DEFAULT_MAX_BATCH_SIZE: usize = 4 * 1024 * 1024; const DEFAULT_MAX_PENDING_WRITES: usize = 1000;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43#[repr(u8)]
44pub enum RecordType {
45 Full = 0,
47 First = 1,
49 Middle = 2,
51 Last = 3,
53 Checkpoint = 4,
55 TxnBegin = 5,
57 TxnCommit = 6,
59 TxnAbort = 7,
61}
62
63impl TryFrom<u8> for RecordType {
64 type Error = io::Error;
65
66 fn try_from(value: u8) -> Result<Self, Self::Error> {
67 match value {
68 0 => Ok(RecordType::Full),
69 1 => Ok(RecordType::First),
70 2 => Ok(RecordType::Middle),
71 3 => Ok(RecordType::Last),
72 4 => Ok(RecordType::Checkpoint),
73 5 => Ok(RecordType::TxnBegin),
74 6 => Ok(RecordType::TxnCommit),
75 7 => Ok(RecordType::TxnAbort),
76 _ => Err(io::Error::new(
77 io::ErrorKind::InvalidData,
78 "Invalid record type",
79 )),
80 }
81 }
82}
83
84#[derive(Debug, Clone, Copy)]
86pub struct RecordFlags(u8);
87
88impl RecordFlags {
89 pub const NONE: Self = Self(0);
90 pub const COMPRESSED: Self = Self(1 << 0);
91 pub const ENCRYPTED: Self = Self(1 << 1);
92 pub const HAS_CHECKSUM: Self = Self(1 << 2);
93
94 pub fn is_compressed(&self) -> bool {
95 self.0 & Self::COMPRESSED.0 != 0
96 }
97
98 pub fn is_encrypted(&self) -> bool {
99 self.0 & Self::ENCRYPTED.0 != 0
100 }
101
102 pub fn has_checksum(&self) -> bool {
103 self.0 & Self::HAS_CHECKSUM.0 != 0
104 }
105}
106
107#[derive(Debug, Clone)]
109pub struct WalRecord {
110 pub lsn: u64,
112 pub record_type: RecordType,
114 pub flags: RecordFlags,
116 pub data: Bytes,
118}
119
120impl WalRecord {
121 pub fn new(lsn: u64, data: Bytes) -> Self {
123 Self {
124 lsn,
125 record_type: RecordType::Full,
126 flags: RecordFlags::HAS_CHECKSUM,
127 data,
128 }
129 }
130
131 pub fn to_bytes(&self) -> Bytes {
133 let mut buf = BytesMut::with_capacity(RECORD_HEADER_SIZE + self.data.len());
134
135 let mut hasher = Hasher::new();
137 hasher.update(&self.data);
138 let crc = hasher.finalize();
139
140 buf.put_u32(WAL_MAGIC);
142 buf.put_u32(crc);
143 buf.put_u32(self.data.len() as u32);
144 buf.put_u8(self.record_type as u8);
145 buf.put_u8(self.flags.0);
146
147 buf.extend_from_slice(&self.data);
149
150 buf.freeze()
151 }
152
153 pub fn from_bytes(data: &[u8], lsn: u64) -> io::Result<Self> {
155 if data.len() < RECORD_HEADER_SIZE {
156 return Err(io::Error::new(
157 io::ErrorKind::InvalidData,
158 "Record too short",
159 ));
160 }
161
162 let magic = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
164 if magic != WAL_MAGIC {
165 return Err(io::Error::new(
166 io::ErrorKind::InvalidData,
167 "Invalid magic number",
168 ));
169 }
170
171 let stored_crc = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
172 let data_len = u32::from_be_bytes([data[8], data[9], data[10], data[11]]) as usize;
173 let record_type = RecordType::try_from(data[12])?;
174 let flags = RecordFlags(data[13]);
175
176 if data.len() < RECORD_HEADER_SIZE + data_len {
177 return Err(io::Error::new(
178 io::ErrorKind::InvalidData,
179 "Incomplete record",
180 ));
181 }
182
183 let record_data =
184 Bytes::copy_from_slice(&data[RECORD_HEADER_SIZE..RECORD_HEADER_SIZE + data_len]);
185
186 let mut hasher = Hasher::new();
188 hasher.update(&record_data);
189 let computed_crc = hasher.finalize();
190
191 if computed_crc != stored_crc {
192 return Err(io::Error::new(io::ErrorKind::InvalidData, "CRC mismatch"));
193 }
194
195 Ok(Self {
196 lsn,
197 record_type,
198 flags,
199 data: record_data,
200 })
201 }
202
203 pub fn serialized_size(&self) -> usize {
205 RECORD_HEADER_SIZE + self.data.len()
206 }
207}
208
209#[derive(Debug, Clone)]
211pub struct WalConfig {
212 pub dir: PathBuf,
214 pub group_commit_window: Duration,
216 pub max_batch_size: usize,
218 pub max_pending_writes: usize,
220 pub preallocate_size: u64,
222 pub direct_io: bool,
224 pub sync_mode: SyncMode,
226 pub max_file_size: u64,
228 #[cfg(feature = "encryption")]
230 pub encryptor: Option<std::sync::Arc<dyn crate::encryption::Encryptor>>,
231}
232
233impl Default for WalConfig {
234 fn default() -> Self {
235 Self {
236 dir: PathBuf::from("./wal"),
237 group_commit_window: Duration::from_micros(DEFAULT_GROUP_COMMIT_WINDOW_US),
238 max_batch_size: DEFAULT_MAX_BATCH_SIZE,
239 max_pending_writes: DEFAULT_MAX_PENDING_WRITES,
240 preallocate_size: 64 * 1024 * 1024, direct_io: false, sync_mode: SyncMode::Fsync,
243 max_file_size: 1024 * 1024 * 1024, #[cfg(feature = "encryption")]
245 encryptor: None,
246 }
247 }
248}
249
250impl WalConfig {
251 pub fn high_throughput() -> Self {
253 Self {
254 group_commit_window: Duration::from_micros(1000), max_batch_size: 16 * 1024 * 1024, max_pending_writes: 5000,
257 ..Default::default()
258 }
259 }
260
261 pub fn low_latency() -> Self {
263 Self {
264 group_commit_window: Duration::from_micros(50), max_batch_size: 512 * 1024, max_pending_writes: 100,
267 ..Default::default()
268 }
269 }
270
271 pub fn durable() -> Self {
273 Self {
274 sync_mode: SyncMode::FsyncData,
275 ..Default::default()
276 }
277 }
278}
279
280#[derive(Debug, Clone, Copy, PartialEq, Eq)]
282pub enum SyncMode {
283 None,
285 FsyncData,
287 Fsync,
289 Dsync,
291}
292
293struct WriteRequest {
295 data: Bytes,
297 record_type: RecordType,
299 completion: oneshot::Sender<WriteResult>,
301}
302
303#[derive(Debug, Clone)]
305pub struct WriteResult {
306 pub lsn: u64,
308 pub size: usize,
310 pub group_commit: bool,
312 pub group_size: usize,
314 pub wait_time: Duration,
316}
317
318pub struct GroupCommitWal {
320 config: WalConfig,
321 writer: Mutex<WalWriter>,
323 current_lsn: AtomicU64,
325 write_tx: mpsc::Sender<WriteRequest>,
327 shutdown: AtomicBool,
329 write_notify: Arc<Notify>,
331 stats: Arc<WalStats>,
333}
334
335impl GroupCommitWal {
336 pub async fn new(config: WalConfig) -> io::Result<Arc<Self>> {
338 std::fs::create_dir_all(&config.dir)?;
339
340 let (current_file, current_lsn) = Self::recover_state(&config).await?;
342
343 let writer = WalWriter::new(current_file, config.clone())?;
344 let (write_tx, write_rx) = mpsc::channel(config.max_pending_writes);
345
346 let wal = Arc::new(Self {
347 config,
348 writer: Mutex::new(writer),
349 current_lsn: AtomicU64::new(current_lsn),
350 write_tx,
351 shutdown: AtomicBool::new(false),
352 write_notify: Arc::new(Notify::new()),
353 stats: Arc::new(WalStats::new()),
354 });
355
356 wal.clone().start_group_commit_worker(write_rx);
358
359 Ok(wal)
360 }
361
362 async fn recover_state(config: &WalConfig) -> io::Result<(PathBuf, u64)> {
364 let mut max_lsn = 0u64;
365 let mut latest_file = None;
366
367 if let Ok(entries) = std::fs::read_dir(&config.dir) {
368 for entry in entries.flatten() {
369 let path = entry.path();
370 if path.extension().is_some_and(|e| e == "wal") {
371 if let Some(name) = path.file_stem() {
372 if let Ok(lsn) = name.to_string_lossy().parse::<u64>() {
373 if lsn >= max_lsn {
374 max_lsn = lsn;
375 latest_file = Some(path);
376 }
377 }
378 }
379 }
380 }
381 }
382
383 if let Some(ref file) = latest_file {
385 if let Ok(recovered_lsn) = Self::scan_wal_file(file).await {
386 max_lsn = recovered_lsn;
387 }
388 }
389
390 let file = latest_file.unwrap_or_else(|| config.dir.join(format!("{:020}.wal", 0)));
392
393 Ok((file, max_lsn))
394 }
395
396 async fn scan_wal_file(path: &Path) -> io::Result<u64> {
398 let data = tokio::fs::read(path).await?;
399 let mut offset = 0;
400 let mut max_lsn = 0u64;
401
402 while offset + RECORD_HEADER_SIZE <= data.len() {
403 let magic = u32::from_be_bytes([
405 data[offset],
406 data[offset + 1],
407 data[offset + 2],
408 data[offset + 3],
409 ]);
410
411 if magic != WAL_MAGIC {
412 break;
413 }
414
415 let data_len = u32::from_be_bytes([
416 data[offset + 8],
417 data[offset + 9],
418 data[offset + 10],
419 data[offset + 11],
420 ]) as usize;
421
422 let record_size = RECORD_HEADER_SIZE + data_len;
423 if offset + record_size > data.len() {
424 break;
425 }
426
427 max_lsn += 1;
428 offset += record_size;
429 }
430
431 Ok(max_lsn)
432 }
433
434 pub async fn write(&self, data: Bytes) -> io::Result<WriteResult> {
436 self.write_with_type(data, RecordType::Full).await
437 }
438
439 pub async fn write_with_type(
441 &self,
442 data: Bytes,
443 record_type: RecordType,
444 ) -> io::Result<WriteResult> {
445 if self.shutdown.load(Ordering::Acquire) {
446 return Err(io::Error::new(
447 io::ErrorKind::BrokenPipe,
448 "WAL is shut down",
449 ));
450 }
451
452 let (tx, rx) = oneshot::channel();
453
454 let request = WriteRequest {
455 data,
456 record_type,
457 completion: tx,
458 };
459
460 self.write_tx
461 .send(request)
462 .await
463 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "WAL write channel closed"))?;
464
465 self.write_notify.notify_one();
466
467 rx.await
468 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "WAL write cancelled"))
469 }
470
471 pub async fn write_batch(&self, records: Vec<Bytes>) -> io::Result<Vec<WriteResult>> {
473 let mut results = Vec::with_capacity(records.len());
474 let mut receivers = Vec::with_capacity(records.len());
475
476 for data in records {
477 let (tx, rx) = oneshot::channel();
478
479 let request = WriteRequest {
480 data,
481 record_type: RecordType::Full,
482 completion: tx,
483 };
484
485 self.write_tx.send(request).await.map_err(|_| {
486 io::Error::new(io::ErrorKind::BrokenPipe, "WAL write channel closed")
487 })?;
488
489 receivers.push(rx);
490 }
491
492 self.write_notify.notify_one();
493
494 for rx in receivers {
495 let result = rx
496 .await
497 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "WAL write cancelled"))?;
498 results.push(result);
499 }
500
501 Ok(results)
502 }
503
504 pub async fn sync(&self) -> io::Result<()> {
506 let mut writer = self.writer.lock().await;
507 writer.sync()
508 }
509
510 pub fn current_lsn(&self) -> u64 {
512 self.current_lsn.load(Ordering::Acquire)
513 }
514
515 pub fn stats(&self) -> WalStatsSnapshot {
517 WalStatsSnapshot {
518 writes_total: self.stats.writes_total.load(Ordering::Relaxed),
519 bytes_written: self.stats.bytes_written.load(Ordering::Relaxed),
520 syncs_total: self.stats.syncs_total.load(Ordering::Relaxed),
521 group_commits: self.stats.group_commits.load(Ordering::Relaxed),
522 avg_group_size: if self.stats.group_commits.load(Ordering::Relaxed) > 0 {
523 self.stats.writes_total.load(Ordering::Relaxed) as f64
524 / self.stats.group_commits.load(Ordering::Relaxed) as f64
525 } else {
526 0.0
527 },
528 current_lsn: self.current_lsn.load(Ordering::Relaxed),
529 }
530 }
531
532 pub async fn shutdown(&self) -> io::Result<()> {
534 self.shutdown.store(true, Ordering::Release);
535 self.write_notify.notify_waiters();
536
537 let mut writer = self.writer.lock().await;
539 writer.sync()
540 }
541
542 fn start_group_commit_worker(self: Arc<Self>, mut rx: mpsc::Receiver<WriteRequest>) {
544 let wal = self.clone();
545
546 tokio::spawn(async move {
547 let mut pending: VecDeque<WriteRequest> = VecDeque::new();
548 let mut batch_buffer = BytesMut::with_capacity(wal.config.max_batch_size);
549 let mut group_start: Option<Instant> = None;
550
551 loop {
552 if wal.shutdown.load(Ordering::Acquire) {
554 while let Ok(request) = rx.try_recv() {
556 pending.push_back(request);
557 }
558 if !pending.is_empty() {
560 wal.flush_batch(&mut pending, &mut batch_buffer, group_start.take())
561 .await;
562 }
563 break;
564 }
565
566 let timeout = if pending.is_empty() {
568 Duration::from_secs(60) } else {
570 wal.config.group_commit_window
571 };
572
573 tokio::select! {
574 biased;
575
576 Some(request) = rx.recv() => {
578 if group_start.is_none() {
579 group_start = Some(Instant::now());
580 }
581 pending.push_back(request);
582
583 let should_flush =
585 pending.len() >= wal.config.max_pending_writes ||
586 batch_buffer.len() >= wal.config.max_batch_size;
587
588 if should_flush {
589 wal.flush_batch(&mut pending, &mut batch_buffer, group_start.take()).await;
590 }
591 }
592
593 _ = wal.write_notify.notified() => {
595 }
597
598 _ = tokio::time::sleep(timeout) => {
600 if !pending.is_empty() {
601 wal.flush_batch(&mut pending, &mut batch_buffer, group_start.take()).await;
602 }
603 }
604 }
605 }
606 });
607 }
608
609 async fn flush_batch(
611 &self,
612 pending: &mut VecDeque<WriteRequest>,
613 batch_buffer: &mut BytesMut,
614 group_start: Option<Instant>,
615 ) {
616 if pending.is_empty() {
617 return;
618 }
619
620 let wait_time = group_start.map(|s| s.elapsed()).unwrap_or(Duration::ZERO);
621 let group_size = pending.len();
622
623 batch_buffer.clear();
625 let mut lsns = Vec::with_capacity(group_size);
626 let mut sizes = Vec::with_capacity(group_size);
627
628 for request in pending.iter() {
629 let lsn = self.current_lsn.fetch_add(1, Ordering::AcqRel) + 1;
630 lsns.push(lsn);
631
632 #[cfg(feature = "encryption")]
634 let (data, is_encrypted) = if let Some(ref encryptor) = self.config.encryptor {
635 if encryptor.is_enabled() {
636 match encryptor.encrypt(&request.data, lsn) {
637 Ok(encrypted) => (Bytes::from(encrypted), true),
638 Err(e) => {
639 tracing::error!("Encryption failed for LSN {}: {:?}", lsn, e);
640 (request.data.clone(), false)
641 }
642 }
643 } else {
644 (request.data.clone(), false)
645 }
646 } else {
647 (request.data.clone(), false)
648 };
649
650 #[cfg(not(feature = "encryption"))]
651 let (data, is_encrypted) = (request.data.clone(), false);
652
653 let flags = if is_encrypted {
654 RecordFlags(RecordFlags::HAS_CHECKSUM.0 | RecordFlags::ENCRYPTED.0)
655 } else {
656 RecordFlags::HAS_CHECKSUM
657 };
658
659 let record = WalRecord {
660 lsn,
661 record_type: request.record_type,
662 flags,
663 data,
664 };
665
666 let record_bytes = record.to_bytes();
667 sizes.push(record_bytes.len());
668 batch_buffer.extend_from_slice(&record_bytes);
669 }
670
671 let write_result = {
673 let mut writer = self.writer.lock().await;
674 writer.write_batch(batch_buffer)
675 };
676
677 self.stats
679 .writes_total
680 .fetch_add(group_size as u64, Ordering::Relaxed);
681 self.stats
682 .bytes_written
683 .fetch_add(batch_buffer.len() as u64, Ordering::Relaxed);
684 self.stats.group_commits.fetch_add(1, Ordering::Relaxed);
685 self.stats.syncs_total.fetch_add(1, Ordering::Relaxed);
686
687 let group_commit = group_size > 1;
689
690 for (i, request) in pending.drain(..).enumerate() {
691 let result = match &write_result {
692 Ok(()) => WriteResult {
693 lsn: lsns[i],
694 size: sizes[i],
695 group_commit,
696 group_size,
697 wait_time,
698 },
699 Err(_) => {
700 WriteResult {
702 lsn: 0,
703 size: 0,
704 group_commit: false,
705 group_size: 0,
706 wait_time,
707 }
708 }
709 };
710
711 let _ = request.completion.send(result);
712 }
713 }
714}
715
716struct WalWriter {
718 file: BufWriter<File>,
719 path: PathBuf,
720 position: u64,
721 config: WalConfig,
722}
723
724impl WalWriter {
725 fn new(path: PathBuf, config: WalConfig) -> io::Result<Self> {
726 use std::io::{Seek, SeekFrom};
727
728 let existing_len = std::fs::metadata(&path).map(|m| m.len()).unwrap_or(0);
730
731 let file = OpenOptions::new()
732 .create(true)
733 .read(true)
734 .write(true)
735 .truncate(false) .open(&path)?;
737
738 let actual_position = if existing_len > 0 {
740 Self::find_actual_end(&file, existing_len)?
741 } else {
742 0
743 };
744
745 if actual_position == 0 && config.preallocate_size > 0 {
747 file.set_len(config.preallocate_size)?;
748 }
749
750 let mut writer = BufWriter::with_capacity(config.max_batch_size, file);
752 writer.seek(SeekFrom::Start(actual_position))?;
753
754 Ok(Self {
755 file: writer,
756 path,
757 position: actual_position,
758 config,
759 })
760 }
761
762 fn find_actual_end(file: &File, file_len: u64) -> io::Result<u64> {
764 use std::io::Read;
765
766 let mut position = 0u64;
767 let mut file = file.try_clone()?;
768
769 while position + RECORD_HEADER_SIZE as u64 <= file_len {
771 let mut header = [0u8; RECORD_HEADER_SIZE];
772
773 use std::io::{Seek, SeekFrom};
774 file.seek(SeekFrom::Start(position))?;
775
776 if file.read_exact(&mut header).is_err() {
777 break;
778 }
779
780 let magic = u32::from_be_bytes([header[0], header[1], header[2], header[3]]);
782 if magic != WAL_MAGIC {
783 break;
784 }
785
786 let data_len =
787 u32::from_be_bytes([header[8], header[9], header[10], header[11]]) as u64;
788 let record_size = RECORD_HEADER_SIZE as u64 + data_len;
789
790 if position + record_size > file_len {
791 break;
792 }
793
794 position += record_size;
795 }
796
797 Ok(position)
798 }
799
800 fn write_batch(&mut self, data: &[u8]) -> io::Result<()> {
801 self.file.write_all(data)?;
802 self.file.flush()?;
803
804 match self.config.sync_mode {
806 SyncMode::None => {}
807 SyncMode::FsyncData => {
808 self.file.get_ref().sync_data()?;
809 }
810 SyncMode::Fsync | SyncMode::Dsync => {
811 self.file.get_ref().sync_all()?;
812 }
813 }
814
815 self.position += data.len() as u64;
816 Ok(())
817 }
818
819 fn sync(&mut self) -> io::Result<()> {
820 self.file.flush()?;
821 self.file.get_ref().sync_all()
822 }
823
824 #[allow(dead_code)]
826 fn path(&self) -> &std::path::Path {
827 &self.path
828 }
829}
830
831struct WalStats {
833 writes_total: AtomicU64,
834 bytes_written: AtomicU64,
835 syncs_total: AtomicU64,
836 group_commits: AtomicU64,
837}
838
839impl WalStats {
840 fn new() -> Self {
841 Self {
842 writes_total: AtomicU64::new(0),
843 bytes_written: AtomicU64::new(0),
844 syncs_total: AtomicU64::new(0),
845 group_commits: AtomicU64::new(0),
846 }
847 }
848}
849
850#[derive(Debug, Clone)]
851pub struct WalStatsSnapshot {
852 pub writes_total: u64,
853 pub bytes_written: u64,
854 pub syncs_total: u64,
855 pub group_commits: u64,
856 pub avg_group_size: f64,
857 pub current_lsn: u64,
858}
859
860pub struct WalReader {
862 path: PathBuf,
863 position: u64,
864 #[cfg(feature = "encryption")]
865 encryptor: Option<std::sync::Arc<dyn crate::encryption::Encryptor>>,
866}
867
868impl WalReader {
869 pub fn open(path: PathBuf) -> io::Result<Self> {
871 Ok(Self {
872 path,
873 position: 0,
874 #[cfg(feature = "encryption")]
875 encryptor: None,
876 })
877 }
878
879 #[cfg(feature = "encryption")]
881 pub fn open_with_encryption(
882 path: PathBuf,
883 encryptor: Option<std::sync::Arc<dyn crate::encryption::Encryptor>>,
884 ) -> io::Result<Self> {
885 Ok(Self {
886 path,
887 position: 0,
888 encryptor,
889 })
890 }
891
892 #[cfg(feature = "encryption")]
894 fn decrypt_record_data(&self, record: &mut WalRecord) -> io::Result<()> {
895 if record.flags.is_encrypted() {
896 if let Some(ref encryptor) = self.encryptor {
897 let decrypted = encryptor
898 .decrypt(&record.data, record.lsn)
899 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
900 record.data = Bytes::from(decrypted);
901 } else {
902 return Err(io::Error::new(
903 io::ErrorKind::InvalidData,
904 "Record is encrypted but no encryptor provided",
905 ));
906 }
907 }
908 Ok(())
909 }
910
911 pub async fn read_all(&mut self) -> io::Result<Vec<WalRecord>> {
913 let data = tokio::fs::read(&self.path).await?;
914 let mut records = Vec::new();
915 let mut lsn = 0u64;
916
917 while self.position + RECORD_HEADER_SIZE as u64 <= data.len() as u64 {
918 let offset = self.position as usize;
919
920 let magic = u32::from_be_bytes([
922 data[offset],
923 data[offset + 1],
924 data[offset + 2],
925 data[offset + 3],
926 ]);
927
928 if magic != WAL_MAGIC {
929 break;
930 }
931
932 let data_len = u32::from_be_bytes([
933 data[offset + 8],
934 data[offset + 9],
935 data[offset + 10],
936 data[offset + 11],
937 ]) as usize;
938
939 let record_size = RECORD_HEADER_SIZE + data_len;
940
941 if offset + record_size > data.len() {
942 break;
943 }
944
945 lsn += 1;
946
947 match WalRecord::from_bytes(&data[offset..offset + record_size], lsn) {
948 #[cfg(feature = "encryption")]
949 Ok(mut record) => {
950 self.decrypt_record_data(&mut record)?;
952
953 records.push(record);
954 self.position += record_size as u64;
955 }
956 #[cfg(not(feature = "encryption"))]
957 Ok(record) => {
958 records.push(record);
959 self.position += record_size as u64;
960 }
961 Err(_) => break,
962 }
963 }
964
965 Ok(records)
966 }
967
968 pub async fn seek_to_lsn(&mut self, target_lsn: u64) -> io::Result<()> {
970 let data = tokio::fs::read(&self.path).await?;
971 let mut position = 0usize;
972 let mut current_lsn = 0u64;
973
974 while position + RECORD_HEADER_SIZE <= data.len() {
975 let magic = u32::from_be_bytes([
976 data[position],
977 data[position + 1],
978 data[position + 2],
979 data[position + 3],
980 ]);
981
982 if magic != WAL_MAGIC {
983 break;
984 }
985
986 let data_len = u32::from_be_bytes([
987 data[position + 8],
988 data[position + 9],
989 data[position + 10],
990 data[position + 11],
991 ]) as usize;
992
993 let record_size = RECORD_HEADER_SIZE + data_len;
994 current_lsn += 1;
995
996 if current_lsn >= target_lsn {
997 self.position = position as u64;
998 return Ok(());
999 }
1000
1001 position += record_size;
1002 }
1003
1004 Err(io::Error::new(
1005 io::ErrorKind::NotFound,
1006 format!("LSN {} not found", target_lsn),
1007 ))
1008 }
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013 use super::*;
1014 use tempfile::TempDir;
1015
1016 #[test]
1017 fn test_wal_record_serialization() {
1018 let data = Bytes::from("test data");
1019 let record = WalRecord::new(1, data.clone());
1020
1021 let serialized = record.to_bytes();
1022 assert!(serialized.len() >= RECORD_HEADER_SIZE + data.len());
1023
1024 let parsed = WalRecord::from_bytes(&serialized, 1).unwrap();
1025 assert_eq!(parsed.lsn, 1);
1026 assert_eq!(parsed.data, data);
1027 }
1028
1029 #[test]
1030 fn test_wal_record_crc() {
1031 let data = Bytes::from("test data");
1032 let record = WalRecord::new(1, data);
1033 let mut serialized = record.to_bytes().to_vec();
1034
1035 serialized[RECORD_HEADER_SIZE] ^= 0xFF;
1037
1038 assert!(WalRecord::from_bytes(&serialized, 1).is_err());
1040 }
1041
1042 #[tokio::test]
1043 async fn test_group_commit_wal_single_write() {
1044 let temp_dir = TempDir::new().unwrap();
1045 let config = WalConfig {
1046 dir: temp_dir.path().to_path_buf(),
1047 group_commit_window: Duration::from_micros(100),
1048 ..Default::default()
1049 };
1050
1051 let wal = GroupCommitWal::new(config).await.unwrap();
1052
1053 let result = wal.write(Bytes::from("test data")).await.unwrap();
1054 assert_eq!(result.lsn, 1);
1055 assert!(result.size > 0);
1056
1057 let stats = wal.stats();
1058 assert_eq!(stats.writes_total, 1);
1059
1060 wal.shutdown().await.unwrap();
1061 }
1062
1063 #[tokio::test]
1064 async fn test_group_commit_wal_batch() {
1065 let temp_dir = TempDir::new().unwrap();
1066 let config = WalConfig {
1067 dir: temp_dir.path().to_path_buf(),
1068 group_commit_window: Duration::from_millis(10),
1069 ..Default::default()
1070 };
1071
1072 let wal = GroupCommitWal::new(config).await.unwrap();
1073
1074 let records: Vec<Bytes> = (0..10)
1075 .map(|i| Bytes::from(format!("record {}", i)))
1076 .collect();
1077
1078 let results = wal.write_batch(records).await.unwrap();
1079
1080 assert_eq!(results.len(), 10);
1081 for (i, result) in results.iter().enumerate() {
1082 assert_eq!(result.lsn, (i + 1) as u64);
1083 }
1084
1085 let stats = wal.stats();
1086 assert_eq!(stats.writes_total, 10);
1087
1088 wal.shutdown().await.unwrap();
1089 }
1090
1091 #[tokio::test]
1092 async fn test_group_commit_batching() {
1093 let temp_dir = TempDir::new().unwrap();
1094 let config = WalConfig {
1095 dir: temp_dir.path().to_path_buf(),
1096 group_commit_window: Duration::from_millis(50),
1097 max_pending_writes: 100,
1098 ..Default::default()
1099 };
1100
1101 let wal = Arc::new(GroupCommitWal::new(config).await.unwrap());
1102
1103 let mut handles = vec![];
1105 for i in 0..20 {
1106 let wal_clone = wal.clone();
1107 handles.push(tokio::spawn(async move {
1108 wal_clone
1109 .write(Bytes::from(format!("concurrent write {}", i)))
1110 .await
1111 }));
1112 }
1113
1114 for handle in handles {
1116 let result = handle.await.unwrap().unwrap();
1117 assert!(result.lsn > 0);
1118 }
1119
1120 let stats = wal.stats();
1121 assert_eq!(stats.writes_total, 20);
1122 assert!(stats.group_commits <= stats.writes_total);
1124
1125 wal.shutdown().await.unwrap();
1126 }
1127
1128 #[tokio::test]
1129 async fn test_wal_reader() {
1130 let temp_dir = TempDir::new().unwrap();
1131 let config = WalConfig {
1132 dir: temp_dir.path().to_path_buf(),
1133 group_commit_window: Duration::from_micros(100),
1134 sync_mode: SyncMode::Fsync,
1135 max_pending_writes: 10,
1136 ..Default::default()
1137 };
1138
1139 let wal = GroupCommitWal::new(config.clone()).await.unwrap();
1140
1141 for i in 0..5 {
1143 let result = wal
1144 .write(Bytes::from(format!("record {}", i)))
1145 .await
1146 .unwrap();
1147 assert!(result.lsn > 0, "Expected valid LSN for record {}", i);
1148 }
1149
1150 wal.sync().await.unwrap();
1152
1153 tokio::time::sleep(Duration::from_millis(100)).await;
1155
1156 wal.shutdown().await.unwrap();
1157
1158 let entries: Vec<_> = std::fs::read_dir(&config.dir)
1160 .unwrap()
1161 .filter_map(|e| e.ok())
1162 .filter(|e| e.path().extension().is_some_and(|ext| ext == "wal"))
1163 .collect();
1164
1165 assert!(!entries.is_empty(), "No WAL files found");
1166
1167 let wal_file = entries[0].path();
1168 let file_size = std::fs::metadata(&wal_file).unwrap().len();
1169 assert!(file_size > 0, "WAL file is empty");
1170
1171 let mut reader = WalReader::open(wal_file).unwrap();
1172 let records = reader.read_all().await.unwrap();
1173
1174 assert_eq!(
1175 records.len(),
1176 5,
1177 "Expected 5 records, got {} (file size: {})",
1178 records.len(),
1179 file_size
1180 );
1181 for (i, record) in records.iter().enumerate() {
1182 let expected = format!("record {}", i);
1183 assert_eq!(record.data, Bytes::from(expected));
1184 }
1185 }
1186
1187 #[test]
1188 fn test_record_flags() {
1189 let flags = RecordFlags::COMPRESSED;
1190 assert!(flags.is_compressed());
1191 assert!(!flags.is_encrypted());
1192
1193 let flags = RecordFlags(RecordFlags::COMPRESSED.0 | RecordFlags::ENCRYPTED.0);
1194 assert!(flags.is_compressed());
1195 assert!(flags.is_encrypted());
1196 }
1197}