use std::io;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use crossbeam_queue::SegQueue;
use parking_lot::{Condvar, Mutex};
use super::writer::WalWriter;
const MAX_LEADER_SPIN_NS: u64 = 50_000;
const LEADER_PENDING_CAPACITY: usize = 64;
pub struct WalAppendCoordinator {
queue: SegQueue<(u64, Vec<u8>)>,
next_lsn: AtomicU64,
written_lsn: AtomicU64,
durable_lsn: AtomicU64,
leader_in_progress: AtomicBool,
wait_lock: Mutex<()>,
wait_cond: Condvar,
}
impl WalAppendCoordinator {
pub fn new(initial_lsn: u64, initial_durable_lsn: u64) -> Self {
Self {
queue: SegQueue::new(),
next_lsn: AtomicU64::new(initial_lsn),
written_lsn: AtomicU64::new(initial_lsn),
durable_lsn: AtomicU64::new(initial_durable_lsn),
leader_in_progress: AtomicBool::new(false),
wait_lock: Mutex::new(()),
wait_cond: Condvar::new(),
}
}
pub fn durable_lsn(&self) -> u64 {
self.durable_lsn.load(Ordering::Acquire)
}
pub fn next_lsn(&self) -> u64 {
self.next_lsn.load(Ordering::Acquire)
}
pub fn reserve_and_enqueue(&self, bytes: Vec<u8>) -> u64 {
let len = bytes.len() as u64;
let lsn = self.next_lsn.fetch_add(len, Ordering::AcqRel);
self.queue.push((lsn, bytes));
lsn + len
}
pub fn commit_at_least(&self, target: u64, wal: &Mutex<WalWriter>) -> io::Result<()> {
loop {
if self.durable_lsn.load(Ordering::Acquire) >= target {
return Ok(());
}
if self
.leader_in_progress
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
let pre_durable = self.durable_lsn.load(Ordering::Acquire);
let result = self.drive_drain(wal);
let post_durable = self.durable_lsn.load(Ordering::Acquire);
self.leader_in_progress.store(false, Ordering::Release);
{
let _g = self.wait_lock.lock();
self.wait_cond.notify_all();
}
result?;
if post_durable >= target {
return Ok(());
}
if post_durable == pre_durable {
let mut guard = self.wait_lock.lock();
if self.durable_lsn.load(Ordering::Acquire) >= target {
return Ok(());
}
self.wait_cond
.wait_for(&mut guard, Duration::from_micros(50));
}
continue;
}
let mut guard = self.wait_lock.lock();
if self.durable_lsn.load(Ordering::Acquire) >= target {
return Ok(());
}
self.wait_cond
.wait_for(&mut guard, Duration::from_millis(1));
}
}
fn drive_drain(&self, wal: &Mutex<WalWriter>) -> io::Result<()> {
let mut cursor = self.written_lsn.load(Ordering::Acquire);
let mut writeable: Vec<(u64, Vec<u8>)> = Vec::with_capacity(LEADER_PENDING_CAPACITY);
let spin_deadline = Instant::now() + Duration::from_nanos(MAX_LEADER_SPIN_NS);
loop {
let mut pending: Vec<(u64, Vec<u8>)> = Vec::with_capacity(LEADER_PENDING_CAPACITY);
while let Some(entry) = self.queue.pop() {
pending.push(entry);
}
if pending.is_empty() {
break;
}
pending.sort_by_key(|(lsn, _)| *lsn);
let mut idx = 0;
while idx < pending.len() && pending[idx].0 < cursor {
idx += 1;
}
while idx < pending.len() && pending[idx].0 == cursor {
let (_, bytes) = std::mem::take(&mut pending[idx]);
cursor += bytes.len() as u64;
writeable.push((cursor - bytes.len() as u64, bytes));
idx += 1;
}
for (lsn, bytes) in pending.drain(idx..) {
if !bytes.is_empty() {
self.queue.push((lsn, bytes));
}
}
if writeable.is_empty() && Instant::now() < spin_deadline {
std::thread::yield_now();
continue;
}
break;
}
if writeable.is_empty() {
return Ok(());
}
let target_lsn = {
let mut wal_guard = wal.lock();
for (_lsn, bytes) in &writeable {
wal_guard.append_bytes(bytes)?;
}
wal_guard.sync()?;
wal_guard.current_lsn()
};
self.written_lsn.store(target_lsn, Ordering::Release);
let prev = self.durable_lsn.load(Ordering::Acquire);
if target_lsn > prev {
self.durable_lsn.store(target_lsn, Ordering::Release);
}
Ok(())
}
pub fn reset(&self, next_lsn: u64) {
while self.queue.pop().is_some() {}
self.next_lsn.store(next_lsn, Ordering::Release);
self.written_lsn.store(next_lsn, Ordering::Release);
self.durable_lsn.store(next_lsn, Ordering::Release);
let _g = self.wait_lock.lock();
self.wait_cond.notify_all();
}
}
impl std::fmt::Debug for WalAppendCoordinator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WalAppendCoordinator")
.field("next_lsn", &self.next_lsn.load(Ordering::Acquire))
.field("written_lsn", &self.written_lsn.load(Ordering::Acquire))
.field("durable_lsn", &self.durable_lsn.load(Ordering::Acquire))
.field(
"leader_in_progress",
&self.leader_in_progress.load(Ordering::Acquire),
)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::wal::reader::WalReader;
use crate::storage::wal::record::WalRecord;
use crate::storage::wal::writer::WalWriter;
use parking_lot::Mutex as PlMutex;
use std::path::PathBuf;
use std::sync::Arc;
use std::thread;
use std::time::{SystemTime, UNIX_EPOCH};
struct FileGuard {
path: PathBuf,
}
impl Drop for FileGuard {
fn drop(&mut self) {
let _ = std::fs::remove_file(&self.path);
}
}
fn temp_wal(name: &str) -> (FileGuard, PathBuf) {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let path = std::env::temp_dir().join(format!(
"rb_wal_coord_{}_{}_{}.wal",
name,
std::process::id(),
nanos
));
let _ = std::fs::remove_file(&path);
(FileGuard { path: path.clone() }, path)
}
#[test]
fn single_writer_round_trip() {
let (_g, path) = temp_wal("single");
let wal = WalWriter::open(&path).unwrap();
let initial = wal.current_lsn();
let durable = wal.durable_lsn();
let wal = Arc::new(PlMutex::new(wal));
let coord = WalAppendCoordinator::new(initial, durable);
let mut blob = Vec::new();
blob.extend_from_slice(&WalRecord::Begin { tx_id: 1 }.encode());
blob.extend_from_slice(&WalRecord::Commit { tx_id: 1 }.encode());
let target = coord.reserve_and_enqueue(blob);
coord.commit_at_least(target, &wal).unwrap();
assert!(coord.durable_lsn() >= target);
let recs: Vec<_> = WalReader::open(&path)
.unwrap()
.iter()
.map(|r| r.unwrap().1)
.collect();
assert_eq!(recs.len(), 2);
}
#[test]
fn concurrent_writers_no_gaps_lsn_ordered() {
let (_g, path) = temp_wal("concurrent_no_gaps");
let wal = WalWriter::open(&path).unwrap();
let initial = wal.current_lsn();
let durable = wal.durable_lsn();
let wal = Arc::new(PlMutex::new(wal));
let coord = Arc::new(WalAppendCoordinator::new(initial, durable));
const WRITERS: u64 = 16;
const PER_WRITER: u64 = 50;
let mut handles = Vec::new();
for tx_base in 0..WRITERS {
let wal_c = Arc::clone(&wal);
let coord_c = Arc::clone(&coord);
handles.push(thread::spawn(move || {
for i in 0..PER_WRITER {
let tx_id = tx_base * 1000 + i;
let mut blob = Vec::new();
blob.extend_from_slice(&WalRecord::Begin { tx_id }.encode());
blob.extend_from_slice(&WalRecord::Commit { tx_id }.encode());
let target = coord_c.reserve_and_enqueue(blob);
coord_c.commit_at_least(target, &wal_c).unwrap();
}
}));
}
for h in handles {
h.join().unwrap();
}
let recs: Vec<_> = WalReader::open(&path)
.unwrap()
.iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(recs.len() as u64, WRITERS * PER_WRITER * 2);
for w in recs.windows(2) {
assert!(w[1].0 > w[0].0, "LSNs must be strictly increasing");
}
for chunk in recs.chunks_exact(2) {
match (&chunk[0].1, &chunk[1].1) {
(WalRecord::Begin { tx_id: a }, WalRecord::Commit { tx_id: b }) => {
assert_eq!(a, b, "Begin/Commit pair tx_id mismatch");
}
other => panic!("unexpected record pair: {other:?}"),
}
}
}
#[test]
fn reserved_lsn_matches_on_disk_offset() {
let (_g, path) = temp_wal("lsn_offset");
let wal = WalWriter::open(&path).unwrap();
let initial = wal.current_lsn();
let durable = wal.durable_lsn();
let wal = Arc::new(PlMutex::new(wal));
let coord = WalAppendCoordinator::new(initial, durable);
let blob = WalRecord::Begin { tx_id: 99 }.encode();
let blob_len = blob.len() as u64;
let target = coord.reserve_and_enqueue(blob);
assert_eq!(target, initial + blob_len);
coord.commit_at_least(target, &wal).unwrap();
let recs: Vec<_> = WalReader::open(&path)
.unwrap()
.iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(recs[0].0, initial);
}
#[test]
fn reset_clears_queue_and_resets_counters() {
let (_g, path) = temp_wal("reset");
let wal = WalWriter::open(&path).unwrap();
let initial = wal.current_lsn();
let wal = Arc::new(PlMutex::new(wal));
let coord = WalAppendCoordinator::new(initial, initial);
let _ = coord.reserve_and_enqueue(vec![1, 2, 3]);
let _ = coord.reserve_and_enqueue(vec![4, 5, 6]);
assert!(coord.next_lsn() > initial);
coord.reset(initial);
assert_eq!(coord.next_lsn(), initial);
assert_eq!(coord.durable_lsn(), initial);
let target = coord.reserve_and_enqueue(WalRecord::Begin { tx_id: 7 }.encode());
coord.commit_at_least(target, &wal).unwrap();
assert_eq!(coord.durable_lsn(), target);
}
#[test]
fn writer_crash_between_reserve_and_push_keeps_file_clean() {
let (_g, path) = temp_wal("writer_crash");
let wal = WalWriter::open(&path).unwrap();
let initial = wal.current_lsn();
let wal = Arc::new(PlMutex::new(wal));
let coord = Arc::new(WalAppendCoordinator::new(initial, initial));
let stuck_len = 10u64;
let _stuck_lsn = coord.next_lsn.fetch_add(stuck_len, Ordering::AcqRel);
let acquired = coord
.leader_in_progress
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok();
assert!(acquired, "test must own the leader flag");
let _ = coord.drive_drain(&wal);
coord.leader_in_progress.store(false, Ordering::Release);
assert_eq!(coord.durable_lsn(), initial);
let on_disk_len = std::fs::metadata(&path).unwrap().len();
assert_eq!(on_disk_len, initial);
}
#[test]
fn writer_a_succeeds_when_b_crashes_before_c() {
let (_g, path) = temp_wal("abc_crash");
let wal = WalWriter::open(&path).unwrap();
let initial = wal.current_lsn();
let wal = Arc::new(PlMutex::new(wal));
let coord = Arc::new(WalAppendCoordinator::new(initial, initial));
let blob_a = WalRecord::Begin { tx_id: 1 }.encode();
let len_a = blob_a.len() as u64;
let target_a = coord.reserve_and_enqueue(blob_a);
let stuck_len = 13u64;
let _stuck_lsn = coord.next_lsn.fetch_add(stuck_len, Ordering::AcqRel);
let blob_c = WalRecord::Begin { tx_id: 3 }.encode();
let _target_c = coord.reserve_and_enqueue(blob_c);
coord.commit_at_least(target_a, &wal).unwrap();
assert_eq!(coord.durable_lsn(), initial + len_a);
let on_disk_len = std::fs::metadata(&path).unwrap().len();
assert_eq!(on_disk_len, initial + len_a);
}
}