use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime};
use crate::core::reactor::Reactor;
pub type WalSegmentId = u64;
#[derive(Debug)]
pub enum WalError {
Io(std::io::Error),
InvalidInput(String),
}
pub type Result<T> = std::result::Result<T, WalError>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WalSyncPolicy {
Always,
Interval(Duration),
Manual,
}
impl From<std::io::Error> for WalError {
fn from(err: std::io::Error) -> Self {
Self::Io(err)
}
}
pub struct Wal {
dir: PathBuf,
max_segment_bytes: u64,
current_segment_id: WalSegmentId,
current_segment_size: u64,
sync_policy: WalSyncPolicy,
last_sync_at: Option<SystemTime>,
reactor: std::sync::Arc<dyn Reactor + Send + Sync>,
}
impl Wal {
pub fn open(dir: PathBuf, max_segment_bytes: u64) -> Result<Self> {
Self::open_with_reactor_and_policy(
dir,
max_segment_bytes,
WalSyncPolicy::Interval(Duration::from_millis(1000)),
std::sync::Arc::new(crate::core::reactor::SystemReactor),
)
}
pub fn open_with_reactor(
dir: PathBuf,
max_segment_bytes: u64,
reactor: std::sync::Arc<dyn Reactor + Send + Sync>,
) -> Result<Self> {
Self::open_with_reactor_and_policy(
dir,
max_segment_bytes,
WalSyncPolicy::Interval(Duration::from_millis(1000)),
reactor,
)
}
pub fn open_with_reactor_and_policy(
dir: PathBuf,
max_segment_bytes: u64,
sync_policy: WalSyncPolicy,
reactor: std::sync::Arc<dyn Reactor + Send + Sync>,
) -> Result<Self> {
if max_segment_bytes < 64 {
return Err(WalError::InvalidInput(
"max_segment_bytes must be >= 64".to_string(),
));
}
reactor.create_dir_all(&dir)?;
let current_segment_id = Self::find_last_segment_id(&dir, reactor.as_ref()).unwrap_or(0);
let current_segment_size = reactor
.metadata_len(&dir.join(format!("ir.wal.{:04}", current_segment_id)))
.unwrap_or(0);
Ok(Self {
dir,
max_segment_bytes,
current_segment_id,
current_segment_size,
sync_policy,
last_sync_at: None,
reactor,
})
}
pub fn append_record(&mut self, record: &[u8]) -> Result<u64> {
self.append_records_batch(std::iter::once(record))
}
pub fn append_records_batch<'a, I>(&mut self, records: I) -> Result<u64>
where
I: IntoIterator<Item = &'a [u8]>,
{
let mut pending = Vec::new();
let mut pending_len = 0u64;
let mut total_encoded = 0u64;
for record in records {
let record_len = record.len() as u64;
let encoded_len = 4 + record_len;
if encoded_len > self.max_segment_bytes {
return Err(WalError::InvalidInput(
"record larger than max segment size".to_string(),
));
}
if self.current_segment_size + pending_len + encoded_len > self.max_segment_bytes {
if !pending.is_empty() {
self.reactor
.append_file(&self.segment_path(self.current_segment_id), &pending)?;
self.current_segment_size += pending_len;
pending.clear();
pending_len = 0;
}
if self.current_segment_size + encoded_len > self.max_segment_bytes {
self.current_segment_id += 1;
self.current_segment_size = 0;
}
}
let len_bytes = (record.len() as u32).to_le_bytes();
pending.extend_from_slice(&len_bytes);
pending.extend_from_slice(record);
pending_len += encoded_len;
total_encoded += encoded_len;
}
if !pending.is_empty() {
self.reactor
.append_file(&self.segment_path(self.current_segment_id), &pending)?;
self.current_segment_size += pending_len;
}
self.maybe_sync_after_append()?;
Ok(total_encoded)
}
pub fn sync(&mut self) -> Result<()> {
let path = self.segment_path(self.current_segment_id);
if std::fs::metadata(&path).is_ok() {
let file = std::fs::OpenOptions::new().read(true).open(&path)?;
file.sync_all()?;
}
self.last_sync_at = Some(self.reactor.now());
Ok(())
}
pub fn replay(&self) -> Result<Vec<Vec<u8>>> {
let mut records = Vec::new();
let mut segments = Self::list_segments(&self.dir, self.reactor.as_ref());
segments.sort();
for segment_id in segments {
let path = self.segment_path(segment_id);
let buffer = self.reactor.read_file(&path)?;
let mut cursor = 0usize;
while cursor + 4 <= buffer.len() {
let len_bytes = [
buffer[cursor],
buffer[cursor + 1],
buffer[cursor + 2],
buffer[cursor + 3],
];
let record_len = u32::from_le_bytes(len_bytes) as usize;
cursor += 4;
if cursor + record_len > buffer.len() {
break;
}
records.push(buffer[cursor..cursor + record_len].to_vec());
cursor += record_len;
}
}
Ok(records)
}
fn segment_path(&self, segment_id: WalSegmentId) -> PathBuf {
self.dir.join(format!("ir.wal.{:04}", segment_id))
}
fn list_segments(dir: &Path, reactor: &dyn Reactor) -> Vec<WalSegmentId> {
let mut segments = Vec::new();
if let Ok(entries) = reactor.read_dir(dir) {
for entry in entries {
if let Some(file_name) = entry.file_name().and_then(|name| name.to_str()) {
if let Some(id) = file_name.strip_prefix("ir.wal.") {
if let Ok(parsed) = id.parse::<WalSegmentId>() {
segments.push(parsed);
}
}
}
}
}
segments
}
fn find_last_segment_id(dir: &Path, reactor: &dyn Reactor) -> Option<WalSegmentId> {
let mut segments = Self::list_segments(dir, reactor);
segments.sort();
segments.pop()
}
fn maybe_sync_after_append(&mut self) -> Result<()> {
match self.sync_policy.clone() {
WalSyncPolicy::Always => self.sync(),
WalSyncPolicy::Manual => Ok(()),
WalSyncPolicy::Interval(interval) => {
let now = self.reactor.now();
if interval.is_zero() {
return self.sync();
}
match self.last_sync_at {
Some(last) => {
if now
.duration_since(last)
.unwrap_or_else(|_| Duration::from_secs(0))
>= interval
{
self.sync()
} else {
Ok(())
}
}
None => {
self.last_sync_at = Some(now);
Ok(())
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::fs::OpenOptions;
use std::io::Write;
fn temp_dir(prefix: &str) -> PathBuf {
let mut dir = std::env::temp_dir();
let stamp = format!(
"{}_{}_{}",
prefix,
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
);
dir.push(stamp);
fs::create_dir_all(&dir).expect("temp dir create");
dir
}
#[test]
fn wal_append_and_replay_round_trip() {
let dir = temp_dir("wal_round_trip");
let reactor = std::sync::Arc::new(crate::core::reactor::DeterministicReactor::new(
std::time::SystemTime::UNIX_EPOCH,
7,
));
let mut wal = Wal::open_with_reactor(dir.clone(), 256, reactor).unwrap();
wal.append_record(b"first").unwrap();
wal.append_record(b"second").unwrap();
let replayed = wal.replay().unwrap();
assert_eq!(replayed, vec![b"first".to_vec(), b"second".to_vec()]);
fs::remove_dir_all(dir).ok();
}
#[test]
fn wal_rotates_segments() {
let dir = temp_dir("wal_rotate");
let reactor = std::sync::Arc::new(crate::core::reactor::DeterministicReactor::new(
std::time::SystemTime::UNIX_EPOCH,
7,
));
let mut wal = Wal::open_with_reactor(dir.clone(), 64, reactor).unwrap();
let payload = vec![b'x'; 40];
wal.append_record(&payload).unwrap();
wal.append_record(&payload).unwrap();
let segments = Wal::list_segments(&dir, &crate::core::reactor::SystemReactor);
assert!(segments.len() >= 2);
fs::remove_dir_all(dir).ok();
}
#[test]
fn wal_replay_ignores_partial_last_record() {
let dir = temp_dir("wal_partial");
let reactor = std::sync::Arc::new(crate::core::reactor::DeterministicReactor::new(
std::time::SystemTime::UNIX_EPOCH,
7,
));
let mut wal = Wal::open_with_reactor(dir.clone(), 256, reactor).unwrap();
wal.append_record(b"alpha").unwrap();
let path = wal.segment_path(wal.current_segment_id);
let mut file = OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.unwrap();
use std::io::{Seek, SeekFrom};
file.seek(SeekFrom::End(0)).unwrap();
file.write_all(&[0x05, 0x00, 0x00, 0x00, b'b', b'e'])
.unwrap();
file.flush().unwrap();
let replayed = wal.replay().unwrap();
assert_eq!(replayed, vec![b"alpha".to_vec()]);
fs::remove_dir_all(dir).ok();
}
#[test]
fn wal_manual_policy_does_not_auto_sync() {
let dir = temp_dir("wal_manual_sync");
let reactor = std::sync::Arc::new(crate::core::reactor::DeterministicReactor::new(
std::time::SystemTime::UNIX_EPOCH,
9,
));
let mut wal =
Wal::open_with_reactor_and_policy(dir.clone(), 256, WalSyncPolicy::Manual, reactor)
.unwrap();
wal.append_record(b"alpha").unwrap();
assert!(wal.last_sync_at.is_none());
wal.sync().unwrap();
assert!(wal.last_sync_at.is_some());
fs::remove_dir_all(dir).ok();
}
#[test]
fn wal_interval_policy_syncs_after_elapsed_interval() {
let dir = temp_dir("wal_interval_sync");
let reactor = std::sync::Arc::new(crate::core::reactor::DeterministicReactor::new(
std::time::SystemTime::UNIX_EPOCH,
11,
));
let mut wal = Wal::open_with_reactor_and_policy(
dir.clone(),
256,
WalSyncPolicy::Interval(std::time::Duration::from_millis(10)),
reactor.clone(),
)
.unwrap();
wal.append_record(b"first").unwrap();
assert!(wal.last_sync_at.is_some());
let first_sync = wal.last_sync_at;
wal.append_record(b"second").unwrap();
assert_eq!(wal.last_sync_at, first_sync);
reactor.sleep(std::time::Duration::from_millis(20));
wal.append_record(b"third").unwrap();
assert!(wal.last_sync_at > first_sync);
fs::remove_dir_all(dir).ok();
}
}