Skip to main content

apfsds_storage/
wal.rs

1use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
2use crc32fast::Hasher;
3use std::fs::{File, OpenOptions};
4use std::io::{self, Read, Seek, SeekFrom, Write};
5use std::path::{Path, PathBuf};
6use std::sync::{Arc, Mutex};
7
8/// Write-Ahead Log (WAL) entry header size (CRC32 + Length)
9#[allow(dead_code)]
10const HEADER_SIZE: usize = 4 + 8;
11
12/// Write-Ahead Log for persistent storage
13pub struct Wal {
14    #[allow(dead_code)]
15    path: PathBuf,
16    file: Arc<Mutex<File>>,
17}
18
19impl Wal {
20    /// Open or create a WAL file
21    pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
22        let path = path.as_ref().to_path_buf();
23        if let Some(parent) = path.parent() {
24            std::fs::create_dir_all(parent)?;
25        }
26
27        let file = OpenOptions::new()
28            .read(true)
29            .write(true)
30            .create(true)
31            .open(&path)?;
32
33        Ok(Self {
34            path,
35            file: Arc::new(Mutex::new(file)),
36        })
37    }
38
39    /// Append an entry to the WAL
40    pub fn append(&self, data: &[u8]) -> io::Result<u64> {
41        let mut file = self.file.lock().unwrap();
42
43        // Calculate CRC32
44        let mut hasher = Hasher::new();
45        hasher.update(data);
46        let checksum = hasher.finalize();
47
48        // Write header: CRC32 (4 bytes) + Length (8 bytes)
49        file.write_u32::<BigEndian>(checksum)?;
50        file.write_u64::<BigEndian>(data.len() as u64)?;
51
52        // Write data
53        file.write_all(data)?;
54
55        Ok(file.stream_position()?)
56    }
57
58    /// Sync changes to disk
59    pub fn sync(&self) -> io::Result<()> {
60        let file = self.file.lock().unwrap();
61        file.sync_all()
62    }
63
64    /// Truncate the WAL to a specific size
65    pub fn truncate(&self, size: u64) -> io::Result<()> {
66        let file = self.file.lock().unwrap();
67        file.set_len(size)?;
68        file.sync_all()
69    }
70
71    /// Read all entries from the WAL
72    pub fn read_all(&self) -> io::Result<Vec<Vec<u8>>> {
73        let mut file = self.file.lock().unwrap();
74        file.seek(SeekFrom::Start(0))?;
75
76        let mut entries = Vec::new();
77        let mut buffer = Vec::new();
78
79        loop {
80            // Read header
81            let checksum = match file.read_u32::<BigEndian>() {
82                Ok(c) => c,
83                Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
84                Err(e) => return Err(e),
85            };
86
87            let len = file.read_u64::<BigEndian>()?;
88
89            // Validate length sanity check (max 128MB per entry)
90            if len > 128 * 1024 * 1024 {
91                return Err(io::Error::new(
92                    io::ErrorKind::InvalidData,
93                    "Entry too large",
94                ));
95            }
96
97            // Read data
98            buffer.resize(len as usize, 0);
99            file.read_exact(&mut buffer)?;
100
101            // Verify checksum
102            let mut hasher = Hasher::new();
103            hasher.update(&buffer);
104            if hasher.finalize() != checksum {
105                return Err(io::Error::new(
106                    io::ErrorKind::InvalidData,
107                    "Checksum mismatch",
108                ));
109            }
110
111            entries.push(buffer.clone());
112        }
113
114        // Restore file position to end
115        file.seek(SeekFrom::End(0))?;
116
117        Ok(entries)
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use tempfile::NamedTempFile;
125
126    #[test]
127    fn test_wal_persistence() -> io::Result<()> {
128        let temp_file = NamedTempFile::new()?;
129        let path = temp_file.path().to_path_buf();
130
131        // Write data
132        {
133            let wal = Wal::open(&path)?;
134            wal.append(b"hello")?;
135            wal.append(b"world")?;
136            wal.sync()?;
137        }
138
139        // Read back
140        {
141            let wal = Wal::open(&path)?;
142            let entries = wal.read_all()?;
143            assert_eq!(entries.len(), 2);
144            assert_eq!(entries[0], b"hello");
145            assert_eq!(entries[1], b"world");
146        }
147
148        Ok(())
149    }
150}