use std::io;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Condvar, Mutex};
use super::writer::WalWriter;
struct GroupCommitState {
in_progress: bool,
}
pub struct GroupCommit {
flushed_lsn: AtomicU64,
state: Mutex<GroupCommitState>,
cond: Condvar,
}
struct LeadershipGuard<'a> {
group_commit: &'a GroupCommit,
}
impl GroupCommit {
pub fn new(initial_durable_lsn: u64) -> Self {
Self {
flushed_lsn: AtomicU64::new(initial_durable_lsn),
state: Mutex::new(GroupCommitState { in_progress: false }),
cond: Condvar::new(),
}
}
pub fn flushed_lsn(&self) -> u64 {
self.flushed_lsn.load(Ordering::Acquire)
}
pub fn commit_at_least(&self, target: u64, wal: &Mutex<WalWriter>) -> io::Result<()> {
if self.flushed_lsn.load(Ordering::Acquire) >= target {
return Ok(());
}
let mut state = self.state.lock().unwrap_or_else(|p| p.into_inner());
loop {
if self.flushed_lsn.load(Ordering::Acquire) >= target {
return Ok(());
}
if state.in_progress {
state = self
.cond
.wait_while(state, |state| {
state.in_progress && self.flushed_lsn.load(Ordering::Acquire) < target
})
.unwrap_or_else(|p| p.into_inner());
continue;
}
state.in_progress = true;
break;
}
drop(state);
let _leader = LeadershipGuard { group_commit: self };
let (target_lsn, sync_handle) = {
let mut wal_guard = wal.lock().unwrap_or_else(|p| p.into_inner());
wal_guard.drain_for_group_sync()?
};
sync_handle.sync_all()?;
{
let mut wal_guard = wal.lock().unwrap_or_else(|p| p.into_inner());
wal_guard.mark_durable(target_lsn);
}
self.flushed_lsn.store(target_lsn, Ordering::Release);
Ok(())
}
fn release_leadership(&self) {
let mut state = self.state.lock().unwrap_or_else(|p| p.into_inner());
state.in_progress = false;
drop(state);
self.cond.notify_all();
}
}
impl std::fmt::Debug for GroupCommit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GroupCommit")
.field("flushed_lsn", &self.flushed_lsn.load(Ordering::Acquire))
.finish()
}
}
impl Drop for LeadershipGuard<'_> {
fn drop(&mut self) {
self.group_commit.release_leadership();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::wal::record::WalRecord;
use crate::storage::wal::writer::WalWriter;
use std::path::PathBuf;
use std::sync::mpsc;
use std::sync::Arc;
use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
struct FileGuard {
path: PathBuf,
}
impl Drop for FileGuard {
fn drop(&mut self) {
let _ = std::fs::remove_file(&self.path);
}
}
fn temp_wal(name: &str) -> (FileGuard, PathBuf) {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let path = std::env::temp_dir().join(format!(
"rb_group_commit_{}_{}_{}.wal",
name,
std::process::id(),
nanos
));
let _ = std::fs::remove_file(&path);
(FileGuard { path: path.clone() }, path)
}
#[test]
fn fast_path_when_already_flushed() {
let (_g, path) = temp_wal("fast_path");
let wal = Mutex::new(WalWriter::open(&path).unwrap());
let initial = wal.lock().unwrap().durable_lsn();
let gc = GroupCommit::new(initial);
gc.commit_at_least(initial, &wal).unwrap();
assert_eq!(gc.flushed_lsn(), initial);
}
#[test]
fn single_writer_advances_flushed_lsn() {
let (_g, path) = temp_wal("single_writer");
let wal = Mutex::new(WalWriter::open(&path).unwrap());
let initial = wal.lock().unwrap().durable_lsn();
let gc = GroupCommit::new(initial);
let target = {
let mut w = wal.lock().unwrap();
w.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
w.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
w.current_lsn()
};
assert!(target > initial);
gc.commit_at_least(target, &wal).unwrap();
assert!(gc.flushed_lsn() >= target);
}
#[test]
fn flushed_lsn_is_monotonic() {
let (_g, path) = temp_wal("monotonic");
let wal = Mutex::new(WalWriter::open(&path).unwrap());
let initial = wal.lock().unwrap().durable_lsn();
let gc = GroupCommit::new(initial);
let lo = {
let mut w = wal.lock().unwrap();
w.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
w.current_lsn()
};
gc.commit_at_least(lo, &wal).unwrap();
let after_lo = gc.flushed_lsn();
let hi = {
let mut w = wal.lock().unwrap();
w.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
w.current_lsn()
};
gc.commit_at_least(hi, &wal).unwrap();
let after_hi = gc.flushed_lsn();
assert!(after_hi >= after_lo);
gc.commit_at_least(lo, &wal).unwrap();
assert_eq!(gc.flushed_lsn(), after_hi);
}
#[test]
fn concurrent_writers_coalesce_through_one_coordinator() {
let (_g, path) = temp_wal("two_writers");
let wal = Arc::new(Mutex::new(WalWriter::open(&path).unwrap()));
let initial = wal.lock().unwrap().durable_lsn();
let gc = Arc::new(GroupCommit::new(initial));
let mut handles = Vec::new();
for tx in 0..2u64 {
let wal_c = Arc::clone(&wal);
let gc_c = Arc::clone(&gc);
handles.push(thread::spawn(move || -> io::Result<()> {
for i in 0..10u64 {
let target = {
let mut w = wal_c.lock().unwrap();
w.append(&WalRecord::Begin {
tx_id: tx * 100 + i,
})?;
w.append(&WalRecord::Commit {
tx_id: tx * 100 + i,
})?;
w.current_lsn()
};
gc_c.commit_at_least(target, &wal_c)?;
}
Ok(())
}));
}
for h in handles {
h.join().unwrap().unwrap();
}
let final_durable = wal.lock().unwrap().durable_lsn();
assert!(gc.flushed_lsn() >= final_durable);
assert!(final_durable >= 8 + 520);
}
#[test]
fn waiter_retries_after_wakeup_if_previous_flush_was_too_small() {
let (_g, path) = temp_wal("waiter_retry");
let wal = Arc::new(Mutex::new(WalWriter::open(&path).unwrap()));
let initial = wal.lock().unwrap().durable_lsn();
let gc = Arc::new(GroupCommit::new(initial));
let target = {
let mut w = wal.lock().unwrap();
w.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
w.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
w.current_lsn()
};
{
let mut state = gc.state.lock().unwrap();
state.in_progress = true;
}
let (done_tx, done_rx) = mpsc::channel();
let wal_c = Arc::clone(&wal);
let gc_c = Arc::clone(&gc);
let waiter = thread::spawn(move || {
let result = gc_c.commit_at_least(target, &wal_c);
let _ = done_tx.send(result);
});
thread::sleep(Duration::from_millis(50));
{
let mut state = gc.state.lock().unwrap_or_else(|p| p.into_inner());
state.in_progress = false;
}
gc.cond.notify_all();
done_rx
.recv_timeout(Duration::from_secs(2))
.expect("waiter should retry as the next leader")
.unwrap();
waiter.join().unwrap();
assert!(gc.flushed_lsn() >= target);
assert!(wal.lock().unwrap().durable_lsn() >= target);
}
#[test]
fn high_concurrency_eight_writers_no_deadlock() {
let (_g, path) = temp_wal("eight_writers");
let wal = Arc::new(Mutex::new(WalWriter::open(&path).unwrap()));
let initial = wal.lock().unwrap().durable_lsn();
let gc = Arc::new(GroupCommit::new(initial));
let mut handles = Vec::new();
for tx in 0..8u64 {
let wal_c = Arc::clone(&wal);
let gc_c = Arc::clone(&gc);
handles.push(thread::spawn(move || -> io::Result<()> {
for i in 0..50u64 {
let target = {
let mut w = wal_c.lock().unwrap();
w.append(&WalRecord::Begin {
tx_id: tx * 1000 + i,
})?;
w.append(&WalRecord::Commit {
tx_id: tx * 1000 + i,
})?;
w.current_lsn()
};
gc_c.commit_at_least(target, &wal_c)?;
}
Ok(())
}));
}
for h in handles {
h.join().unwrap().unwrap();
}
let current = wal.lock().unwrap().current_lsn();
let durable = wal.lock().unwrap().durable_lsn();
assert_eq!(durable, current, "every appended byte must be durable");
assert!(gc.flushed_lsn() >= current);
}
#[test]
fn writers_recover_from_poisoned_state() {
let (_g, path) = temp_wal("poison_recovery");
let wal = Arc::new(Mutex::new(WalWriter::open(&path).unwrap()));
let initial = wal.lock().unwrap().durable_lsn();
let gc = Arc::new(GroupCommit::new(initial));
let gc_c = Arc::clone(&gc);
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _state = gc_c.state.lock().unwrap();
panic!("intentional poison");
}));
let target = {
let mut w = wal.lock().unwrap();
w.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
w.current_lsn()
};
gc.commit_at_least(target, &wal).unwrap();
assert!(gc.flushed_lsn() >= target);
}
}