use std::collections::BTreeMap;
use crate::{
error::{Result, WalError},
store::WalStore,
sync::{Condvar, Mutex, MutexGuard},
};
#[derive(Debug)]
pub(crate) struct Commit {
state: Mutex<State>,
cond: Condvar,
}
#[derive(Debug)]
struct State {
committed: u64,
durable: u64,
pending: BTreeMap<u64, u64>,
syncing: bool,
waiters: usize,
poison: Option<u64>,
}
impl Commit {
pub(crate) fn new(recovered: u64) -> Self {
Commit {
state: Mutex::new(State {
committed: recovered,
durable: recovered,
pending: BTreeMap::new(),
syncing: false,
waiters: 0,
poison: None,
}),
cond: Condvar::new(),
}
}
fn lock(&self) -> MutexGuard<'_, State> {
match self.state.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
pub(crate) fn mark_written(&self, start: u64, end: u64) {
let mut state = self.lock();
state.insert_region(start, end);
if state.waiters > 0 {
self.cond.notify_all();
}
}
pub(crate) fn mark_failed(&self, start: u64) {
let mut state = self.lock();
state.poison = Some(state.poison.map_or(start, |p| p.min(start)));
if state.waiters > 0 {
self.cond.notify_all();
}
}
pub(crate) fn committed(&self) -> u64 {
self.lock().committed
}
pub(crate) fn reset(&self, offset: u64) {
let mut state = self.lock();
state.committed = offset;
state.durable = offset;
state.pending.clear();
state.poison = None;
self.cond.notify_all();
}
pub(crate) fn sync_to<S: WalStore>(&self, store: &S, target: u64) -> Result<()> {
let mut state = self.lock();
loop {
if let Some(poison) = state.poison {
if target > poison {
return Err(WalError::corruption(
poison,
"a record write failed; the log is truncated at this offset",
));
}
}
if state.durable >= target {
return Ok(());
}
if state.committed >= target && !state.syncing {
state.syncing = true;
let flush_to = state.committed;
drop(state);
let result = store.sync();
state = self.lock();
state.syncing = false;
match result {
Ok(()) => {
if state.durable < flush_to {
state.durable = flush_to;
}
self.cond.notify_all();
}
Err(error) => {
self.cond.notify_all();
return Err(error);
}
}
} else {
state.waiters += 1;
state = match self.cond.wait(state) {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
state.waiters -= 1;
}
}
}
}
impl State {
fn insert_region(&mut self, start: u64, end: u64) {
if start == self.committed {
self.committed = end;
while let Some(next_end) = self.pending.remove(&self.committed) {
self.committed = next_end;
}
} else {
let _ = self.pending.insert(start, end);
}
}
}
#[cfg(all(test, not(loom)))]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::store::MemStore;
#[test]
fn test_in_order_completion_advances_committed() {
let commit = Commit::new(0);
commit.mark_written(0, 10);
assert_eq!(commit.committed(), 10);
commit.mark_written(10, 25);
assert_eq!(commit.committed(), 25);
}
#[test]
fn test_out_of_order_completion_holds_until_gap_fills() {
let commit = Commit::new(0);
commit.mark_written(10, 25);
assert_eq!(commit.committed(), 0);
commit.mark_written(0, 10);
assert_eq!(commit.committed(), 25);
}
#[test]
fn test_sync_to_covered_target_is_durable() {
let store = MemStore::new();
let commit = Commit::new(0);
commit.mark_written(0, 32);
commit.sync_to(&store, 32).unwrap();
commit.sync_to(&store, 32).unwrap();
}
#[test]
fn test_sync_past_poison_errors() {
let store = MemStore::new();
let commit = Commit::new(0);
commit.mark_written(0, 10);
commit.mark_failed(10);
commit.sync_to(&store, 10).unwrap();
let err = commit.sync_to(&store, 11).unwrap_err();
assert!(matches!(err, WalError::Corruption { .. }));
}
}