#![cfg(loom)]
use loom::sync::atomic::{fence, AtomicU64, Ordering};
use loom::sync::Arc;
use loom::thread;
const CAPACITY: u64 = 4;
const MASK: u64 = CAPACITY - 1;
struct RingModel {
next_seq: AtomicU64,
cursor: AtomicU64,
stamps: [AtomicU64; CAPACITY as usize],
}
impl RingModel {
fn new() -> Self {
RingModel {
next_seq: AtomicU64::new(0),
cursor: AtomicU64::new(u64::MAX), stamps: core::array::from_fn(|_| AtomicU64::new(0)),
}
}
fn slot_write(&self, seq: u64) {
let idx = (seq & MASK) as usize;
let writing = seq * 2 + 1;
let done = seq * 2 + 2;
self.stamps[idx].store(writing, Ordering::Relaxed);
fence(Ordering::Release);
self.stamps[idx].store(done, Ordering::Release);
}
fn stamp_load(&self, seq: u64) -> u64 {
let idx = (seq & MASK) as usize;
self.stamps[idx].load(Ordering::Acquire)
}
fn advance_cursor(&self, seq: u64) {
let expected_cursor = if seq == 0 { u64::MAX } else { seq - 1 };
if self
.cursor
.compare_exchange(expected_cursor, seq, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
self.catch_up_cursor(seq);
return;
}
if seq > 0 {
let pred_done = (seq - 1) * 2 + 2;
while self.stamp_load(seq - 1) < pred_done {
loom::thread::yield_now(); }
}
let _ = self.cursor.compare_exchange(
expected_cursor,
seq,
Ordering::Release,
Ordering::Relaxed,
);
if self.cursor.load(Ordering::Relaxed) == seq {
self.catch_up_cursor(seq);
}
}
fn catch_up_cursor(&self, mut seq: u64) {
loop {
let next = seq + 1;
if next >= self.next_seq.load(Ordering::Acquire) {
break;
}
let done_stamp = next * 2 + 2;
if self.stamp_load(next) < done_stamp {
break;
}
if self
.cursor
.compare_exchange(seq, next, Ordering::Release, Ordering::Relaxed)
.is_err()
{
break;
}
seq = next;
}
}
fn publish(&self) -> u64 {
let seq = self.next_seq.fetch_add(1, Ordering::AcqRel);
self.slot_write(seq);
self.advance_cursor(seq);
seq
}
}
#[test]
fn two_producers_no_gap_invariant() {
loom::model(|| {
let ring = Arc::new(RingModel::new());
let r1 = ring.clone();
let r2 = ring.clone();
let t1 = thread::spawn(move || r1.publish());
let t2 = thread::spawn(move || r2.publish());
t1.join().unwrap();
t2.join().unwrap();
let final_cursor = ring.cursor.load(Ordering::Acquire);
assert_ne!(
final_cursor,
u64::MAX,
"cursor was never advanced from sentinel"
);
for seq in 0..=final_cursor {
let done = seq * 2 + 2;
let stamp = ring.stamp_load(seq);
assert!(
stamp >= done,
"NoGap violated: cursor={final_cursor} but seq {seq} stamp={stamp} (need >= {done})"
);
}
});
}
#[test]
fn stamps_all_committed_after_completion() {
loom::model(|| {
let ring = Arc::new(RingModel::new());
let r1 = ring.clone();
let r2 = ring.clone();
let t1 = thread::spawn(move || r1.publish());
let t2 = thread::spawn(move || r2.publish());
t1.join().unwrap();
t2.join().unwrap();
let claimed = ring.next_seq.load(Ordering::Acquire);
assert_eq!(claimed, 2, "should have claimed 2 sequences");
for seq in 0..claimed {
let stamp = ring.stamp_load(seq);
let done = seq * 2 + 2;
assert!(
stamp >= done,
"seq {seq} stamp should be >= {done} (committed), got {stamp}"
);
assert!(
stamp % 2 == 0,
"seq {seq} stamp is odd ({stamp}), indicating incomplete write"
);
}
});
}
#[test]
fn consumer_safety_at_quiescence() {
loom::model(|| {
let ring = Arc::new(RingModel::new());
let r1 = ring.clone();
let r2 = ring.clone();
let t1 = thread::spawn(move || r1.publish());
let t2 = thread::spawn(move || r2.publish());
t1.join().unwrap();
t2.join().unwrap();
let cursor_val = ring.cursor.load(Ordering::Acquire);
if cursor_val != u64::MAX {
for seq in 0..=cursor_val {
let done_stamp = seq * 2 + 2;
let stamp = ring.stamp_load(seq);
assert!(
stamp >= done_stamp,
"cursor={cursor_val} but seq {seq} stamp={stamp} (expected >= {done_stamp})"
);
}
}
});
}
#[test]
fn minimum_cursor_progress() {
loom::model(|| {
let ring = Arc::new(RingModel::new());
let r1 = ring.clone();
let r2 = ring.clone();
let t1 = thread::spawn(move || r1.publish());
let t2 = thread::spawn(move || r2.publish());
t1.join().unwrap();
t2.join().unwrap();
let final_cursor = ring.cursor.load(Ordering::Acquire);
assert_ne!(
final_cursor,
u64::MAX,
"cursor stuck at sentinel after 2 publishes"
);
});
}
#[test]
fn sequence_numbers_are_unique() {
loom::model(|| {
let ring = Arc::new(RingModel::new());
let r1 = ring.clone();
let r2 = ring.clone();
let t1 = thread::spawn(move || r1.publish());
let t2 = thread::spawn(move || r2.publish());
let s1 = t1.join().unwrap();
let s2 = t2.join().unwrap();
assert_ne!(s1, s2, "sequences must be distinct");
assert!(s1 <= 1 && s2 <= 1, "sequences must be 0 or 1");
assert_eq!(s1 + s2, 1, "sequences must be {{0, 1}}, got {s1} and {s2}");
});
}
#[test]
fn catch_up_with_delayed_writer() {
loom::model(|| {
let ring = Arc::new(RingModel::new());
let r1 = ring.clone();
let r2 = ring.clone();
let t1 = thread::spawn(move || {
let seq = r1.next_seq.fetch_add(1, Ordering::AcqRel);
loom::thread::yield_now(); r1.slot_write(seq);
r1.advance_cursor(seq);
seq
});
let t2 = thread::spawn(move || {
let seq = r2.next_seq.fetch_add(1, Ordering::AcqRel);
r2.slot_write(seq);
r2.advance_cursor(seq);
seq
});
let s1 = t1.join().unwrap();
let s2 = t2.join().unwrap();
assert_ne!(s1, s2);
let final_cursor = ring.cursor.load(Ordering::Acquire);
assert_ne!(final_cursor, u64::MAX, "cursor stuck at sentinel");
for seq in 0..=final_cursor {
let done = seq * 2 + 2;
let stamp = ring.stamp_load(seq);
assert!(
stamp >= done,
"NoGap violated: cursor={final_cursor}, seq {seq} stamp={stamp} (need >= {done})"
);
}
});
}
#[test]
#[ignore]
fn consumer_single_check_during_publish() {
loom::model(|| {
let ring = Arc::new(RingModel::new());
let r1 = ring.clone();
let r2 = ring.clone();
let rc = ring.clone();
let t1 = thread::spawn(move || r1.publish());
let t2 = thread::spawn(move || r2.publish());
let consumer = thread::spawn(move || {
let cursor_val = rc.cursor.load(Ordering::Acquire);
if cursor_val != u64::MAX {
let done = cursor_val * 2 + 2;
let stamp = rc.stamp_load(cursor_val);
assert!(
stamp >= done,
"cursor={cursor_val} but its stamp={stamp} (expected >= {done})"
);
}
});
t1.join().unwrap();
t2.join().unwrap();
consumer.join().unwrap();
});
}