1use std::{
2 env,
3 error::Error,
4 ffi::OsString,
5 fs::{File, OpenOptions},
6 io::{BufReader, BufWriter, Read, Write},
7 path::{Path, PathBuf},
8 sync::Arc,
9};
10
11use bincode::{Decode, Encode};
12use tokio::{
13 fs,
14 sync::{mpsc, RwLock, Semaphore},
15};
16
17use crate::message::{MessageBox, MessageHistory, MessageStatus};
18
19use super::config::PersistenceConfig;
20
21const WAL_MAGIC: &[u8; 4] = b"AMQW";
22const WAL_VERSION: u32 = 1;
23
24#[derive(Debug, Clone, Decode, Encode)]
25pub enum WalEntry {
26 AddMessage(MessageBox),
27 UpdateStatus { id: u64, status: MessageStatus },
28 AddHistory { id: u64, history: MessageHistory },
29 RemoveMessage(u64),
30}
31
32impl WalEntry {
33 pub fn serialize(&self) -> Result<Vec<u8>, Box<dyn Error>> {
34 let config = bincode::config::standard()
35 .with_variable_int_encoding()
36 .with_little_endian();
37 Ok(bincode::encode_to_vec(self, config)?)
38 }
39
40 pub fn deserialize(bytes: &[u8]) -> Result<Self, Box<dyn Error>> {
41 let config = bincode::config::standard()
42 .with_variable_int_encoding()
43 .with_little_endian();
44 Ok(bincode::decode_from_slice(bytes, config)?.0)
45 }
46}
47
48#[derive(Clone)]
49pub struct WalManager {
50 config: PersistenceConfig,
51 wal_dir: PathBuf,
52 current_wal: Arc<RwLock<Option<WalFile>>>,
53 entry_sender: mpsc::Sender<WalEntry>,
54 write_semaphore: Arc<Semaphore>,
55}
56
57struct WalFile {
58 #[allow(dead_code)]
59 path: PathBuf,
60 writer: BufWriter<File>,
61 size: usize,
62 entry_count: usize,
63}
64
65impl WalManager {
66 pub async fn new(config: PersistenceConfig) -> Result<Self, Box<dyn Error>> {
67 let home_dir = env::var_os("HOME")
68 .or_else(|| env::var_os("USERPROFILE"))
69 .unwrap_or(OsString::from("./"));
70 let wal_dir = Path::new(&home_dir).join(".ahriknow/ahrimq/wal");
71 fs::create_dir_all(&wal_dir).await?;
72
73 let (entry_sender, entry_receiver) = mpsc::channel(config.buffer_max_entries);
74 let write_semaphore = Arc::new(Semaphore::new(1));
75
76 let manager = Self {
77 config: config.clone(),
78 wal_dir,
79 current_wal: Arc::new(RwLock::new(None)),
80 entry_sender,
81 write_semaphore,
82 };
83
84 manager.start_background_writer(entry_receiver).await?;
85
86 Ok(manager)
87 }
88
89 async fn start_background_writer(
90 &self,
91 mut entry_receiver: mpsc::Receiver<WalEntry>,
92 ) -> Result<(), Box<dyn Error>> {
93 let config = self.config.clone();
94 let wal_dir = self.wal_dir.clone();
95 let current_wal = Arc::clone(&self.current_wal);
96 let write_semaphore = Arc::clone(&self.write_semaphore);
97
98 tokio::spawn(async move {
99 let mut buffer = Vec::with_capacity(config.buffer_max_entries);
100 let mut buffer_size = 0;
101 let mut last_flush = tokio::time::Instant::now();
102
103 while let Some(entry) = entry_receiver.recv().await {
104 let serialized = match entry.serialize() {
105 Ok(data) => data,
106 Err(e) => {
107 eprintln!("Failed to serialize WAL entry: {}", e);
108 continue;
109 }
110 };
111
112 let entry_size = serialized.len();
113 if buffer_size + entry_size > config.buffer_max_size
114 || buffer.len() >= config.buffer_max_entries
115 {
116 Self::flush_entries(
117 &wal_dir,
118 ¤t_wal,
119 &write_semaphore,
120 &config,
121 &mut buffer,
122 &mut buffer_size,
123 )
124 .await;
125 last_flush = tokio::time::Instant::now();
126 }
127
128 buffer.push(entry);
129 buffer_size += entry_size;
130
131 if last_flush.elapsed() >= config.flush_interval {
132 Self::flush_entries(
133 &wal_dir,
134 ¤t_wal,
135 &write_semaphore,
136 &config,
137 &mut buffer,
138 &mut buffer_size,
139 )
140 .await;
141 last_flush = tokio::time::Instant::now();
142 }
143 }
144
145 if !buffer.is_empty() {
146 Self::flush_entries(
147 &wal_dir,
148 ¤t_wal,
149 &write_semaphore,
150 &config,
151 &mut buffer,
152 &mut buffer_size,
153 )
154 .await;
155 }
156 });
157
158 Ok(())
159 }
160
161 async fn flush_entries(
162 wal_dir: &Path,
163 current_wal: &Arc<RwLock<Option<WalFile>>>,
164 write_semaphore: &Arc<Semaphore>,
165 config: &PersistenceConfig,
166 buffer: &mut Vec<WalEntry>,
167 buffer_size: &mut usize,
168 ) {
169 if buffer.is_empty() {
170 return;
171 }
172
173 let _permit = write_semaphore.acquire().await.unwrap();
174
175 let entries: Vec<WalEntry> = buffer.drain(..).collect();
176 *buffer_size = 0;
177
178 let mut wal_guard = current_wal.write().await;
179
180 if wal_guard.is_none() || wal_guard.as_ref().unwrap().size > config.wal_max_size {
181 if let Some(mut old_wal) = wal_guard.take() {
182 let _ = old_wal.writer.flush();
183 }
184 *wal_guard = Self::create_new_wal(wal_dir).await;
185 }
186
187 if let Some(wal) = wal_guard.as_mut() {
188 for entry in &entries {
189 if let Ok(data) = entry.serialize() {
190 let len = data.len() as u32;
191 if wal.writer.write_all(&len.to_le_bytes()).is_ok()
192 && wal.writer.write_all(&data).is_ok()
193 {
194 wal.size += 4 + data.len();
195 wal.entry_count += 1;
196 }
197 }
198 }
199 let _ = wal.writer.flush();
200 }
201 }
202
203 async fn create_new_wal(wal_dir: &Path) -> Option<WalFile> {
204 let timestamp = std::time::SystemTime::now()
205 .duration_since(std::time::UNIX_EPOCH)
206 .unwrap()
207 .as_secs();
208 let filename = format!("wal_{}.amq", timestamp);
209 let path = wal_dir.join(filename);
210
211 let file = match OpenOptions::new()
212 .create(true)
213 .write(true)
214 .truncate(true)
215 .open(&path)
216 {
217 Ok(f) => f,
218 Err(e) => {
219 eprintln!("Failed to create WAL file: {}", e);
220 return None;
221 }
222 };
223
224 let mut writer = BufWriter::new(file);
225
226 if writer.write_all(WAL_MAGIC).is_ok()
227 && writer.write_all(&WAL_VERSION.to_le_bytes()).is_ok()
228 {
229 Some(WalFile {
230 path,
231 writer,
232 size: 8,
233 entry_count: 0,
234 })
235 } else {
236 None
237 }
238 }
239
240 pub async fn append(&self, entry: WalEntry) -> Result<(), Box<dyn Error>> {
241 self.entry_sender.send(entry).await?;
242 Ok(())
243 }
244
245 pub async fn append_batch(&self, entries: Vec<WalEntry>) -> Result<(), Box<dyn Error>> {
246 for entry in entries {
247 self.entry_sender.send(entry).await?;
248 }
249 Ok(())
250 }
251
252 pub async fn recover(&self) -> Result<Vec<WalEntry>, Box<dyn Error>> {
253 let mut entries = Vec::new();
254 let mut wal_files: Vec<(u64, PathBuf)> = Vec::new();
255
256 let mut dir = fs::read_dir(&self.wal_dir).await?;
257 while let Some(entry) = dir.next_entry().await? {
258 let path = entry.path();
259 if path.extension().map_or(false, |e| e == "amq") {
260 if let Some(filename) = path.file_stem().and_then(|s| s.to_str()) {
261 if filename.starts_with("wal_") {
262 if let Ok(timestamp) = filename[4..].parse::<u64>() {
263 wal_files.push((timestamp, path));
264 }
265 }
266 }
267 }
268 }
269
270 wal_files.sort_by_key(|(ts, _)| *ts);
271
272 for (_, path) in wal_files {
273 if let Ok(file_entries) = Self::read_wal_file(&path).await {
274 entries.extend(file_entries);
275 }
276 }
277
278 Ok(entries)
279 }
280
281 async fn read_wal_file(path: &Path) -> Result<Vec<WalEntry>, Box<dyn Error>> {
282 let file = File::open(path)?;
283 let mut reader = BufReader::new(file);
284
285 let mut magic = [0u8; 4];
286 reader.read_exact(&mut magic)?;
287 if &magic != WAL_MAGIC {
288 return Ok(Vec::new());
289 }
290
291 let mut version_bytes = [0u8; 4];
292 reader.read_exact(&mut version_bytes)?;
293 let _version = u32::from_le_bytes(version_bytes);
294
295 let mut entries = Vec::new();
296
297 loop {
298 let mut len_bytes = [0u8; 4];
299 match reader.read_exact(&mut len_bytes) {
300 Ok(_) => {}
301 Err(_) if entries.is_empty() => return Ok(entries),
302 Err(_) => break,
303 }
304
305 let len = u32::from_le_bytes(len_bytes) as usize;
306 let mut data = vec![0u8; len];
307 if reader.read_exact(&mut data).is_err() {
308 break;
309 }
310
311 if let Ok(entry) = WalEntry::deserialize(&data) {
312 entries.push(entry);
313 }
314 }
315
316 Ok(entries)
317 }
318
319 pub async fn cleanup_old_wals(&self, snapshot_timestamp: u64) -> Result<(), Box<dyn Error>> {
320 let mut dir = fs::read_dir(&self.wal_dir).await?;
321 let mut wal_files: Vec<(u64, PathBuf)> = Vec::new();
322
323 while let Some(entry) = dir.next_entry().await? {
324 let path = entry.path();
325 if path.extension().map_or(false, |e| e == "amq") {
326 if let Some(filename) = path.file_stem().and_then(|s| s.to_str()) {
327 if filename.starts_with("wal_") {
328 if let Ok(timestamp) = filename[4..].parse::<u64>() {
329 wal_files.push((timestamp, path));
330 }
331 }
332 }
333 }
334 }
335
336 wal_files.sort_by_key(|(ts, _)| *ts);
337
338 let to_remove = wal_files
339 .into_iter()
340 .filter(|(ts, _)| *ts < snapshot_timestamp)
341 .take(self.config.wal_rotation_count);
342
343 for (_, path) in to_remove {
344 let _ = fs::remove_file(path).await;
345 }
346
347 Ok(())
348 }
349
350 pub async fn get_entry_count(&self) -> usize {
351 let wal_guard = self.current_wal.read().await;
352 wal_guard.as_ref().map(|w| w.entry_count).unwrap_or(0)
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_wal_entry_serialize() {
362 let entry = WalEntry::AddMessage(MessageBox {
363 id: 1,
364 status: MessageStatus::New,
365 timestamp: 0,
366 message: crate::message::Message::ReqPing(crate::message::ReqMsgPing {}),
367 history: vec![],
368 });
369
370 let serialized = entry.serialize().unwrap();
371 let deserialized = WalEntry::deserialize(&serialized).unwrap();
372
373 match deserialized {
374 WalEntry::AddMessage(msg) => {
375 assert_eq!(msg.id, 1);
376 }
377 _ => panic!("Unexpected entry type"),
378 }
379 }
380}