1use crate::error::{AuroraError, Result};
2use serde::{Deserialize, Serialize};
3use std::fs::{File, OpenOptions};
4use std::io::{self, BufReader, BufWriter, Read, Seek, SeekFrom, Write};
5use std::path::PathBuf;
6use std::time::SystemTime;
7
8#[derive(Serialize, Deserialize)]
9pub struct LogEntry {
10 timestamp: u64,
11 operation: Operation,
12 key: String,
13 value: Option<Vec<u8>>,
14}
15
16#[derive(Serialize, Deserialize)]
17pub enum Operation {
18 Put,
19 Delete,
20 BeginTx,
21 CommitTx,
22 RollbackTx,
23}
24
25pub struct WriteAheadLog {
26 file: BufWriter<File>,
27 _path: PathBuf,
28}
29
30impl WriteAheadLog {
31 pub fn new(path: &str) -> Result<Self> {
32 let path = PathBuf::from(path);
33 let wal_path = path.with_extension("wal");
34
35 let file = BufWriter::new(
36 OpenOptions::new()
37 .create(true)
38 .read(true)
39 .write(true)
40 .append(true)
41 .open(&wal_path)?
42 );
43
44 Ok(Self {
45 file,
46 _path: wal_path,
47 })
48 }
49
50 pub fn append(&mut self, operation: Operation, key: &str, value: Option<&[u8]>) -> Result<()> {
51 let timestamp = SystemTime::now()
52 .duration_since(SystemTime::UNIX_EPOCH)
53 .map_err(|e| AuroraError::Protocol(e.to_string()))?
54 .as_secs();
55
56 let entry = LogEntry {
57 timestamp,
58 operation,
59 key: key.to_string(),
60 value: value.map(|v| v.to_vec()),
61 };
62
63 let serialized = bincode::serialize(&entry)
64 .map_err(|e| AuroraError::Protocol(e.to_string()))?;
65 let len = serialized.len() as u32;
66
67 self.file.write_all(&len.to_le_bytes())?;
68 self.file.write_all(&serialized)?;
69 self.file.flush()?;
70
71 Ok(())
72 }
73
74 pub fn recover(&mut self) -> Result<Vec<LogEntry>> {
75 let mut file = self.file.get_ref();
76 file.seek(SeekFrom::Start(0))?;
77 let mut reader = BufReader::new(file);
78 let mut entries = Vec::new();
79
80 loop {
81 let mut len_bytes = [0u8; 4];
82 match reader.read_exact(&mut len_bytes) {
83 Ok(_) => {
84 let len = u32::from_le_bytes(len_bytes) as usize;
85 let mut buffer = vec![0u8; len];
86 reader.read_exact(&mut buffer)?;
87 let entry: LogEntry = bincode::deserialize(&buffer)
88 .map_err(|e| AuroraError::Protocol(e.to_string()))?;
89 entries.push(entry);
90 }
91 Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
92 Err(e) => return Err(e.into()),
93 }
94 }
95
96 Ok(entries)
97 }
98
99 pub fn truncate(&mut self) -> Result<()> {
100 let file = self.file.get_mut();
101 file.set_len(0)?;
102 file.sync_all()?;
103 file.seek(SeekFrom::Start(0))?;
104 Ok(())
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use tempfile::tempdir;
112
113 #[test]
114 fn test_log_operations() -> Result<()> {
115 let temp_dir = tempdir()?;
116 let log_path = temp_dir.path().join("test.wal");
117 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
118
119 wal.append(Operation::Put, "test_key", Some(b"test_value"))?;
121 wal.append(Operation::BeginTx, "", None)?;
122 wal.append(Operation::CommitTx, "", None)?;
123
124 let entries = wal.recover()?;
126 assert_eq!(entries.len(), 3);
127
128 assert!(matches!(entries[0].operation, Operation::Put));
129 assert_eq!(entries[0].key, "test_key");
130 assert_eq!(entries[0].value.as_ref().unwrap(), b"test_value");
131
132 assert!(matches!(entries[1].operation, Operation::BeginTx));
133 assert!(matches!(entries[2].operation, Operation::CommitTx));
134
135 Ok(())
136 }
137
138 #[test]
139 fn test_truncate() -> Result<()> {
140 let temp_dir = tempdir()?;
141 let log_path = temp_dir.path().join("test.wal");
142 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
143
144 wal.append(Operation::Put, "key1", Some(b"value1"))?;
146 wal.append(Operation::Put, "key2", Some(b"value2"))?;
147
148 wal.truncate()?;
150
151 let entries = wal.recover()?;
153 assert_eq!(entries.len(), 0);
154
155 Ok(())
156 }
157
158 #[test]
159 fn test_transaction_operations() -> Result<()> {
160 let temp_dir = tempdir()?;
161 let log_path = temp_dir.path().join("test.wal");
162 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
163
164 wal.append(Operation::BeginTx, "", None)?;
166 wal.append(Operation::Put, "tx_key1", Some(b"tx_value1"))?;
167 wal.append(Operation::Put, "tx_key2", Some(b"tx_value2"))?;
168 wal.append(Operation::CommitTx, "", None)?;
169
170 let entries = wal.recover()?;
171 assert_eq!(entries.len(), 4);
172
173 assert!(matches!(entries[0].operation, Operation::BeginTx));
174 assert!(matches!(entries[1].operation, Operation::Put));
175 assert!(matches!(entries[2].operation, Operation::Put));
176 assert!(matches!(entries[3].operation, Operation::CommitTx));
177
178 Ok(())
179 }
180
181 #[test]
182 fn test_rollback_operation() -> Result<()> {
183 let temp_dir = tempdir()?;
184 let log_path = temp_dir.path().join("test.wal");
185 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
186 wal.append(Operation::BeginTx, "", None)?;
188 wal.append(Operation::Put, "key_to_rollback", Some(b"value"))?;
189 wal.append(Operation::RollbackTx, "", None)?;
190
191 let entries = wal.recover()?;
192 assert_eq!(entries.len(), 3);
193
194 assert!(matches!(entries[0].operation, Operation::BeginTx));
195 assert!(matches!(entries[1].operation, Operation::Put));
196 assert!(matches!(entries[2].operation, Operation::RollbackTx));
197
198 Ok(())
199 }
200
201 #[test]
202 fn test_large_values() -> Result<()> {
203 let temp_dir = tempdir()?;
204 let log_path = temp_dir.path().join("test.wal");
205 let mut wal = WriteAheadLog::new(log_path.to_str().unwrap())?;
206
207 let large_value = vec![0u8; 1024 * 1024]; wal.append(Operation::Put, "large_key", Some(&large_value))?;
210
211 let entries = wal.recover()?;
212 assert_eq!(entries.len(), 1);
213 assert_eq!(entries[0].value.as_ref().unwrap().len(), large_value.len());
214
215 Ok(())
216 }
217
218 #[test]
219 fn test_invalid_path() {
220 let result = WriteAheadLog::new("/nonexistent/directory/test.wal");
221 assert!(result.is_err());
222 }
223}