use loom::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use loom::sync::{Arc, RwLock};
use loom::thread;
struct LoomConcurrentGraph {
epoch: AtomicU64,
node_count: RwLock<usize>,
writing: AtomicBool,
}
impl LoomConcurrentGraph {
fn new() -> Self {
Self {
epoch: AtomicU64::new(0),
node_count: RwLock::new(0),
writing: AtomicBool::new(false),
}
}
fn snapshot_epoch(&self) -> u64 {
self.epoch.load(Ordering::Acquire)
}
fn read_node_count(&self) -> usize {
*self.node_count.read().unwrap()
}
fn try_begin_write(&self) -> bool {
self.writing
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
}
fn write_add_node(&self) {
let mut count = self.node_count.write().unwrap();
*count += 1;
}
fn end_write(&self) {
self.epoch.fetch_add(1, Ordering::AcqRel);
self.writing.store(false, Ordering::Release);
}
fn is_writing(&self) -> bool {
self.writing.load(Ordering::Acquire)
}
}
struct LoomWriteQueue {
enqueued: AtomicU64,
dequeued: AtomicU64,
closed: AtomicBool,
}
impl LoomWriteQueue {
fn new() -> Self {
Self {
enqueued: AtomicU64::new(0),
dequeued: AtomicU64::new(0),
closed: AtomicBool::new(false),
}
}
fn enqueue(&self) -> bool {
if self.closed.load(Ordering::Acquire) {
return false;
}
let _prev = self.enqueued.fetch_add(1, Ordering::AcqRel);
if self.closed.load(Ordering::Acquire) {
self.enqueued.fetch_sub(1, Ordering::AcqRel);
return false;
}
true
}
fn dequeue(&self) -> bool {
loop {
let enq = self.enqueued.load(Ordering::Acquire);
let deq = self.dequeued.load(Ordering::Acquire);
if deq >= enq {
return false;
}
if self
.dequeued
.compare_exchange_weak(deq, deq + 1, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
return true;
}
}
}
fn close(&self) {
self.closed.store(true, Ordering::Release);
}
fn enqueued(&self) -> u64 {
self.enqueued.load(Ordering::Acquire)
}
fn dequeued(&self) -> u64 {
self.dequeued.load(Ordering::Acquire)
}
fn in_flight(&self) -> u64 {
self.enqueued().saturating_sub(self.dequeued())
}
}
#[test]
fn test_cp5_single_writer_serialization() {
loom::model(|| {
let queue = Arc::new(LoomWriteQueue::new());
let q1 = Arc::clone(&queue);
let q2 = Arc::clone(&queue);
let t1 = thread::spawn(move || {
q1.enqueue();
q1.enqueue();
});
let t2 = thread::spawn(move || {
q2.enqueue();
});
t1.join().unwrap();
t2.join().unwrap();
assert_eq!(queue.enqueued(), 3, "All enqueues should be counted");
});
}
#[test]
fn test_cp5_operation_ordering() {
loom::model(|| {
let seq = Arc::new(AtomicU64::new(0));
let s1 = Arc::clone(&seq);
let s2 = Arc::clone(&seq);
let seq1 = Arc::new(AtomicU64::new(u64::MAX));
let seq2 = Arc::new(AtomicU64::new(u64::MAX));
let r1 = Arc::clone(&seq1);
let r2 = Arc::clone(&seq2);
let t1 = thread::spawn(move || {
let n = s1.fetch_add(1, Ordering::AcqRel);
r1.store(n, Ordering::Relaxed);
});
let t2 = thread::spawn(move || {
let n = s2.fetch_add(1, Ordering::AcqRel);
r2.store(n, Ordering::Relaxed);
});
t1.join().unwrap();
t2.join().unwrap();
let n1 = seq1.load(Ordering::Relaxed);
let n2 = seq2.load(Ordering::Relaxed);
assert_ne!(n1, n2, "Sequence numbers should be unique");
assert!(n1 < 2 && n2 < 2, "Sequences should be 0 or 1");
});
}
#[test]
fn test_cp6_epoch_consistency() {
loom::model(|| {
let graph = Arc::new(LoomConcurrentGraph::new());
let graph1 = Arc::clone(&graph);
let graph2 = Arc::clone(&graph);
let reader_epoch = Arc::new(AtomicU64::new(u64::MAX));
let reader_count = Arc::new(AtomicUsize::new(usize::MAX));
let re = Arc::clone(&reader_epoch);
let rc = Arc::clone(&reader_count);
let t1 = thread::spawn(move || {
if graph1.try_begin_write() {
graph1.write_add_node();
graph1.end_write();
}
});
let t2 = thread::spawn(move || {
let epoch = graph2.snapshot_epoch();
re.store(epoch, Ordering::Relaxed);
let count = graph2.read_node_count();
rc.store(count, Ordering::Relaxed);
});
t1.join().unwrap();
t2.join().unwrap();
let epoch = reader_epoch.load(Ordering::Relaxed);
let count = reader_count.load(Ordering::Relaxed);
assert!(
!(epoch == 1 && count == 0),
"Read-skew detected: epoch {} but count {}",
epoch,
count
);
assert!(epoch <= 1, "Epoch should be 0 or 1, got {}", epoch);
});
}
#[test]
fn test_cp6_concurrent_readers() {
loom::model(|| {
let graph = Arc::new(LoomConcurrentGraph::new());
{
graph.try_begin_write();
graph.write_add_node();
graph.write_add_node();
graph.end_write();
}
let graph1 = Arc::clone(&graph);
let graph2 = Arc::clone(&graph);
let graph3 = Arc::clone(&graph);
let epoch1 = Arc::new(AtomicU64::new(u64::MAX));
let epoch2 = Arc::new(AtomicU64::new(u64::MAX));
let count1 = Arc::new(AtomicUsize::new(usize::MAX));
let count2 = Arc::new(AtomicUsize::new(usize::MAX));
let e1 = Arc::clone(&epoch1);
let e2 = Arc::clone(&epoch2);
let c1 = Arc::clone(&count1);
let c2 = Arc::clone(&count2);
let t1 = thread::spawn(move || {
if graph1.try_begin_write() {
graph1.write_add_node();
graph1.end_write();
}
});
let t2 = thread::spawn(move || {
let epoch = graph2.snapshot_epoch();
e1.store(epoch, Ordering::Relaxed);
let count = graph2.read_node_count();
c1.store(count, Ordering::Relaxed);
});
let t3 = thread::spawn(move || {
let epoch = graph3.snapshot_epoch();
e2.store(epoch, Ordering::Relaxed);
let count = graph3.read_node_count();
c2.store(count, Ordering::Relaxed);
});
t1.join().unwrap();
t2.join().unwrap();
t3.join().unwrap();
let r1_epoch = epoch1.load(Ordering::Relaxed);
let r1_count = count1.load(Ordering::Relaxed);
let r2_epoch = epoch2.load(Ordering::Relaxed);
let r2_count = count2.load(Ordering::Relaxed);
assert!(
r1_epoch == 1 || r1_epoch == 2,
"Reader 1 epoch should be 1 or 2, got {}",
r1_epoch
);
assert!(
r2_epoch == 1 || r2_epoch == 2,
"Reader 2 epoch should be 1 or 2, got {}",
r2_epoch
);
assert!(
r1_count == 2 || r1_count == 3,
"Reader 1 should see valid count, got {}",
r1_count
);
assert!(
r2_count == 2 || r2_count == 3,
"Reader 2 should see valid count, got {}",
r2_count
);
assert!(
!(r1_epoch == 2 && r1_count == 2),
"Reader 1 read-skew: epoch {} but count {}",
r1_epoch,
r1_count
);
assert!(
!(r2_epoch == 2 && r2_count == 2),
"Reader 2 read-skew: epoch {} but count {}",
r2_epoch,
r2_count
);
});
}
#[test]
fn test_exclusive_writer_access() {
loom::model(|| {
let graph = Arc::new(LoomConcurrentGraph::new());
let graph1 = Arc::clone(&graph);
let graph2 = Arc::clone(&graph);
let active_writers = Arc::new(AtomicUsize::new(0));
let overlap_detected = Arc::new(AtomicBool::new(false));
let active1 = Arc::clone(&active_writers);
let active2 = Arc::clone(&active_writers);
let overlap1 = Arc::clone(&overlap_detected);
let overlap2 = Arc::clone(&overlap_detected);
let write_count = Arc::new(AtomicUsize::new(0));
let wc1 = Arc::clone(&write_count);
let wc2 = Arc::clone(&write_count);
let t1 = thread::spawn(move || {
if graph1.try_begin_write() {
let prev = active1.fetch_add(1, Ordering::AcqRel);
if prev > 0 {
overlap1.store(true, Ordering::Release);
}
wc1.fetch_add(1, Ordering::Relaxed);
graph1.write_add_node();
active1.fetch_sub(1, Ordering::AcqRel);
graph1.end_write();
}
});
let t2 = thread::spawn(move || {
if graph2.try_begin_write() {
let prev = active2.fetch_add(1, Ordering::AcqRel);
if prev > 0 {
overlap2.store(true, Ordering::Release);
}
wc2.fetch_add(1, Ordering::Relaxed);
graph2.write_add_node();
active2.fetch_sub(1, Ordering::AcqRel);
graph2.end_write();
}
});
t1.join().unwrap();
t2.join().unwrap();
assert!(
!overlap_detected.load(Ordering::Acquire),
"Mutual exclusion violated: concurrent writers detected"
);
let writes = write_count.load(Ordering::Relaxed);
assert!(writes <= 2, "At most 2 writes should succeed");
assert!(!graph.is_writing(), "Writing should be complete");
});
}
#[test]
fn test_epoch_monotonicity() {
loom::model(|| {
let graph = Arc::new(LoomConcurrentGraph::new());
let graph1 = Arc::clone(&graph);
let graph2 = Arc::clone(&graph);
let epochs = Arc::new(RwLock::new(Vec::new()));
let e2 = Arc::clone(&epochs);
let t1 = thread::spawn(move || {
if graph1.try_begin_write() {
graph1.write_add_node();
graph1.end_write();
}
});
let t2 = thread::spawn(move || {
let epoch1 = graph2.snapshot_epoch();
let epoch2 = graph2.snapshot_epoch();
let mut epochs = e2.write().unwrap();
epochs.push(epoch1);
epochs.push(epoch2);
});
t1.join().unwrap();
t2.join().unwrap();
let epochs = epochs.read().unwrap();
for i in 1..epochs.len() {
assert!(
epochs[i] >= epochs[i - 1],
"Epochs should be monotonically non-decreasing"
);
}
});
}
#[test]
fn test_queue_counter_tracking() {
loom::model(|| {
let queue = Arc::new(LoomWriteQueue::new());
let q1 = Arc::clone(&queue);
let t1 = thread::spawn(move || {
q1.enqueue();
q1.enqueue();
});
t1.join().unwrap();
assert!(queue.dequeue(), "First dequeue should succeed");
assert!(queue.dequeue(), "Second dequeue should succeed");
assert!(!queue.dequeue(), "Third dequeue should fail (empty)");
assert_eq!(queue.enqueued(), 2, "Enqueued count should be 2");
assert_eq!(queue.dequeued(), 2, "Dequeued count should be 2");
assert_eq!(queue.in_flight(), 0, "In-flight should be 0");
});
}
#[test]
fn test_queue_closure() {
loom::model(|| {
let queue = Arc::new(LoomWriteQueue::new());
let q1 = Arc::clone(&queue);
let q2 = Arc::clone(&queue);
let close_completed = Arc::new(AtomicBool::new(false));
let cc = Arc::clone(&close_completed);
let t1 = thread::spawn(move || {
q1.enqueue();
q1.close();
cc.store(true, Ordering::Release);
});
let t2 = thread::spawn(move || q2.enqueue());
t1.join().unwrap();
let t2_result = t2.join().unwrap();
assert!(!queue.enqueue(), "Enqueue after close should fail");
let total = queue.enqueued();
if t2_result {
assert_eq!(total, 2, "If t2 succeeded, total should be 2");
} else {
assert_eq!(total, 1, "If t2 failed, total should be 1");
}
assert!(
total >= 1 && total <= 2,
"Total enqueued should be 1 or 2, got {}",
total
);
});
}
#[test]
fn test_concurrent_enqueue_dequeue() {
loom::model(|| {
let queue = Arc::new(LoomWriteQueue::new());
queue.enqueue();
queue.enqueue();
let q1 = Arc::clone(&queue);
let q2 = Arc::clone(&queue);
let dequeued = Arc::new(AtomicUsize::new(0));
let d1 = Arc::clone(&dequeued);
let d2 = Arc::clone(&dequeued);
let t1 = thread::spawn(move || {
if q1.dequeue() {
d1.fetch_add(1, Ordering::Relaxed);
}
});
let t2 = thread::spawn(move || {
if q2.dequeue() {
d2.fetch_add(1, Ordering::Relaxed);
}
});
t1.join().unwrap();
t2.join().unwrap();
let total_dequeued = dequeued.load(Ordering::Relaxed);
assert_eq!(total_dequeued, 2, "Both dequeues should succeed");
});
}