use crate::error::{AmateRSError, ErrorContext, Result};
use crate::types::{CipherBlob, Key};
use std::fs::{File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Default)]
pub struct RecoveryStats {
pub entries_recovered: u64,
pub entries_corrupted: u64,
pub bytes_recovered: u64,
}
#[derive(Debug, Clone, PartialEq)]
pub enum WalEntryType {
Put = 1,
Delete = 2,
}
#[derive(Debug, Clone, PartialEq)]
pub struct WalEntry {
pub sequence: u64,
pub entry_type: WalEntryType,
pub key: Key,
pub value: Option<CipherBlob>,
pub checksum: u32,
}
impl WalEntry {
pub fn put(sequence: u64, key: Key, value: CipherBlob) -> Self {
let mut entry = Self {
sequence,
entry_type: WalEntryType::Put,
key,
value: Some(value),
checksum: 0,
};
entry.checksum = entry.calculate_checksum();
entry
}
pub fn delete(sequence: u64, key: Key) -> Self {
let mut entry = Self {
sequence,
entry_type: WalEntryType::Delete,
key,
value: None,
checksum: 0,
};
entry.checksum = entry.calculate_checksum();
entry
}
fn calculate_checksum(&self) -> u32 {
let mut hasher = crc32fast::Hasher::new();
hasher.update(&self.sequence.to_le_bytes());
hasher.update(&[self.entry_type.clone() as u8]);
hasher.update(self.key.as_bytes());
if let Some(ref value) = self.value {
hasher.update(value.as_bytes());
}
hasher.finalize()
}
pub fn verify_checksum(&self) -> Result<()> {
let calculated = self.calculate_checksum();
if calculated == self.checksum {
Ok(())
} else {
Err(AmateRSError::StorageIntegrity(ErrorContext::new(format!(
"WAL entry checksum mismatch: expected {}, got {}",
self.checksum, calculated
))))
}
}
pub fn encode(&self) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.extend_from_slice(&0x57414Cu32.to_le_bytes());
bytes.extend_from_slice(&self.sequence.to_le_bytes());
bytes.push(self.entry_type.clone() as u8);
bytes.extend_from_slice(&(self.key.len() as u32).to_le_bytes());
bytes.extend_from_slice(self.key.as_bytes());
if let Some(ref value) = self.value {
bytes.extend_from_slice(&(value.len() as u32).to_le_bytes());
bytes.extend_from_slice(value.as_bytes());
} else {
bytes.extend_from_slice(&0u32.to_le_bytes());
}
bytes.extend_from_slice(&self.checksum.to_le_bytes());
bytes
}
pub fn decode(bytes: &[u8]) -> Result<Self> {
if bytes.len() < 17 {
return Err(AmateRSError::SerializationError(ErrorContext::new(
"WAL entry too short",
)));
}
let mut offset = 0;
let magic = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
if magic != 0x57414C {
return Err(AmateRSError::SerializationError(ErrorContext::new(
"Invalid WAL entry magic number",
)));
}
offset += 4;
let sequence = u64::from_le_bytes(bytes[offset..offset + 8].try_into().map_err(|_| {
AmateRSError::SerializationError(ErrorContext::new("Failed to read sequence"))
})?);
offset += 8;
let entry_type = match bytes[offset] {
1 => WalEntryType::Put,
2 => WalEntryType::Delete,
_ => {
return Err(AmateRSError::SerializationError(ErrorContext::new(
"Invalid WAL entry type",
)));
}
};
offset += 1;
let key_len = u32::from_le_bytes(bytes[offset..offset + 4].try_into().map_err(|_| {
AmateRSError::SerializationError(ErrorContext::new("Failed to read key length"))
})?) as usize;
offset += 4;
let key_bytes = &bytes[offset..offset + key_len];
let key = Key::from_slice(key_bytes);
offset += key_len;
let value_len = u32::from_le_bytes(bytes[offset..offset + 4].try_into().map_err(|_| {
AmateRSError::SerializationError(ErrorContext::new("Failed to read value length"))
})?) as usize;
offset += 4;
let value = if value_len > 0 {
let value_bytes = &bytes[offset..offset + value_len];
Some(CipherBlob::new(value_bytes.to_vec()))
} else {
None
};
offset += value_len;
let checksum = u32::from_le_bytes(bytes[offset..offset + 4].try_into().map_err(|_| {
AmateRSError::SerializationError(ErrorContext::new("Failed to read checksum"))
})?);
let entry = Self {
sequence,
entry_type,
key,
value,
checksum,
};
entry.verify_checksum()?;
Ok(entry)
}
}
#[derive(Debug, Clone)]
pub struct WalConfig {
pub wal_dir: PathBuf,
pub max_file_size: u64,
pub max_wal_files: usize,
pub sync_on_write: bool,
}
impl Default for WalConfig {
fn default() -> Self {
Self {
wal_dir: PathBuf::from("./wal"),
max_file_size: 64 * 1024 * 1024, max_wal_files: 10,
sync_on_write: true,
}
}
}
pub struct Wal {
config: WalConfig,
current_path: PathBuf,
writer: BufWriter<File>,
sequence: u64,
current_file_size: u64,
current_file_number: u64,
}
impl Wal {
pub fn create(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let parent = path.parent().ok_or_else(|| {
AmateRSError::IoError(ErrorContext::new("WAL path has no parent directory"))
})?;
let config = WalConfig {
wal_dir: parent.to_path_buf(),
..Default::default()
};
Self::with_config(config)
}
pub fn with_config(config: WalConfig) -> Result<Self> {
std::fs::create_dir_all(&config.wal_dir).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to create WAL directory: {}",
e
)))
})?;
let (file_number, sequence) = Self::find_latest_wal(&config)?;
let current_path = Self::wal_file_path(&config.wal_dir, file_number);
let file = OpenOptions::new()
.create(true)
.append(true)
.open(¤t_path)
.map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("Failed to open WAL: {}", e)))
})?;
let current_file_size = file
.metadata()
.map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to get WAL file size: {}",
e
)))
})?
.len();
Ok(Self {
config,
current_path,
writer: BufWriter::new(file),
sequence,
current_file_size,
current_file_number: file_number,
})
}
fn find_latest_wal(config: &WalConfig) -> Result<(u64, u64)> {
let mut max_file_number = 0u64;
let mut max_sequence = 0u64;
if config.wal_dir.exists() {
let wal_file_numbers = Self::list_wal_file_numbers(&config.wal_dir)?;
if let Some(&last) = wal_file_numbers.last() {
max_file_number = last;
}
for file_num in &wal_file_numbers {
let file_path = Self::wal_file_path(&config.wal_dir, *file_num);
if let Ok(mut reader) = WalReader::open(&file_path) {
loop {
match reader.read_entry() {
Ok(Some(entry)) => {
if entry.sequence >= max_sequence {
max_sequence = entry.sequence + 1;
}
}
Ok(None) => break,
Err(_) => {
tracing::warn!(
"Corrupted entry found in WAL file {} during startup",
file_path.display()
);
continue;
}
}
}
}
}
}
Ok((max_file_number, max_sequence))
}
fn wal_file_path(wal_dir: &Path, file_number: u64) -> PathBuf {
wal_dir.join(format!("wal_{:08}.log", file_number))
}
fn list_wal_file_numbers(wal_dir: &Path) -> Result<Vec<u64>> {
let entries = std::fs::read_dir(wal_dir).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to read WAL directory: {}",
e
)))
})?;
let mut numbers = Vec::new();
for entry in entries {
let entry = entry.map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to read directory entry: {}",
e
)))
})?;
let file_name = entry.file_name();
let name = file_name.to_string_lossy();
if name.starts_with("wal_") && name.ends_with(".log") {
if let Ok(number) = name[4..name.len() - 4].parse::<u64>() {
numbers.push(number);
}
}
}
numbers.sort_unstable();
Ok(numbers)
}
pub fn put(&mut self, key: Key, value: CipherBlob) -> Result<u64> {
let sequence = self.sequence;
self.sequence += 1;
let entry = WalEntry::put(sequence, key, value);
self.write_entry(&entry)?;
Ok(sequence)
}
pub fn delete(&mut self, key: Key) -> Result<u64> {
let sequence = self.sequence;
self.sequence += 1;
let entry = WalEntry::delete(sequence, key);
self.write_entry(&entry)?;
Ok(sequence)
}
fn write_entry(&mut self, entry: &WalEntry) -> Result<()> {
let bytes = entry.encode();
let len = bytes.len() as u32;
self.writer.write_all(&len.to_le_bytes()).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to write WAL entry: {}",
e
)))
})?;
self.writer.write_all(&bytes).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to write WAL entry: {}",
e
)))
})?;
let entry_size = (4 + bytes.len()) as u64; self.current_file_size += entry_size;
if self.config.sync_on_write {
self.writer.flush().map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("Failed to flush WAL: {}", e)))
})?;
}
if self.current_file_size >= self.config.max_file_size {
self.rotate()?;
}
Ok(())
}
pub fn rotate(&mut self) -> Result<()> {
self.flush()?;
self.current_file_number += 1;
let new_path = Self::wal_file_path(&self.config.wal_dir, self.current_file_number);
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&new_path)
.map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to create new WAL file: {}",
e
)))
})?;
self.current_path = new_path;
self.writer = BufWriter::new(file);
self.current_file_size = 0;
self.cleanup_old_wal_files()?;
Ok(())
}
fn cleanup_old_wal_files(&self) -> Result<()> {
let wal_files = Self::list_wal_file_numbers(&self.config.wal_dir)?;
if wal_files.len() > self.config.max_wal_files {
let files_to_delete = wal_files.len() - self.config.max_wal_files;
for &file_number in wal_files.iter().take(files_to_delete) {
let file_path = Self::wal_file_path(&self.config.wal_dir, file_number);
std::fs::remove_file(&file_path).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to delete old WAL file: {}",
e
)))
})?;
}
}
Ok(())
}
pub fn cleanup(&self) -> Result<()> {
self.cleanup_old_wal_files()
}
pub fn current_file_size(&self) -> u64 {
self.current_file_size
}
pub fn current_file_number(&self) -> u64 {
self.current_file_number
}
pub fn flush(&mut self) -> Result<()> {
self.writer.flush().map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("Failed to flush WAL: {}", e)))
})?;
self.writer.get_ref().sync_all().map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("Failed to sync WAL: {}", e)))
})?;
Ok(())
}
pub fn sequence(&self) -> u64 {
self.sequence
}
pub fn path(&self) -> &Path {
&self.current_path
}
pub fn recover(wal_dir: impl AsRef<Path>) -> Result<(Vec<WalEntry>, u64)> {
let wal_dir = wal_dir.as_ref();
if !wal_dir.exists() {
return Ok((Vec::new(), 0));
}
let wal_files = Self::list_wal_file_numbers(wal_dir)?;
let mut all_entries = Vec::new();
let mut max_sequence = 0u64;
for file_number in wal_files {
let file_path = Self::wal_file_path(wal_dir, file_number);
let mut reader = WalReader::open(&file_path)?;
loop {
match reader.read_entry() {
Ok(Some(entry)) => {
if entry.sequence > max_sequence {
max_sequence = entry.sequence;
}
all_entries.push(entry);
}
Ok(None) => break,
Err(e) => {
tracing::warn!(
"Skipping corrupted entry in {}: {}",
file_path.display(),
e
);
continue;
}
}
}
}
Ok((all_entries, max_sequence))
}
pub fn current_size(&self) -> u64 {
self.current_file_size
}
pub fn total_wal_size(&self) -> Result<u64> {
let wal_files = Self::list_wal_file_numbers(&self.config.wal_dir)?;
let mut total_size = 0u64;
for file_number in wal_files {
let file_path = Self::wal_file_path(&self.config.wal_dir, file_number);
let metadata = std::fs::metadata(&file_path).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to read WAL file metadata: {}",
e
)))
})?;
total_size += metadata.len();
}
Ok(total_size)
}
pub fn truncate_before(&mut self, sequence: u64) -> Result<u64> {
self.flush()?;
let all_files = Self::list_wal_file_numbers(&self.config.wal_dir)?;
let wal_files: Vec<u64> = all_files
.into_iter()
.filter(|&n| n != self.current_file_number)
.collect();
let mut files_truncated = 0u64;
for file_number in wal_files {
let file_path = Self::wal_file_path(&self.config.wal_dir, file_number);
let mut file_max_seq = 0u64;
if let Ok(mut reader) = WalReader::open(&file_path) {
loop {
match reader.read_entry() {
Ok(Some(entry)) => {
if entry.sequence > file_max_seq {
file_max_seq = entry.sequence;
}
}
Ok(None) => break,
Err(_) => continue,
}
}
}
if file_max_seq <= sequence {
std::fs::remove_file(&file_path).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to remove WAL file {}: {}",
file_path.display(),
e
)))
})?;
files_truncated += 1;
}
}
Ok(files_truncated)
}
pub fn recover_with_stats(
wal_dir: impl AsRef<Path>,
) -> Result<(Vec<WalEntry>, u64, RecoveryStats)> {
let wal_dir = wal_dir.as_ref();
let mut stats = RecoveryStats::default();
if !wal_dir.exists() {
return Ok((Vec::new(), 0, stats));
}
let wal_files = Self::list_wal_file_numbers(wal_dir)?;
let mut all_entries = Vec::new();
let mut max_sequence = 0u64;
for file_number in wal_files {
let file_path = Self::wal_file_path(wal_dir, file_number);
let mut reader = WalReader::open(&file_path)?;
loop {
match reader.read_entry() {
Ok(Some(entry)) => {
let entry_bytes = entry.encode().len() as u64 + 4; stats.bytes_recovered += entry_bytes;
stats.entries_recovered += 1;
if entry.sequence > max_sequence {
max_sequence = entry.sequence;
}
all_entries.push(entry);
}
Ok(None) => break,
Err(e) => {
stats.entries_corrupted += 1;
tracing::warn!(
"Skipping corrupted entry in {}: {}",
file_path.display(),
e
);
continue;
}
}
}
}
Ok((all_entries, max_sequence, stats))
}
pub fn replay_to_memtable(
wal_dir: impl AsRef<Path>,
memtable: &crate::storage::memtable::Memtable,
) -> Result<u64> {
let (entries, max_sequence) = Self::recover(wal_dir)?;
for entry in entries {
match entry.entry_type {
WalEntryType::Put => {
if let Some(value) = entry.value {
memtable.put(entry.key, value)?;
}
}
WalEntryType::Delete => {
memtable.delete(entry.key)?;
}
}
}
Ok(max_sequence)
}
}
pub struct WalReader {
reader: BufReader<File>,
}
impl WalReader {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let file = File::open(path.as_ref()).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("Failed to open WAL file: {}", e)))
})?;
Ok(Self {
reader: BufReader::new(file),
})
}
pub fn read_entry(&mut self) -> Result<Option<WalEntry>> {
let mut len_bytes = [0u8; 4];
match self.reader.read_exact(&mut len_bytes) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
return Ok(None);
}
Err(e) => {
return Err(AmateRSError::IoError(ErrorContext::new(format!(
"Failed to read WAL entry length: {}",
e
))));
}
}
let len = u32::from_le_bytes(len_bytes) as usize;
if len > 100 * 1024 * 1024 {
return Err(AmateRSError::SerializationError(ErrorContext::new(
format!("WAL entry too large: {} bytes", len),
)));
}
let mut entry_bytes = vec![0u8; len];
match self.reader.read_exact(&mut entry_bytes) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
return Err(AmateRSError::SerializationError(ErrorContext::new(
"Incomplete WAL entry (truncated file)",
)));
}
Err(e) => {
return Err(AmateRSError::IoError(ErrorContext::new(format!(
"Failed to read WAL entry: {}",
e
))));
}
}
let entry = WalEntry::decode(&entry_bytes)?;
Ok(Some(entry))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::Memtable;
use std::fs;
use tempfile::tempdir;
#[test]
fn test_wal_entry_encode_decode() -> Result<()> {
let key = Key::from_str("test_key");
let value = CipherBlob::new(vec![1, 2, 3, 4, 5]);
let entry = WalEntry::put(42, key.clone(), value.clone());
let bytes = entry.encode();
let decoded = WalEntry::decode(&bytes)?;
assert_eq!(decoded.sequence, 42);
assert_eq!(decoded.entry_type, WalEntryType::Put);
assert_eq!(decoded.key, key);
assert_eq!(decoded.value, Some(value));
Ok(())
}
#[test]
fn test_wal_delete_entry() -> Result<()> {
let key = Key::from_str("delete_me");
let entry = WalEntry::delete(99, key.clone());
let bytes = entry.encode();
let decoded = WalEntry::decode(&bytes)?;
assert_eq!(decoded.sequence, 99);
assert_eq!(decoded.entry_type, WalEntryType::Delete);
assert_eq!(decoded.key, key);
assert_eq!(decoded.value, None);
Ok(())
}
#[test]
fn test_wal_checksum_verification() -> Result<()> {
let key = Key::from_str("test");
let value = CipherBlob::new(vec![1, 2, 3]);
let entry = WalEntry::put(1, key, value);
entry.verify_checksum()?;
let mut corrupted = entry.clone();
corrupted.checksum = 0;
assert!(corrupted.verify_checksum().is_err());
Ok(())
}
#[test]
fn test_wal_basic_operations() -> Result<()> {
let temp_dir = tempdir().map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to create temp dir: {}",
e
)))
})?;
let wal_path = temp_dir.path().join("test.wal");
let mut wal = Wal::create(&wal_path)?;
let seq1 = wal.put(Key::from_str("key1"), CipherBlob::new(vec![1, 2, 3]))?;
let seq2 = wal.put(Key::from_str("key2"), CipherBlob::new(vec![4, 5, 6]))?;
let seq3 = wal.delete(Key::from_str("key1"))?;
assert_eq!(seq1, 0);
assert_eq!(seq2, 1);
assert_eq!(seq3, 2);
wal.flush()?;
assert!(wal.path().exists());
Ok(())
}
#[test]
fn test_wal_sequence_increment() -> Result<()> {
let temp_dir = tempdir().map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to create temp dir: {}",
e
)))
})?;
let wal_path = temp_dir.path().join("test_seq.wal");
let mut wal = Wal::create(&wal_path)?;
assert_eq!(wal.sequence(), 0);
wal.put(Key::from_str("key"), CipherBlob::new(vec![1]))?;
assert_eq!(wal.sequence(), 1);
wal.delete(Key::from_str("key"))?;
assert_eq!(wal.sequence(), 2);
Ok(())
}
#[test]
fn test_wal_entry_large_value() -> Result<()> {
let key = Key::from_str("large");
let large_value = CipherBlob::new(vec![0u8; 10_000]);
let entry = WalEntry::put(1, key.clone(), large_value.clone());
let bytes = entry.encode();
let decoded = WalEntry::decode(&bytes)?;
assert_eq!(decoded.key, key);
assert_eq!(decoded.value, Some(large_value));
Ok(())
}
#[test]
fn test_wal_rotation() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_rotation");
std::fs::create_dir_all(&temp_dir).ok();
let config = WalConfig {
wal_dir: temp_dir.clone(),
max_file_size: 1024, sync_on_write: false, ..Default::default()
};
let mut wal = Wal::with_config(config)?;
let initial_file_number = wal.current_file_number();
for i in 0..20 {
wal.put(
Key::from_str(&format!("key_{}", i)),
CipherBlob::new(vec![i as u8; 100]),
)?;
}
assert!(wal.current_file_number() > initial_file_number);
assert!(wal.path().exists());
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_cleanup() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_cleanup");
std::fs::create_dir_all(&temp_dir).ok();
let config = WalConfig {
wal_dir: temp_dir.clone(),
max_file_size: 512, max_wal_files: 3, sync_on_write: false,
};
let mut wal = Wal::with_config(config)?;
for i in 0..100 {
wal.put(
Key::from_str(&format!("key_{}", i)),
CipherBlob::new(vec![i as u8; 100]),
)?;
}
let wal_file_count = std::fs::read_dir(&temp_dir)?
.filter_map(|e| e.ok())
.filter(|e| {
e.file_name().to_string_lossy().starts_with("wal_")
&& e.file_name().to_string_lossy().ends_with(".log")
})
.count();
assert!(wal_file_count <= 3);
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_manual_cleanup() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_manual_cleanup");
std::fs::create_dir_all(&temp_dir).ok();
let config = WalConfig {
wal_dir: temp_dir.clone(),
max_file_size: 512,
max_wal_files: 5,
sync_on_write: false,
};
let mut wal = Wal::with_config(config)?;
for i in 0..80 {
wal.put(
Key::from_str(&format!("key_{}", i)),
CipherBlob::new(vec![i as u8; 100]),
)?;
}
wal.cleanup()?;
let wal_file_count = std::fs::read_dir(&temp_dir)?
.filter_map(|e| e.ok())
.filter(|e| {
e.file_name().to_string_lossy().starts_with("wal_")
&& e.file_name().to_string_lossy().ends_with(".log")
})
.count();
assert!(wal_file_count <= 5);
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_recovery_basic() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_recovery_basic");
std::fs::create_dir_all(&temp_dir).ok();
{
let config = WalConfig {
wal_dir: temp_dir.clone(),
sync_on_write: true,
..Default::default()
};
let mut wal = Wal::with_config(config)?;
wal.put(Key::from_str("key1"), CipherBlob::new(vec![1, 2, 3]))?;
wal.put(Key::from_str("key2"), CipherBlob::new(vec![4, 5, 6]))?;
wal.delete(Key::from_str("key1"))?;
wal.put(Key::from_str("key3"), CipherBlob::new(vec![7, 8, 9]))?;
wal.flush()?;
}
let (entries, max_sequence) = Wal::recover(&temp_dir)?;
assert_eq!(entries.len(), 4);
assert_eq!(max_sequence, 3);
assert_eq!(entries[0].key, Key::from_str("key1"));
assert_eq!(entries[0].entry_type, WalEntryType::Put);
assert_eq!(entries[0].value, Some(CipherBlob::new(vec![1, 2, 3])));
assert_eq!(entries[1].key, Key::from_str("key2"));
assert_eq!(entries[1].entry_type, WalEntryType::Put);
assert_eq!(entries[2].key, Key::from_str("key1"));
assert_eq!(entries[2].entry_type, WalEntryType::Delete);
assert_eq!(entries[2].value, None);
assert_eq!(entries[3].key, Key::from_str("key3"));
assert_eq!(entries[3].entry_type, WalEntryType::Put);
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_recovery_multiple_files() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_recovery_multiple");
std::fs::create_dir_all(&temp_dir).ok();
{
let config = WalConfig {
wal_dir: temp_dir.clone(),
max_file_size: 512, sync_on_write: true,
..Default::default()
};
let mut wal = Wal::with_config(config)?;
for i in 0..20 {
wal.put(
Key::from_str(&format!("key_{}", i)),
CipherBlob::new(vec![i as u8; 100]),
)?;
}
wal.flush()?;
}
let (entries, max_sequence) = Wal::recover(&temp_dir)?;
assert_eq!(entries.len(), 20);
assert_eq!(max_sequence, 19);
for (i, entry) in entries.iter().enumerate() {
assert_eq!(entry.sequence, i as u64);
assert_eq!(entry.key, Key::from_str(&format!("key_{}", i)));
}
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_recovery_empty_directory() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_recovery_empty");
std::fs::create_dir_all(&temp_dir).ok();
let (entries, max_sequence) = Wal::recover(&temp_dir)?;
assert_eq!(entries.len(), 0);
assert_eq!(max_sequence, 0);
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_recovery_nonexistent_directory() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("nonexistent_wal_dir_12345");
let (entries, max_sequence) = Wal::recover(&temp_dir)?;
assert_eq!(entries.len(), 0);
assert_eq!(max_sequence, 0);
Ok(())
}
#[test]
fn test_wal_replay_to_memtable() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_replay_memtable");
std::fs::create_dir_all(&temp_dir).ok();
{
let config = WalConfig {
wal_dir: temp_dir.clone(),
sync_on_write: true,
..Default::default()
};
let mut wal = Wal::with_config(config)?;
wal.put(Key::from_str("key1"), CipherBlob::new(vec![1, 2, 3]))?;
wal.put(Key::from_str("key2"), CipherBlob::new(vec![4, 5, 6]))?;
wal.delete(Key::from_str("key1"))?;
wal.put(Key::from_str("key3"), CipherBlob::new(vec![7, 8, 9]))?;
wal.flush()?;
}
let memtable = Memtable::new();
let max_sequence = Wal::replay_to_memtable(&temp_dir, &memtable)?;
assert_eq!(max_sequence, 3);
assert_eq!(memtable.get(&Key::from_str("key1"))?, None); assert_eq!(
memtable.get(&Key::from_str("key2"))?,
Some(CipherBlob::new(vec![4, 5, 6]))
);
assert_eq!(
memtable.get(&Key::from_str("key3"))?,
Some(CipherBlob::new(vec![7, 8, 9]))
);
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_reader_basic() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_reader_basic");
std::fs::create_dir_all(&temp_dir).ok();
let wal_file = temp_dir.join("test.wal");
{
let mut wal = Wal::create(&wal_file)?;
wal.put(Key::from_str("key1"), CipherBlob::new(vec![1, 2, 3]))?;
wal.put(Key::from_str("key2"), CipherBlob::new(vec![4, 5, 6]))?;
wal.flush()?;
}
let wal_file_actual = temp_dir.join("wal_00000000.log");
let mut reader = WalReader::open(&wal_file_actual)?;
let entry1 = reader.read_entry()?.expect("Should have entry 1");
assert_eq!(entry1.sequence, 0);
assert_eq!(entry1.key, Key::from_str("key1"));
let entry2 = reader.read_entry()?.expect("Should have entry 2");
assert_eq!(entry2.sequence, 1);
assert_eq!(entry2.key, Key::from_str("key2"));
let entry3 = reader.read_entry()?;
assert_eq!(entry3, None);
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_recovery_with_truncated_file() -> Result<()> {
use std::env;
use std::io::Write as IoWrite;
let temp_dir = env::temp_dir().join("test_wal_recovery_truncated");
std::fs::create_dir_all(&temp_dir).ok();
let wal_file = temp_dir.join("wal_00000000.log");
{
let mut wal = Wal::create(&wal_file)?;
wal.put(Key::from_str("key1"), CipherBlob::new(vec![1, 2, 3]))?;
wal.put(Key::from_str("key2"), CipherBlob::new(vec![4, 5, 6]))?;
wal.flush()?;
let mut file = OpenOptions::new().append(true).open(&wal_file)?;
let incomplete_len = 1234u32;
file.write_all(&incomplete_len.to_le_bytes())?;
file.flush()?;
}
let (entries, _) = Wal::recover(&temp_dir)?;
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].key, Key::from_str("key1"));
assert_eq!(entries[1].key, Key::from_str("key2"));
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_sequence_recovery_after_crash() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_seq_recovery_crash");
std::fs::remove_dir_all(&temp_dir).ok();
std::fs::create_dir_all(&temp_dir).ok();
{
let config = WalConfig {
wal_dir: temp_dir.clone(),
sync_on_write: true,
..Default::default()
};
let mut wal = Wal::with_config(config)?;
wal.put(Key::from_str("a"), CipherBlob::new(vec![1]))?;
wal.put(Key::from_str("b"), CipherBlob::new(vec![2]))?;
wal.put(Key::from_str("c"), CipherBlob::new(vec![3]))?;
wal.put(Key::from_str("d"), CipherBlob::new(vec![4]))?;
wal.put(Key::from_str("e"), CipherBlob::new(vec![5]))?;
wal.flush()?;
}
{
let config = WalConfig {
wal_dir: temp_dir.clone(),
sync_on_write: true,
..Default::default()
};
let mut wal = Wal::with_config(config)?;
assert_eq!(wal.sequence(), 5);
let seq = wal.put(Key::from_str("f"), CipherBlob::new(vec![6]))?;
assert_eq!(seq, 5);
let seq = wal.put(Key::from_str("g"), CipherBlob::new(vec![7]))?;
assert_eq!(seq, 6);
wal.flush()?;
}
let (entries, max_sequence) = Wal::recover(&temp_dir)?;
assert_eq!(entries.len(), 7);
assert_eq!(max_sequence, 6);
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_corruption_detection_and_partial_recovery() -> Result<()> {
use std::env;
use std::io::Write as IoWrite;
let temp_dir = env::temp_dir().join("test_wal_corruption_detect");
std::fs::remove_dir_all(&temp_dir).ok();
std::fs::create_dir_all(&temp_dir).ok();
let wal_file = temp_dir.join("wal_00000000.log");
{
let config = WalConfig {
wal_dir: temp_dir.clone(),
sync_on_write: true,
..Default::default()
};
let mut wal = Wal::with_config(config)?;
wal.put(Key::from_str("key1"), CipherBlob::new(vec![1, 2, 3]))?;
wal.put(Key::from_str("key2"), CipherBlob::new(vec![4, 5, 6]))?;
wal.put(Key::from_str("key3"), CipherBlob::new(vec![7, 8, 9]))?;
wal.flush()?;
}
{
let data = std::fs::read(&wal_file).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("Failed to read WAL: {}", e)))
})?;
let mut corrupted_data = data.clone();
let first_entry_len = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
let second_entry_start = 4 + first_entry_len;
let corrupt_offset = second_entry_start + 4 + 10; if corrupt_offset < corrupted_data.len() {
corrupted_data[corrupt_offset] ^= 0xFF;
}
let mut file = File::create(&wal_file).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("Failed to create file: {}", e)))
})?;
file.write_all(&corrupted_data).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("Failed to write file: {}", e)))
})?;
file.flush().map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("Failed to flush file: {}", e)))
})?;
}
let (entries, _max_seq, stats) = Wal::recover_with_stats(&temp_dir)?;
assert_eq!(stats.entries_corrupted, 1);
assert_eq!(stats.entries_recovered, entries.len() as u64);
assert!(stats.bytes_recovered > 0);
assert!(entries.len() >= 2);
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_truncate_before() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_truncate_before");
std::fs::remove_dir_all(&temp_dir).ok();
std::fs::create_dir_all(&temp_dir).ok();
let config = WalConfig {
wal_dir: temp_dir.clone(),
max_file_size: 512, max_wal_files: 100, sync_on_write: true,
};
let mut wal = Wal::with_config(config)?;
for i in 0..30 {
wal.put(
Key::from_str(&format!("key_{}", i)),
CipherBlob::new(vec![i as u8; 100]),
)?;
}
wal.flush()?;
let file_count_before = std::fs::read_dir(&temp_dir)
.map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("Failed to read dir: {}", e)))
})?
.filter_map(|e| e.ok())
.filter(|e| {
let name = e.file_name().to_string_lossy().to_string();
name.starts_with("wal_") && name.ends_with(".log")
})
.count();
assert!(file_count_before > 1, "Should have multiple WAL files");
let truncated = wal.truncate_before(10)?;
assert!(truncated > 0, "Should have truncated at least one file");
let (remaining_entries, _) = Wal::recover(&temp_dir)?;
let has_high_seq = remaining_entries.iter().any(|e| e.sequence > 10);
assert!(has_high_seq, "Should still have entries with sequence > 10");
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_size_tracking() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_size_tracking");
std::fs::remove_dir_all(&temp_dir).ok();
std::fs::create_dir_all(&temp_dir).ok();
let config = WalConfig {
wal_dir: temp_dir.clone(),
sync_on_write: true,
..Default::default()
};
let mut wal = Wal::with_config(config)?;
assert_eq!(wal.current_size(), 0);
wal.put(Key::from_str("key1"), CipherBlob::new(vec![1, 2, 3]))?;
let size_after_one = wal.current_size();
assert!(size_after_one > 0, "Size should increase after writing");
wal.put(Key::from_str("key2"), CipherBlob::new(vec![4, 5, 6]))?;
let size_after_two = wal.current_size();
assert!(
size_after_two > size_after_one,
"Size should increase with more entries"
);
wal.flush()?;
let total = wal.total_wal_size()?;
assert_eq!(total, size_after_two);
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_total_size_multiple_files() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_total_size_multi");
std::fs::remove_dir_all(&temp_dir).ok();
std::fs::create_dir_all(&temp_dir).ok();
let config = WalConfig {
wal_dir: temp_dir.clone(),
max_file_size: 512,
max_wal_files: 100,
sync_on_write: true,
};
let mut wal = Wal::with_config(config)?;
for i in 0..20 {
wal.put(
Key::from_str(&format!("key_{}", i)),
CipherBlob::new(vec![i as u8; 100]),
)?;
}
wal.flush()?;
let total = wal.total_wal_size()?;
assert!(total > 0, "Total WAL size should be positive");
if wal.current_file_number() > 0 {
assert!(
total >= wal.current_size(),
"Total size should be >= current file size"
);
}
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_empty_recovery() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_empty_recovery");
std::fs::remove_dir_all(&temp_dir).ok();
std::fs::create_dir_all(&temp_dir).ok();
{
let config = WalConfig {
wal_dir: temp_dir.clone(),
sync_on_write: true,
..Default::default()
};
let wal = Wal::with_config(config)?;
drop(wal);
}
let (entries, max_seq, stats) = Wal::recover_with_stats(&temp_dir)?;
assert_eq!(entries.len(), 0);
assert_eq!(max_seq, 0);
assert_eq!(stats.entries_recovered, 0);
assert_eq!(stats.entries_corrupted, 0);
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_single_entry_recovery() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_single_entry_recovery");
std::fs::remove_dir_all(&temp_dir).ok();
std::fs::create_dir_all(&temp_dir).ok();
{
let config = WalConfig {
wal_dir: temp_dir.clone(),
sync_on_write: true,
..Default::default()
};
let mut wal = Wal::with_config(config)?;
wal.put(Key::from_str("only_key"), CipherBlob::new(vec![42]))?;
wal.flush()?;
}
let (entries, max_seq, stats) = Wal::recover_with_stats(&temp_dir)?;
assert_eq!(entries.len(), 1);
assert_eq!(max_seq, 0);
assert_eq!(stats.entries_recovered, 1);
assert_eq!(stats.entries_corrupted, 0);
assert!(stats.bytes_recovered > 0);
assert_eq!(entries[0].key, Key::from_str("only_key"));
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_large_recovery() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_large_recovery");
std::fs::remove_dir_all(&temp_dir).ok();
std::fs::create_dir_all(&temp_dir).ok();
let entry_count = 500;
{
let config = WalConfig {
wal_dir: temp_dir.clone(),
max_file_size: 4096,
max_wal_files: 1000,
sync_on_write: false,
};
let mut wal = Wal::with_config(config)?;
for i in 0..entry_count {
wal.put(
Key::from_str(&format!("large_key_{:05}", i)),
CipherBlob::new(vec![(i % 256) as u8; 50]),
)?;
}
wal.flush()?;
}
let (entries, max_seq, stats) = Wal::recover_with_stats(&temp_dir)?;
assert_eq!(entries.len(), entry_count);
assert_eq!(max_seq, (entry_count - 1) as u64);
assert_eq!(stats.entries_recovered, entry_count as u64);
assert_eq!(stats.entries_corrupted, 0);
assert!(stats.bytes_recovered > 0);
for (i, entry) in entries.iter().enumerate() {
assert_eq!(entry.sequence, i as u64);
}
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_truncate_keeps_current_file() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_truncate_keeps_current");
std::fs::remove_dir_all(&temp_dir).ok();
std::fs::create_dir_all(&temp_dir).ok();
let config = WalConfig {
wal_dir: temp_dir.clone(),
max_file_size: 512,
max_wal_files: 100,
sync_on_write: true,
};
let mut wal = Wal::with_config(config)?;
for i in 0..30 {
wal.put(
Key::from_str(&format!("key_{}", i)),
CipherBlob::new(vec![i as u8; 100]),
)?;
}
wal.flush()?;
let current_file_num = wal.current_file_number();
wal.truncate_before(u64::MAX)?;
let current_path = Wal::wal_file_path(&temp_dir, current_file_num);
assert!(
current_path.exists(),
"Current active WAL file should not be removed"
);
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_sequence_recovery_across_rotations() -> Result<()> {
use std::env;
let temp_dir = env::temp_dir().join("test_wal_seq_recovery_rotation");
std::fs::remove_dir_all(&temp_dir).ok();
std::fs::create_dir_all(&temp_dir).ok();
let entries_written;
{
let config = WalConfig {
wal_dir: temp_dir.clone(),
max_file_size: 512,
max_wal_files: 100,
sync_on_write: true,
};
let mut wal = Wal::with_config(config)?;
for i in 0..25 {
wal.put(
Key::from_str(&format!("rkey_{}", i)),
CipherBlob::new(vec![i as u8; 80]),
)?;
}
wal.flush()?;
entries_written = wal.sequence();
}
{
let config = WalConfig {
wal_dir: temp_dir.clone(),
max_file_size: 512,
max_wal_files: 100,
sync_on_write: true,
};
let wal = Wal::with_config(config)?;
assert_eq!(
wal.sequence(),
entries_written,
"Sequence should continue from where it left off"
);
}
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
#[test]
fn test_wal_recovery_stats_with_corruption() -> Result<()> {
use std::env;
use std::io::Write as IoWrite;
let temp_dir = env::temp_dir().join("test_wal_recovery_stats_corrupt");
std::fs::remove_dir_all(&temp_dir).ok();
std::fs::create_dir_all(&temp_dir).ok();
let wal_file = temp_dir.join("wal_00000000.log");
{
let config = WalConfig {
wal_dir: temp_dir.clone(),
sync_on_write: true,
..Default::default()
};
let mut wal = Wal::with_config(config)?;
wal.put(Key::from_str("s1"), CipherBlob::new(vec![10]))?;
wal.put(Key::from_str("s2"), CipherBlob::new(vec![20]))?;
wal.flush()?;
}
{
let mut file = OpenOptions::new()
.append(true)
.open(&wal_file)
.map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to open for corruption: {}",
e
)))
})?;
let fake_len = 30u32;
file.write_all(&fake_len.to_le_bytes()).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("write error: {}", e)))
})?;
file.write_all(&[0xDE; 30]).map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("write error: {}", e)))
})?;
file.flush().map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!("flush error: {}", e)))
})?;
}
let (_entries, _max_seq, stats) = Wal::recover_with_stats(&temp_dir)?;
assert_eq!(stats.entries_recovered, 2);
assert!(
stats.entries_corrupted >= 1,
"Should detect at least one corrupted entry"
);
std::fs::remove_dir_all(&temp_dir).ok();
Ok(())
}
}