1use crate::error::{AmateRSError, ErrorContext, Result};
7use crate::types::{CipherBlob, Key};
8use std::fs::{File, OpenOptions};
9use std::io::{BufReader, BufWriter, Read, Write};
10use std::path::{Path, PathBuf};
11
12#[derive(Debug, Clone, Default)]
14pub struct RecoveryStats {
15 pub entries_recovered: u64,
17 pub entries_corrupted: u64,
19 pub bytes_recovered: u64,
21}
22
23#[derive(Debug, Clone, PartialEq)]
25pub enum WalEntryType {
26 Put = 1,
27 Delete = 2,
28}
29
30#[derive(Debug, Clone, PartialEq)]
32pub struct WalEntry {
33 pub sequence: u64,
35 pub entry_type: WalEntryType,
37 pub key: Key,
39 pub value: Option<CipherBlob>,
41 pub checksum: u32,
43}
44
45impl WalEntry {
46 pub fn put(sequence: u64, key: Key, value: CipherBlob) -> Self {
48 let mut entry = Self {
49 sequence,
50 entry_type: WalEntryType::Put,
51 key,
52 value: Some(value),
53 checksum: 0,
54 };
55 entry.checksum = entry.calculate_checksum();
56 entry
57 }
58
59 pub fn delete(sequence: u64, key: Key) -> Self {
61 let mut entry = Self {
62 sequence,
63 entry_type: WalEntryType::Delete,
64 key,
65 value: None,
66 checksum: 0,
67 };
68 entry.checksum = entry.calculate_checksum();
69 entry
70 }
71
72 fn calculate_checksum(&self) -> u32 {
74 let mut hasher = crc32fast::Hasher::new();
75
76 hasher.update(&self.sequence.to_le_bytes());
78
79 hasher.update(&[self.entry_type.clone() as u8]);
81
82 hasher.update(self.key.as_bytes());
84
85 if let Some(ref value) = self.value {
87 hasher.update(value.as_bytes());
88 }
89
90 hasher.finalize()
91 }
92
93 pub fn verify_checksum(&self) -> Result<()> {
95 let calculated = self.calculate_checksum();
96 if calculated == self.checksum {
97 Ok(())
98 } else {
99 Err(AmateRSError::StorageIntegrity(ErrorContext::new(format!(
100 "WAL entry checksum mismatch: expected {}, got {}",
101 self.checksum, calculated
102 ))))
103 }
104 }
105
106 pub fn encode(&self) -> Vec<u8> {
108 let mut bytes = Vec::new();
109
110 bytes.extend_from_slice(&0x57414Cu32.to_le_bytes());
112
113 bytes.extend_from_slice(&self.sequence.to_le_bytes());
115
116 bytes.push(self.entry_type.clone() as u8);
118
119 bytes.extend_from_slice(&(self.key.len() as u32).to_le_bytes());
121 bytes.extend_from_slice(self.key.as_bytes());
122
123 if let Some(ref value) = self.value {
125 bytes.extend_from_slice(&(value.len() as u32).to_le_bytes());
126 bytes.extend_from_slice(value.as_bytes());
127 } else {
128 bytes.extend_from_slice(&0u32.to_le_bytes());
129 }
130
131 bytes.extend_from_slice(&self.checksum.to_le_bytes());
133
134 bytes
135 }
136
137 pub fn decode(bytes: &[u8]) -> Result<Self> {
139 if bytes.len() < 17 {
140 return Err(AmateRSError::SerializationError(ErrorContext::new(
142 "WAL entry too short",
143 )));
144 }
145
146 let mut offset = 0;
147
148 let magic = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
150 if magic != 0x57414C {
151 return Err(AmateRSError::SerializationError(ErrorContext::new(
152 "Invalid WAL entry magic number",
153 )));
154 }
155 offset += 4;
156
157 let sequence = u64::from_le_bytes(bytes[offset..offset + 8].try_into().map_err(|_| {
159 AmateRSError::SerializationError(ErrorContext::new("Failed to read sequence"))
160 })?);
161 offset += 8;
162
163 let entry_type = match bytes[offset] {
165 1 => WalEntryType::Put,
166 2 => WalEntryType::Delete,
167 _ => {
168 return Err(AmateRSError::SerializationError(ErrorContext::new(
169 "Invalid WAL entry type",
170 )));
171 }
172 };
173 offset += 1;
174
175 let key_len = u32::from_le_bytes(bytes[offset..offset + 4].try_into().map_err(|_| {
177 AmateRSError::SerializationError(ErrorContext::new("Failed to read key length"))
178 })?) as usize;
179 offset += 4;
180
181 let key_bytes = &bytes[offset..offset + key_len];
182 let key = Key::from_slice(key_bytes);
183 offset += key_len;
184
185 let value_len = u32::from_le_bytes(bytes[offset..offset + 4].try_into().map_err(|_| {
187 AmateRSError::SerializationError(ErrorContext::new("Failed to read value length"))
188 })?) as usize;
189 offset += 4;
190
191 let value = if value_len > 0 {
192 let value_bytes = &bytes[offset..offset + value_len];
193 Some(CipherBlob::new(value_bytes.to_vec()))
194 } else {
195 None
196 };
197 offset += value_len;
198
199 let checksum = u32::from_le_bytes(bytes[offset..offset + 4].try_into().map_err(|_| {
201 AmateRSError::SerializationError(ErrorContext::new("Failed to read checksum"))
202 })?);
203
204 let entry = Self {
205 sequence,
206 entry_type,
207 key,
208 value,
209 checksum,
210 };
211
212 entry.verify_checksum()?;
214
215 Ok(entry)
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct WalConfig {
222 pub wal_dir: PathBuf,
224 pub max_file_size: u64,
226 pub max_wal_files: usize,
228 pub sync_on_write: bool,
230}
231
232impl Default for WalConfig {
233 fn default() -> Self {
234 Self {
235 wal_dir: PathBuf::from("./wal"),
236 max_file_size: 64 * 1024 * 1024, max_wal_files: 10,
238 sync_on_write: true,
239 }
240 }
241}
242
243pub struct Wal {
245 config: WalConfig,
247 current_path: PathBuf,
249 writer: BufWriter<File>,
251 sequence: u64,
253 current_file_size: u64,
255 current_file_number: u64,
257}
258
259impl Wal {
260 pub fn create(path: impl AsRef<Path>) -> Result<Self> {
262 let path = path.as_ref().to_path_buf();
263 let parent = path.parent().ok_or_else(|| {
264 AmateRSError::IoError(ErrorContext::new("WAL path has no parent directory"))
265 })?;
266
267 let config = WalConfig {
268 wal_dir: parent.to_path_buf(),
269 ..Default::default()
270 };
271
272 Self::with_config(config)
273 }
274
275 pub fn with_config(config: WalConfig) -> Result<Self> {
277 std::fs::create_dir_all(&config.wal_dir).map_err(|e| {
279 AmateRSError::IoError(ErrorContext::new(format!(
280 "Failed to create WAL directory: {}",
281 e
282 )))
283 })?;
284
285 let (file_number, sequence) = Self::find_latest_wal(&config)?;
287
288 let current_path = Self::wal_file_path(&config.wal_dir, file_number);
289
290 let file = OpenOptions::new()
291 .create(true)
292 .append(true)
293 .open(¤t_path)
294 .map_err(|e| {
295 AmateRSError::IoError(ErrorContext::new(format!("Failed to open WAL: {}", e)))
296 })?;
297
298 let current_file_size = file
299 .metadata()
300 .map_err(|e| {
301 AmateRSError::IoError(ErrorContext::new(format!(
302 "Failed to get WAL file size: {}",
303 e
304 )))
305 })?
306 .len();
307
308 Ok(Self {
309 config,
310 current_path,
311 writer: BufWriter::new(file),
312 sequence,
313 current_file_size,
314 current_file_number: file_number,
315 })
316 }
317
318 fn find_latest_wal(config: &WalConfig) -> Result<(u64, u64)> {
320 let mut max_file_number = 0u64;
321 let mut max_sequence = 0u64;
322
323 if config.wal_dir.exists() {
324 let wal_file_numbers = Self::list_wal_file_numbers(&config.wal_dir)?;
325
326 if let Some(&last) = wal_file_numbers.last() {
327 max_file_number = last;
328 }
329
330 for file_num in &wal_file_numbers {
332 let file_path = Self::wal_file_path(&config.wal_dir, *file_num);
333 if let Ok(mut reader) = WalReader::open(&file_path) {
334 loop {
335 match reader.read_entry() {
336 Ok(Some(entry)) => {
337 if entry.sequence >= max_sequence {
338 max_sequence = entry.sequence + 1;
339 }
340 }
341 Ok(None) => break,
342 Err(_) => {
343 tracing::warn!(
344 "Corrupted entry found in WAL file {} during startup",
345 file_path.display()
346 );
347 continue;
348 }
349 }
350 }
351 }
352 }
353 }
354
355 Ok((max_file_number, max_sequence))
356 }
357
358 fn wal_file_path(wal_dir: &Path, file_number: u64) -> PathBuf {
360 wal_dir.join(format!("wal_{:08}.log", file_number))
361 }
362
363 fn list_wal_file_numbers(wal_dir: &Path) -> Result<Vec<u64>> {
365 let entries = std::fs::read_dir(wal_dir).map_err(|e| {
366 AmateRSError::IoError(ErrorContext::new(format!(
367 "Failed to read WAL directory: {}",
368 e
369 )))
370 })?;
371
372 let mut numbers = Vec::new();
373 for entry in entries {
374 let entry = entry.map_err(|e| {
375 AmateRSError::IoError(ErrorContext::new(format!(
376 "Failed to read directory entry: {}",
377 e
378 )))
379 })?;
380 let file_name = entry.file_name();
381 let name = file_name.to_string_lossy();
382 if name.starts_with("wal_") && name.ends_with(".log") {
383 if let Ok(number) = name[4..name.len() - 4].parse::<u64>() {
384 numbers.push(number);
385 }
386 }
387 }
388 numbers.sort_unstable();
389 Ok(numbers)
390 }
391
392 pub fn put(&mut self, key: Key, value: CipherBlob) -> Result<u64> {
394 let sequence = self.sequence;
395 self.sequence += 1;
396
397 let entry = WalEntry::put(sequence, key, value);
398 self.write_entry(&entry)?;
399
400 Ok(sequence)
401 }
402
403 pub fn delete(&mut self, key: Key) -> Result<u64> {
405 let sequence = self.sequence;
406 self.sequence += 1;
407
408 let entry = WalEntry::delete(sequence, key);
409 self.write_entry(&entry)?;
410
411 Ok(sequence)
412 }
413
414 fn write_entry(&mut self, entry: &WalEntry) -> Result<()> {
416 let bytes = entry.encode();
417
418 let len = bytes.len() as u32;
420 self.writer.write_all(&len.to_le_bytes()).map_err(|e| {
421 AmateRSError::IoError(ErrorContext::new(format!(
422 "Failed to write WAL entry: {}",
423 e
424 )))
425 })?;
426
427 self.writer.write_all(&bytes).map_err(|e| {
429 AmateRSError::IoError(ErrorContext::new(format!(
430 "Failed to write WAL entry: {}",
431 e
432 )))
433 })?;
434
435 let entry_size = (4 + bytes.len()) as u64; self.current_file_size += entry_size;
438
439 if self.config.sync_on_write {
441 self.writer.flush().map_err(|e| {
442 AmateRSError::IoError(ErrorContext::new(format!("Failed to flush WAL: {}", e)))
443 })?;
444 }
445
446 if self.current_file_size >= self.config.max_file_size {
448 self.rotate()?;
449 }
450
451 Ok(())
452 }
453
454 pub fn rotate(&mut self) -> Result<()> {
456 self.flush()?;
458
459 self.current_file_number += 1;
461
462 let new_path = Self::wal_file_path(&self.config.wal_dir, self.current_file_number);
464
465 let file = OpenOptions::new()
466 .create(true)
467 .append(true)
468 .open(&new_path)
469 .map_err(|e| {
470 AmateRSError::IoError(ErrorContext::new(format!(
471 "Failed to create new WAL file: {}",
472 e
473 )))
474 })?;
475
476 self.current_path = new_path;
477 self.writer = BufWriter::new(file);
478 self.current_file_size = 0;
479
480 self.cleanup_old_wal_files()?;
482
483 Ok(())
484 }
485
486 fn cleanup_old_wal_files(&self) -> Result<()> {
488 let wal_files = Self::list_wal_file_numbers(&self.config.wal_dir)?;
489
490 if wal_files.len() > self.config.max_wal_files {
491 let files_to_delete = wal_files.len() - self.config.max_wal_files;
492
493 for &file_number in wal_files.iter().take(files_to_delete) {
494 let file_path = Self::wal_file_path(&self.config.wal_dir, file_number);
495 std::fs::remove_file(&file_path).map_err(|e| {
496 AmateRSError::IoError(ErrorContext::new(format!(
497 "Failed to delete old WAL file: {}",
498 e
499 )))
500 })?;
501 }
502 }
503
504 Ok(())
505 }
506
507 pub fn cleanup(&self) -> Result<()> {
509 self.cleanup_old_wal_files()
510 }
511
512 pub fn current_file_size(&self) -> u64 {
514 self.current_file_size
515 }
516
517 pub fn current_file_number(&self) -> u64 {
519 self.current_file_number
520 }
521
522 pub fn flush(&mut self) -> Result<()> {
524 self.writer.flush().map_err(|e| {
525 AmateRSError::IoError(ErrorContext::new(format!("Failed to flush WAL: {}", e)))
526 })?;
527
528 self.writer.get_ref().sync_all().map_err(|e| {
529 AmateRSError::IoError(ErrorContext::new(format!("Failed to sync WAL: {}", e)))
530 })?;
531
532 Ok(())
533 }
534
535 pub fn sequence(&self) -> u64 {
537 self.sequence
538 }
539
540 pub fn path(&self) -> &Path {
542 &self.current_path
543 }
544
545 pub fn recover(wal_dir: impl AsRef<Path>) -> Result<(Vec<WalEntry>, u64)> {
553 let wal_dir = wal_dir.as_ref();
554
555 if !wal_dir.exists() {
556 return Ok((Vec::new(), 0));
557 }
558
559 let wal_files = Self::list_wal_file_numbers(wal_dir)?;
560
561 let mut all_entries = Vec::new();
562 let mut max_sequence = 0u64;
563
564 for file_number in wal_files {
565 let file_path = Self::wal_file_path(wal_dir, file_number);
566 let mut reader = WalReader::open(&file_path)?;
567
568 loop {
569 match reader.read_entry() {
570 Ok(Some(entry)) => {
571 if entry.sequence > max_sequence {
572 max_sequence = entry.sequence;
573 }
574 all_entries.push(entry);
575 }
576 Ok(None) => break,
577 Err(e) => {
578 tracing::warn!(
579 "Skipping corrupted entry in {}: {}",
580 file_path.display(),
581 e
582 );
583 continue;
584 }
585 }
586 }
587 }
588
589 Ok((all_entries, max_sequence))
590 }
591
592 pub fn current_size(&self) -> u64 {
594 self.current_file_size
595 }
596
597 pub fn total_wal_size(&self) -> Result<u64> {
599 let wal_files = Self::list_wal_file_numbers(&self.config.wal_dir)?;
600 let mut total_size = 0u64;
601
602 for file_number in wal_files {
603 let file_path = Self::wal_file_path(&self.config.wal_dir, file_number);
604 let metadata = std::fs::metadata(&file_path).map_err(|e| {
605 AmateRSError::IoError(ErrorContext::new(format!(
606 "Failed to read WAL file metadata: {}",
607 e
608 )))
609 })?;
610 total_size += metadata.len();
611 }
612
613 Ok(total_size)
614 }
615
616 pub fn truncate_before(&mut self, sequence: u64) -> Result<u64> {
623 self.flush()?;
624
625 let all_files = Self::list_wal_file_numbers(&self.config.wal_dir)?;
626 let wal_files: Vec<u64> = all_files
628 .into_iter()
629 .filter(|&n| n != self.current_file_number)
630 .collect();
631
632 let mut files_truncated = 0u64;
633
634 for file_number in wal_files {
635 let file_path = Self::wal_file_path(&self.config.wal_dir, file_number);
636
637 let mut file_max_seq = 0u64;
639 if let Ok(mut reader) = WalReader::open(&file_path) {
640 loop {
641 match reader.read_entry() {
642 Ok(Some(entry)) => {
643 if entry.sequence > file_max_seq {
644 file_max_seq = entry.sequence;
645 }
646 }
647 Ok(None) => break,
648 Err(_) => continue,
649 }
650 }
651 }
652
653 if file_max_seq <= sequence {
655 std::fs::remove_file(&file_path).map_err(|e| {
656 AmateRSError::IoError(ErrorContext::new(format!(
657 "Failed to remove WAL file {}: {}",
658 file_path.display(),
659 e
660 )))
661 })?;
662 files_truncated += 1;
663 }
664 }
665
666 Ok(files_truncated)
667 }
668
669 pub fn recover_with_stats(
674 wal_dir: impl AsRef<Path>,
675 ) -> Result<(Vec<WalEntry>, u64, RecoveryStats)> {
676 let wal_dir = wal_dir.as_ref();
677 let mut stats = RecoveryStats::default();
678
679 if !wal_dir.exists() {
680 return Ok((Vec::new(), 0, stats));
681 }
682
683 let wal_files = Self::list_wal_file_numbers(wal_dir)?;
684
685 let mut all_entries = Vec::new();
686 let mut max_sequence = 0u64;
687
688 for file_number in wal_files {
689 let file_path = Self::wal_file_path(wal_dir, file_number);
690 let mut reader = WalReader::open(&file_path)?;
691
692 loop {
693 match reader.read_entry() {
694 Ok(Some(entry)) => {
695 let entry_bytes = entry.encode().len() as u64 + 4; stats.bytes_recovered += entry_bytes;
697 stats.entries_recovered += 1;
698 if entry.sequence > max_sequence {
699 max_sequence = entry.sequence;
700 }
701 all_entries.push(entry);
702 }
703 Ok(None) => break,
704 Err(e) => {
705 stats.entries_corrupted += 1;
706 tracing::warn!(
707 "Skipping corrupted entry in {}: {}",
708 file_path.display(),
709 e
710 );
711 continue;
712 }
713 }
714 }
715 }
716
717 Ok((all_entries, max_sequence, stats))
718 }
719
720 pub fn replay_to_memtable(
727 wal_dir: impl AsRef<Path>,
728 memtable: &crate::storage::memtable::Memtable,
729 ) -> Result<u64> {
730 let (entries, max_sequence) = Self::recover(wal_dir)?;
731
732 for entry in entries {
733 match entry.entry_type {
734 WalEntryType::Put => {
735 if let Some(value) = entry.value {
736 memtable.put(entry.key, value)?;
737 }
738 }
739 WalEntryType::Delete => {
740 memtable.delete(entry.key)?;
741 }
742 }
743 }
744
745 Ok(max_sequence)
746 }
747}
748
749pub struct WalReader {
751 reader: BufReader<File>,
752}
753
754impl WalReader {
755 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
757 let file = File::open(path.as_ref()).map_err(|e| {
758 AmateRSError::IoError(ErrorContext::new(format!("Failed to open WAL file: {}", e)))
759 })?;
760
761 Ok(Self {
762 reader: BufReader::new(file),
763 })
764 }
765
766 pub fn read_entry(&mut self) -> Result<Option<WalEntry>> {
773 let mut len_bytes = [0u8; 4];
775 match self.reader.read_exact(&mut len_bytes) {
776 Ok(()) => {}
777 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
778 return Ok(None);
780 }
781 Err(e) => {
782 return Err(AmateRSError::IoError(ErrorContext::new(format!(
783 "Failed to read WAL entry length: {}",
784 e
785 ))));
786 }
787 }
788
789 let len = u32::from_le_bytes(len_bytes) as usize;
790
791 if len > 100 * 1024 * 1024 {
793 return Err(AmateRSError::SerializationError(ErrorContext::new(
794 format!("WAL entry too large: {} bytes", len),
795 )));
796 }
797
798 let mut entry_bytes = vec![0u8; len];
800 match self.reader.read_exact(&mut entry_bytes) {
801 Ok(()) => {}
802 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
803 return Err(AmateRSError::SerializationError(ErrorContext::new(
805 "Incomplete WAL entry (truncated file)",
806 )));
807 }
808 Err(e) => {
809 return Err(AmateRSError::IoError(ErrorContext::new(format!(
810 "Failed to read WAL entry: {}",
811 e
812 ))));
813 }
814 }
815
816 let entry = WalEntry::decode(&entry_bytes)?;
818
819 Ok(Some(entry))
820 }
821}
822
823#[cfg(test)]
824#[path = "wal_tests.rs"]
825mod tests;