use loom::sync::Arc;
use loom::sync::atomic::{AtomicUsize, Ordering};
use loom::thread;
#[derive(Debug)]
struct LoomSharedBufferState {
committed_bytes: AtomicUsize,
committed_ops: AtomicUsize,
reserved_bytes: AtomicUsize,
reserved_ops: AtomicUsize,
active_guards: AtomicUsize,
}
impl LoomSharedBufferState {
fn new() -> Self {
Self {
committed_bytes: AtomicUsize::new(0),
committed_ops: AtomicUsize::new(0),
reserved_bytes: AtomicUsize::new(0),
reserved_ops: AtomicUsize::new(0),
active_guards: AtomicUsize::new(0),
}
}
fn committed_bytes(&self) -> usize {
self.committed_bytes.load(Ordering::Acquire)
}
fn reserved_bytes(&self) -> usize {
self.reserved_bytes.load(Ordering::Acquire)
}
fn active_guards(&self) -> usize {
self.active_guards.load(Ordering::Acquire)
}
#[allow(dead_code)]
fn total_bytes(&self) -> usize {
self.committed_bytes() + self.reserved_bytes()
}
fn increment_active_guards(&self) {
self.active_guards.fetch_add(1, Ordering::AcqRel);
}
fn decrement_active_guards(&self) {
let prev = self.active_guards.fetch_sub(1, Ordering::AcqRel);
assert!(prev > 0, "active_guards underflow: decrement from 0");
}
fn add_reserved(&self, bytes: usize, ops: usize) {
self.reserved_bytes.fetch_add(bytes, Ordering::AcqRel);
self.reserved_ops.fetch_add(ops, Ordering::AcqRel);
}
fn sub_reserved(&self, bytes: usize, ops: usize) {
let prev_bytes = self.reserved_bytes.fetch_sub(bytes, Ordering::AcqRel);
let prev_ops = self.reserved_ops.fetch_sub(ops, Ordering::AcqRel);
assert!(
prev_bytes >= bytes,
"reserved_bytes underflow: {} < {}",
prev_bytes,
bytes
);
assert!(
prev_ops >= ops,
"reserved_ops underflow: {} < {}",
prev_ops,
ops
);
}
fn transfer_reserved_to_committed(
&self,
reserved_bytes: usize,
reserved_ops: usize,
actual_bytes: usize,
actual_ops: usize,
) {
assert!(
actual_bytes <= reserved_bytes,
"actual_bytes {} exceeds reserved_bytes {}",
actual_bytes,
reserved_bytes
);
assert!(
actual_ops <= reserved_ops,
"actual_ops {} exceeds reserved_ops {}",
actual_ops,
reserved_ops
);
let prev_bytes = self
.reserved_bytes
.fetch_sub(reserved_bytes, Ordering::AcqRel);
let prev_ops = self.reserved_ops.fetch_sub(reserved_ops, Ordering::AcqRel);
assert!(
prev_bytes >= reserved_bytes,
"reserved_bytes underflow in transfer"
);
assert!(
prev_ops >= reserved_ops,
"reserved_ops underflow in transfer"
);
self.committed_bytes
.fetch_add(actual_bytes, Ordering::AcqRel);
self.committed_ops.fetch_add(actual_ops, Ordering::AcqRel);
}
fn sub_committed(&self, bytes: usize, ops: usize) {
let prev_bytes = self.committed_bytes.fetch_sub(bytes, Ordering::AcqRel);
let prev_ops = self.committed_ops.fetch_sub(ops, Ordering::AcqRel);
assert!(
prev_bytes >= bytes,
"committed_bytes underflow: {} < {}",
prev_bytes,
bytes
);
assert!(
prev_ops >= ops,
"committed_ops underflow: {} < {}",
prev_ops,
ops
);
}
fn try_reset_to_zero(&self) -> bool {
let active = self.active_guards.load(Ordering::Acquire);
if active > 0 {
return false;
}
self.committed_bytes.store(0, Ordering::Release);
self.committed_ops.store(0, Ordering::Release);
self.reserved_bytes.store(0, Ordering::Release);
self.reserved_ops.store(0, Ordering::Release);
true
}
}
fn try_reserve_with_limit(
state: &LoomSharedBufferState,
max_bytes: usize,
bytes: usize,
ops: usize,
) -> bool {
loop {
let current_reserved = state.reserved_bytes.load(Ordering::Acquire);
let current_committed = state.committed_bytes.load(Ordering::Acquire);
let total = current_reserved + current_committed;
if total + bytes > max_bytes {
return false;
}
if state
.reserved_bytes
.compare_exchange_weak(
current_reserved,
current_reserved + bytes,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
state.reserved_ops.fetch_add(ops, Ordering::AcqRel);
return true;
}
}
}
fn spawn_reservation_thread(
state: Arc<LoomSharedBufferState>,
max_bytes: usize,
bytes: usize,
ops: usize,
success_count: Arc<AtomicUsize>,
) -> loom::thread::JoinHandle<()> {
thread::spawn(move || {
if try_reserve_with_limit(&state, max_bytes, bytes, ops) {
success_count.fetch_add(1, Ordering::Relaxed);
}
})
}
fn join_threads(threads: [loom::thread::JoinHandle<()>; 2]) {
for thread in threads {
thread.join().unwrap();
}
}
#[test]
fn test_cp16_reserve_commit_race() {
loom::model(|| {
let state = Arc::new(LoomSharedBufferState::new());
let state1 = Arc::clone(&state);
let state2 = Arc::clone(&state);
let t1 = thread::spawn(move || {
state1.increment_active_guards();
state1.add_reserved(100, 1);
state1.transfer_reserved_to_committed(100, 1, 80, 1);
state1.decrement_active_guards();
});
let t2 = thread::spawn(move || {
state2.increment_active_guards();
state2.add_reserved(200, 2);
state2.transfer_reserved_to_committed(200, 2, 150, 2);
state2.decrement_active_guards();
});
join_threads([t1, t2]);
assert_eq!(state.active_guards(), 0);
assert_eq!(state.reserved_bytes(), 0);
assert_eq!(state.committed_bytes(), 230); });
}
#[test]
fn test_cp17_concurrent_reservation_limits() {
loom::model(|| {
let state = Arc::new(LoomSharedBufferState::new());
let max_bytes: usize = 150;
let success_count = Arc::new(AtomicUsize::new(0));
let t1 = spawn_reservation_thread(
Arc::clone(&state),
max_bytes,
100,
1,
Arc::clone(&success_count),
);
let t2 = spawn_reservation_thread(
Arc::clone(&state),
max_bytes,
100,
1,
Arc::clone(&success_count),
);
join_threads([t1, t2]);
let total = state.reserved_bytes();
let successes = success_count.load(Ordering::Relaxed);
assert!(
total <= max_bytes,
"Limit violated: total {} > max {}",
total,
max_bytes
);
assert!(
successes <= 1,
"At most one reservation should succeed, got {}",
successes
);
});
}
#[test]
fn test_cp18_guard_underflow_protection() {
loom::model(|| {
let state = Arc::new(LoomSharedBufferState::new());
let state1 = Arc::clone(&state);
let state2 = Arc::clone(&state);
state.increment_active_guards();
state.increment_active_guards();
let t1 = thread::spawn(move || {
state1.decrement_active_guards();
});
let t2 = thread::spawn(move || {
state2.decrement_active_guards();
});
join_threads([t1, t2]);
assert_eq!(state.active_guards(), 0);
});
}
#[test]
fn test_cp19_reset_with_active_guards() {
loom::model(|| {
let state = Arc::new(LoomSharedBufferState::new());
state.add_reserved(100, 1);
state.increment_active_guards();
let state1 = Arc::clone(&state);
let state2 = Arc::clone(&state);
let t1_reset_succeeded = Arc::new(AtomicUsize::new(0));
let reset1 = Arc::clone(&t1_reset_succeeded);
let t1 = thread::spawn(move || {
if state1.try_reset_to_zero() {
reset1.fetch_add(1, Ordering::Relaxed);
}
});
let t2 = thread::spawn(move || {
state2.decrement_active_guards();
let _ = state2.try_reset_to_zero();
});
join_threads([t1, t2]);
assert_eq!(state.active_guards(), 0, "Guard should be released");
});
}
#[test]
fn test_concurrent_reserve_abort() {
loom::model(|| {
let state = Arc::new(LoomSharedBufferState::new());
state.add_reserved(200, 2);
let state1 = Arc::clone(&state);
let state2 = Arc::clone(&state);
let t1 = thread::spawn(move || {
state1.sub_reserved(100, 1);
});
let t2 = thread::spawn(move || {
state2.sub_reserved(100, 1);
});
join_threads([t1, t2]);
assert_eq!(state.reserved_bytes(), 0);
});
}
#[test]
fn test_concurrent_commit_compaction() {
loom::model(|| {
let state = Arc::new(LoomSharedBufferState::new());
state.add_reserved(200, 2);
state.transfer_reserved_to_committed(200, 2, 200, 2);
let state1 = Arc::clone(&state);
let state2 = Arc::clone(&state);
let t1 = thread::spawn(move || {
state1.sub_committed(100, 1);
});
let t2 = thread::spawn(move || {
state2.sub_committed(100, 1);
});
join_threads([t1, t2]);
assert_eq!(state.committed_bytes(), 0);
});
}