1use crate::{Error, Result, Record, Lsn};
2use bytes::{BytesMut, BufMut};
3use parking_lot::Mutex;
4use std::fs::{File, OpenOptions};
5use std::io::{Read, Write, Seek, SeekFrom};
6use std::path::Path;
7use std::sync::Arc;
8
9const WAL_HEADER_SIZE: usize = 16;
10const WAL_MAGIC: u32 = 0x57414C00; const RECORD_HEADER_SIZE: usize = 12; pub struct Wal {
17 inner: Arc<Mutex<WalInner>>,
18}
19
20struct WalInner {
21 file: File,
22 next_lsn: Lsn,
23 pending: Vec<Record>,
24}
25
26impl Wal {
27 pub fn create(path: impl AsRef<Path>) -> Result<Self> {
29 let mut file = OpenOptions::new()
30 .read(true)
31 .write(true)
32 .create_new(true)
33 .open(path)?;
34
35 let mut header = BytesMut::with_capacity(WAL_HEADER_SIZE);
37 header.put_u32(WAL_MAGIC); header.put_u32_le(1); header.put_u64_le(0); file.write_all(&header)?;
41 file.sync_all()?;
42
43 Ok(Self {
44 inner: Arc::new(Mutex::new(WalInner {
45 file,
46 next_lsn: 1,
47 pending: Vec::new(),
48 })),
49 })
50 }
51
52 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
54 let mut file = OpenOptions::new()
55 .read(true)
56 .write(true)
57 .open(path)?;
58
59 let mut header = [0u8; WAL_HEADER_SIZE];
61 file.read_exact(&mut header)?;
62 let magic = u32::from_be_bytes([header[0], header[1], header[2], header[3]]);
63 if magic != WAL_MAGIC {
64 return Err(Error::Corruption("Invalid WAL magic".to_string()));
65 }
66
67 file.seek(SeekFrom::Start(WAL_HEADER_SIZE as u64))?;
69 let mut max_lsn = 0u64;
70
71 loop {
72 let mut rec_header = [0u8; RECORD_HEADER_SIZE];
73 match file.read_exact(&mut rec_header) {
74 Ok(_) => {
75 let lsn = u64::from_le_bytes([
76 rec_header[0], rec_header[1], rec_header[2], rec_header[3],
77 rec_header[4], rec_header[5], rec_header[6], rec_header[7],
78 ]);
79 let len = u32::from_le_bytes([
80 rec_header[8], rec_header[9], rec_header[10], rec_header[11],
81 ]) as u64;
82
83 max_lsn = max_lsn.max(lsn);
84
85 file.seek(SeekFrom::Current(len as i64 + 4))?;
87 }
88 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
89 Err(e) => return Err(e.into()),
90 }
91 }
92
93 Ok(Self {
94 inner: Arc::new(Mutex::new(WalInner {
95 file,
96 next_lsn: max_lsn + 1,
97 pending: Vec::new(),
98 })),
99 })
100 }
101
102 pub fn append(&self, record: Record) -> Result<Lsn> {
104 let mut inner = self.inner.lock();
105 let lsn = inner.next_lsn;
106 inner.next_lsn += 1;
107 inner.pending.push(record);
108 Ok(lsn)
109 }
110
111 pub fn flush(&self) -> Result<()> {
113 let mut inner = self.inner.lock();
114 if inner.pending.is_empty() {
115 return Ok(());
116 }
117
118 inner.file.seek(SeekFrom::End(0))?;
120
121 let mut full_buf = BytesMut::new();
123 let base_lsn = inner.next_lsn - inner.pending.len() as u64;
124
125 for (i, record) in inner.pending.iter().enumerate() {
126 let lsn = base_lsn + i as u64;
127
128 let data = bincode::serialize(record)
129 .map_err(|e| Error::Internal(format!("Serialize error: {}", e)))?;
130 let crc = crc32fast::hash(&data);
131
132 full_buf.put_u64_le(lsn);
133 full_buf.put_u32_le(data.len() as u32);
134 full_buf.put_slice(&data);
135 full_buf.put_u32_le(crc);
136 }
137
138 inner.file.write_all(&full_buf)?;
140
141 inner.file.sync_all()?;
142 inner.pending.clear();
143
144 Ok(())
145 }
146
147 pub fn read_all(&self) -> Result<Vec<(Lsn, Record)>> {
149 let inner = self.inner.lock();
150 let mut file = inner.file.try_clone()?;
151 drop(inner);
152
153 file.seek(SeekFrom::Start(WAL_HEADER_SIZE as u64))?;
154
155 let mut records = Vec::new();
156 loop {
157 let mut rec_header = [0u8; RECORD_HEADER_SIZE];
158 match file.read_exact(&mut rec_header) {
159 Ok(_) => {
160 let lsn = u64::from_le_bytes([
161 rec_header[0], rec_header[1], rec_header[2], rec_header[3],
162 rec_header[4], rec_header[5], rec_header[6], rec_header[7],
163 ]);
164 let len = u32::from_le_bytes([
165 rec_header[8], rec_header[9], rec_header[10], rec_header[11],
166 ]) as usize;
167
168 let mut data = vec![0u8; len];
169 file.read_exact(&mut data)?;
170
171 let mut crc_bytes = [0u8; 4];
172 file.read_exact(&mut crc_bytes)?;
173 let expected_crc = u32::from_le_bytes(crc_bytes);
174 let actual_crc = crc32fast::hash(&data);
175
176 if expected_crc != actual_crc {
177 return Err(Error::ChecksumMismatch);
178 }
179
180 let record: Record = bincode::deserialize(&data)
181 .map_err(|e| Error::Corruption(format!("Deserialize error: {}", e)))?;
182
183 records.push((lsn, record));
184 }
185 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
186 Err(e) => return Err(e.into()),
187 }
188 }
189
190 Ok(records)
191 }
192
193 pub fn next_lsn(&self) -> Lsn {
194 self.inner.lock().next_lsn
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::{Key, Value};
202 use tempfile::TempDir;
203 use std::collections::HashMap;
204
205 #[test]
206 fn test_wal_create_and_write() {
207 let tmp = TempDir::new().unwrap();
208 let path = tmp.path().join("wal.log");
209
210 let wal = Wal::create(path).unwrap();
211
212 let key = Key::new(b"user#123".to_vec());
213 let mut item = HashMap::new();
214 item.insert("name".to_string(), Value::string("Alice"));
215 let record = Record::put(key, item, 1);
216
217 let lsn = wal.append(record).unwrap();
218 assert_eq!(lsn, 1);
219
220 wal.flush().unwrap();
221
222 let records = wal.read_all().unwrap();
224 assert_eq!(records.len(), 1);
225 assert_eq!(records[0].0, 1);
226 }
227
228 #[test]
229 fn test_wal_reopen() {
230 let tmp = TempDir::new().unwrap();
231 let path = tmp.path().join("wal.log");
232
233 {
234 let wal = Wal::create(&path).unwrap();
235 let key = Key::new(b"test".to_vec());
236 let item = HashMap::new();
237 let record = Record::put(key, item, 1);
238 wal.append(record).unwrap();
239 wal.flush().unwrap();
240 }
241
242 let wal = Wal::open(&path).unwrap();
244 assert_eq!(wal.next_lsn(), 2);
245
246 let records = wal.read_all().unwrap();
247 assert_eq!(records.len(), 1);
248 }
249
250 #[test]
251 fn test_wal_group_commit() {
252 let tmp = TempDir::new().unwrap();
253 let wal = Wal::create(tmp.path().join("wal.log")).unwrap();
254
255 for i in 0..10 {
257 let key = Key::new(format!("key{}", i).into_bytes());
258 let item = HashMap::new();
259 let record = Record::put(key, item, i);
260 wal.append(record).unwrap();
261 }
262
263 wal.flush().unwrap();
265
266 let records = wal.read_all().unwrap();
267 assert_eq!(records.len(), 10);
268 }
269}