use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use thiserror::Error;
use tokio::fs::{self, OpenOptions};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[derive(Debug, Error)]
pub enum WalError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("Deserialization error: {0}")]
Deserialization(String),
#[error("Corrupted WAL entry at sequence {0}")]
CorruptedEntry(u64),
#[error("Invalid WAL format")]
InvalidFormat,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum Operation {
WriteChunk {
cid: String,
chunk_index: u64,
data: Vec<u8>,
},
DeleteChunk { cid: String, chunk_index: u64 },
PinContent { cid: String, chunk_count: u64 },
UnpinContent { cid: String },
UpdateMetadata { cid: String, metadata: Vec<u8> },
Checkpoint { sequence: u64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogEntry {
pub sequence: u64,
pub operation: Operation,
pub timestamp_ms: i64,
}
impl LogEntry {
#[must_use]
pub fn new(sequence: u64, operation: Operation) -> Self {
let timestamp_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64;
Self {
sequence,
operation,
timestamp_ms,
}
}
#[must_use]
#[inline]
pub const fn sequence(&self) -> u64 {
self.sequence
}
#[must_use]
#[inline]
pub const fn operation(&self) -> &Operation {
&self.operation
}
fn to_bytes(&self) -> Result<Vec<u8>, WalError> {
let data = crate::serde_helpers::encode(self)
.map_err(|e| WalError::Serialization(e.to_string()))?;
let len = data.len() as u32;
let mut result = Vec::with_capacity(4 + data.len());
result.extend_from_slice(&len.to_le_bytes());
result.extend_from_slice(&data);
Ok(result)
}
fn from_bytes(bytes: &[u8]) -> Result<Self, WalError> {
crate::serde_helpers::decode(bytes).map_err(|e| WalError::Deserialization(e.to_string()))
}
}
pub struct WriteAheadLog {
log_path: PathBuf,
next_sequence: u64,
checkpoint_sequence: u64,
}
impl WriteAheadLog {
pub async fn new(log_path: PathBuf) -> Result<Self, WalError> {
if let Some(parent) = log_path.parent() {
fs::create_dir_all(parent).await?;
}
let mut wal = Self {
log_path,
next_sequence: 1,
checkpoint_sequence: 0,
};
if wal.log_path.exists() {
let entries = wal.replay().await?;
if let Some(last_entry) = entries.last() {
wal.next_sequence = last_entry.sequence + 1;
for entry in entries.iter().rev() {
if let Operation::Checkpoint { sequence } = entry.operation {
wal.checkpoint_sequence = sequence;
break;
}
}
}
}
Ok(wal)
}
pub async fn append(&mut self, entry: &LogEntry) -> Result<(), WalError> {
let bytes = entry.to_bytes()?;
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(&self.log_path)
.await?;
file.write_all(&bytes).await?;
file.sync_all().await?;
self.next_sequence = self.next_sequence.max(entry.sequence + 1);
Ok(())
}
pub async fn log_operation(&mut self, operation: Operation) -> Result<u64, WalError> {
let sequence = self.next_sequence;
let entry = LogEntry::new(sequence, operation);
self.append(&entry).await?;
Ok(sequence)
}
pub async fn replay(&self) -> Result<Vec<LogEntry>, WalError> {
if !self.log_path.exists() {
return Ok(Vec::new());
}
let mut file = fs::File::open(&self.log_path).await?;
let mut entries = Vec::new();
loop {
let mut len_bytes = [0u8; 4];
match file.read_exact(&mut len_bytes).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(WalError::Io(e)),
}
let len = u32::from_le_bytes(len_bytes) as usize;
let mut data = vec![0u8; len];
file.read_exact(&mut data).await?;
let entry = LogEntry::from_bytes(&data)?;
entries.push(entry);
}
Ok(entries)
}
pub async fn truncate(&mut self, up_to_sequence: u64) -> Result<(), WalError> {
let entries = self.replay().await?;
let remaining: Vec<LogEntry> = entries
.into_iter()
.filter(|e| e.sequence > up_to_sequence)
.collect();
if self.log_path.exists() {
fs::remove_file(&self.log_path).await?;
}
for entry in &remaining {
self.append(entry).await?;
}
self.checkpoint_sequence = up_to_sequence;
Ok(())
}
pub async fn checkpoint(&mut self) -> Result<u64, WalError> {
let sequence = self.next_sequence;
let operation = Operation::Checkpoint { sequence };
self.log_operation(operation).await?;
self.checkpoint_sequence = sequence;
Ok(sequence)
}
pub async fn entries_since_checkpoint(&self) -> Result<Vec<LogEntry>, WalError> {
let all_entries = self.replay().await?;
Ok(all_entries
.into_iter()
.filter(|e| e.sequence > self.checkpoint_sequence)
.collect())
}
#[must_use]
#[inline]
pub const fn next_sequence(&self) -> u64 {
self.next_sequence
}
#[must_use]
#[inline]
pub const fn checkpoint_sequence(&self) -> u64 {
self.checkpoint_sequence
}
pub async fn clear(&mut self) -> Result<(), WalError> {
if self.log_path.exists() {
fs::remove_file(&self.log_path).await?;
}
self.next_sequence = 1;
self.checkpoint_sequence = 0;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_wal_creation() {
let temp_dir = TempDir::new().unwrap();
let log_path = temp_dir.path().join("test.wal");
let wal = WriteAheadLog::new(log_path).await.unwrap();
assert_eq!(wal.next_sequence(), 1);
assert_eq!(wal.checkpoint_sequence(), 0);
}
#[tokio::test]
async fn test_wal_append_and_replay() {
let temp_dir = TempDir::new().unwrap();
let log_path = temp_dir.path().join("test.wal");
let mut wal = WriteAheadLog::new(log_path.clone()).await.unwrap();
let op1 = Operation::WriteChunk {
cid: "QmTest1".to_string(),
chunk_index: 0,
data: vec![1, 2, 3],
};
let op2 = Operation::WriteChunk {
cid: "QmTest2".to_string(),
chunk_index: 1,
data: vec![4, 5, 6],
};
wal.log_operation(op1.clone()).await.unwrap();
wal.log_operation(op2.clone()).await.unwrap();
let entries = wal.replay().await.unwrap();
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].sequence, 1);
assert_eq!(entries[1].sequence, 2);
assert_eq!(entries[0].operation, op1);
assert_eq!(entries[1].operation, op2);
}
#[tokio::test]
async fn test_wal_checkpoint() {
let temp_dir = TempDir::new().unwrap();
let log_path = temp_dir.path().join("test.wal");
let mut wal = WriteAheadLog::new(log_path).await.unwrap();
wal.log_operation(Operation::PinContent {
cid: "QmTest".to_string(),
chunk_count: 5,
})
.await
.unwrap();
let checkpoint_seq = wal.checkpoint().await.unwrap();
assert_eq!(checkpoint_seq, 2);
assert_eq!(wal.checkpoint_sequence(), 2);
}
#[tokio::test]
async fn test_wal_truncate() {
let temp_dir = TempDir::new().unwrap();
let log_path = temp_dir.path().join("test.wal");
let mut wal = WriteAheadLog::new(log_path).await.unwrap();
for i in 0..5 {
wal.log_operation(Operation::WriteChunk {
cid: format!("QmTest{}", i),
chunk_index: i,
data: vec![i as u8],
})
.await
.unwrap();
}
wal.truncate(3).await.unwrap();
let entries = wal.replay().await.unwrap();
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].sequence, 4);
assert_eq!(entries[1].sequence, 5);
}
#[tokio::test]
async fn test_wal_entries_since_checkpoint() {
let temp_dir = TempDir::new().unwrap();
let log_path = temp_dir.path().join("test.wal");
let mut wal = WriteAheadLog::new(log_path).await.unwrap();
wal.log_operation(Operation::PinContent {
cid: "QmTest1".to_string(),
chunk_count: 1,
})
.await
.unwrap();
wal.log_operation(Operation::PinContent {
cid: "QmTest2".to_string(),
chunk_count: 2,
})
.await
.unwrap();
wal.checkpoint().await.unwrap();
wal.log_operation(Operation::PinContent {
cid: "QmTest3".to_string(),
chunk_count: 3,
})
.await
.unwrap();
let entries = wal.entries_since_checkpoint().await.unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].sequence, 4);
}
#[tokio::test]
async fn test_wal_persistence() {
let temp_dir = TempDir::new().unwrap();
let log_path = temp_dir.path().join("test.wal");
{
let mut wal = WriteAheadLog::new(log_path.clone()).await.unwrap();
wal.log_operation(Operation::PinContent {
cid: "QmPersist".to_string(),
chunk_count: 10,
})
.await
.unwrap();
}
let wal = WriteAheadLog::new(log_path).await.unwrap();
assert_eq!(wal.next_sequence(), 2);
let entries = wal.replay().await.unwrap();
assert_eq!(entries.len(), 1);
}
#[tokio::test]
async fn test_wal_clear() {
let temp_dir = TempDir::new().unwrap();
let log_path = temp_dir.path().join("test.wal");
let mut wal = WriteAheadLog::new(log_path).await.unwrap();
for i in 0..3 {
wal.log_operation(Operation::DeleteChunk {
cid: format!("QmTest{}", i),
chunk_index: i,
})
.await
.unwrap();
}
wal.clear().await.unwrap();
let entries = wal.replay().await.unwrap();
assert_eq!(entries.len(), 0);
assert_eq!(wal.next_sequence(), 1);
}
}