1use serde::{Deserialize, Serialize};
42use std::path::PathBuf;
43use thiserror::Error;
44use tokio::fs::{self, OpenOptions};
45use tokio::io::{AsyncReadExt, AsyncWriteExt};
46
47#[derive(Debug, Error)]
49pub enum WalError {
50 #[error("IO error: {0}")]
51 Io(#[from] std::io::Error),
52
53 #[error("Serialization error: {0}")]
54 Serialization(String),
55
56 #[error("Deserialization error: {0}")]
57 Deserialization(String),
58
59 #[error("Corrupted WAL entry at sequence {0}")]
60 CorruptedEntry(u64),
61
62 #[error("Invalid WAL format")]
63 InvalidFormat,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
68pub enum Operation {
69 WriteChunk {
71 cid: String,
72 chunk_index: u64,
73 data: Vec<u8>,
74 },
75 DeleteChunk { cid: String, chunk_index: u64 },
77 PinContent { cid: String, chunk_count: u64 },
79 UnpinContent { cid: String },
81 UpdateMetadata { cid: String, metadata: Vec<u8> },
83 Checkpoint { sequence: u64 },
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct LogEntry {
90 pub sequence: u64,
92 pub operation: Operation,
94 pub timestamp_ms: i64,
96}
97
98impl LogEntry {
99 #[must_use]
101 pub fn new(sequence: u64, operation: Operation) -> Self {
102 let timestamp_ms = std::time::SystemTime::now()
103 .duration_since(std::time::UNIX_EPOCH)
104 .unwrap_or_default()
105 .as_millis() as i64;
106
107 Self {
108 sequence,
109 operation,
110 timestamp_ms,
111 }
112 }
113
114 #[must_use]
116 #[inline]
117 pub const fn sequence(&self) -> u64 {
118 self.sequence
119 }
120
121 #[must_use]
123 #[inline]
124 pub const fn operation(&self) -> &Operation {
125 &self.operation
126 }
127
128 fn to_bytes(&self) -> Result<Vec<u8>, WalError> {
130 let data = crate::serde_helpers::encode(self)
131 .map_err(|e| WalError::Serialization(e.to_string()))?;
132
133 let len = data.len() as u32;
135 let mut result = Vec::with_capacity(4 + data.len());
136 result.extend_from_slice(&len.to_le_bytes());
137 result.extend_from_slice(&data);
138
139 Ok(result)
140 }
141
142 fn from_bytes(bytes: &[u8]) -> Result<Self, WalError> {
144 crate::serde_helpers::decode(bytes).map_err(|e| WalError::Deserialization(e.to_string()))
145 }
146}
147
148pub struct WriteAheadLog {
150 log_path: PathBuf,
151 next_sequence: u64,
152 checkpoint_sequence: u64,
153}
154
155impl WriteAheadLog {
156 pub async fn new(log_path: PathBuf) -> Result<Self, WalError> {
158 if let Some(parent) = log_path.parent() {
160 fs::create_dir_all(parent).await?;
161 }
162
163 let mut wal = Self {
164 log_path,
165 next_sequence: 1,
166 checkpoint_sequence: 0,
167 };
168
169 if wal.log_path.exists() {
171 let entries = wal.replay().await?;
172 if let Some(last_entry) = entries.last() {
173 wal.next_sequence = last_entry.sequence + 1;
174
175 for entry in entries.iter().rev() {
177 if let Operation::Checkpoint { sequence } = entry.operation {
178 wal.checkpoint_sequence = sequence;
179 break;
180 }
181 }
182 }
183 }
184
185 Ok(wal)
186 }
187
188 pub async fn append(&mut self, entry: &LogEntry) -> Result<(), WalError> {
190 let bytes = entry.to_bytes()?;
191
192 let mut file = OpenOptions::new()
193 .create(true)
194 .append(true)
195 .open(&self.log_path)
196 .await?;
197
198 file.write_all(&bytes).await?;
199 file.sync_all().await?; self.next_sequence = self.next_sequence.max(entry.sequence + 1);
202
203 Ok(())
204 }
205
206 pub async fn log_operation(&mut self, operation: Operation) -> Result<u64, WalError> {
208 let sequence = self.next_sequence;
209 let entry = LogEntry::new(sequence, operation);
210 self.append(&entry).await?;
211 Ok(sequence)
212 }
213
214 pub async fn replay(&self) -> Result<Vec<LogEntry>, WalError> {
218 if !self.log_path.exists() {
219 return Ok(Vec::new());
220 }
221
222 let mut file = fs::File::open(&self.log_path).await?;
223 let mut entries = Vec::new();
224
225 loop {
226 let mut len_bytes = [0u8; 4];
228 match file.read_exact(&mut len_bytes).await {
229 Ok(_) => {}
230 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
231 Err(e) => return Err(WalError::Io(e)),
232 }
233
234 let len = u32::from_le_bytes(len_bytes) as usize;
235
236 let mut data = vec![0u8; len];
238 file.read_exact(&mut data).await?;
239
240 let entry = LogEntry::from_bytes(&data)?;
242 entries.push(entry);
243 }
244
245 Ok(entries)
246 }
247
248 pub async fn truncate(&mut self, up_to_sequence: u64) -> Result<(), WalError> {
252 let entries = self.replay().await?;
253 let remaining: Vec<LogEntry> = entries
254 .into_iter()
255 .filter(|e| e.sequence > up_to_sequence)
256 .collect();
257
258 if self.log_path.exists() {
260 fs::remove_file(&self.log_path).await?;
261 }
262
263 for entry in &remaining {
264 self.append(entry).await?;
265 }
266
267 self.checkpoint_sequence = up_to_sequence;
268
269 Ok(())
270 }
271
272 pub async fn checkpoint(&mut self) -> Result<u64, WalError> {
274 let sequence = self.next_sequence;
275 let operation = Operation::Checkpoint { sequence };
276 self.log_operation(operation).await?;
277 self.checkpoint_sequence = sequence;
278 Ok(sequence)
279 }
280
281 pub async fn entries_since_checkpoint(&self) -> Result<Vec<LogEntry>, WalError> {
283 let all_entries = self.replay().await?;
284 Ok(all_entries
285 .into_iter()
286 .filter(|e| e.sequence > self.checkpoint_sequence)
287 .collect())
288 }
289
290 #[must_use]
292 #[inline]
293 pub const fn next_sequence(&self) -> u64 {
294 self.next_sequence
295 }
296
297 #[must_use]
299 #[inline]
300 pub const fn checkpoint_sequence(&self) -> u64 {
301 self.checkpoint_sequence
302 }
303
304 pub async fn clear(&mut self) -> Result<(), WalError> {
306 if self.log_path.exists() {
307 fs::remove_file(&self.log_path).await?;
308 }
309 self.next_sequence = 1;
310 self.checkpoint_sequence = 0;
311 Ok(())
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use tempfile::TempDir;
319
320 #[tokio::test]
321 async fn test_wal_creation() {
322 let temp_dir = TempDir::new().unwrap();
323 let log_path = temp_dir.path().join("test.wal");
324
325 let wal = WriteAheadLog::new(log_path).await.unwrap();
326 assert_eq!(wal.next_sequence(), 1);
327 assert_eq!(wal.checkpoint_sequence(), 0);
328 }
329
330 #[tokio::test]
331 async fn test_wal_append_and_replay() {
332 let temp_dir = TempDir::new().unwrap();
333 let log_path = temp_dir.path().join("test.wal");
334
335 let mut wal = WriteAheadLog::new(log_path.clone()).await.unwrap();
336
337 let op1 = Operation::WriteChunk {
339 cid: "QmTest1".to_string(),
340 chunk_index: 0,
341 data: vec![1, 2, 3],
342 };
343 let op2 = Operation::WriteChunk {
344 cid: "QmTest2".to_string(),
345 chunk_index: 1,
346 data: vec![4, 5, 6],
347 };
348
349 wal.log_operation(op1.clone()).await.unwrap();
350 wal.log_operation(op2.clone()).await.unwrap();
351
352 let entries = wal.replay().await.unwrap();
354 assert_eq!(entries.len(), 2);
355 assert_eq!(entries[0].sequence, 1);
356 assert_eq!(entries[1].sequence, 2);
357 assert_eq!(entries[0].operation, op1);
358 assert_eq!(entries[1].operation, op2);
359 }
360
361 #[tokio::test]
362 async fn test_wal_checkpoint() {
363 let temp_dir = TempDir::new().unwrap();
364 let log_path = temp_dir.path().join("test.wal");
365
366 let mut wal = WriteAheadLog::new(log_path).await.unwrap();
367
368 wal.log_operation(Operation::PinContent {
370 cid: "QmTest".to_string(),
371 chunk_count: 5,
372 })
373 .await
374 .unwrap();
375
376 let checkpoint_seq = wal.checkpoint().await.unwrap();
378 assert_eq!(checkpoint_seq, 2);
379 assert_eq!(wal.checkpoint_sequence(), 2);
380 }
381
382 #[tokio::test]
383 async fn test_wal_truncate() {
384 let temp_dir = TempDir::new().unwrap();
385 let log_path = temp_dir.path().join("test.wal");
386
387 let mut wal = WriteAheadLog::new(log_path).await.unwrap();
388
389 for i in 0..5 {
391 wal.log_operation(Operation::WriteChunk {
392 cid: format!("QmTest{}", i),
393 chunk_index: i,
394 data: vec![i as u8],
395 })
396 .await
397 .unwrap();
398 }
399
400 wal.truncate(3).await.unwrap();
402
403 let entries = wal.replay().await.unwrap();
405 assert_eq!(entries.len(), 2);
406 assert_eq!(entries[0].sequence, 4);
407 assert_eq!(entries[1].sequence, 5);
408 }
409
410 #[tokio::test]
411 async fn test_wal_entries_since_checkpoint() {
412 let temp_dir = TempDir::new().unwrap();
413 let log_path = temp_dir.path().join("test.wal");
414
415 let mut wal = WriteAheadLog::new(log_path).await.unwrap();
416
417 wal.log_operation(Operation::PinContent {
419 cid: "QmTest1".to_string(),
420 chunk_count: 1,
421 })
422 .await
423 .unwrap();
424 wal.log_operation(Operation::PinContent {
425 cid: "QmTest2".to_string(),
426 chunk_count: 2,
427 })
428 .await
429 .unwrap();
430
431 wal.checkpoint().await.unwrap();
433
434 wal.log_operation(Operation::PinContent {
436 cid: "QmTest3".to_string(),
437 chunk_count: 3,
438 })
439 .await
440 .unwrap();
441
442 let entries = wal.entries_since_checkpoint().await.unwrap();
444 assert_eq!(entries.len(), 1);
445 assert_eq!(entries[0].sequence, 4);
446 }
447
448 #[tokio::test]
449 async fn test_wal_persistence() {
450 let temp_dir = TempDir::new().unwrap();
451 let log_path = temp_dir.path().join("test.wal");
452
453 {
454 let mut wal = WriteAheadLog::new(log_path.clone()).await.unwrap();
455 wal.log_operation(Operation::PinContent {
456 cid: "QmPersist".to_string(),
457 chunk_count: 10,
458 })
459 .await
460 .unwrap();
461 }
462
463 let wal = WriteAheadLog::new(log_path).await.unwrap();
465 assert_eq!(wal.next_sequence(), 2); let entries = wal.replay().await.unwrap();
468 assert_eq!(entries.len(), 1);
469 }
470
471 #[tokio::test]
472 async fn test_wal_clear() {
473 let temp_dir = TempDir::new().unwrap();
474 let log_path = temp_dir.path().join("test.wal");
475
476 let mut wal = WriteAheadLog::new(log_path).await.unwrap();
477
478 for i in 0..3 {
480 wal.log_operation(Operation::DeleteChunk {
481 cid: format!("QmTest{}", i),
482 chunk_index: i,
483 })
484 .await
485 .unwrap();
486 }
487
488 wal.clear().await.unwrap();
490
491 let entries = wal.replay().await.unwrap();
493 assert_eq!(entries.len(), 0);
494 assert_eq!(wal.next_sequence(), 1);
495 }
496}