#![allow(unused_unsafe)]
use kovan::{Atomic, RetiredNode, pin, retire};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::thread;
use std::time::Duration;
#[repr(C)]
struct EraTestNode {
retired: RetiredNode,
value: u64,
freed: Arc<AtomicBool>,
}
impl EraTestNode {
fn new(value: u64, freed: Arc<AtomicBool>) -> *mut Self {
Box::into_raw(Box::new(Self {
retired: RetiredNode::new(),
value,
freed,
}))
}
}
impl Drop for EraTestNode {
fn drop(&mut self) {
self.freed.store(true, Ordering::SeqCst);
}
}
#[repr(C)]
struct DummyNode {
retired: RetiredNode,
}
impl DummyNode {
fn new() -> *mut Self {
Box::into_raw(Box::new(Self {
retired: RetiredNode::new(),
}))
}
}
fn advance_epoch_by(n: usize) {
let mut handles = vec![];
for _ in 0..4 {
let count = n * 128 / 4; handles.push(thread::spawn(move || {
for _ in 0..count {
let _guard = pin();
unsafe { retire(DummyNode::new()) };
}
}));
}
for h in handles {
h.join().unwrap();
}
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_protect_prevents_uaf_across_epochs() {
let freed_a = Arc::new(AtomicBool::new(false));
let freed_b = Arc::new(AtomicBool::new(false));
let shared = Arc::new(Atomic::new(EraTestNode::new(1, freed_a.clone())));
let shared1 = shared.clone();
let freed_b1 = freed_b.clone();
let reader_started = Arc::new(AtomicBool::new(false));
let reader_started1 = reader_started.clone();
let writer_done = Arc::new(AtomicBool::new(false));
let writer_done1 = writer_done.clone();
let reader = thread::spawn(move || {
let guard = pin();
let _ptr1 = shared1.load(Ordering::Acquire, &guard);
reader_started1.store(true, Ordering::Release);
while !writer_done1.load(Ordering::Acquire) {
thread::yield_now();
}
let ptr2 = shared1.load(Ordering::Acquire, &guard);
if let Some(node) = unsafe { ptr2.as_ref() } {
assert_eq!(node.value, 2, "should see the new node");
}
thread::sleep(Duration::from_millis(50));
assert!(
!freed_b1.load(Ordering::SeqCst),
"Node born in later epoch was freed while guard was held! (UAF)"
);
drop(guard);
});
while !reader_started.load(Ordering::Acquire) {
thread::yield_now();
}
advance_epoch_by(4);
let new_ptr = EraTestNode::new(2, freed_b.clone());
let swap_guard = pin();
let old = shared.swap(
unsafe { kovan::Shared::from_raw(new_ptr) },
Ordering::AcqRel,
&swap_guard,
);
if !old.is_null() {
unsafe { retire(old.as_raw()) };
}
writer_done.store(true, Ordering::Release);
reader.join().unwrap();
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_era_updates_across_many_loads() {
let drops = Arc::new(AtomicUsize::new(0));
let shared: Arc<Atomic<CountedNode>> = Arc::new(Atomic::null());
let shared1 = shared.clone();
let drops1 = drops.clone();
let stop = Arc::new(AtomicBool::new(false));
let stop1 = stop.clone();
let readers_ready = Arc::new(AtomicUsize::new(0));
let readers_ready1 = readers_ready.clone();
let mut readers = vec![];
for _ in 0..4 {
let shared2 = shared.clone();
let stop2 = stop.clone();
let ready = readers_ready.clone();
readers.push(thread::spawn(move || {
let mut loads = 0u64;
ready.fetch_add(1, Ordering::Release);
while !stop2.load(Ordering::Relaxed) {
let guard = pin();
let ptr = shared2.load(Ordering::Acquire, &guard);
if let Some(node) = unsafe { ptr.as_ref() } {
let _ = std::hint::black_box(node.value);
}
drop(guard);
loads += 1;
}
loads
}));
}
let writer = thread::spawn(move || {
while readers_ready1.load(Ordering::Acquire) < 4 {
core::hint::spin_loop();
}
for i in 0..2000 {
let node = Box::into_raw(Box::new(CountedNode {
retired: RetiredNode::new(),
value: i,
drop_count: drops1.clone(),
}));
let guard = pin();
let old = shared1.swap(
unsafe { kovan::Shared::from_raw(node) },
Ordering::AcqRel,
&guard,
);
if !old.is_null() {
unsafe { retire(old.as_raw()) };
}
}
stop1.store(true, Ordering::Release);
});
writer.join().unwrap();
let total_loads: u64 = readers.into_iter().map(|h| h.join().unwrap()).sum();
assert!(total_loads > 0, "readers should have done some loads");
assert!(
drops.load(Ordering::SeqCst) > 0,
"some nodes should be freed",
);
}
#[repr(C)]
struct CountedNode {
retired: RetiredNode,
value: usize,
drop_count: Arc<AtomicUsize>,
}
impl Drop for CountedNode {
fn drop(&mut self) {
self.drop_count.fetch_add(1, Ordering::SeqCst);
}
}