1use crate::error::{AqlError, ErrorCode, Result};
2use serde::{Deserialize, Serialize};
3use std::fs::{File, OpenOptions};
4use std::io::{self, BufReader, BufWriter, Read, Seek, SeekFrom, Write};
5use std::path::PathBuf;
6use std::time::SystemTime;
7
8#[derive(Serialize, Deserialize)]
9pub struct LogEntry {
10 pub timestamp: u64,
11 pub operation: Operation,
12 pub key: String,
13 pub value: Option<Vec<u8>>,
14}
15
16#[derive(Serialize, Deserialize)]
17pub enum Operation {
18 Put,
19 Delete,
20 BeginTx,
21 CommitTx,
22 RollbackTx,
23}
24
25pub struct WriteAheadLog {
26 file: BufWriter<File>,
27 _path: PathBuf,
28}
29
30impl WriteAheadLog {
31 pub fn new(path: &str) -> Result<Self> {
37 let path = PathBuf::from(path);
38 let wal_path = path.with_extension("wal");
39
40 let mut file = OpenOptions::new()
43 .create(true)
44 .read(true)
45 .write(true)
46 .open(&wal_path)?;
47
48 file.seek(SeekFrom::End(0))?;
50
51 Ok(Self {
52 file: BufWriter::new(file),
53 _path: wal_path,
54 })
55 }
56
57 pub fn append(&mut self, operation: Operation, key: &str, value: Option<&[u8]>) -> Result<()> {
58 let timestamp = SystemTime::now()
59 .duration_since(SystemTime::UNIX_EPOCH)
60 .map_err(|e| AqlError::new(ErrorCode::ProtocolError, e.to_string()))?
61 .as_secs();
62
63 let entry = LogEntry {
64 timestamp,
65 operation,
66 key: key.to_string(),
67 value: value.map(|v| v.to_vec()),
68 };
69
70 let serialized = bincode::serialize(&entry)
71 .map_err(|e| AqlError::new(ErrorCode::ProtocolError, e.to_string()))?;
72 let len = serialized.len() as u32;
73
74 self.file.write_all(&len.to_le_bytes())?;
75 self.file.write_all(&serialized)?;
76 self.file.flush()?;
77 self.file.get_mut().sync_all()?;
78
79 Ok(())
80 }
81
82 pub fn recover(&mut self) -> Result<Vec<LogEntry>> {
83 let mut file = self.file.get_ref();
84 file.seek(SeekFrom::Start(0))?;
85 let mut reader = BufReader::new(file);
86 let mut entries = Vec::new();
87
88 loop {
89 let mut len_bytes = [0u8; 4];
90 match reader.read_exact(&mut len_bytes) {
91 Ok(_) => {
92 let len = u32::from_le_bytes(len_bytes) as usize;
93 let mut buffer = vec![0u8; len];
94 reader.read_exact(&mut buffer)?;
95 let entry: LogEntry = bincode::deserialize(&buffer)
96 .map_err(|e| AqlError::new(ErrorCode::ProtocolError, e.to_string()))?;
97 entries.push(entry);
98 }
99 Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
100 Err(e) => return Err(e.into()),
101 }
102 }
103
104 Ok(entries)
105 }
106
107 pub fn truncate(&mut self) -> Result<()> {
108 let file = self.file.get_mut();
109 file.set_len(0)?;
110 file.sync_all()?;
111 file.seek(SeekFrom::Start(0))?;
112 Ok(())
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use tempfile::tempdir;
120
121 #[test]
122 fn test_log_operations() -> Result<()> {
123 let temp_dir = tempdir()?;
124 let log_path = temp_dir.path().join("test.wal");
125 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
126
127 wal.append(Operation::Put, "test_key", Some(b"test_value"))?;
129 wal.append(Operation::BeginTx, "", None)?;
130 wal.append(Operation::CommitTx, "", None)?;
131
132 let entries = wal.recover()?;
134 assert_eq!(entries.len(), 3);
135
136 assert!(matches!(entries[0].operation, Operation::Put));
137 assert_eq!(entries[0].key, "test_key");
138 assert_eq!(entries[0].value.as_ref().unwrap(), b"test_value");
139
140 assert!(matches!(entries[1].operation, Operation::BeginTx));
141 assert!(matches!(entries[2].operation, Operation::CommitTx));
142
143 Ok(())
144 }
145
146 #[test]
147 fn test_truncate() -> Result<()> {
148 let temp_dir = tempdir()?;
149 let log_path = temp_dir.path().join("test.wal");
150 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
151
152 wal.append(Operation::Put, "key1", Some(b"value1"))?;
154 wal.append(Operation::Put, "key2", Some(b"value2"))?;
155
156 wal.truncate()?;
158
159 let entries = wal.recover()?;
161 assert_eq!(entries.len(), 0);
162
163 Ok(())
164 }
165
166 #[test]
167 fn test_transaction_operations() -> Result<()> {
168 let temp_dir = tempdir()?;
169 let log_path = temp_dir.path().join("test.wal");
170 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
171
172 wal.append(Operation::BeginTx, "", None)?;
174 wal.append(Operation::Put, "tx_key1", Some(b"tx_value1"))?;
175 wal.append(Operation::Put, "tx_key2", Some(b"tx_value2"))?;
176 wal.append(Operation::CommitTx, "", None)?;
177
178 let entries = wal.recover()?;
179 assert_eq!(entries.len(), 4);
180
181 assert!(matches!(entries[0].operation, Operation::BeginTx));
182 assert!(matches!(entries[1].operation, Operation::Put));
183 assert!(matches!(entries[2].operation, Operation::Put));
184 assert!(matches!(entries[3].operation, Operation::CommitTx));
185
186 Ok(())
187 }
188
189 #[test]
190 fn test_rollback_operation() -> Result<()> {
191 let temp_dir = tempdir()?;
192 let log_path = temp_dir.path().join("test.wal");
193 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
194 wal.append(Operation::BeginTx, "", None)?;
196 wal.append(Operation::Put, "key_to_rollback", Some(b"value"))?;
197 wal.append(Operation::RollbackTx, "", None)?;
198
199 let entries = wal.recover()?;
200 assert_eq!(entries.len(), 3);
201
202 assert!(matches!(entries[0].operation, Operation::BeginTx));
203 assert!(matches!(entries[1].operation, Operation::Put));
204 assert!(matches!(entries[2].operation, Operation::RollbackTx));
205
206 Ok(())
207 }
208
209 #[test]
210 fn test_large_values() -> Result<()> {
211 let temp_dir = tempdir()?;
212 let log_path = temp_dir.path().join("test.wal");
213 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
214
215 let large_value = vec![0u8; 1024 * 1024]; wal.append(Operation::Put, "large_key", Some(&large_value))?;
218
219 let entries = wal.recover()?;
220 assert_eq!(entries.len(), 1);
221 assert_eq!(entries[0].value.as_ref().unwrap().len(), large_value.len());
222
223 Ok(())
224 }
225
226 #[test]
227 fn test_invalid_path() {
228 let result = WriteAheadLog::new("/nonexistent/directory/test.wal");
229 assert!(result.is_err());
230 }
231}