agentic_memory/v3/
recovery.rs1use super::block::Block;
4use std::fs::{File, OpenOptions};
5use std::io::{BufReader, Read, Seek, SeekFrom, Write};
6use std::path::{Path, PathBuf};
7
8pub struct WriteAheadLog {
10 path: PathBuf,
11 file: File,
12 sequence: u64,
13}
14
15impl WriteAheadLog {
16 pub fn open(dir: &Path) -> Result<Self, std::io::Error> {
18 std::fs::create_dir_all(dir)?;
19 let path = dir.join("memory.wal");
20
21 let file = OpenOptions::new()
22 .read(true)
23 .write(true)
24 .create(true)
25 .truncate(false)
26 .open(&path)?;
27
28 let mut wal = Self {
29 path,
30 file,
31 sequence: 0,
32 };
33
34 wal.recover_sequence()?;
36
37 Ok(wal)
38 }
39
40 pub fn write(&mut self, block: &Block) -> Result<(), std::io::Error> {
42 let data = serde_json::to_vec(block)?;
43
44 let checksum = crc32fast::hash(&data);
46
47 self.file.seek(SeekFrom::End(0))?;
48 self.file.write_all(&self.sequence.to_le_bytes())?;
49 self.file.write_all(&(data.len() as u32).to_le_bytes())?;
50 self.file.write_all(&data)?;
51 self.file.write_all(&checksum.to_le_bytes())?;
52 self.file.sync_all()?;
53
54 self.sequence += 1;
55 Ok(())
56 }
57
58 pub fn commit(&mut self, _sequence: u64) -> Result<(), std::io::Error> {
60 Ok(())
63 }
64
65 pub fn recover(&self) -> Result<Vec<Block>, std::io::Error> {
67 let mut entries = Vec::new();
68 let mut skipped = 0u32;
69
70 let file = match OpenOptions::new().read(true).open(&self.path) {
71 Ok(f) => f,
72 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(vec![]),
73 Err(e) => return Err(e),
74 };
75
76 let file_len = file.metadata()?.len();
77 if file_len == 0 {
78 return Ok(vec![]);
79 }
80
81 let mut reader = BufReader::new(file);
82
83 loop {
84 let pos = reader.stream_position().unwrap_or(file_len);
85 if pos >= file_len {
86 break;
87 }
88
89 let mut seq_buf = [0u8; 8];
91 match reader.read_exact(&mut seq_buf) {
92 Ok(_) => {}
93 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
94 Err(e) => return Err(e),
95 }
96
97 let mut len_buf = [0u8; 4];
99 match reader.read_exact(&mut len_buf) {
100 Ok(_) => {}
101 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
102 Err(e) => return Err(e),
103 }
104 let len = u32::from_le_bytes(len_buf) as usize;
105
106 if len > 100 * 1024 * 1024 {
108 log::warn!(
109 "WAL entry at position {} has unreasonable length {}, skipping",
110 pos,
111 len
112 );
113 skipped += 1;
114 if self.try_skip_to_next_entry(&mut reader, file_len).is_err() {
116 break;
117 }
118 continue;
119 }
120
121 let mut data = vec![0u8; len];
123 match reader.read_exact(&mut data) {
124 Ok(_) => {}
125 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
126 Err(e) => return Err(e),
127 }
128
129 let mut checksum_buf = [0u8; 4];
131 match reader.read_exact(&mut checksum_buf) {
132 Ok(_) => {}
133 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
134 Err(e) => return Err(e),
135 }
136 let stored_checksum = u32::from_le_bytes(checksum_buf);
137 let computed_checksum = crc32fast::hash(&data);
138
139 if stored_checksum == computed_checksum {
140 if let Ok(block) = serde_json::from_slice::<Block>(&data) {
141 if block.verify() {
142 entries.push(block);
143 } else {
144 log::warn!(
145 "WAL entry at position {} failed block verification, skipping",
146 pos
147 );
148 skipped += 1;
149 }
150 } else {
151 log::warn!(
152 "WAL entry at position {} failed deserialization, skipping",
153 pos
154 );
155 skipped += 1;
156 }
157 } else {
158 log::warn!(
159 "WAL checksum mismatch at position {} (stored={:#x}, computed={:#x}), skipping",
160 pos,
161 stored_checksum,
162 computed_checksum
163 );
164 skipped += 1;
165 }
166 }
167
168 if skipped > 0 {
169 log::warn!(
170 "WAL recovery skipped {} corrupt entries, recovered {}",
171 skipped,
172 entries.len()
173 );
174 }
175
176 Ok(entries)
177 }
178
179 fn try_skip_to_next_entry(
181 &self,
182 reader: &mut BufReader<File>,
183 file_len: u64,
184 ) -> Result<(), std::io::Error> {
185 let mut byte = [0u8; 1];
187 let scan_limit = 1024; let mut scanned = 0;
189
190 while scanned < scan_limit {
191 let pos = reader.stream_position()?;
192 if pos + 16 >= file_len {
193 return Err(std::io::Error::new(
194 std::io::ErrorKind::UnexpectedEof,
195 "End of WAL",
196 ));
197 }
198
199 match reader.read_exact(&mut byte) {
200 Ok(_) => scanned += 1,
201 Err(_) => {
202 return Err(std::io::Error::new(
203 std::io::ErrorKind::UnexpectedEof,
204 "End of WAL",
205 ))
206 }
207 }
208
209 let current_pos = reader.stream_position()?;
211 if current_pos + 12 < file_len {
212 let mut peek_seq = [0u8; 8];
214 let mut peek_len = [0u8; 4];
215 if reader.read_exact(&mut peek_seq).is_ok()
216 && reader.read_exact(&mut peek_len).is_ok()
217 {
218 let seq = u64::from_le_bytes(peek_seq);
219 let len = u32::from_le_bytes(peek_len) as usize;
220 if seq < 1_000_000_000 && len > 0 && len < 100 * 1024 * 1024 {
222 reader.seek(SeekFrom::Start(current_pos.saturating_sub(1)))?;
226 return Ok(());
227 }
228 }
229 reader.seek(SeekFrom::Start(current_pos))?;
231 }
232 }
233
234 Err(std::io::Error::other("Could not find next valid WAL entry"))
235 }
236
237 pub fn clear(&mut self) -> Result<(), std::io::Error> {
239 self.file.set_len(0)?;
240 self.file.seek(SeekFrom::Start(0))?;
241 self.sequence = 0;
242 Ok(())
243 }
244
245 fn recover_sequence(&mut self) -> Result<(), std::io::Error> {
246 let metadata = self.file.metadata()?;
247 if metadata.len() == 0 {
248 return Ok(());
249 }
250
251 let file = OpenOptions::new().read(true).open(&self.path)?;
252 let mut reader = BufReader::new(file);
253 let mut max_seq = 0u64;
254
255 loop {
256 let mut seq_buf = [0u8; 8];
257 match reader.read_exact(&mut seq_buf) {
258 Ok(_) => {
259 let seq = u64::from_le_bytes(seq_buf);
260 if seq > 1_000_000_000 {
262 break;
263 }
264 max_seq = max_seq.max(seq);
265
266 let mut len_buf = [0u8; 4];
268 if reader.read_exact(&mut len_buf).is_err() {
269 break;
270 }
271 let len = u32::from_le_bytes(len_buf) as usize;
272
273 if len > 100 * 1024 * 1024 {
275 break;
276 }
277
278 let mut skip = vec![0u8; len + 4]; if reader.read_exact(&mut skip).is_err() {
280 break;
281 }
282 }
283 Err(_) => break,
284 }
285 }
286
287 self.sequence = if metadata.len() > 0 {
288 max_seq.saturating_add(1)
289 } else {
290 0
291 };
292 Ok(())
293 }
294}
295
296pub struct RecoveryManager {
298 wal: WriteAheadLog,
299}
300
301impl RecoveryManager {
302 pub fn new(data_dir: &Path) -> Result<Self, std::io::Error> {
303 Ok(Self {
304 wal: WriteAheadLog::open(data_dir)?,
305 })
306 }
307
308 pub fn pre_write(&mut self, block: &Block) -> Result<(), std::io::Error> {
310 self.wal.write(block)
311 }
312
313 pub fn post_write(&mut self, sequence: u64) -> Result<(), std::io::Error> {
315 self.wal.commit(sequence)
316 }
317
318 pub fn recover(&self) -> Result<Vec<Block>, std::io::Error> {
320 self.wal.recover()
321 }
322
323 pub fn checkpoint(&mut self) -> Result<(), std::io::Error> {
325 self.wal.clear()
326 }
327}