use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
use fsqlite_types::CommitSeq;
use fsqlite_types::sync_primitives::{Instant, Mutex};
pub const MAX_COMMIT_THREADS: usize = 64;
const SLOT_EMPTY: u8 = 0;
const SLOT_PENDING: u8 = 1;
const SLOT_DONE: u8 = 2;
const SPIN_BEFORE_YIELD: u32 = 1024;
static COMMIT_COMBINE_BATCHES: AtomicU64 = AtomicU64::new(0);
static COMMIT_COMBINE_OPS: AtomicU64 = AtomicU64::new(0);
static COMMIT_COMBINE_BATCH_SIZE_SUM: AtomicU64 = AtomicU64::new(0);
static COMMIT_COMBINE_BATCH_SIZE_MAX: AtomicU64 = AtomicU64::new(0);
static COMMIT_COMBINE_WAIT_NS_TOTAL: AtomicU64 = AtomicU64::new(0);
static COMMIT_COMBINE_WAIT_NS_MAX: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
pub struct CommitCombineMetrics {
pub batches_total: u64,
pub ops_total: u64,
pub batch_size_sum: u64,
pub batch_size_max: u64,
pub wait_ns_total: u64,
pub wait_ns_max: u64,
}
#[must_use]
pub fn commit_combine_metrics() -> CommitCombineMetrics {
CommitCombineMetrics {
batches_total: COMMIT_COMBINE_BATCHES.load(Ordering::Relaxed),
ops_total: COMMIT_COMBINE_OPS.load(Ordering::Relaxed),
batch_size_sum: COMMIT_COMBINE_BATCH_SIZE_SUM.load(Ordering::Relaxed),
batch_size_max: COMMIT_COMBINE_BATCH_SIZE_MAX.load(Ordering::Relaxed),
wait_ns_total: COMMIT_COMBINE_WAIT_NS_TOTAL.load(Ordering::Relaxed),
wait_ns_max: COMMIT_COMBINE_WAIT_NS_MAX.load(Ordering::Relaxed),
}
}
pub fn reset_commit_combine_metrics() {
COMMIT_COMBINE_BATCHES.store(0, Ordering::Relaxed);
COMMIT_COMBINE_OPS.store(0, Ordering::Relaxed);
COMMIT_COMBINE_BATCH_SIZE_SUM.store(0, Ordering::Relaxed);
COMMIT_COMBINE_BATCH_SIZE_MAX.store(0, Ordering::Relaxed);
COMMIT_COMBINE_WAIT_NS_TOTAL.store(0, Ordering::Relaxed);
COMMIT_COMBINE_WAIT_NS_MAX.store(0, Ordering::Relaxed);
}
fn update_max(metric: &AtomicU64, val: u64) {
let mut prev = metric.load(Ordering::Relaxed);
while val > prev {
match metric.compare_exchange_weak(prev, val, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => break,
Err(actual) => prev = actual,
}
}
}
#[repr(align(64))]
struct CommitSlot {
state: AtomicU8,
_pad1: [u8; 7],
result: AtomicU64,
_pad2: [u8; 48],
}
impl CommitSlot {
const fn new() -> Self {
Self {
state: AtomicU8::new(SLOT_EMPTY),
_pad1: [0; 7],
result: AtomicU64::new(0),
_pad2: [0; 48],
}
}
}
pub struct CommitSequenceCombiner {
next_commit_seq: AtomicU64,
slots: [CommitSlot; MAX_COMMIT_THREADS],
owners: [AtomicU64; MAX_COMMIT_THREADS],
combiner_lock: Mutex<()>,
}
impl CommitSequenceCombiner {
pub fn new(initial_commit_seq: u64) -> Self {
Self {
next_commit_seq: AtomicU64::new(initial_commit_seq),
slots: std::array::from_fn(|_| CommitSlot::new()),
owners: std::array::from_fn(|_| AtomicU64::new(0)),
combiner_lock: Mutex::new(()),
}
}
pub fn register(&self) -> Option<CommitCombineHandle<'_>> {
let tid = thread_id_hash();
for i in 0..MAX_COMMIT_THREADS {
if self.owners[i]
.compare_exchange(0, tid, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
return Some(CommitCombineHandle {
combiner: self,
slot: i,
});
}
}
None
}
#[must_use]
pub fn next_seq(&self) -> u64 {
self.next_commit_seq.load(Ordering::Relaxed)
}
#[must_use]
pub fn active_threads(&self) -> usize {
self.owners
.iter()
.filter(|o| o.load(Ordering::Relaxed) != 0)
.count()
}
fn combine_locked(&self) {
let mut pending_count = 0u64;
let mut pending_slots = [false; MAX_COMMIT_THREADS];
for (slot, is_pending) in self.slots.iter().zip(pending_slots.iter_mut()) {
let state = slot.state.load(Ordering::Acquire);
if state == SLOT_PENDING {
*is_pending = true;
pending_count += 1;
}
}
if pending_count == 0 {
return;
}
let base_seq = self
.next_commit_seq
.fetch_add(pending_count, Ordering::AcqRel);
let mut assigned = 0u64;
for (slot, is_pending) in self.slots.iter().zip(pending_slots.iter()) {
if *is_pending {
let seq = base_seq + assigned;
assigned += 1;
slot.result.store(seq, Ordering::Release);
slot.state.store(SLOT_DONE, Ordering::Release);
}
}
debug_assert_eq!(assigned, pending_count);
COMMIT_COMBINE_BATCHES.fetch_add(1, Ordering::Relaxed);
COMMIT_COMBINE_OPS.fetch_add(pending_count, Ordering::Relaxed);
COMMIT_COMBINE_BATCH_SIZE_SUM.fetch_add(pending_count, Ordering::Relaxed);
update_max(&COMMIT_COMBINE_BATCH_SIZE_MAX, pending_count);
tracing::debug!(
target: "fsqlite.commit_combine",
batch_size = pending_count,
base_seq,
"commit_combine_batch"
);
}
}
#[allow(clippy::missing_fields_in_debug)]
impl std::fmt::Debug for CommitSequenceCombiner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CommitSequenceCombiner")
.field("next_seq", &self.next_seq())
.field("active_threads", &self.active_threads())
.finish_non_exhaustive()
}
}
pub struct CommitCombineHandle<'a> {
combiner: &'a CommitSequenceCombiner,
slot: usize,
}
impl CommitCombineHandle<'_> {
pub fn alloc_commit_seq(&self) -> CommitSeq {
let start = Instant::now();
self.combiner.slots[self.slot]
.state
.store(SLOT_PENDING, Ordering::Release);
if let Some(_guard) = self.combiner.combiner_lock.try_lock() {
self.combiner.combine_locked();
}
let mut spins = 0u32;
loop {
let state = self.combiner.slots[self.slot].state.load(Ordering::Acquire);
if state == SLOT_DONE {
let seq = self.combiner.slots[self.slot]
.result
.load(Ordering::Acquire);
self.combiner.slots[self.slot]
.state
.store(SLOT_EMPTY, Ordering::Release);
#[allow(clippy::cast_possible_truncation)]
let elapsed_ns = start.elapsed().as_nanos() as u64;
COMMIT_COMBINE_WAIT_NS_TOTAL.fetch_add(elapsed_ns, Ordering::Relaxed);
update_max(&COMMIT_COMBINE_WAIT_NS_MAX, elapsed_ns);
return CommitSeq::new(seq);
}
spins += 1;
if spins < SPIN_BEFORE_YIELD {
std::hint::spin_loop();
} else {
if let Some(_guard) = self.combiner.combiner_lock.try_lock() {
self.combiner.combine_locked();
} else {
std::thread::yield_now();
}
spins = 0;
}
}
}
#[must_use]
pub fn slot(&self) -> usize {
self.slot
}
}
impl Drop for CommitCombineHandle<'_> {
fn drop(&mut self) {
self.combiner.slots[self.slot]
.state
.store(SLOT_EMPTY, Ordering::Release);
self.combiner.owners[self.slot].store(0, Ordering::Release);
}
}
fn thread_id_hash() -> u64 {
let t = std::thread::current().id();
let s = format!("{t:?}");
let mut h = 1u64;
for b in s.bytes() {
h = h.wrapping_mul(31).wrapping_add(u64::from(b));
}
if h == 0 { 1 } else { h }
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Barrier};
use std::thread;
#[test]
fn test_combiner_single_thread() {
let combiner = CommitSequenceCombiner::new(100);
let handle = combiner.register().unwrap();
let seq1 = handle.alloc_commit_seq();
assert_eq!(seq1.get(), 100);
let seq2 = handle.alloc_commit_seq();
assert_eq!(seq2.get(), 101);
let seq3 = handle.alloc_commit_seq();
assert_eq!(seq3.get(), 102);
drop(handle);
assert_eq!(combiner.next_seq(), 103);
}
#[test]
fn test_combiner_8t_all_commits_succeed() {
let combiner = Arc::new(CommitSequenceCombiner::new(1000));
let barrier = Arc::new(Barrier::new(8));
let mut handles = Vec::new();
for _ in 0..8 {
let c = Arc::clone(&combiner);
let b = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
let h = c.register().unwrap();
b.wait();
let mut seqs = Vec::new();
for _ in 0..100 {
seqs.push(h.alloc_commit_seq().get());
}
drop(h);
seqs
}));
}
let mut all_seqs = Vec::new();
for h in handles {
all_seqs.extend(h.join().unwrap());
}
all_seqs.sort();
let unique_count = all_seqs.len();
all_seqs.dedup();
assert_eq!(
all_seqs.len(),
unique_count,
"all commit sequences must be unique"
);
assert_eq!(all_seqs.len(), 800);
assert!(all_seqs.iter().all(|&s| s >= 1000 && s < 1800));
assert_eq!(combiner.next_seq(), 1800);
}
#[test]
fn test_combiner_16t_throughput() {
let combiner = Arc::new(CommitSequenceCombiner::new(0));
let barrier = Arc::new(Barrier::new(16));
let mut handles = Vec::new();
let start = Instant::now();
for _ in 0..16 {
let c = Arc::clone(&combiner);
let b = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
let h = c.register().unwrap();
b.wait();
for _ in 0..1000 {
h.alloc_commit_seq();
}
drop(h);
}));
}
for h in handles {
h.join().unwrap();
}
let elapsed = start.elapsed();
assert_eq!(combiner.next_seq(), 16000);
assert!(
elapsed.as_millis() < 1000,
"16000 commits took {}ms, expected < 1000ms",
elapsed.as_millis()
);
}
#[test]
fn test_combiner_cache_line_padding() {
assert_eq!(
std::mem::align_of::<CommitSlot>(),
64,
"CommitSlot must be 64-byte aligned"
);
assert_eq!(
std::mem::size_of::<CommitSlot>(),
64,
"CommitSlot must be exactly 64 bytes"
);
}
#[test]
fn test_combiner_batch_size_varies() {
let combiner = Arc::new(CommitSequenceCombiner::new(0));
{
let h = combiner.register().unwrap();
h.alloc_commit_seq();
drop(h);
}
assert_eq!(combiner.next_seq(), 1);
{
let barrier = Arc::new(Barrier::new(4));
let mut handles = Vec::new();
for _ in 0..4 {
let c = Arc::clone(&combiner);
let b = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
let h = c.register().unwrap();
b.wait();
h.alloc_commit_seq();
drop(h);
}));
}
for h in handles {
h.join().unwrap();
}
}
assert_eq!(combiner.next_seq(), 5);
}
#[test]
fn test_combiner_fairness() {
let combiner = Arc::new(CommitSequenceCombiner::new(0));
let barrier = Arc::new(Barrier::new(8));
let mut handles = Vec::new();
for tid in 0..8u64 {
let c = Arc::clone(&combiner);
let b = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
let h = c.register().unwrap();
b.wait();
let start = Instant::now();
let mut max_wait_ns = 0u64;
for _ in 0..50 {
let op_start = Instant::now();
h.alloc_commit_seq();
#[allow(clippy::cast_possible_truncation)]
let wait = op_start.elapsed().as_nanos() as u64;
max_wait_ns = max_wait_ns.max(wait);
}
let total = start.elapsed();
drop(h);
(tid, max_wait_ns, total)
}));
}
for h in handles {
let (tid, max_wait_ns, total) = h.join().unwrap();
assert!(
max_wait_ns < 10_000_000,
"thread {tid} max wait {max_wait_ns}ns > 10ms"
);
assert!(
total.as_millis() < 500,
"thread {tid} total time {}ms > 500ms",
total.as_millis()
);
}
}
}