1use std::{
2 env,
3 error::Error,
4 ffi::OsString,
5 fs::{File, OpenOptions},
6 io::{BufReader, BufWriter, Read, Write},
7 path::{Path, PathBuf},
8 sync::Arc,
9};
10
11use bincode::{Decode, Encode};
12use tokio::{
13 fs,
14 sync::{mpsc, RwLock, Semaphore},
15};
16
17use crate::value::Entry;
18
19use super::config::PersistenceConfig;
20
21const WAL_MAGIC: &[u8; 4] = b"AKVW";
22const WAL_VERSION: u32 = 1;
23
24#[derive(Debug, Clone, Decode, Encode)]
25pub enum WalEntry {
26 Set { db: String, key: String, entry: Entry },
27 Delete { db: String, key: String },
28 Expire { db: String, key: String, expires_at: u64 },
29}
30
31impl WalEntry {
32 pub fn serialize(&self) -> Result<Vec<u8>, Box<dyn Error>> {
33 let config = bincode::config::standard()
34 .with_variable_int_encoding()
35 .with_little_endian();
36 Ok(bincode::encode_to_vec(self, config)?)
37 }
38
39 pub fn deserialize(bytes: &[u8]) -> Result<Self, Box<dyn Error>> {
40 let config = bincode::config::standard()
41 .with_variable_int_encoding()
42 .with_little_endian();
43 Ok(bincode::decode_from_slice(bytes, config)?.0)
44 }
45}
46
47#[derive(Clone)]
48pub struct WalManager {
49 config: PersistenceConfig,
50 wal_dir: PathBuf,
51 current_wal: Arc<RwLock<Option<WalFile>>>,
52 entry_sender: mpsc::Sender<WalEntry>,
53 write_semaphore: Arc<Semaphore>,
54}
55
56struct WalFile {
57 #[allow(dead_code)]
58 path: PathBuf,
59 writer: BufWriter<File>,
60 size: usize,
61 entry_count: usize,
62}
63
64impl WalManager {
65 pub async fn new(config: PersistenceConfig) -> Result<Self, Box<dyn Error>> {
66 let home_dir = env::var_os("HOME")
67 .or_else(|| env::var_os("USERPROFILE"))
68 .unwrap_or(OsString::from("./"));
69 let wal_dir = Path::new(&home_dir).join(".ahriknow/ahrikv/wal");
70 fs::create_dir_all(&wal_dir).await?;
71
72 let (entry_sender, entry_receiver) = mpsc::channel(config.buffer_max_entries);
73 let write_semaphore = Arc::new(Semaphore::new(1));
74
75 let manager = Self {
76 config: config.clone(),
77 wal_dir,
78 current_wal: Arc::new(RwLock::new(None)),
79 entry_sender,
80 write_semaphore,
81 };
82
83 manager.start_background_writer(entry_receiver).await?;
84
85 Ok(manager)
86 }
87
88 async fn start_background_writer(
89 &self,
90 mut entry_receiver: mpsc::Receiver<WalEntry>,
91 ) -> Result<(), Box<dyn Error>> {
92 let config = self.config.clone();
93 let wal_dir = self.wal_dir.clone();
94 let current_wal = Arc::clone(&self.current_wal);
95 let write_semaphore = Arc::clone(&self.write_semaphore);
96
97 tokio::spawn(async move {
98 let mut buffer = Vec::with_capacity(config.buffer_max_entries);
99 let mut buffer_size = 0;
100 let mut last_flush = tokio::time::Instant::now();
101
102 while let Some(entry) = entry_receiver.recv().await {
103 let serialized = match entry.serialize() {
104 Ok(data) => data,
105 Err(e) => {
106 eprintln!("Failed to serialize WAL entry: {}", e);
107 continue;
108 }
109 };
110
111 let entry_size = serialized.len();
112 if buffer_size + entry_size > config.buffer_max_size
113 || buffer.len() >= config.buffer_max_entries
114 {
115 Self::flush_entries(
116 &wal_dir,
117 ¤t_wal,
118 &write_semaphore,
119 &config,
120 &mut buffer,
121 &mut buffer_size,
122 )
123 .await;
124 last_flush = tokio::time::Instant::now();
125 }
126
127 buffer.push(entry);
128 buffer_size += entry_size;
129
130 if last_flush.elapsed() >= config.flush_interval {
131 Self::flush_entries(
132 &wal_dir,
133 ¤t_wal,
134 &write_semaphore,
135 &config,
136 &mut buffer,
137 &mut buffer_size,
138 )
139 .await;
140 last_flush = tokio::time::Instant::now();
141 }
142 }
143
144 if !buffer.is_empty() {
145 Self::flush_entries(
146 &wal_dir,
147 ¤t_wal,
148 &write_semaphore,
149 &config,
150 &mut buffer,
151 &mut buffer_size,
152 )
153 .await;
154 }
155 });
156
157 Ok(())
158 }
159
160 async fn flush_entries(
161 wal_dir: &Path,
162 current_wal: &Arc<RwLock<Option<WalFile>>>,
163 write_semaphore: &Arc<Semaphore>,
164 config: &PersistenceConfig,
165 buffer: &mut Vec<WalEntry>,
166 buffer_size: &mut usize,
167 ) {
168 if buffer.is_empty() {
169 return;
170 }
171
172 let _permit = write_semaphore.acquire().await.unwrap();
173
174 let entries: Vec<WalEntry> = buffer.drain(..).collect();
175 *buffer_size = 0;
176
177 let mut wal_guard = current_wal.write().await;
178
179 if wal_guard.is_none() || wal_guard.as_ref().unwrap().size > config.wal_max_size {
180 if let Some(mut old_wal) = wal_guard.take() {
181 let _ = old_wal.writer.flush();
182 }
183 *wal_guard = Self::create_new_wal(wal_dir).await;
184 }
185
186 if let Some(wal) = wal_guard.as_mut() {
187 for entry in &entries {
188 if let Ok(data) = entry.serialize() {
189 let len = data.len() as u32;
190 if wal.writer.write_all(&len.to_le_bytes()).is_ok()
191 && wal.writer.write_all(&data).is_ok()
192 {
193 wal.size += 4 + data.len();
194 wal.entry_count += 1;
195 }
196 }
197 }
198 let _ = wal.writer.flush();
199 }
200 }
201
202 async fn create_new_wal(wal_dir: &Path) -> Option<WalFile> {
203 let timestamp = std::time::SystemTime::now()
204 .duration_since(std::time::UNIX_EPOCH)
205 .unwrap()
206 .as_secs();
207 let filename = format!("wal_{}.akv", timestamp);
208 let path = wal_dir.join(filename);
209
210 let file = match OpenOptions::new()
211 .create(true)
212 .write(true)
213 .truncate(true)
214 .open(&path)
215 {
216 Ok(f) => f,
217 Err(e) => {
218 eprintln!("Failed to create WAL file: {}", e);
219 return None;
220 }
221 };
222
223 let mut writer = BufWriter::new(file);
224
225 if writer.write_all(WAL_MAGIC).is_ok()
226 && writer.write_all(&WAL_VERSION.to_le_bytes()).is_ok()
227 {
228 Some(WalFile {
229 path,
230 writer,
231 size: 8,
232 entry_count: 0,
233 })
234 } else {
235 None
236 }
237 }
238
239 pub async fn append(&self, entry: WalEntry) -> Result<(), Box<dyn Error>> {
240 self.entry_sender.send(entry).await?;
241 Ok(())
242 }
243
244 pub async fn append_batch(&self, entries: Vec<WalEntry>) -> Result<(), Box<dyn Error>> {
245 for entry in entries {
246 self.entry_sender.send(entry).await?;
247 }
248 Ok(())
249 }
250
251 pub async fn recover(&self) -> Result<Vec<WalEntry>, Box<dyn Error>> {
252 let mut entries = Vec::new();
253 let mut wal_files: Vec<(u64, PathBuf)> = Vec::new();
254
255 let mut dir = fs::read_dir(&self.wal_dir).await?;
256 while let Some(entry) = dir.next_entry().await? {
257 let path = entry.path();
258 if path.extension().map_or(false, |e| e == "akv") {
259 if let Some(filename) = path.file_stem().and_then(|s| s.to_str()) {
260 if filename.starts_with("wal_") {
261 if let Ok(timestamp) = filename[4..].parse::<u64>() {
262 wal_files.push((timestamp, path));
263 }
264 }
265 }
266 }
267 }
268
269 wal_files.sort_by_key(|(ts, _)| *ts);
270
271 for (_, path) in wal_files {
272 if let Ok(file_entries) = Self::read_wal_file(&path).await {
273 entries.extend(file_entries);
274 }
275 }
276
277 Ok(entries)
278 }
279
280 async fn read_wal_file(path: &Path) -> Result<Vec<WalEntry>, Box<dyn Error>> {
281 let file = File::open(path)?;
282 let mut reader = BufReader::new(file);
283
284 let mut magic = [0u8; 4];
285 reader.read_exact(&mut magic)?;
286 if &magic != WAL_MAGIC {
287 return Ok(Vec::new());
288 }
289
290 let mut version_bytes = [0u8; 4];
291 reader.read_exact(&mut version_bytes)?;
292 let _version = u32::from_le_bytes(version_bytes);
293
294 let mut entries = Vec::new();
295
296 loop {
297 let mut len_bytes = [0u8; 4];
298 match reader.read_exact(&mut len_bytes) {
299 Ok(_) => {}
300 Err(_) if entries.is_empty() => return Ok(entries),
301 Err(_) => break,
302 }
303
304 let len = u32::from_le_bytes(len_bytes) as usize;
305 let mut data = vec![0u8; len];
306 if reader.read_exact(&mut data).is_err() {
307 break;
308 }
309
310 if let Ok(entry) = WalEntry::deserialize(&data) {
311 entries.push(entry);
312 }
313 }
314
315 Ok(entries)
316 }
317
318 pub async fn cleanup_old_wals(&self, snapshot_timestamp: u64) -> Result<(), Box<dyn Error>> {
319 let mut dir = fs::read_dir(&self.wal_dir).await?;
320 let mut wal_files: Vec<(u64, PathBuf)> = Vec::new();
321
322 while let Some(entry) = dir.next_entry().await? {
323 let path = entry.path();
324 if path.extension().map_or(false, |e| e == "akv") {
325 if let Some(filename) = path.file_stem().and_then(|s| s.to_str()) {
326 if filename.starts_with("wal_") {
327 if let Ok(timestamp) = filename[4..].parse::<u64>() {
328 wal_files.push((timestamp, path));
329 }
330 }
331 }
332 }
333 }
334
335 wal_files.sort_by_key(|(ts, _)| *ts);
336
337 let to_remove = wal_files
338 .into_iter()
339 .filter(|(ts, _)| *ts < snapshot_timestamp)
340 .take(self.config.wal_rotation_count);
341
342 for (_, path) in to_remove {
343 let _ = fs::remove_file(path).await;
344 }
345
346 Ok(())
347 }
348
349 pub async fn get_entry_count(&self) -> usize {
350 let wal_guard = self.current_wal.read().await;
351 wal_guard.as_ref().map(|w| w.entry_count).unwrap_or(0)
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::value::Value;
359
360 #[test]
361 fn test_wal_entry_serialize() {
362 let entry = WalEntry::Set {
363 db: "default".to_string(),
364 key: "test".to_string(),
365 entry: Entry {
366 value: Value::String("value".to_string()),
367 expires_at: None,
368 },
369 };
370
371 let serialized = entry.serialize().unwrap();
372 let deserialized = WalEntry::deserialize(&serialized).unwrap();
373
374 match deserialized {
375 WalEntry::Set { db, key, .. } => {
376 assert_eq!(db, "default");
377 assert_eq!(key, "test");
378 }
379 _ => panic!("Unexpected entry type"),
380 }
381 }
382}