use anyhow::{anyhow, Result};
use std::path::Path;
use tokio::fs::File;
use tokio::io::{AsyncReadExt, AsyncSeekExt, SeekFrom};
#[derive(Debug, Clone)]
pub struct WalHeader {
pub magic: u32,
pub format_version: u32,
pub page_size: u32,
pub checkpoint_seq: u32,
pub salt1: u32,
pub salt2: u32,
pub checksum1: u32,
pub checksum2: u32,
}
#[derive(Debug, Clone)]
pub struct FrameHeader {
pub page_number: u32,
pub db_size: u32, pub salt1: u32,
pub salt2: u32,
pub checksum1: u32,
pub checksum2: u32,
}
pub const WAL_HEADER_SIZE: u64 = 32;
pub const FRAME_HEADER_SIZE: u64 = 24;
pub async fn read_header(path: &Path) -> Result<Option<WalHeader>> {
let mut file = match File::open(path).await {
Ok(f) => f,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(e) => return Err(e.into()),
};
let metadata = file.metadata().await?;
if metadata.len() < WAL_HEADER_SIZE {
return Ok(None);
}
let mut buf = [0u8; 32];
file.read_exact(&mut buf).await?;
let magic = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
if magic != 0x377F0682 && magic != 0x377F0683 {
return Err(anyhow!("Invalid WAL magic number: {:#x}", magic));
}
Ok(Some(WalHeader {
magic,
format_version: u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]),
page_size: u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]]),
checkpoint_seq: u32::from_be_bytes([buf[12], buf[13], buf[14], buf[15]]),
salt1: u32::from_be_bytes([buf[16], buf[17], buf[18], buf[19]]),
salt2: u32::from_be_bytes([buf[20], buf[21], buf[22], buf[23]]),
checksum1: u32::from_be_bytes([buf[24], buf[25], buf[26], buf[27]]),
checksum2: u32::from_be_bytes([buf[28], buf[29], buf[30], buf[31]]),
}))
}
pub async fn read_frames_from(
path: &Path,
page_size: u32,
start_offset: u64,
) -> Result<(Vec<u8>, u64, usize)> {
let mut file = File::open(path).await?;
let file_size = file.metadata().await?.len();
let frame_size = FRAME_HEADER_SIZE + page_size as u64;
let start_pos = if start_offset == 0 {
WAL_HEADER_SIZE
} else {
start_offset
};
if start_pos >= file_size {
return Ok((Vec::new(), start_pos, 0));
}
file.seek(SeekFrom::Start(start_pos)).await?;
let available = file_size - start_pos;
let full_frames = available / frame_size;
if full_frames == 0 {
return Ok((Vec::new(), start_pos, 0));
}
let bytes_to_read = full_frames * frame_size;
let mut data = vec![0u8; bytes_to_read as usize];
file.read_exact(&mut data).await?;
let new_offset = start_pos + bytes_to_read;
Ok((data, new_offset, full_frames as usize))
}
pub async fn get_wal_size(path: &Path) -> Result<u64> {
match tokio::fs::metadata(path).await {
Ok(m) => Ok(m.len()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(0),
Err(e) => Err(e.into()),
}
}
pub async fn read_wal(path: &Path) -> Result<Vec<u8>> {
match tokio::fs::read(path).await {
Ok(data) => Ok(data),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(Vec::new()),
Err(e) => Err(e.into()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[tokio::test]
async fn test_read_header_nonexistent_file() {
let path = PathBuf::from("/tmp/nonexistent-wal-file.db-wal");
let result = read_header(&path).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_read_header_empty_file() {
let path = PathBuf::from(format!("/tmp/walsync-test-{}.db-wal", uuid::Uuid::new_v4()));
tokio::fs::write(&path, &[]).await.unwrap();
let result = read_header(&path).await.unwrap();
assert!(result.is_none());
tokio::fs::remove_file(&path).await.ok();
}
#[tokio::test]
async fn test_read_header_too_small() {
let path = PathBuf::from(format!("/tmp/walsync-test-{}.db-wal", uuid::Uuid::new_v4()));
tokio::fs::write(&path, &[0u8; 20]).await.unwrap();
let result = read_header(&path).await.unwrap();
assert!(result.is_none());
tokio::fs::remove_file(&path).await.ok();
}
#[tokio::test]
async fn test_read_header_invalid_magic() {
let path = PathBuf::from(format!("/tmp/walsync-test-{}.db-wal", uuid::Uuid::new_v4()));
tokio::fs::write(&path, &[0u8; 32]).await.unwrap();
let result = read_header(&path).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Invalid WAL magic"));
tokio::fs::remove_file(&path).await.ok();
}
#[tokio::test]
async fn test_read_header_valid_magic_big_endian() {
let path = PathBuf::from(format!("/tmp/walsync-test-{}.db-wal", uuid::Uuid::new_v4()));
let mut header = [0u8; 32];
header[0..4].copy_from_slice(&0x377F0682u32.to_be_bytes()); header[4..8].copy_from_slice(&3007000u32.to_be_bytes()); header[8..12].copy_from_slice(&4096u32.to_be_bytes());
tokio::fs::write(&path, &header).await.unwrap();
let result = read_header(&path).await.unwrap().unwrap();
assert_eq!(result.magic, 0x377F0682);
assert_eq!(result.format_version, 3007000);
assert_eq!(result.page_size, 4096);
tokio::fs::remove_file(&path).await.ok();
}
#[tokio::test]
async fn test_read_header_valid_magic_little_endian() {
let path = PathBuf::from(format!("/tmp/walsync-test-{}.db-wal", uuid::Uuid::new_v4()));
let mut header = [0u8; 32];
header[0..4].copy_from_slice(&0x377F0683u32.to_be_bytes()); header[4..8].copy_from_slice(&3007000u32.to_be_bytes()); header[8..12].copy_from_slice(&4096u32.to_be_bytes());
tokio::fs::write(&path, &header).await.unwrap();
let result = read_header(&path).await.unwrap().unwrap();
assert_eq!(result.magic, 0x377F0683);
tokio::fs::remove_file(&path).await.ok();
}
#[tokio::test]
async fn test_get_wal_size_nonexistent() {
let path = PathBuf::from("/tmp/nonexistent-wal-file.db-wal");
let size = get_wal_size(&path).await.unwrap();
assert_eq!(size, 0);
}
#[tokio::test]
async fn test_get_wal_size_existing() {
let path = PathBuf::from(format!("/tmp/walsync-test-{}.db-wal", uuid::Uuid::new_v4()));
let data = vec![0u8; 1024];
tokio::fs::write(&path, &data).await.unwrap();
let size = get_wal_size(&path).await.unwrap();
assert_eq!(size, 1024);
tokio::fs::remove_file(&path).await.ok();
}
#[tokio::test]
async fn test_read_wal_nonexistent() {
let path = PathBuf::from("/tmp/nonexistent-wal-file.db-wal");
let data = read_wal(&path).await.unwrap();
assert!(data.is_empty());
}
#[tokio::test]
async fn test_read_wal_existing() {
let path = PathBuf::from(format!("/tmp/walsync-test-{}.db-wal", uuid::Uuid::new_v4()));
let expected = vec![1u8, 2, 3, 4, 5];
tokio::fs::write(&path, &expected).await.unwrap();
let data = read_wal(&path).await.unwrap();
assert_eq!(data, expected);
tokio::fs::remove_file(&path).await.ok();
}
#[tokio::test]
async fn test_read_frames_from_no_frames() {
let path = PathBuf::from(format!("/tmp/walsync-test-{}.db-wal", uuid::Uuid::new_v4()));
let mut header = [0u8; 32];
header[0..4].copy_from_slice(&0x377F0682u32.to_be_bytes());
header[8..12].copy_from_slice(&4096u32.to_be_bytes());
tokio::fs::write(&path, &header).await.unwrap();
let (frames, offset, count) = read_frames_from(&path, 4096, 0).await.unwrap();
assert!(frames.is_empty());
assert_eq!(offset, WAL_HEADER_SIZE);
assert_eq!(count, 0);
tokio::fs::remove_file(&path).await.ok();
}
#[tokio::test]
async fn test_read_frames_from_with_frames() {
let path = PathBuf::from(format!("/tmp/walsync-test-{}.db-wal", uuid::Uuid::new_v4()));
let page_size: u32 = 4096;
let frame_size = FRAME_HEADER_SIZE as usize + page_size as usize;
let mut data = vec![0u8; 32 + frame_size * 2];
data[0..4].copy_from_slice(&0x377F0682u32.to_be_bytes()); data[8..12].copy_from_slice(&page_size.to_be_bytes());
for i in 0..frame_size * 2 {
data[32 + i] = (i % 256) as u8;
}
tokio::fs::write(&path, &data).await.unwrap();
let (frames, offset, count) = read_frames_from(&path, page_size, 0).await.unwrap();
assert_eq!(count, 2);
assert_eq!(frames.len(), frame_size * 2);
assert_eq!(offset, WAL_HEADER_SIZE + (frame_size * 2) as u64);
tokio::fs::remove_file(&path).await.ok();
}
#[tokio::test]
async fn test_read_frames_from_with_offset() {
let path = PathBuf::from(format!("/tmp/walsync-test-{}.db-wal", uuid::Uuid::new_v4()));
let page_size: u32 = 4096;
let frame_size = FRAME_HEADER_SIZE as usize + page_size as usize;
let mut data = vec![0u8; 32 + frame_size * 3];
data[0..4].copy_from_slice(&0x377F0682u32.to_be_bytes());
data[8..12].copy_from_slice(&page_size.to_be_bytes());
tokio::fs::write(&path, &data).await.unwrap();
let start_offset = WAL_HEADER_SIZE + frame_size as u64;
let (frames, offset, count) = read_frames_from(&path, page_size, start_offset).await.unwrap();
assert_eq!(count, 2); assert_eq!(frames.len(), frame_size * 2);
assert_eq!(offset, start_offset + (frame_size * 2) as u64);
tokio::fs::remove_file(&path).await.ok();
}
#[tokio::test]
async fn test_read_frames_partial_frame_ignored() {
let path = PathBuf::from(format!("/tmp/walsync-test-{}.db-wal", uuid::Uuid::new_v4()));
let page_size: u32 = 4096;
let frame_size = FRAME_HEADER_SIZE as usize + page_size as usize;
let mut data = vec![0u8; 32 + frame_size + 100]; data[0..4].copy_from_slice(&0x377F0682u32.to_be_bytes());
data[8..12].copy_from_slice(&page_size.to_be_bytes());
tokio::fs::write(&path, &data).await.unwrap();
let (frames, _offset, count) = read_frames_from(&path, page_size, 0).await.unwrap();
assert_eq!(count, 1);
assert_eq!(frames.len(), frame_size);
tokio::fs::remove_file(&path).await.ok();
}
}