amq/persistence/
wal.rs

1use std::{
2    env,
3    error::Error,
4    ffi::OsString,
5    fs::{File, OpenOptions},
6    io::{BufReader, BufWriter, Read, Write},
7    path::{Path, PathBuf},
8    sync::Arc,
9};
10
11use bincode::{Decode, Encode};
12use tokio::{
13    fs,
14    sync::{mpsc, RwLock, Semaphore},
15};
16
17use crate::message::{MessageBox, MessageHistory, MessageStatus};
18
19use super::config::PersistenceConfig;
20
21const WAL_MAGIC: &[u8; 4] = b"AMQW";
22const WAL_VERSION: u32 = 1;
23
24#[derive(Debug, Clone, Decode, Encode)]
25pub enum WalEntry {
26    AddMessage(MessageBox),
27    UpdateStatus { id: u64, status: MessageStatus },
28    AddHistory { id: u64, history: MessageHistory },
29    RemoveMessage(u64),
30}
31
32impl WalEntry {
33    pub fn serialize(&self) -> Result<Vec<u8>, Box<dyn Error>> {
34        let config = bincode::config::standard()
35            .with_variable_int_encoding()
36            .with_little_endian();
37        Ok(bincode::encode_to_vec(self, config)?)
38    }
39
40    pub fn deserialize(bytes: &[u8]) -> Result<Self, Box<dyn Error>> {
41        let config = bincode::config::standard()
42            .with_variable_int_encoding()
43            .with_little_endian();
44        Ok(bincode::decode_from_slice(bytes, config)?.0)
45    }
46}
47
48#[derive(Clone)]
49pub struct WalManager {
50    config: PersistenceConfig,
51    wal_dir: PathBuf,
52    current_wal: Arc<RwLock<Option<WalFile>>>,
53    entry_sender: mpsc::Sender<WalEntry>,
54    write_semaphore: Arc<Semaphore>,
55}
56
57struct WalFile {
58    #[allow(dead_code)]
59    path: PathBuf,
60    writer: BufWriter<File>,
61    size: usize,
62    entry_count: usize,
63}
64
65impl WalManager {
66    pub async fn new(config: PersistenceConfig) -> Result<Self, Box<dyn Error>> {
67        let home_dir = env::var_os("HOME")
68            .or_else(|| env::var_os("USERPROFILE"))
69            .unwrap_or(OsString::from("./"));
70        let wal_dir = Path::new(&home_dir).join(".ahriknow/ahrimq/wal");
71        fs::create_dir_all(&wal_dir).await?;
72
73        let (entry_sender, entry_receiver) = mpsc::channel(config.buffer_max_entries);
74        let write_semaphore = Arc::new(Semaphore::new(1));
75
76        let manager = Self {
77            config: config.clone(),
78            wal_dir,
79            current_wal: Arc::new(RwLock::new(None)),
80            entry_sender,
81            write_semaphore,
82        };
83
84        manager.start_background_writer(entry_receiver).await?;
85
86        Ok(manager)
87    }
88
89    async fn start_background_writer(
90        &self,
91        mut entry_receiver: mpsc::Receiver<WalEntry>,
92    ) -> Result<(), Box<dyn Error>> {
93        let config = self.config.clone();
94        let wal_dir = self.wal_dir.clone();
95        let current_wal = Arc::clone(&self.current_wal);
96        let write_semaphore = Arc::clone(&self.write_semaphore);
97
98        tokio::spawn(async move {
99            let mut buffer = Vec::with_capacity(config.buffer_max_entries);
100            let mut buffer_size = 0;
101            let mut last_flush = tokio::time::Instant::now();
102
103            while let Some(entry) = entry_receiver.recv().await {
104                let serialized = match entry.serialize() {
105                    Ok(data) => data,
106                    Err(e) => {
107                        eprintln!("Failed to serialize WAL entry: {}", e);
108                        continue;
109                    }
110                };
111
112                let entry_size = serialized.len();
113                if buffer_size + entry_size > config.buffer_max_size
114                    || buffer.len() >= config.buffer_max_entries
115                {
116                    Self::flush_entries(
117                        &wal_dir,
118                        &current_wal,
119                        &write_semaphore,
120                        &config,
121                        &mut buffer,
122                        &mut buffer_size,
123                    )
124                    .await;
125                    last_flush = tokio::time::Instant::now();
126                }
127
128                buffer.push(entry);
129                buffer_size += entry_size;
130
131                if last_flush.elapsed() >= config.flush_interval {
132                    Self::flush_entries(
133                        &wal_dir,
134                        &current_wal,
135                        &write_semaphore,
136                        &config,
137                        &mut buffer,
138                        &mut buffer_size,
139                    )
140                    .await;
141                    last_flush = tokio::time::Instant::now();
142                }
143            }
144
145            if !buffer.is_empty() {
146                Self::flush_entries(
147                    &wal_dir,
148                    &current_wal,
149                    &write_semaphore,
150                    &config,
151                    &mut buffer,
152                    &mut buffer_size,
153                )
154                .await;
155            }
156        });
157
158        Ok(())
159    }
160
161    async fn flush_entries(
162        wal_dir: &Path,
163        current_wal: &Arc<RwLock<Option<WalFile>>>,
164        write_semaphore: &Arc<Semaphore>,
165        config: &PersistenceConfig,
166        buffer: &mut Vec<WalEntry>,
167        buffer_size: &mut usize,
168    ) {
169        if buffer.is_empty() {
170            return;
171        }
172
173        let _permit = write_semaphore.acquire().await.unwrap();
174
175        let entries: Vec<WalEntry> = buffer.drain(..).collect();
176        *buffer_size = 0;
177
178        let mut wal_guard = current_wal.write().await;
179
180        if wal_guard.is_none() || wal_guard.as_ref().unwrap().size > config.wal_max_size {
181            if let Some(mut old_wal) = wal_guard.take() {
182                let _ = old_wal.writer.flush();
183            }
184            *wal_guard = Self::create_new_wal(wal_dir).await;
185        }
186
187        if let Some(wal) = wal_guard.as_mut() {
188            for entry in &entries {
189                if let Ok(data) = entry.serialize() {
190                    let len = data.len() as u32;
191                    if wal.writer.write_all(&len.to_le_bytes()).is_ok()
192                        && wal.writer.write_all(&data).is_ok()
193                    {
194                        wal.size += 4 + data.len();
195                        wal.entry_count += 1;
196                    }
197                }
198            }
199            let _ = wal.writer.flush();
200        }
201    }
202
203    async fn create_new_wal(wal_dir: &Path) -> Option<WalFile> {
204        let timestamp = std::time::SystemTime::now()
205            .duration_since(std::time::UNIX_EPOCH)
206            .unwrap()
207            .as_secs();
208        let filename = format!("wal_{}.amq", timestamp);
209        let path = wal_dir.join(filename);
210
211        let file = match OpenOptions::new()
212            .create(true)
213            .write(true)
214            .truncate(true)
215            .open(&path)
216        {
217            Ok(f) => f,
218            Err(e) => {
219                eprintln!("Failed to create WAL file: {}", e);
220                return None;
221            }
222        };
223
224        let mut writer = BufWriter::new(file);
225
226        if writer.write_all(WAL_MAGIC).is_ok()
227            && writer.write_all(&WAL_VERSION.to_le_bytes()).is_ok()
228        {
229            Some(WalFile {
230                path,
231                writer,
232                size: 8,
233                entry_count: 0,
234            })
235        } else {
236            None
237        }
238    }
239
240    pub async fn append(&self, entry: WalEntry) -> Result<(), Box<dyn Error>> {
241        self.entry_sender.send(entry).await?;
242        Ok(())
243    }
244
245    pub async fn append_batch(&self, entries: Vec<WalEntry>) -> Result<(), Box<dyn Error>> {
246        for entry in entries {
247            self.entry_sender.send(entry).await?;
248        }
249        Ok(())
250    }
251
252    pub async fn recover(&self) -> Result<Vec<WalEntry>, Box<dyn Error>> {
253        let mut entries = Vec::new();
254        let mut wal_files: Vec<(u64, PathBuf)> = Vec::new();
255
256        let mut dir = fs::read_dir(&self.wal_dir).await?;
257        while let Some(entry) = dir.next_entry().await? {
258            let path = entry.path();
259            if path.extension().map_or(false, |e| e == "amq") {
260                if let Some(filename) = path.file_stem().and_then(|s| s.to_str()) {
261                    if filename.starts_with("wal_") {
262                        if let Ok(timestamp) = filename[4..].parse::<u64>() {
263                            wal_files.push((timestamp, path));
264                        }
265                    }
266                }
267            }
268        }
269
270        wal_files.sort_by_key(|(ts, _)| *ts);
271
272        for (_, path) in wal_files {
273            if let Ok(file_entries) = Self::read_wal_file(&path).await {
274                entries.extend(file_entries);
275            }
276        }
277
278        Ok(entries)
279    }
280
281    async fn read_wal_file(path: &Path) -> Result<Vec<WalEntry>, Box<dyn Error>> {
282        let file = File::open(path)?;
283        let mut reader = BufReader::new(file);
284
285        let mut magic = [0u8; 4];
286        reader.read_exact(&mut magic)?;
287        if &magic != WAL_MAGIC {
288            return Ok(Vec::new());
289        }
290
291        let mut version_bytes = [0u8; 4];
292        reader.read_exact(&mut version_bytes)?;
293        let _version = u32::from_le_bytes(version_bytes);
294
295        let mut entries = Vec::new();
296
297        loop {
298            let mut len_bytes = [0u8; 4];
299            match reader.read_exact(&mut len_bytes) {
300                Ok(_) => {}
301                Err(_) if entries.is_empty() => return Ok(entries),
302                Err(_) => break,
303            }
304
305            let len = u32::from_le_bytes(len_bytes) as usize;
306            let mut data = vec![0u8; len];
307            if reader.read_exact(&mut data).is_err() {
308                break;
309            }
310
311            if let Ok(entry) = WalEntry::deserialize(&data) {
312                entries.push(entry);
313            }
314        }
315
316        Ok(entries)
317    }
318
319    pub async fn cleanup_old_wals(&self, snapshot_timestamp: u64) -> Result<(), Box<dyn Error>> {
320        let mut dir = fs::read_dir(&self.wal_dir).await?;
321        let mut wal_files: Vec<(u64, PathBuf)> = Vec::new();
322
323        while let Some(entry) = dir.next_entry().await? {
324            let path = entry.path();
325            if path.extension().map_or(false, |e| e == "amq") {
326                if let Some(filename) = path.file_stem().and_then(|s| s.to_str()) {
327                    if filename.starts_with("wal_") {
328                        if let Ok(timestamp) = filename[4..].parse::<u64>() {
329                            wal_files.push((timestamp, path));
330                        }
331                    }
332                }
333            }
334        }
335
336        wal_files.sort_by_key(|(ts, _)| *ts);
337
338        let to_remove = wal_files
339            .into_iter()
340            .filter(|(ts, _)| *ts < snapshot_timestamp)
341            .take(self.config.wal_rotation_count);
342
343        for (_, path) in to_remove {
344            let _ = fs::remove_file(path).await;
345        }
346
347        Ok(())
348    }
349
350    pub async fn get_entry_count(&self) -> usize {
351        let wal_guard = self.current_wal.read().await;
352        wal_guard.as_ref().map(|w| w.entry_count).unwrap_or(0)
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_wal_entry_serialize() {
362        let entry = WalEntry::AddMessage(MessageBox {
363            id: 1,
364            status: MessageStatus::New,
365            timestamp: 0,
366            message: crate::message::Message::ReqPing(crate::message::ReqMsgPing {}),
367            history: vec![],
368        });
369
370        let serialized = entry.serialize().unwrap();
371        let deserialized = WalEntry::deserialize(&serialized).unwrap();
372
373        match deserialized {
374            WalEntry::AddMessage(msg) => {
375                assert_eq!(msg.id, 1);
376            }
377            _ => panic!("Unexpected entry type"),
378        }
379    }
380}