use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use crate::error::{Result, WalError};
use crate::writer::WalWriter;
#[derive(Debug)]
pub struct PendingWrite {
pub record_type: u16,
pub tenant_id: u32,
pub vshard_id: u16,
pub payload: Vec<u8>,
}
#[derive(Debug, Clone, Copy)]
pub struct CommitResult {
pub lsn: u64,
pub durable: bool,
}
pub struct GroupCommitter {
pending: Mutex<Vec<PendingWrite>>,
durable_lsn: AtomicU64,
last_commit_failed: AtomicBool,
commit_lock: Mutex<()>,
}
impl GroupCommitter {
pub fn new() -> Self {
Self {
pending: Mutex::new(Vec::with_capacity(1024)),
durable_lsn: AtomicU64::new(0),
last_commit_failed: AtomicBool::new(false),
commit_lock: Mutex::new(()),
}
}
pub fn submit(&self, writer: &Mutex<WalWriter>, write: PendingWrite) -> Result<CommitResult> {
{
let mut pending = self.pending.lock().map_err(|_| WalError::LockPoisoned {
context: "pending queue",
})?;
pending.push(write);
}
let _commit_guard = self
.commit_lock
.lock()
.map_err(|_| WalError::LockPoisoned {
context: "commit lock",
})?;
let batch: Vec<PendingWrite> = {
let mut pending = self.pending.lock().map_err(|_| WalError::LockPoisoned {
context: "pending queue (drain)",
})?;
std::mem::take(&mut *pending)
};
if batch.is_empty() {
if self.last_commit_failed.load(Ordering::Acquire) {
return Err(WalError::Io(std::io::Error::other(
"WAL fsync failed in previous group commit batch",
)));
}
let lsn = self.durable_lsn.load(Ordering::Acquire);
return Ok(CommitResult { lsn, durable: true });
}
let mut wal = writer.lock().map_err(|_| WalError::LockPoisoned {
context: "WAL writer",
})?;
let mut last_lsn = 0;
for w in &batch {
last_lsn = wal.append(w.record_type, w.tenant_id, w.vshard_id, &w.payload)?;
}
let sync_result = wal.sync();
drop(wal);
match sync_result {
Ok(()) => {
self.last_commit_failed.store(false, Ordering::Release);
self.durable_lsn.store(last_lsn, Ordering::Release);
Ok(CommitResult {
lsn: last_lsn,
durable: true,
})
}
Err(e) => {
self.last_commit_failed.store(true, Ordering::Release);
Err(e)
}
}
}
pub fn durable_lsn(&self) -> u64 {
self.durable_lsn.load(Ordering::Acquire)
}
}
impl Default for GroupCommitter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reader::WalReader;
use crate::record::RecordType;
use std::sync::Arc;
use std::thread;
#[test]
fn single_thread_group_commit() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.wal");
let writer = Mutex::new(WalWriter::open_without_direct_io(&path).unwrap());
let gc = GroupCommitter::new();
let result = gc
.submit(
&writer,
PendingWrite {
record_type: RecordType::Put as u16,
tenant_id: 1,
vshard_id: 0,
payload: b"hello".to_vec(),
},
)
.unwrap();
assert!(result.durable);
assert_eq!(result.lsn, 1);
assert_eq!(gc.durable_lsn(), 1);
let reader = WalReader::open(&path).unwrap();
let records: Vec<_> = reader
.records()
.collect::<crate::error::Result<_>>()
.unwrap();
assert_eq!(records.len(), 1);
assert_eq!(records[0].payload, b"hello");
}
#[test]
fn concurrent_group_commit() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.wal");
let writer = Arc::new(Mutex::new(
WalWriter::open_without_direct_io(&path).unwrap(),
));
let gc = Arc::new(GroupCommitter::new());
let mut handles = Vec::new();
for i in 0..10 {
let w = Arc::clone(&writer);
let g = Arc::clone(&gc);
handles.push(thread::spawn(move || {
let payload = format!("record-{i}");
let result = g
.submit(
&w,
PendingWrite {
record_type: RecordType::Put as u16,
tenant_id: 1,
vshard_id: 0,
payload: payload.into_bytes(),
},
)
.unwrap();
assert!(result.durable);
result.lsn
}));
}
let lsns: Vec<u64> = handles.into_iter().map(|h| h.join().unwrap()).collect();
assert!(lsns.iter().all(|l| *l > 0));
let reader = WalReader::open(&path).unwrap();
let records: Vec<_> = reader
.records()
.collect::<crate::error::Result<_>>()
.unwrap();
assert_eq!(records.len(), 10);
}
}