Skip to main content

agentic_memory/v3/
recovery.rs

1//! Crash recovery and write-ahead logging.
2
3use super::block::Block;
4use std::fs::{File, OpenOptions};
5use std::io::{BufReader, Read, Seek, SeekFrom, Write};
6use std::path::{Path, PathBuf};
7
8/// Write-ahead log for crash recovery
9pub struct WriteAheadLog {
10    path: PathBuf,
11    file: File,
12    sequence: u64,
13}
14
15impl WriteAheadLog {
16    /// Open or create WAL
17    pub fn open(dir: &Path) -> Result<Self, std::io::Error> {
18        std::fs::create_dir_all(dir)?;
19        let path = dir.join("memory.wal");
20
21        let file = OpenOptions::new()
22            .read(true)
23            .write(true)
24            .create(true)
25            .truncate(false)
26            .open(&path)?;
27
28        let mut wal = Self {
29            path,
30            file,
31            sequence: 0,
32        };
33
34        // Recover sequence number
35        wal.recover_sequence()?;
36
37        Ok(wal)
38    }
39
40    /// Write entry to WAL (before main log)
41    pub fn write(&mut self, block: &Block) -> Result<(), std::io::Error> {
42        let data = serde_json::to_vec(block)?;
43
44        // Write: sequence (8) + length (4) + data + checksum (4)
45        let checksum = crc32fast::hash(&data);
46
47        self.file.seek(SeekFrom::End(0))?;
48        self.file.write_all(&self.sequence.to_le_bytes())?;
49        self.file.write_all(&(data.len() as u32).to_le_bytes())?;
50        self.file.write_all(&data)?;
51        self.file.write_all(&checksum.to_le_bytes())?;
52        self.file.sync_all()?;
53
54        self.sequence += 1;
55        Ok(())
56    }
57
58    /// Mark entry as committed (can be garbage collected)
59    pub fn commit(&mut self, _sequence: u64) -> Result<(), std::io::Error> {
60        // In a full implementation, we'd track committed entries
61        // and periodically truncate the WAL
62        Ok(())
63    }
64
65    /// Recover uncommitted entries after crash (handles WAL corruption gracefully)
66    pub fn recover(&self) -> Result<Vec<Block>, std::io::Error> {
67        let mut entries = Vec::new();
68        let mut skipped = 0u32;
69
70        let file = match OpenOptions::new().read(true).open(&self.path) {
71            Ok(f) => f,
72            Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(vec![]),
73            Err(e) => return Err(e),
74        };
75
76        let file_len = file.metadata()?.len();
77        if file_len == 0 {
78            return Ok(vec![]);
79        }
80
81        let mut reader = BufReader::new(file);
82
83        loop {
84            let pos = reader.stream_position().unwrap_or(file_len);
85            if pos >= file_len {
86                break;
87            }
88
89            // Read sequence
90            let mut seq_buf = [0u8; 8];
91            match reader.read_exact(&mut seq_buf) {
92                Ok(_) => {}
93                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
94                Err(e) => return Err(e),
95            }
96
97            // Read length
98            let mut len_buf = [0u8; 4];
99            match reader.read_exact(&mut len_buf) {
100                Ok(_) => {}
101                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
102                Err(e) => return Err(e),
103            }
104            let len = u32::from_le_bytes(len_buf) as usize;
105
106            // Sanity check length (max 100MB per entry)
107            if len > 100 * 1024 * 1024 {
108                log::warn!(
109                    "WAL entry at position {} has unreasonable length {}, skipping",
110                    pos,
111                    len
112                );
113                skipped += 1;
114                // Try to find next valid entry by scanning forward
115                if self.try_skip_to_next_entry(&mut reader, file_len).is_err() {
116                    break;
117                }
118                continue;
119            }
120
121            // Read data
122            let mut data = vec![0u8; len];
123            match reader.read_exact(&mut data) {
124                Ok(_) => {}
125                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
126                Err(e) => return Err(e),
127            }
128
129            // Read and verify checksum
130            let mut checksum_buf = [0u8; 4];
131            match reader.read_exact(&mut checksum_buf) {
132                Ok(_) => {}
133                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
134                Err(e) => return Err(e),
135            }
136            let stored_checksum = u32::from_le_bytes(checksum_buf);
137            let computed_checksum = crc32fast::hash(&data);
138
139            if stored_checksum == computed_checksum {
140                if let Ok(block) = serde_json::from_slice::<Block>(&data) {
141                    if block.verify() {
142                        entries.push(block);
143                    } else {
144                        log::warn!(
145                            "WAL entry at position {} failed block verification, skipping",
146                            pos
147                        );
148                        skipped += 1;
149                    }
150                } else {
151                    log::warn!(
152                        "WAL entry at position {} failed deserialization, skipping",
153                        pos
154                    );
155                    skipped += 1;
156                }
157            } else {
158                log::warn!(
159                    "WAL checksum mismatch at position {} (stored={:#x}, computed={:#x}), skipping",
160                    pos,
161                    stored_checksum,
162                    computed_checksum
163                );
164                skipped += 1;
165            }
166        }
167
168        if skipped > 0 {
169            log::warn!(
170                "WAL recovery skipped {} corrupt entries, recovered {}",
171                skipped,
172                entries.len()
173            );
174        }
175
176        Ok(entries)
177    }
178
179    /// Try to find next valid WAL entry after corruption
180    fn try_skip_to_next_entry(
181        &self,
182        reader: &mut BufReader<File>,
183        file_len: u64,
184    ) -> Result<(), std::io::Error> {
185        // Scan byte-by-byte looking for a reasonable sequence number
186        let mut byte = [0u8; 1];
187        let scan_limit = 1024; // Don't scan more than 1KB ahead
188        let mut scanned = 0;
189
190        while scanned < scan_limit {
191            let pos = reader.stream_position()?;
192            if pos + 16 >= file_len {
193                return Err(std::io::Error::new(
194                    std::io::ErrorKind::UnexpectedEof,
195                    "End of WAL",
196                ));
197            }
198
199            match reader.read_exact(&mut byte) {
200                Ok(_) => scanned += 1,
201                Err(_) => {
202                    return Err(std::io::Error::new(
203                        std::io::ErrorKind::UnexpectedEof,
204                        "End of WAL",
205                    ))
206                }
207            }
208
209            // Try reading a sequence number at current position
210            let current_pos = reader.stream_position()?;
211            if current_pos + 12 < file_len {
212                // Peek at potential entry
213                let mut peek_seq = [0u8; 8];
214                let mut peek_len = [0u8; 4];
215                if reader.read_exact(&mut peek_seq).is_ok()
216                    && reader.read_exact(&mut peek_len).is_ok()
217                {
218                    let seq = u64::from_le_bytes(peek_seq);
219                    let len = u32::from_le_bytes(peek_len) as usize;
220                    // Reasonable sequence number and length?
221                    if seq < 1_000_000_000 && len > 0 && len < 100 * 1024 * 1024 {
222                        // Seek back to start of this potential entry
223                        // (current_pos is always >= 1 here since we read a byte,
224                        // but use saturating_sub as safety)
225                        reader.seek(SeekFrom::Start(current_pos.saturating_sub(1)))?;
226                        return Ok(());
227                    }
228                }
229                // Seek back to continue scanning
230                reader.seek(SeekFrom::Start(current_pos))?;
231            }
232        }
233
234        Err(std::io::Error::other("Could not find next valid WAL entry"))
235    }
236
237    /// Clear WAL after successful checkpoint
238    pub fn clear(&mut self) -> Result<(), std::io::Error> {
239        self.file.set_len(0)?;
240        self.file.seek(SeekFrom::Start(0))?;
241        self.sequence = 0;
242        Ok(())
243    }
244
245    fn recover_sequence(&mut self) -> Result<(), std::io::Error> {
246        let metadata = self.file.metadata()?;
247        if metadata.len() == 0 {
248            return Ok(());
249        }
250
251        let file = OpenOptions::new().read(true).open(&self.path)?;
252        let mut reader = BufReader::new(file);
253        let mut max_seq = 0u64;
254
255        loop {
256            let mut seq_buf = [0u8; 8];
257            match reader.read_exact(&mut seq_buf) {
258                Ok(_) => {
259                    let seq = u64::from_le_bytes(seq_buf);
260                    // Sanity check: reject unreasonable sequence numbers
261                    if seq > 1_000_000_000 {
262                        break;
263                    }
264                    max_seq = max_seq.max(seq);
265
266                    // Skip rest of entry
267                    let mut len_buf = [0u8; 4];
268                    if reader.read_exact(&mut len_buf).is_err() {
269                        break;
270                    }
271                    let len = u32::from_le_bytes(len_buf) as usize;
272
273                    // Sanity check: reject unreasonable lengths (max 100MB per entry)
274                    if len > 100 * 1024 * 1024 {
275                        break;
276                    }
277
278                    let mut skip = vec![0u8; len + 4]; // data + checksum
279                    if reader.read_exact(&mut skip).is_err() {
280                        break;
281                    }
282                }
283                Err(_) => break,
284            }
285        }
286
287        self.sequence = if metadata.len() > 0 {
288            max_seq.saturating_add(1)
289        } else {
290            0
291        };
292        Ok(())
293    }
294}
295
296/// Recovery manager wrapping WAL
297pub struct RecoveryManager {
298    wal: WriteAheadLog,
299}
300
301impl RecoveryManager {
302    pub fn new(data_dir: &Path) -> Result<Self, std::io::Error> {
303        Ok(Self {
304            wal: WriteAheadLog::open(data_dir)?,
305        })
306    }
307
308    /// Call before writing to main log
309    pub fn pre_write(&mut self, block: &Block) -> Result<(), std::io::Error> {
310        self.wal.write(block)
311    }
312
313    /// Call after successful write to main log
314    pub fn post_write(&mut self, sequence: u64) -> Result<(), std::io::Error> {
315        self.wal.commit(sequence)
316    }
317
318    /// Recover any uncommitted writes
319    pub fn recover(&self) -> Result<Vec<Block>, std::io::Error> {
320        self.wal.recover()
321    }
322
323    /// Checkpoint: clear WAL
324    pub fn checkpoint(&mut self) -> Result<(), std::io::Error> {
325        self.wal.clear()
326    }
327}