akv/persistence/
wal.rs

1use std::{
2    env,
3    error::Error,
4    ffi::OsString,
5    fs::{File, OpenOptions},
6    io::{BufReader, BufWriter, Read, Write},
7    path::{Path, PathBuf},
8    sync::Arc,
9};
10
11use bincode::{Decode, Encode};
12use tokio::{
13    fs,
14    sync::{mpsc, RwLock, Semaphore},
15};
16
17use crate::value::Entry;
18
19use super::config::PersistenceConfig;
20
21const WAL_MAGIC: &[u8; 4] = b"AKVW";
22const WAL_VERSION: u32 = 1;
23
24#[derive(Debug, Clone, Decode, Encode)]
25pub enum WalEntry {
26    Set { db: String, key: String, entry: Entry },
27    Delete { db: String, key: String },
28    Expire { db: String, key: String, expires_at: u64 },
29}
30
31impl WalEntry {
32    pub fn serialize(&self) -> Result<Vec<u8>, Box<dyn Error>> {
33        let config = bincode::config::standard()
34            .with_variable_int_encoding()
35            .with_little_endian();
36        Ok(bincode::encode_to_vec(self, config)?)
37    }
38
39    pub fn deserialize(bytes: &[u8]) -> Result<Self, Box<dyn Error>> {
40        let config = bincode::config::standard()
41            .with_variable_int_encoding()
42            .with_little_endian();
43        Ok(bincode::decode_from_slice(bytes, config)?.0)
44    }
45}
46
47#[derive(Clone)]
48pub struct WalManager {
49    config: PersistenceConfig,
50    wal_dir: PathBuf,
51    current_wal: Arc<RwLock<Option<WalFile>>>,
52    entry_sender: mpsc::Sender<WalEntry>,
53    write_semaphore: Arc<Semaphore>,
54}
55
56struct WalFile {
57    #[allow(dead_code)]
58    path: PathBuf,
59    writer: BufWriter<File>,
60    size: usize,
61    entry_count: usize,
62}
63
64impl WalManager {
65    pub async fn new(config: PersistenceConfig) -> Result<Self, Box<dyn Error>> {
66        let home_dir = env::var_os("HOME")
67            .or_else(|| env::var_os("USERPROFILE"))
68            .unwrap_or(OsString::from("./"));
69        let wal_dir = Path::new(&home_dir).join(".ahriknow/ahrikv/wal");
70        fs::create_dir_all(&wal_dir).await?;
71
72        let (entry_sender, entry_receiver) = mpsc::channel(config.buffer_max_entries);
73        let write_semaphore = Arc::new(Semaphore::new(1));
74
75        let manager = Self {
76            config: config.clone(),
77            wal_dir,
78            current_wal: Arc::new(RwLock::new(None)),
79            entry_sender,
80            write_semaphore,
81        };
82
83        manager.start_background_writer(entry_receiver).await?;
84
85        Ok(manager)
86    }
87
88    async fn start_background_writer(
89        &self,
90        mut entry_receiver: mpsc::Receiver<WalEntry>,
91    ) -> Result<(), Box<dyn Error>> {
92        let config = self.config.clone();
93        let wal_dir = self.wal_dir.clone();
94        let current_wal = Arc::clone(&self.current_wal);
95        let write_semaphore = Arc::clone(&self.write_semaphore);
96
97        tokio::spawn(async move {
98            let mut buffer = Vec::with_capacity(config.buffer_max_entries);
99            let mut buffer_size = 0;
100            let mut last_flush = tokio::time::Instant::now();
101
102            while let Some(entry) = entry_receiver.recv().await {
103                let serialized = match entry.serialize() {
104                    Ok(data) => data,
105                    Err(e) => {
106                        eprintln!("Failed to serialize WAL entry: {}", e);
107                        continue;
108                    }
109                };
110
111                let entry_size = serialized.len();
112                if buffer_size + entry_size > config.buffer_max_size
113                    || buffer.len() >= config.buffer_max_entries
114                {
115                    Self::flush_entries(
116                        &wal_dir,
117                        &current_wal,
118                        &write_semaphore,
119                        &config,
120                        &mut buffer,
121                        &mut buffer_size,
122                    )
123                    .await;
124                    last_flush = tokio::time::Instant::now();
125                }
126
127                buffer.push(entry);
128                buffer_size += entry_size;
129
130                if last_flush.elapsed() >= config.flush_interval {
131                    Self::flush_entries(
132                        &wal_dir,
133                        &current_wal,
134                        &write_semaphore,
135                        &config,
136                        &mut buffer,
137                        &mut buffer_size,
138                    )
139                    .await;
140                    last_flush = tokio::time::Instant::now();
141                }
142            }
143
144            if !buffer.is_empty() {
145                Self::flush_entries(
146                    &wal_dir,
147                    &current_wal,
148                    &write_semaphore,
149                    &config,
150                    &mut buffer,
151                    &mut buffer_size,
152                )
153                .await;
154            }
155        });
156
157        Ok(())
158    }
159
160    async fn flush_entries(
161        wal_dir: &Path,
162        current_wal: &Arc<RwLock<Option<WalFile>>>,
163        write_semaphore: &Arc<Semaphore>,
164        config: &PersistenceConfig,
165        buffer: &mut Vec<WalEntry>,
166        buffer_size: &mut usize,
167    ) {
168        if buffer.is_empty() {
169            return;
170        }
171
172        let _permit = write_semaphore.acquire().await.unwrap();
173
174        let entries: Vec<WalEntry> = buffer.drain(..).collect();
175        *buffer_size = 0;
176
177        let mut wal_guard = current_wal.write().await;
178
179        if wal_guard.is_none() || wal_guard.as_ref().unwrap().size > config.wal_max_size {
180            if let Some(mut old_wal) = wal_guard.take() {
181                let _ = old_wal.writer.flush();
182            }
183            *wal_guard = Self::create_new_wal(wal_dir).await;
184        }
185
186        if let Some(wal) = wal_guard.as_mut() {
187            for entry in &entries {
188                if let Ok(data) = entry.serialize() {
189                    let len = data.len() as u32;
190                    if wal.writer.write_all(&len.to_le_bytes()).is_ok()
191                        && wal.writer.write_all(&data).is_ok()
192                    {
193                        wal.size += 4 + data.len();
194                        wal.entry_count += 1;
195                    }
196                }
197            }
198            let _ = wal.writer.flush();
199        }
200    }
201
202    async fn create_new_wal(wal_dir: &Path) -> Option<WalFile> {
203        let timestamp = std::time::SystemTime::now()
204            .duration_since(std::time::UNIX_EPOCH)
205            .unwrap()
206            .as_secs();
207        let filename = format!("wal_{}.akv", timestamp);
208        let path = wal_dir.join(filename);
209
210        let file = match OpenOptions::new()
211            .create(true)
212            .write(true)
213            .truncate(true)
214            .open(&path)
215        {
216            Ok(f) => f,
217            Err(e) => {
218                eprintln!("Failed to create WAL file: {}", e);
219                return None;
220            }
221        };
222
223        let mut writer = BufWriter::new(file);
224
225        if writer.write_all(WAL_MAGIC).is_ok()
226            && writer.write_all(&WAL_VERSION.to_le_bytes()).is_ok()
227        {
228            Some(WalFile {
229                path,
230                writer,
231                size: 8,
232                entry_count: 0,
233            })
234        } else {
235            None
236        }
237    }
238
239    pub async fn append(&self, entry: WalEntry) -> Result<(), Box<dyn Error>> {
240        self.entry_sender.send(entry).await?;
241        Ok(())
242    }
243
244    pub async fn append_batch(&self, entries: Vec<WalEntry>) -> Result<(), Box<dyn Error>> {
245        for entry in entries {
246            self.entry_sender.send(entry).await?;
247        }
248        Ok(())
249    }
250
251    pub async fn recover(&self) -> Result<Vec<WalEntry>, Box<dyn Error>> {
252        let mut entries = Vec::new();
253        let mut wal_files: Vec<(u64, PathBuf)> = Vec::new();
254
255        let mut dir = fs::read_dir(&self.wal_dir).await?;
256        while let Some(entry) = dir.next_entry().await? {
257            let path = entry.path();
258            if path.extension().map_or(false, |e| e == "akv") {
259                if let Some(filename) = path.file_stem().and_then(|s| s.to_str()) {
260                    if filename.starts_with("wal_") {
261                        if let Ok(timestamp) = filename[4..].parse::<u64>() {
262                            wal_files.push((timestamp, path));
263                        }
264                    }
265                }
266            }
267        }
268
269        wal_files.sort_by_key(|(ts, _)| *ts);
270
271        for (_, path) in wal_files {
272            if let Ok(file_entries) = Self::read_wal_file(&path).await {
273                entries.extend(file_entries);
274            }
275        }
276
277        Ok(entries)
278    }
279
280    async fn read_wal_file(path: &Path) -> Result<Vec<WalEntry>, Box<dyn Error>> {
281        let file = File::open(path)?;
282        let mut reader = BufReader::new(file);
283
284        let mut magic = [0u8; 4];
285        reader.read_exact(&mut magic)?;
286        if &magic != WAL_MAGIC {
287            return Ok(Vec::new());
288        }
289
290        let mut version_bytes = [0u8; 4];
291        reader.read_exact(&mut version_bytes)?;
292        let _version = u32::from_le_bytes(version_bytes);
293
294        let mut entries = Vec::new();
295
296        loop {
297            let mut len_bytes = [0u8; 4];
298            match reader.read_exact(&mut len_bytes) {
299                Ok(_) => {}
300                Err(_) if entries.is_empty() => return Ok(entries),
301                Err(_) => break,
302            }
303
304            let len = u32::from_le_bytes(len_bytes) as usize;
305            let mut data = vec![0u8; len];
306            if reader.read_exact(&mut data).is_err() {
307                break;
308            }
309
310            if let Ok(entry) = WalEntry::deserialize(&data) {
311                entries.push(entry);
312            }
313        }
314
315        Ok(entries)
316    }
317
318    pub async fn cleanup_old_wals(&self, snapshot_timestamp: u64) -> Result<(), Box<dyn Error>> {
319        let mut dir = fs::read_dir(&self.wal_dir).await?;
320        let mut wal_files: Vec<(u64, PathBuf)> = Vec::new();
321
322        while let Some(entry) = dir.next_entry().await? {
323            let path = entry.path();
324            if path.extension().map_or(false, |e| e == "akv") {
325                if let Some(filename) = path.file_stem().and_then(|s| s.to_str()) {
326                    if filename.starts_with("wal_") {
327                        if let Ok(timestamp) = filename[4..].parse::<u64>() {
328                            wal_files.push((timestamp, path));
329                        }
330                    }
331                }
332            }
333        }
334
335        wal_files.sort_by_key(|(ts, _)| *ts);
336
337        let to_remove = wal_files
338            .into_iter()
339            .filter(|(ts, _)| *ts < snapshot_timestamp)
340            .take(self.config.wal_rotation_count);
341
342        for (_, path) in to_remove {
343            let _ = fs::remove_file(path).await;
344        }
345
346        Ok(())
347    }
348
349    pub async fn get_entry_count(&self) -> usize {
350        let wal_guard = self.current_wal.read().await;
351        wal_guard.as_ref().map(|w| w.entry_count).unwrap_or(0)
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::value::Value;
359
360    #[test]
361    fn test_wal_entry_serialize() {
362        let entry = WalEntry::Set {
363            db: "default".to_string(),
364            key: "test".to_string(),
365            entry: Entry {
366                value: Value::String("value".to_string()),
367                expires_at: None,
368            },
369        };
370
371        let serialized = entry.serialize().unwrap();
372        let deserialized = WalEntry::deserialize(&serialized).unwrap();
373
374        match deserialized {
375            WalEntry::Set { db, key, .. } => {
376                assert_eq!(db, "default");
377                assert_eq!(key, "test");
378            }
379            _ => panic!("Unexpected entry type"),
380        }
381    }
382}