use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
#[derive(Serialize, Deserialize, Debug)]
pub enum WalEntry {
AddBegin {
path: String,
},
AddEnd,
CommitBegin {
head_backup: Option<String>,
index_backup: Vec<u8>,
},
CommitEnd,
}
pub struct Wal {
path: PathBuf,
}
impl Wal {
pub fn new(shard_dir: &Path) -> Self {
Self {
path: shard_dir.join("wal.log"),
}
}
pub fn exists(&self) -> bool {
self.path.exists()
}
pub fn append(&self, entry: &WalEntry) -> Result<()> {
let mut file = fs::OpenOptions::new()
.create(true)
.append(true)
.open(&self.path)?;
let line = serde_json::to_string(entry)?;
writeln!(file, "{}", line)?;
file.flush()?;
Ok(())
}
pub fn read(&self) -> Result<Vec<WalEntry>> {
if !self.path.exists() {
return Ok(Vec::new());
}
let content = fs::read_to_string(&self.path)?;
content
.lines()
.filter(|l| !l.trim().is_empty())
.map(|l| serde_json::from_str(l).map_err(Into::into))
.collect()
}
pub fn truncate(&self) -> Result<()> {
if self.path.exists() {
fs::remove_file(&self.path)?;
}
Ok(())
}
}
pub fn recover(shard_dir: &Path) -> Result<()> {
let wal = Wal::new(shard_dir);
if !wal.exists() {
return Ok(());
}
let entries = wal.read()?;
if entries.is_empty() {
wal.truncate()?;
return Ok(());
}
let has_commit_begin = entries
.iter()
.any(|e| matches!(e, WalEntry::CommitBegin { .. }));
let has_commit_end = entries.iter().any(|e| matches!(e, WalEntry::CommitEnd));
if has_commit_begin && !has_commit_end {
for entry in &entries {
if let WalEntry::CommitBegin {
head_backup,
index_backup,
} = entry
{
let head_path = shard_dir.join("HEAD");
match head_backup {
Some(head) => fs::write(&head_path, head)?,
None => {
let _ = fs::remove_file(&head_path);
}
}
fs::write(shard_dir.join("index"), index_backup)?;
}
}
eprintln!("Recovered from incomplete commit (rolled back)");
}
wal.truncate()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_wal_append_read_roundtrip() {
let dir = tempdir().unwrap();
let wal = Wal::new(dir.path());
wal.append(&WalEntry::AddBegin {
path: "f.txt".into(),
})
.unwrap();
wal.append(&WalEntry::AddEnd).unwrap();
let entries = wal.read().unwrap();
assert_eq!(entries.len(), 2);
assert!(matches!(entries[0], WalEntry::AddBegin { .. }));
assert!(matches!(entries[1], WalEntry::AddEnd));
}
#[test]
fn test_wal_read_empty() {
let dir = tempdir().unwrap();
let wal = Wal::new(dir.path());
assert!(!wal.exists());
let entries = wal.read().unwrap();
assert!(entries.is_empty());
}
#[test]
fn test_wal_truncate() {
let dir = tempdir().unwrap();
let wal = Wal::new(dir.path());
wal.append(&WalEntry::AddEnd).unwrap();
assert!(wal.exists());
wal.truncate().unwrap();
assert!(!wal.exists());
}
#[test]
fn test_wal_commit_begin_end_roundtrip() {
let dir = tempdir().unwrap();
let wal = Wal::new(dir.path());
wal.append(&WalEntry::CommitBegin {
head_backup: Some("abc".into()),
index_backup: b"index_data".to_vec(),
})
.unwrap();
wal.append(&WalEntry::CommitEnd).unwrap();
let entries = wal.read().unwrap();
assert_eq!(entries.len(), 2);
if let WalEntry::CommitBegin {
head_backup,
index_backup,
} = &entries[0]
{
assert_eq!(head_backup.as_deref(), Some("abc"));
assert_eq!(index_backup, b"index_data");
} else {
panic!("Expected CommitBegin");
}
}
#[test]
fn test_recover_no_wal() {
let dir = tempdir().unwrap();
recover(dir.path()).unwrap();
}
#[test]
fn test_recover_empty_wal() {
let dir = tempdir().unwrap();
let wal = Wal::new(dir.path());
wal.append(&WalEntry::AddEnd).unwrap();
recover(dir.path()).unwrap();
assert!(!wal.exists());
}
#[test]
fn test_recover_incomplete_commit() {
let dir = tempdir().unwrap();
let shard = dir.path();
fs::write(shard.join("HEAD"), "ref: refs/heads/main").unwrap();
fs::write(shard.join("index"), b"original_index").unwrap();
let wal = Wal::new(shard);
wal.append(&WalEntry::CommitBegin {
head_backup: Some("ref: refs/heads/main".into()),
index_backup: b"original_index".to_vec(),
})
.unwrap();
fs::write(shard.join("HEAD"), "new_commit_id").unwrap();
fs::write(shard.join("index"), b"new_index").unwrap();
recover(shard).unwrap();
assert_eq!(
fs::read_to_string(shard.join("HEAD")).unwrap(),
"ref: refs/heads/main"
);
assert_eq!(fs::read(shard.join("index")).unwrap(), b"original_index");
}
#[test]
fn test_recover_incomplete_add() {
let dir = tempdir().unwrap();
let wal = Wal::new(dir.path());
wal.append(&WalEntry::AddBegin {
path: "f.txt".into(),
})
.unwrap();
recover(dir.path()).unwrap();
assert!(!wal.exists());
}
}