use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicU64, Ordering};
use std::sync::Arc;
use parking_lot::Mutex;
use tokio::sync::{oneshot, Semaphore};
use crate::batch::Batch;
use crate::error::{Error, Result};
use crate::stall::WriteStallController;
const MAX_CONCURRENT_COMMITS: usize = 8;
const DEQUEUE_BITS: u32 = 32;
pub trait CommitEnv: Send + Sync + 'static {
fn write(&self, batch: &Batch, seq_num: u64, sync: bool) -> Result<Batch>;
fn apply(&self, batch: &Batch) -> Result<()>;
fn check_background_error(&self) -> Result<()>;
}
struct CommitBatch {
seq_num: AtomicU64,
count: u32, applied: AtomicBool,
complete_tx: Mutex<Option<oneshot::Sender<Result<()>>>>,
}
impl CommitBatch {
fn new(count: u32) -> (Arc<Self>, oneshot::Receiver<Result<()>>) {
let (tx, rx) = oneshot::channel();
let commit = Arc::new(Self {
seq_num: AtomicU64::new(0),
count,
applied: AtomicBool::new(false),
complete_tx: Mutex::new(Some(tx)),
});
(commit, rx)
}
fn set_seq_num(&self, seq: u64) {
self.seq_num.store(seq, Ordering::Release);
}
fn get_seq_num(&self) -> u64 {
self.seq_num.load(Ordering::Acquire)
}
fn mark_applied(&self) {
self.applied.store(true, Ordering::Release);
}
fn is_applied(&self) -> bool {
self.applied.load(Ordering::Acquire)
}
fn complete(&self, result: Result<()>) {
let mut guard = self.complete_tx.lock();
if let Some(tx) = guard.take() {
let _ = tx.send(result);
}
}
}
struct CommitQueue {
head_tail: AtomicU64,
slots: [AtomicPtr<CommitBatch>; MAX_CONCURRENT_COMMITS],
}
impl CommitQueue {
fn new() -> Self {
Self {
head_tail: AtomicU64::new(0),
slots: std::array::from_fn(|_| AtomicPtr::new(std::ptr::null_mut())),
}
}
fn unpack(&self, ptrs: u64) -> (u32, u32) {
let head = (ptrs >> DEQUEUE_BITS) as u32;
let tail = ptrs as u32;
(head, tail)
}
fn pack(&self, head: u32, tail: u32) -> u64 {
((head as u64) << DEQUEUE_BITS) | (tail as u64)
}
fn enqueue(&self, batch: Arc<CommitBatch>) {
let ptrs = self.head_tail.load(Ordering::Acquire);
let (head, tail) = self.unpack(ptrs);
if tail.wrapping_add(MAX_CONCURRENT_COMMITS as u32) == head {
panic!("commit queue overflow - should not be reached");
}
let slot_idx = (head & (MAX_CONCURRENT_COMMITS as u32 - 1)) as usize;
let slot = &self.slots[slot_idx];
while !slot.load(Ordering::Acquire).is_null() {
std::hint::spin_loop();
}
let batch_ptr = Arc::into_raw(batch);
slot.store(batch_ptr as *mut CommitBatch, Ordering::Release);
self.head_tail.fetch_add(1 << DEQUEUE_BITS, Ordering::Release);
}
fn dequeue_applied(&self) -> Option<Arc<CommitBatch>> {
loop {
let ptrs = self.head_tail.load(Ordering::Acquire);
let (head, tail) = self.unpack(ptrs);
if tail == head {
return None;
}
let slot_idx = (tail & (MAX_CONCURRENT_COMMITS as u32 - 1)) as usize;
let slot = &self.slots[slot_idx];
let batch_ptr = slot.load(Ordering::Acquire);
if batch_ptr.is_null() {
return None;
}
let is_applied = unsafe { (*batch_ptr).is_applied() };
if !is_applied {
return None;
}
let new_ptrs = self.pack(head, tail.wrapping_add(1));
if self
.head_tail
.compare_exchange_weak(ptrs, new_ptrs, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
slot.store(std::ptr::null_mut(), Ordering::Release);
let batch = unsafe { Arc::from_raw(batch_ptr) };
return Some(batch);
}
}
}
}
pub(crate) struct CommitPipeline {
env: Arc<dyn CommitEnv>,
log_seq_num: AtomicU64,
visible_seq_num: Arc<AtomicU64>,
write_mutex: Mutex<()>,
pending: CommitQueue,
commit_sem: Arc<Semaphore>,
shutdown: AtomicBool,
write_stall: Arc<WriteStallController>,
}
impl CommitPipeline {
pub(crate) fn new(
env: Arc<dyn CommitEnv>,
visible_seq_num: Arc<AtomicU64>,
write_stall: Arc<WriteStallController>,
) -> Arc<Self> {
Arc::new(Self {
env,
log_seq_num: AtomicU64::new(1),
visible_seq_num,
write_mutex: Mutex::new(()),
pending: CommitQueue::new(),
commit_sem: Arc::new(Semaphore::new(MAX_CONCURRENT_COMMITS - 1)),
shutdown: AtomicBool::new(false),
write_stall,
})
}
pub(crate) fn set_seq_num(&self, seq_num: u64) {
if seq_num > 0 {
self.visible_seq_num.store(seq_num, Ordering::Release);
self.log_seq_num.store(seq_num + 1, Ordering::Release);
}
}
pub(crate) async fn commit(&self, mut batch: Batch, sync: bool) -> Result<()> {
if self.shutdown.load(Ordering::Acquire) {
return Err(Error::PipelineStall);
}
self.env.check_background_error()?;
if batch.is_empty() {
return Ok(());
}
self.write_stall.check().await?;
let _permit = self.commit_sem.acquire().await.map_err(|_| Error::PipelineStall)?;
let (commit_batch, complete_rx) = CommitBatch::new(batch.count());
let processed_batch = self.prepare(&mut batch, Arc::clone(&commit_batch), sync)?;
let apply_result = {
let env = Arc::clone(&self.env);
env.apply(&processed_batch)
};
commit_batch.mark_applied();
let apply_err = if let Err(ref e) = apply_result {
let err = Error::CommitFail(e.to_string());
commit_batch.complete(Err(err.clone()));
Some(err)
} else {
None
};
self.publish();
if let Some(err) = apply_err {
return Err(err);
}
complete_rx.await.map_err(|_| Error::PipelineStall)?
}
fn prepare(
&self,
batch: &mut Batch,
commit_batch: Arc<CommitBatch>,
sync: bool,
) -> Result<Batch> {
let _guard = self.write_mutex.lock();
let count = batch.count() as u64;
let seq_num = self.log_seq_num.fetch_add(count, Ordering::SeqCst);
commit_batch.set_seq_num(seq_num);
batch.set_starting_seq_num(seq_num);
self.pending.enqueue(commit_batch);
let processed_batch = self.env.write(batch, seq_num, sync)?;
Ok(processed_batch)
}
fn publish(&self) {
loop {
let dequeued = self.pending.dequeue_applied();
match dequeued {
Some(batch) => {
let new_visible = batch.get_seq_num() + batch.count as u64 - 1;
loop {
let current = self.visible_seq_num.load(Ordering::Acquire);
if new_visible <= current {
break;
}
if self
.visible_seq_num
.compare_exchange_weak(
current,
new_visible,
Ordering::Release,
Ordering::Relaxed,
)
.is_ok()
{
break;
}
}
batch.complete(Ok(()));
}
None => {
break;
}
}
}
}
pub(crate) fn get_visible_seq_num(&self) -> u64 {
self.visible_seq_num.load(Ordering::Acquire)
}
pub(crate) fn shutdown(&self) {
self.shutdown.store(true, Ordering::Release);
}
}
impl Drop for CommitPipeline {
fn drop(&mut self) {
self.shutdown();
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use test_log::test;
use super::*;
use crate::InternalKeyKind;
fn test_visible_seq_num() -> Arc<AtomicU64> {
Arc::new(AtomicU64::new(0))
}
struct MockStallProvider;
impl crate::stall::WriteStallCountProvider for MockStallProvider {
fn get_stall_counts(&self) -> crate::stall::StallCounts {
crate::stall::StallCounts {
immutable_memtables: 0,
l0_files: 0,
}
}
}
fn test_write_stall() -> Arc<crate::stall::WriteStallController> {
let provider: Arc<dyn crate::stall::WriteStallCountProvider> = Arc::new(MockStallProvider);
let thresholds = crate::stall::StallThresholds {
memtable_limit: 2,
l0_file_limit: 12,
};
Arc::new(crate::stall::WriteStallController::new(provider, thresholds))
}
struct MockEnv;
impl CommitEnv for MockEnv {
fn write(&self, batch: &Batch, _seq_num: u64, _sync: bool) -> Result<Batch> {
let mut new_batch = Batch::new(_seq_num);
for entry in batch.entries() {
new_batch.add_record(
entry.kind,
entry.key.clone(),
entry.value.clone(),
entry.timestamp,
)?;
}
Ok(new_batch)
}
fn apply(&self, _batch: &Batch) -> Result<()> {
Ok(())
}
fn check_background_error(&self) -> Result<()> {
Ok(())
}
}
#[test(tokio::test)]
async fn test_single_commit() {
let pipeline =
CommitPipeline::new(Arc::new(MockEnv), test_visible_seq_num(), test_write_stall());
let mut batch = Batch::new(0);
batch
.add_record(InternalKeyKind::Set, b"key1".to_vec(), Some(b"value1".to_vec()), 0)
.unwrap();
let result = pipeline.commit(batch, false).await;
assert!(result.is_ok(), "Single commit failed: {result:?}");
let visible = pipeline.get_visible_seq_num();
assert_eq!(
visible, 1,
"Expected visible=1 after one commit with count=1 (highest seq num used)"
);
pipeline.shutdown();
}
#[test(tokio::test(flavor = "multi_thread", worker_threads = 4))]
async fn test_sequential_commits() {
let pipeline =
CommitPipeline::new(Arc::new(MockEnv), test_visible_seq_num(), test_write_stall());
for i in 0..5 {
let mut batch = Batch::new(0);
batch
.add_record(
InternalKeyKind::Set,
format!("key{i}").into_bytes(),
Some(vec![1, 2, 3]),
i,
)
.unwrap();
let result = pipeline.commit(batch, false).await;
assert!(result.is_ok(), "Sequential commit {i} failed: {result:?}");
}
assert_eq!(pipeline.get_visible_seq_num(), 5);
pipeline.shutdown();
}
#[test(tokio::test(flavor = "multi_thread", worker_threads = 4))]
async fn test_concurrent_commits() {
let pipeline =
CommitPipeline::new(Arc::new(MockEnv), test_visible_seq_num(), test_write_stall());
let mut handles = vec![];
for i in 0..10 {
let pipeline = Arc::clone(&pipeline);
let handle = tokio::spawn(async move {
let mut batch = Batch::new(0);
batch
.add_record(
InternalKeyKind::Set,
format!("key{i}").into_bytes(),
Some(vec![1, 2, 3]),
i,
)
.unwrap();
pipeline.commit(batch, false).await
});
handles.push(handle);
}
for (i, handle) in handles.into_iter().enumerate() {
let result = handle.await.unwrap();
assert!(result.is_ok(), "Commit {i} failed: {result:?}");
}
let start = std::time::Instant::now();
while pipeline.get_visible_seq_num() < 10 && start.elapsed() < Duration::from_secs(5) {
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(pipeline.get_visible_seq_num(), 10, "Not all batches were published");
pipeline.shutdown();
}
struct DelayedMockEnv;
impl CommitEnv for DelayedMockEnv {
fn write(&self, batch: &Batch, _seq_num: u64, _sync: bool) -> Result<Batch> {
let start = std::time::Instant::now();
while start.elapsed() < Duration::from_micros(100) {
std::hint::spin_loop();
}
let mut new_batch = Batch::new(_seq_num);
for entry in batch.entries() {
new_batch.add_record(
entry.kind,
entry.key.clone(),
entry.value.clone(),
entry.timestamp,
)?;
}
Ok(new_batch)
}
fn apply(&self, _batch: &Batch) -> Result<()> {
let start = std::time::Instant::now();
while start.elapsed() < Duration::from_micros(50) {
std::hint::spin_loop();
}
Ok(())
}
fn check_background_error(&self) -> Result<()> {
Ok(())
}
}
#[test(tokio::test(flavor = "multi_thread"))]
async fn test_concurrent_commits_with_delays() {
let pipeline = CommitPipeline::new(
Arc::new(DelayedMockEnv),
test_visible_seq_num(),
test_write_stall(),
);
let mut handles = vec![];
for i in 0..5 {
let pipeline = Arc::clone(&pipeline);
let handle = tokio::spawn(async move {
let mut batch = Batch::new(0);
batch
.add_record(
InternalKeyKind::Set,
format!("key{i}").into_bytes(),
Some(vec![1, 2, 3]),
i,
)
.unwrap();
pipeline.commit(batch, false).await
});
handles.push(handle);
}
for handle in handles {
assert!(handle.await.unwrap().is_ok());
}
assert_eq!(pipeline.get_visible_seq_num(), 5);
pipeline.shutdown();
}
struct AlwaysFailApplyEnv;
impl CommitEnv for AlwaysFailApplyEnv {
fn write(&self, batch: &Batch, seq_num: u64, _sync: bool) -> Result<Batch> {
let mut new_batch = Batch::new(seq_num);
for entry in batch.entries() {
new_batch.add_record(
entry.kind,
entry.key.clone(),
entry.value.clone(),
entry.timestamp,
)?;
}
Ok(new_batch)
}
fn apply(&self, _batch: &Batch) -> Result<()> {
Err(Error::CommitFail("simulated apply failure".into()))
}
fn check_background_error(&self) -> Result<()> {
Ok(())
}
}
#[test(tokio::test)]
async fn test_queue_overflow_all_fail() {
let pipeline = CommitPipeline::new(
Arc::new(AlwaysFailApplyEnv),
test_visible_seq_num(),
test_write_stall(),
);
for i in 0..20 {
let mut batch = Batch::new(0);
batch
.add_record(
InternalKeyKind::Set,
format!("key{i}").into_bytes(),
Some(b"value".to_vec()),
0,
)
.unwrap();
let result = pipeline.commit(batch, false).await;
assert!(result.is_err(), "Expected error at iteration {i}");
}
pipeline.shutdown();
}
struct FailNTimesEnv {
call_count: std::sync::atomic::AtomicUsize,
fail_until: usize,
}
impl FailNTimesEnv {
fn new(fail_count: usize) -> Self {
Self {
call_count: std::sync::atomic::AtomicUsize::new(0),
fail_until: fail_count,
}
}
}
impl CommitEnv for FailNTimesEnv {
fn write(&self, batch: &Batch, seq_num: u64, _sync: bool) -> Result<Batch> {
let mut new_batch = Batch::new(seq_num);
for entry in batch.entries() {
new_batch.add_record(
entry.kind,
entry.key.clone(),
entry.value.clone(),
entry.timestamp,
)?;
}
Ok(new_batch)
}
fn apply(&self, _batch: &Batch) -> Result<()> {
let call_num = self.call_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if call_num < self.fail_until {
Err(Error::CommitFail("simulated failure".into()))
} else {
Ok(())
}
}
fn check_background_error(&self) -> Result<()> {
Ok(())
}
}
#[test(tokio::test)]
async fn test_queue_overflow_partial_fail() {
let fail_count = 10; let env = Arc::new(FailNTimesEnv::new(fail_count));
let pipeline = CommitPipeline::new(env, test_visible_seq_num(), test_write_stall());
for i in 0..20 {
let mut batch = Batch::new(0);
batch
.add_record(
InternalKeyKind::Set,
format!("key{i}").into_bytes(),
Some(b"value".to_vec()),
i as u64,
)
.unwrap();
let result = pipeline.commit(batch, false).await;
if i < fail_count {
assert!(result.is_err(), "Expected error at iteration {i}");
} else {
assert!(result.is_ok(), "Expected success at iteration {i}, got {result:?}");
}
}
let visible = pipeline.get_visible_seq_num();
assert_eq!(visible, 20, "Expected visible_seq_num=20");
pipeline.shutdown();
}
}