1use crate::error::{AqlError, ErrorCode, Result};
2use serde::{Deserialize, Serialize};
3use std::fs::{File, OpenOptions};
4use std::io::{self, BufReader, BufWriter, Read, Seek, SeekFrom, Write};
5use std::path::PathBuf;
6use std::time::SystemTime;
7
8#[derive(Serialize, Deserialize, Debug, Clone)]
9pub struct LogEntry {
10 pub timestamp: u128, pub operation: Operation,
12 pub key: String,
13 pub value: Option<Vec<u8>>,
14}
15
16#[derive(Serialize)]
18struct LogEntryRef<'a> {
19 pub timestamp: u128,
20 pub operation: &'a Operation,
21 pub key: &'a str,
22 pub value: Option<&'a [u8]>,
23}
24
25#[derive(Serialize, Deserialize, Debug, Clone)]
26pub enum Operation {
27 Put,
28 Delete,
29 BeginTx,
30 CommitTx,
31 RollbackTx,
32}
33
34pub struct WriteAheadLog {
35 file: BufWriter<File>,
36 _path: PathBuf,
37}
38
39impl WriteAheadLog {
40 pub fn new(path: &str) -> Result<Self> {
46 let path = PathBuf::from(path);
47 let wal_path = path.with_extension("wal");
48
49 let mut file = OpenOptions::new()
52 .create(true)
53 .read(true)
54 .write(true)
55 .open(&wal_path)?;
56
57 file.seek(SeekFrom::End(0))?;
59
60 Ok(Self {
61 file: BufWriter::new(file),
62 _path: wal_path,
63 })
64 }
65
66 pub fn append(&mut self, operation: Operation, key: &str, value: Option<&[u8]>) -> Result<()> {
67 let timestamp = SystemTime::now()
68 .duration_since(SystemTime::UNIX_EPOCH)
69 .map_err(|e| AqlError::new(ErrorCode::ProtocolError, e.to_string()))?
70 .as_micros(); let entry = LogEntryRef {
73 timestamp,
74 operation: &operation,
75 key,
76 value,
77 };
78
79 let serialized = bincode::serialize(&entry)
80 .map_err(|e| AqlError::new(ErrorCode::ProtocolError, e.to_string()))?;
81 let len = serialized.len() as u32;
82
83 self.file.write_all(&len.to_le_bytes())?;
84 self.file.write_all(&serialized)?;
85
86 self.file.flush()?;
88
89 Ok(())
90 }
91
92 pub fn sync(&mut self) -> Result<()> {
94 self.file.flush()?;
95 self.file.get_mut().sync_all()?;
96 Ok(())
97 }
98
99 pub fn recover(&mut self) -> Result<Vec<LogEntry>> {
100 let mut file = self.file.get_ref();
101 file.seek(SeekFrom::Start(0))?;
102 let mut reader = BufReader::new(file);
103 let mut entries = Vec::new();
104
105 loop {
106 let mut len_bytes = [0u8; 4];
107 match reader.read_exact(&mut len_bytes) {
108 Ok(_) => {
109 let len = u32::from_le_bytes(len_bytes) as usize;
110 let mut buffer = vec![0u8; len];
111 reader.read_exact(&mut buffer)?;
112 let entry: LogEntry = bincode::deserialize(&buffer)
113 .map_err(|e| AqlError::new(ErrorCode::ProtocolError, e.to_string()))?;
114 entries.push(entry);
115 }
116 Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
117 Err(e) => return Err(e.into()),
118 }
119 }
120
121 Ok(entries)
122 }
123
124 pub fn truncate(&mut self) -> Result<()> {
125 let file = self.file.get_mut();
126 file.set_len(0)?;
127 file.sync_all()?;
128 file.seek(SeekFrom::Start(0))?;
129 Ok(())
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use tempfile::tempdir;
137
138 #[test]
139 fn test_log_operations() -> Result<()> {
140 let temp_dir = tempdir()?;
141 let log_path = temp_dir.path().join("test.wal");
142 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
143
144 wal.append(Operation::Put, "test_key", Some(b"test_value"))?;
146 wal.append(Operation::BeginTx, "", None)?;
147 wal.append(Operation::CommitTx, "", None)?;
148
149 let entries = wal.recover()?;
151 assert_eq!(entries.len(), 3);
152
153 assert!(matches!(entries[0].operation, Operation::Put));
154 assert_eq!(entries[0].key, "test_key");
155 assert_eq!(entries[0].value.as_ref().unwrap(), b"test_value");
156
157 assert!(matches!(entries[1].operation, Operation::BeginTx));
158 assert!(matches!(entries[2].operation, Operation::CommitTx));
159
160 Ok(())
161 }
162
163 #[test]
164 fn test_truncate() -> Result<()> {
165 let temp_dir = tempdir()?;
166 let log_path = temp_dir.path().join("test.wal");
167 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
168
169 wal.append(Operation::Put, "key1", Some(b"value1"))?;
171 wal.append(Operation::Put, "key2", Some(b"value2"))?;
172
173 wal.truncate()?;
175
176 let entries = wal.recover()?;
178 assert_eq!(entries.len(), 0);
179
180 Ok(())
181 }
182
183 #[test]
184 fn test_transaction_operations() -> Result<()> {
185 let temp_dir = tempdir()?;
186 let log_path = temp_dir.path().join("test.wal");
187 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
188
189 wal.append(Operation::BeginTx, "", None)?;
191 wal.append(Operation::Put, "tx_key1", Some(b"tx_value1"))?;
192 wal.append(Operation::Put, "tx_key2", Some(b"tx_value2"))?;
193 wal.append(Operation::CommitTx, "", None)?;
194
195 let entries = wal.recover()?;
196 assert_eq!(entries.len(), 4);
197
198 assert!(matches!(entries[0].operation, Operation::BeginTx));
199 assert!(matches!(entries[1].operation, Operation::Put));
200 assert!(matches!(entries[2].operation, Operation::Put));
201 assert!(matches!(entries[3].operation, Operation::CommitTx));
202
203 Ok(())
204 }
205
206 #[test]
207 fn test_rollback_operation() -> Result<()> {
208 let temp_dir = tempdir()?;
209 let log_path = temp_dir.path().join("test.wal");
210 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
211 wal.append(Operation::BeginTx, "", None)?;
213 wal.append(Operation::Put, "key_to_rollback", Some(b"value"))?;
214 wal.append(Operation::RollbackTx, "", None)?;
215
216 let entries = wal.recover()?;
217 assert_eq!(entries.len(), 3);
218
219 assert!(matches!(entries[0].operation, Operation::BeginTx));
220 assert!(matches!(entries[1].operation, Operation::Put));
221 assert!(matches!(entries[2].operation, Operation::RollbackTx));
222
223 Ok(())
224 }
225
226 #[test]
227 fn test_large_values() -> Result<()> {
228 let temp_dir = tempdir()?;
229 let log_path = temp_dir.path().join("test.wal");
230 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
231
232 let large_value = vec![0u8; 1024 * 1024]; wal.append(Operation::Put, "large_key", Some(&large_value))?;
235
236 let entries = wal.recover()?;
237 assert_eq!(entries.len(), 1);
238 assert_eq!(entries[0].value.as_ref().unwrap().len(), large_value.len());
239
240 Ok(())
241 }
242
243 #[test]
244 fn test_invalid_path() {
245 let result = WriteAheadLog::new("/nonexistent/directory/test.wal");
246 assert!(result.is_err());
247 }
248}