use std::{fmt, io, sync::atomic::Ordering};
#[cfg(not(loom))]
use std::{cell::RefCell, path::Path};
use crate::{
commit::Commit,
config::{RecoveryPolicy, WalConfig},
error::{Result, WalError},
lsn::Lsn,
record::{self, HEADER_LEN},
store::{FileStore, WalStore},
sync::AtomicU64,
};
#[repr(align(64))]
#[derive(Debug)]
struct CacheAligned<T>(T);
pub struct Wal<S = FileStore> {
tail: CacheAligned<AtomicU64>,
store: S,
max_record_size: u32,
recovery_policy: RecoveryPolicy,
commit: Commit,
}
#[cfg(not(loom))]
impl Wal<FileStore> {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
Self::open_with(path, WalConfig::new())
}
pub fn open_with(path: impl AsRef<Path>, config: WalConfig) -> Result<Self> {
let store = FileStore::open(path)?;
Self::with_store_and_config(store, config)
}
}
#[cfg(not(loom))]
impl Wal<crate::segment::SegmentedStore> {
pub fn open_segmented(dir: impl AsRef<Path>, segment_size: u64) -> Result<Self> {
Self::open_segmented_with(dir, segment_size, WalConfig::new())
}
pub fn open_segmented_with(
dir: impl AsRef<Path>,
segment_size: u64,
config: WalConfig,
) -> Result<Self> {
let store = crate::segment::SegmentedStore::open(dir, segment_size)?;
Self::with_store_and_config(store, config)
}
}
impl<S: WalStore> Wal<S> {
pub fn with_store(store: S) -> Result<Self> {
Self::with_store_and_config(store, WalConfig::new())
}
pub fn with_store_and_config(store: S, config: WalConfig) -> Result<Self> {
let recovered = recover(&store, config.max_record_size())?;
Ok(Wal {
tail: CacheAligned(AtomicU64::new(recovered)),
store,
max_record_size: config.max_record_size(),
recovery_policy: config.recovery_policy(),
commit: Commit::new(recovered),
})
}
pub fn append(&self, record: &[u8]) -> Result<Lsn> {
let payload_len = record.len();
if payload_len > self.max_record_size as usize {
return Err(WalError::RecordTooLarge {
len: payload_len,
max: self.max_record_size,
});
}
let frame_len = record::framed_len(payload_len) as u64;
let start = self.tail.0.fetch_add(frame_len, Ordering::Relaxed);
let end = match start.checked_add(frame_len) {
Some(end) => end,
None => {
self.commit.mark_failed(start);
return Err(WalError::io(
"reserving a record offset",
io::Error::other("log size exceeds u64"),
));
}
};
match self.frame_and_write(start, record) {
Ok(()) => {
self.commit.mark_written(start, end);
Ok(Lsn::new(start))
}
Err(error) => {
self.commit.mark_failed(start);
Err(error)
}
}
}
pub fn sync(&self) -> Result<()> {
let target = self.tail.0.load(Ordering::Acquire);
if target == 0 {
return Ok(());
}
self.commit.sync_to(&self.store, target)
}
pub fn append_and_sync(&self, record: &[u8]) -> Result<Lsn> {
let lsn = self.append(record)?;
let end = lsn.get() + record::framed_len(record.len()) as u64;
self.commit.sync_to(&self.store, end)?;
Ok(lsn)
}
#[cfg(feature = "pack-io")]
pub fn append_typed<T: pack_io::Serialize + ?Sized>(&self, value: &T) -> Result<Lsn> {
let bytes = pack_io::encode(value).map_err(WalError::encoding)?;
self.append(&bytes)
}
pub fn iter(&self) -> Result<WalIter<'_, S>> {
let end = self.commit.committed();
Ok(WalIter {
wal: self,
offset: 0,
end,
done: false,
policy: self.recovery_policy,
})
}
pub fn iter_from(&self, from: Lsn) -> Result<WalIter<'_, S>> {
let end = self.commit.committed();
Ok(WalIter {
wal: self,
offset: from.get().min(end),
end,
done: false,
policy: self.recovery_policy,
})
}
pub fn truncate_after(&self, lsn: Lsn) -> Result<()> {
let start = lsn.get();
let mut header = [0u8; HEADER_LEN];
if self.store.read_at(start, &mut header)? < HEADER_LEN {
return Err(WalError::corruption(start, "no record at this LSN"));
}
let parsed = record::parse_header(&header);
if parsed.len > self.max_record_size {
return Err(WalError::corruption(start, "no valid record at this LSN"));
}
let payload_start = start
.checked_add(HEADER_LEN as u64)
.ok_or_else(|| WalError::corruption(start, "record offset overflow"))?;
let mut payload = vec![0u8; parsed.len as usize];
if self.store.read_at(payload_start, &mut payload)? < payload.len() {
return Err(WalError::corruption(start, "incomplete record at this LSN"));
}
if !record::verify(&header, &payload, parsed.crc) {
return Err(WalError::corruption(start, "no valid record at this LSN"));
}
let new_end = payload_start
.checked_add(u64::from(parsed.len))
.ok_or_else(|| WalError::corruption(start, "record offset overflow"))?;
self.store.truncate(new_end)?;
self.store.sync()?;
self.tail.0.store(new_end, Ordering::Release);
self.commit.reset(new_end);
Ok(())
}
#[must_use]
pub fn len(&self) -> u64 {
self.tail.0.load(Ordering::Acquire)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
fn frame_and_write(&self, start: u64, record: &[u8]) -> Result<()> {
with_frame_buffer(|buf| {
record::encode(buf, record);
self.store.write_at(start, buf)
})
}
#[cfg(test)]
pub(crate) fn store(&self) -> &S {
&self.store
}
}
#[cfg(not(loom))]
fn with_frame_buffer<R>(f: impl FnOnce(&mut Vec<u8>) -> R) -> R {
thread_local! {
static FRAME: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
}
FRAME.with(|cell| f(&mut cell.borrow_mut()))
}
#[cfg(loom)]
fn with_frame_buffer<R>(f: impl FnOnce(&mut Vec<u8>) -> R) -> R {
let mut buf = Vec::new();
f(&mut buf)
}
fn recover<S: WalStore>(store: &S, max_record_size: u32) -> Result<u64> {
let physical = store.len()?;
let mut offset: u64 = 0;
let mut header = [0u8; HEADER_LEN];
while offset < physical {
if store.read_at(offset, &mut header)? < HEADER_LEN {
break; }
let parsed = record::parse_header(&header);
if parsed.len > max_record_size {
break; }
let payload_start = match offset.checked_add(HEADER_LEN as u64) {
Some(start) => start,
None => break,
};
let mut payload = vec![0u8; parsed.len as usize];
if store.read_at(payload_start, &mut payload)? < payload.len() {
break; }
if !record::verify(&header, &payload, parsed.crc) {
break; }
offset = match payload_start.checked_add(u64::from(parsed.len)) {
Some(end) => end,
None => break,
};
}
if offset < physical {
store.truncate(offset)?;
}
Ok(offset)
}
impl<S: WalStore> fmt::Debug for Wal<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Wal")
.field("len", &self.tail.0.load(Ordering::Relaxed))
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Record {
lsn: Lsn,
data: Vec<u8>,
}
impl Record {
pub fn lsn(&self) -> Lsn {
self.lsn
}
#[must_use]
pub fn data(&self) -> &[u8] {
&self.data
}
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[must_use]
pub fn into_data(self) -> Vec<u8> {
self.data
}
#[cfg(feature = "pack-io")]
pub fn decode<T: pack_io::Deserialize>(&self) -> Result<T> {
pack_io::decode(&self.data).map_err(WalError::encoding)
}
}
enum Step {
Record(Record, u64),
Damaged(WalError, Option<u64>),
End,
}
pub struct WalIter<'a, S: WalStore = FileStore> {
wal: &'a Wal<S>,
offset: u64,
end: u64,
done: bool,
policy: RecoveryPolicy,
}
impl<S: WalStore> WalIter<'_, S> {
fn step(&self) -> Result<Step> {
let mut header = [0u8; HEADER_LEN];
if self.wal.store.read_at(self.offset, &mut header)? < HEADER_LEN {
return Ok(Step::End);
}
let parsed = record::parse_header(&header);
if parsed.len > self.wal.max_record_size {
return Ok(Step::Damaged(
WalError::corruption(self.offset, "record length exceeds the maximum"),
None,
));
}
let payload_start = self
.offset
.checked_add(HEADER_LEN as u64)
.ok_or_else(|| WalError::corruption(self.offset, "record offset overflow"))?;
let mut payload = vec![0u8; parsed.len as usize];
if self.wal.store.read_at(payload_start, &mut payload)? < payload.len() {
return Ok(Step::End);
}
let next = payload_start
.checked_add(u64::from(parsed.len))
.ok_or_else(|| WalError::corruption(self.offset, "record offset overflow"))?;
if !record::verify(&header, &payload, parsed.crc) {
return Ok(Step::Damaged(
WalError::corruption(self.offset, "checksum mismatch"),
Some(next),
));
}
Ok(Step::Record(
Record {
lsn: Lsn::new(self.offset),
data: payload,
},
next,
))
}
}
impl<S: WalStore> Iterator for WalIter<'_, S> {
type Item = Result<Record>;
fn next(&mut self) -> Option<Self::Item> {
if self.done || self.offset >= self.end {
return None;
}
match self.step() {
Ok(Step::Record(record, next)) => {
self.offset = next;
Some(Ok(record))
}
Ok(Step::Damaged(error, Some(next)))
if self.policy == RecoveryPolicy::SkipBadRecords =>
{
self.offset = next;
Some(Err(error))
}
Ok(Step::Damaged(error, _)) => {
self.done = true;
Some(Err(error))
}
Ok(Step::End) => {
self.done = true;
None
}
Err(error) => {
self.done = true;
Some(Err(error))
}
}
}
}
impl<S: WalStore> fmt::Debug for WalIter<'_, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WalIter")
.field("offset", &self.offset)
.field("end", &self.end)
.field("done", &self.done)
.finish()
}
}
#[cfg(all(test, not(loom)))]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
unused_must_use,
unused_results
)]
mod tests {
use std::sync::Arc;
use std::thread;
use super::*;
use crate::store::MemStore;
fn drain(wal: &Wal<MemStore>) -> Vec<Vec<u8>> {
wal.iter()
.unwrap()
.map(|r| r.unwrap().into_data())
.collect()
}
fn corrupt_byte(store: &MemStore, offset: u64) {
let mut byte = [0u8; 1];
store.read_at(offset, &mut byte).unwrap();
byte[0] ^= 0xFF;
store.write_at(offset, &byte).unwrap();
}
#[test]
fn test_stop_at_first_error_stops_at_corruption() {
let wal = Wal::with_store(MemStore::new()).unwrap(); wal.append(b"first").unwrap();
let second = wal.append(b"second").unwrap();
wal.append(b"third").unwrap();
corrupt_byte(wal.store(), second.get() + HEADER_LEN as u64);
let items: Vec<_> = wal.iter().unwrap().collect();
assert_eq!(items.len(), 2); assert_eq!(items[0].as_ref().unwrap().data(), b"first");
assert!(matches!(items[1], Err(WalError::Corruption { .. })));
}
#[test]
fn test_skip_bad_records_continues_past_corruption() {
let config = WalConfig::new().with_recovery_policy(RecoveryPolicy::SkipBadRecords);
let wal = Wal::with_store_and_config(MemStore::new(), config).unwrap();
wal.append(b"first").unwrap();
let second = wal.append(b"second").unwrap();
wal.append(b"third").unwrap();
corrupt_byte(wal.store(), second.get() + HEADER_LEN as u64);
let items: Vec<_> = wal.iter().unwrap().collect();
assert_eq!(items.len(), 3);
assert_eq!(items[0].as_ref().unwrap().data(), b"first");
assert!(matches!(items[1], Err(WalError::Corruption { .. })));
assert_eq!(items[2].as_ref().unwrap().data(), b"third");
}
#[test]
fn test_skip_bad_records_still_stops_on_unreadable_length() {
let config = WalConfig::new()
.with_max_record_size(16)
.with_recovery_policy(RecoveryPolicy::SkipBadRecords);
let wal = Wal::with_store_and_config(MemStore::new(), config).unwrap();
wal.append(b"ok").unwrap();
let second = wal.append(b"victim").unwrap();
corrupt_byte(wal.store(), second.get() + 4);
let items: Vec<_> = wal.iter().unwrap().collect();
assert_eq!(items.len(), 2); assert_eq!(items[0].as_ref().unwrap().data(), b"ok");
assert!(matches!(items[1], Err(WalError::Corruption { .. })));
}
#[cfg(feature = "pack-io")]
#[test]
fn test_typed_record_roundtrip() {
use pack_io::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct Entry {
id: u64,
label: String,
}
let wal = Wal::with_store(MemStore::new()).unwrap();
wal.append_typed(&Entry {
id: 9,
label: "nine".into(),
})
.unwrap();
wal.append_typed(&Entry {
id: 10,
label: "ten".into(),
})
.unwrap();
let decoded: Vec<Entry> = wal
.iter()
.unwrap()
.map(|r| r.unwrap().decode().unwrap())
.collect();
assert_eq!(
decoded[0],
Entry {
id: 9,
label: "nine".into()
}
);
assert_eq!(
decoded[1],
Entry {
id: 10,
label: "ten".into()
}
);
}
#[cfg(feature = "pack-io")]
#[test]
fn test_typed_decode_wrong_type_errors() {
use pack_io::{Deserialize, Serialize};
#[derive(Serialize)]
struct Big {
a: u64,
b: u64,
c: u64,
}
#[derive(Deserialize)]
struct Small {
_a: u8,
}
let wal = Wal::with_store(MemStore::new()).unwrap();
wal.append_typed(&Big { a: 1, b: 2, c: 3 }).unwrap();
let record = wal.iter().unwrap().next().unwrap().unwrap();
let result: Result<Small> = record.decode();
assert!(matches!(result, Err(WalError::Encoding { .. })));
}
#[test]
fn test_append_assigns_byte_offset_lsns() {
let wal = Wal::with_store(MemStore::new()).unwrap();
let a = wal.append(b"abc").unwrap(); let b = wal.append(b"de").unwrap();
assert_eq!(a.get(), 0);
assert_eq!(b.get(), 11);
}
#[test]
fn test_iter_reads_back_all_records_in_order() {
let wal = Wal::with_store(MemStore::new()).unwrap();
wal.append(b"one").unwrap();
wal.append(b"two").unwrap();
wal.append(b"three").unwrap();
assert_eq!(
drain(&wal),
vec![b"one".to_vec(), b"two".to_vec(), b"three".to_vec()]
);
}
#[test]
fn test_empty_log_iterates_to_nothing() {
let wal = Wal::with_store(MemStore::new()).unwrap();
assert!(wal.is_empty());
assert_eq!(drain(&wal).len(), 0);
}
#[test]
fn test_empty_record_roundtrips() {
let wal = Wal::with_store(MemStore::new()).unwrap();
wal.append(b"").unwrap();
assert_eq!(drain(&wal), vec![Vec::<u8>::new()]);
}
#[test]
fn test_record_too_large_is_rejected() {
let config = WalConfig::new().with_max_record_size(4);
let wal = Wal::with_store_and_config(MemStore::new(), config).unwrap();
wal.append(b"ok").unwrap();
let err = wal.append(b"too long").unwrap_err();
assert!(matches!(err, WalError::RecordTooLarge { len: 8, max: 4 }));
assert_eq!(drain(&wal), vec![b"ok".to_vec()]);
}
#[test]
fn test_reopen_recovers_records() {
let wal = Wal::with_store(MemStore::new()).unwrap();
wal.append(b"first").unwrap();
wal.append(b"second").unwrap();
wal.sync().unwrap();
let image = wal.store().snapshot();
let reopened = Wal::with_store(MemStore::from_bytes(image)).unwrap();
assert_eq!(
drain(&reopened),
vec![b"first".to_vec(), b"second".to_vec()]
);
assert_eq!(reopened.append(b"third").unwrap().get(), 27);
}
#[test]
fn test_recovery_truncates_torn_tail() {
let wal = Wal::with_store(MemStore::new()).unwrap();
wal.append(b"good record").unwrap();
let clean_len = wal.len();
wal.store().write_at(clean_len, &[0xAB; 5]).unwrap();
let reopened = Wal::with_store(MemStore::from_bytes(wal.store().snapshot())).unwrap();
assert_eq!(drain(&reopened), vec![b"good record".to_vec()]);
assert_eq!(reopened.len(), clean_len);
}
#[test]
fn test_corrupt_record_surfaces_error_then_stops() {
let wal = Wal::with_store(MemStore::new()).unwrap();
wal.append(b"intact").unwrap();
let second = wal.append(b"victim").unwrap();
let payload_offset = second.get() + HEADER_LEN as u64;
let mut byte = [0u8; 1];
wal.store().read_at(payload_offset, &mut byte).unwrap();
byte[0] ^= 0xFF;
wal.store().write_at(payload_offset, &byte).unwrap();
let mut iter = wal.iter().unwrap();
assert_eq!(iter.next().unwrap().unwrap().data(), b"intact");
assert!(matches!(
iter.next().unwrap(),
Err(WalError::Corruption { .. })
));
assert!(iter.next().is_none());
}
#[test]
fn test_append_and_sync_is_durable() {
let wal = Wal::with_store(MemStore::new()).unwrap();
wal.append_and_sync(b"committed").unwrap();
assert_eq!(drain(&wal), vec![b"committed".to_vec()]);
}
#[test]
fn test_iter_from_seeks_to_lsn() {
let wal = Wal::with_store(MemStore::new()).unwrap();
wal.append(b"a").unwrap();
let b = wal.append(b"b").unwrap();
wal.append(b"c").unwrap();
let got: Vec<Vec<u8>> = wal
.iter_from(b)
.unwrap()
.map(|r| r.unwrap().into_data())
.collect();
assert_eq!(got, vec![b"b".to_vec(), b"c".to_vec()]);
}
#[test]
fn test_iter_from_past_end_is_empty() {
let wal = Wal::with_store(MemStore::new()).unwrap();
wal.append(b"a").unwrap();
assert_eq!(wal.iter_from(Lsn::new(9_999)).unwrap().count(), 0);
}
#[test]
fn test_truncate_after_drops_later_records() {
let wal = Wal::with_store(MemStore::new()).unwrap();
wal.append(b"first").unwrap(); let keep = wal.append(b"second").unwrap(); wal.append(b"third").unwrap();
wal.append(b"fourth").unwrap();
wal.truncate_after(keep).unwrap();
assert_eq!(drain(&wal), vec![b"first".to_vec(), b"second".to_vec()]);
assert_eq!(wal.len(), 27);
assert_eq!(wal.append(b"new").unwrap().get(), 27);
assert_eq!(
drain(&wal),
vec![b"first".to_vec(), b"second".to_vec(), b"new".to_vec()]
);
}
#[test]
fn test_truncate_after_keeping_last_record_is_a_no_op() {
let wal = Wal::with_store(MemStore::new()).unwrap();
wal.append(b"first").unwrap();
let last = wal.append(b"second").unwrap();
let before = wal.len();
wal.truncate_after(last).unwrap();
assert_eq!(wal.len(), before);
assert_eq!(drain(&wal), vec![b"first".to_vec(), b"second".to_vec()]);
}
#[test]
fn test_truncate_after_invalid_lsn_errors() {
let config = WalConfig::new().with_max_record_size(64);
let wal = Wal::with_store_and_config(MemStore::new(), config).unwrap();
wal.append(b"only record").unwrap();
let err = wal.truncate_after(Lsn::new(3)).unwrap_err();
assert!(matches!(err, WalError::Corruption { .. }));
}
#[test]
fn test_concurrent_appends_no_overlap_all_recovered() {
const THREADS: usize = 8;
const PER_THREAD: usize = 200;
let wal = Arc::new(Wal::with_store(MemStore::new()).unwrap());
let mut handles = Vec::new();
for t in 0..THREADS {
let wal = Arc::clone(&wal);
handles.push(thread::spawn(move || {
let mut lsns = Vec::with_capacity(PER_THREAD);
for i in 0..PER_THREAD {
let payload = format!("t{t}-r{i}").into_bytes();
lsns.push(wal.append(&payload).unwrap().get());
}
lsns
}));
}
let mut all_lsns = Vec::new();
for h in handles {
all_lsns.extend(h.join().unwrap());
}
wal.sync().unwrap();
let mut sorted = all_lsns.clone();
sorted.sort_unstable();
sorted.dedup();
assert_eq!(sorted.len(), THREADS * PER_THREAD);
let records = drain(&wal);
assert_eq!(records.len(), THREADS * PER_THREAD);
let reopened = Wal::with_store(MemStore::from_bytes(wal.store().snapshot())).unwrap();
assert_eq!(reopened.iter().unwrap().count(), THREADS * PER_THREAD);
}
#[test]
fn test_concurrent_append_and_sync_all_durable() {
const THREADS: usize = 8;
let wal = Arc::new(Wal::with_store(MemStore::new()).unwrap());
let mut handles = Vec::new();
for t in 0..THREADS {
let wal = Arc::clone(&wal);
handles.push(thread::spawn(move || {
for i in 0..50 {
wal.append_and_sync(format!("{t}:{i}").as_bytes()).unwrap();
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(drain(&wal).len(), THREADS * 50);
}
}