1use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
2use crc32fast::Hasher;
3use std::fs::{File, OpenOptions};
4use std::io::{self, Read, Seek, SeekFrom, Write};
5use std::path::{Path, PathBuf};
6use std::sync::{Arc, Mutex};
7
8#[allow(dead_code)]
10const HEADER_SIZE: usize = 4 + 8;
11
12pub struct Wal {
14 #[allow(dead_code)]
15 path: PathBuf,
16 file: Arc<Mutex<File>>,
17}
18
19impl Wal {
20 pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
22 let path = path.as_ref().to_path_buf();
23 if let Some(parent) = path.parent() {
24 std::fs::create_dir_all(parent)?;
25 }
26
27 let file = OpenOptions::new()
28 .read(true)
29 .write(true)
30 .create(true)
31 .open(&path)?;
32
33 Ok(Self {
34 path,
35 file: Arc::new(Mutex::new(file)),
36 })
37 }
38
39 pub fn append(&self, data: &[u8]) -> io::Result<u64> {
41 let mut file = self.file.lock().unwrap();
42
43 let mut hasher = Hasher::new();
45 hasher.update(data);
46 let checksum = hasher.finalize();
47
48 file.write_u32::<BigEndian>(checksum)?;
50 file.write_u64::<BigEndian>(data.len() as u64)?;
51
52 file.write_all(data)?;
54
55 Ok(file.stream_position()?)
56 }
57
58 pub fn sync(&self) -> io::Result<()> {
60 let file = self.file.lock().unwrap();
61 file.sync_all()
62 }
63
64 pub fn truncate(&self, size: u64) -> io::Result<()> {
66 let file = self.file.lock().unwrap();
67 file.set_len(size)?;
68 file.sync_all()
69 }
70
71 pub fn read_all(&self) -> io::Result<Vec<Vec<u8>>> {
73 let mut file = self.file.lock().unwrap();
74 file.seek(SeekFrom::Start(0))?;
75
76 let mut entries = Vec::new();
77 let mut buffer = Vec::new();
78
79 loop {
80 let checksum = match file.read_u32::<BigEndian>() {
82 Ok(c) => c,
83 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
84 Err(e) => return Err(e),
85 };
86
87 let len = file.read_u64::<BigEndian>()?;
88
89 if len > 128 * 1024 * 1024 {
91 return Err(io::Error::new(
92 io::ErrorKind::InvalidData,
93 "Entry too large",
94 ));
95 }
96
97 buffer.resize(len as usize, 0);
99 file.read_exact(&mut buffer)?;
100
101 let mut hasher = Hasher::new();
103 hasher.update(&buffer);
104 if hasher.finalize() != checksum {
105 return Err(io::Error::new(
106 io::ErrorKind::InvalidData,
107 "Checksum mismatch",
108 ));
109 }
110
111 entries.push(buffer.clone());
112 }
113
114 file.seek(SeekFrom::End(0))?;
116
117 Ok(entries)
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use tempfile::NamedTempFile;
125
126 #[test]
127 fn test_wal_persistence() -> io::Result<()> {
128 let temp_file = NamedTempFile::new()?;
129 let path = temp_file.path().to_path_buf();
130
131 {
133 let wal = Wal::open(&path)?;
134 wal.append(b"hello")?;
135 wal.append(b"world")?;
136 wal.sync()?;
137 }
138
139 {
141 let wal = Wal::open(&path)?;
142 let entries = wal.read_all()?;
143 assert_eq!(entries.len(), 2);
144 assert_eq!(entries[0], b"hello");
145 assert_eq!(entries[1], b"world");
146 }
147
148 Ok(())
149 }
150}