use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use crc32fast::Hasher;
use std::fs::{File, OpenOptions};
use std::io::{self, Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
#[allow(dead_code)]
const HEADER_SIZE: usize = 4 + 8;
pub struct Wal {
#[allow(dead_code)]
path: PathBuf,
file: Arc<Mutex<File>>,
}
impl Wal {
pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
let path = path.as_ref().to_path_buf();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(&path)?;
Ok(Self {
path,
file: Arc::new(Mutex::new(file)),
})
}
pub fn append(&self, data: &[u8]) -> io::Result<u64> {
let mut file = self.file.lock().unwrap();
let mut hasher = Hasher::new();
hasher.update(data);
let checksum = hasher.finalize();
file.write_u32::<BigEndian>(checksum)?;
file.write_u64::<BigEndian>(data.len() as u64)?;
file.write_all(data)?;
Ok(file.stream_position()?)
}
pub fn sync(&self) -> io::Result<()> {
let file = self.file.lock().unwrap();
file.sync_all()
}
pub fn truncate(&self, size: u64) -> io::Result<()> {
let file = self.file.lock().unwrap();
file.set_len(size)?;
file.sync_all()
}
pub fn read_all(&self) -> io::Result<Vec<Vec<u8>>> {
let mut file = self.file.lock().unwrap();
file.seek(SeekFrom::Start(0))?;
let mut entries = Vec::new();
let mut buffer = Vec::new();
loop {
let checksum = match file.read_u32::<BigEndian>() {
Ok(c) => c,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e),
};
let len = file.read_u64::<BigEndian>()?;
if len > 128 * 1024 * 1024 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Entry too large",
));
}
buffer.resize(len as usize, 0);
file.read_exact(&mut buffer)?;
let mut hasher = Hasher::new();
hasher.update(&buffer);
if hasher.finalize() != checksum {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Checksum mismatch",
));
}
entries.push(buffer.clone());
}
file.seek(SeekFrom::End(0))?;
Ok(entries)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn test_wal_persistence() -> io::Result<()> {
let temp_file = NamedTempFile::new()?;
let path = temp_file.path().to_path_buf();
{
let wal = Wal::open(&path)?;
wal.append(b"hello")?;
wal.append(b"world")?;
wal.sync()?;
}
{
let wal = Wal::open(&path)?;
let entries = wal.read_all()?;
assert_eq!(entries.len(), 2);
assert_eq!(entries[0], b"hello");
assert_eq!(entries[1], b"world");
}
Ok(())
}
}