1use std::sync::Arc;
2
3use kaya_core::{KayaError, Lsn, Result, WalConfig};
4use kaya_io::{Disk, RelativePath};
5
6use crate::codec::DecodeRecordResult;
7use crate::writer::parse_segment_id;
8use crate::{decode_record, SegmentId, WalRecord, WalWarning};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct RecoveredRecord {
12 pub segment_id: SegmentId,
13 pub offset: u64,
14 pub record: WalRecord,
15}
16
17#[derive(Debug, Clone, PartialEq, Eq, Default)]
18pub struct WalRecoveryReport {
19 pub records: Vec<RecoveredRecord>,
20 pub last_lsn: Option<Lsn>,
21 pub valid_bytes: u64,
22 pub truncated_bytes: u64,
23 pub warnings: Vec<WalWarning>,
24}
25
26pub async fn recover_wal<D: Disk>(config: WalConfig, disk: Arc<D>) -> Result<WalRecoveryReport> {
27 let wal_dir = RelativePath::new("wal")?;
28 let mut segments = disk
29 .list_dir(&wal_dir)
30 .await?
31 .into_iter()
32 .filter_map(|entry| parse_segment_id(&entry.path).map(|id| (SegmentId(id), entry.path)))
33 .collect::<Vec<_>>();
34 segments.sort_by_key(|(segment_id, _)| segment_id.0);
35
36 let mut report = WalRecoveryReport::default();
37 let mut expected_lsn = None::<u64>;
38 let mut stop_index = None::<usize>;
39
40 for (index, (segment_id, path)) in segments.iter().enumerate() {
41 let len = match disk.file_len(path).await {
42 Ok(len) => len,
43 Err(KayaError::NotFound) => 0,
44 Err(error) => return Err(error),
45 };
46 if len == 0 {
47 continue;
48 }
49
50 let mut bytes = vec![0_u8; len as usize];
51 let read = disk.read_at(path, 0, &mut bytes).await?;
52 bytes.truncate(read);
53
54 let mut offset = 0_usize;
55 while offset < bytes.len() {
56 match decode_record(&bytes[offset..], offset as u64, config.max_record_bytes) {
57 DecodeRecordResult::Complete { record, bytes_read } => {
58 if let Some(expected) = expected_lsn {
59 if record.lsn.get() != expected {
60 report.warnings.push(WalWarning::NonMonotonicLsn {
61 offset: offset as u64,
62 expected,
63 found: record.lsn.get(),
64 });
65 stop_index = Some(index);
66 break;
67 }
68 }
69 expected_lsn = Some(record.lsn.next().get());
70 report.last_lsn = Some(record.lsn);
71 report.records.push(RecoveredRecord {
72 segment_id: *segment_id,
73 offset: offset as u64,
74 record,
75 });
76 offset += bytes_read;
77 }
78 DecodeRecordResult::Incomplete { warning }
79 | DecodeRecordResult::Invalid { warning } => {
80 report.warnings.push(warning);
81 stop_index = Some(index);
82 break;
83 }
84 }
85 }
86
87 report.valid_bytes += offset as u64;
88
89 if stop_index == Some(index) {
90 let valid_len = offset as u64;
91 if len > valid_len {
92 disk.truncate(path, valid_len).await?;
93 let truncated = len - valid_len;
94 report.truncated_bytes += truncated;
95 report.warnings.push(WalWarning::TailTruncated {
96 path: path.as_str().to_owned(),
97 valid_len,
98 truncated_bytes: truncated,
99 });
100 }
101 break;
102 }
103 }
104
105 if let Some(index) = stop_index {
106 let ignored = segments.len().saturating_sub(index + 1);
107 if ignored > 0 {
108 report
109 .warnings
110 .push(WalWarning::TrailingSegmentsIgnored { count: ignored });
111 }
112 }
113
114 Ok(report)
115}