#![no_std]
#[cfg(any(test, feature = "std",),)]
extern crate std;
use core::{
ptr,
marker::{Send, Sync,},
sync::atomic::{self, AtomicPtr, AtomicBool, Ordering,},
};
pub struct SyncStack(AtomicPtr<SyncStackNode>,);
impl SyncStack {
pub const INIT: Self = SyncStack(AtomicPtr::new(ptr::null_mut(),),);
#[inline]
pub const fn new() -> Self { Self::INIT }
pub fn park<P,>(&self,) -> bool
where P: Park, {
let park = P::new();
let mut node = SyncStackNode {
used: AtomicBool::new(false,),
unpark: &mut move || park.unpark(),
rest: self.0.load(Ordering::Relaxed,),
};
if self.0.compare_and_swap(node.rest, &mut node, Ordering::AcqRel,) == node.rest {
while !node.used.load(Ordering::SeqCst,) {
P::park();
}
true
} else { false }
}
pub fn pop(&self,) -> bool {
let mut node_ptr = self.0.load(Ordering::Acquire,);
loop {
if node_ptr == ptr::null_mut() { return false }
let node = unsafe { &mut *node_ptr };
let rest = node.rest;
let new_node = AtomicPtr::new(self.0.compare_and_swap(node_ptr, rest, Ordering::Release,),);
atomic::fence(Ordering::Release,);
if new_node.load(Ordering::Relaxed,) == node_ptr {
atomic::fence(Ordering::Acquire,);
if !node.used.compare_and_swap(false, true, Ordering::Release,) {
atomic::fence(Ordering::SeqCst,);
unsafe { (*node.unpark)(); }
return true;
}
} else {
node_ptr = new_node.load(Ordering::Relaxed,);
}
}
}
}
struct SyncStackNode {
used: AtomicBool,
unpark: *mut dyn FnMut(),
rest: *mut Self,
}
pub unsafe trait Park: 'static + Send + Sync {
fn new() -> Self;
fn park();
fn unpark(&self,);
}
#[cfg(any(test, feature = "std",))]
unsafe impl Park for std::thread::Thread {
#[inline]
fn new() -> Self { std::thread::current() }
#[inline]
fn park() { std::thread::park() }
#[inline]
fn unpark(&self,) { self.unpark() }
}
#[cfg(test,)]
mod tests {
use super::*;
use std::{
sync::{Mutex, Arc,},
thread::{self, Thread,},
time::Duration,
};
#[test]
fn test_sync_stack_data_race() {
static STACK: SyncStack = SyncStack::new();
const THREADS_HALF: u64 = 1000;
const CHAOS: u64 = 10;
const CYCLES: u64 = 5;
const THREADS: u64 = THREADS_HALF + THREADS_HALF;
const SLEEP: u64 = 500;
let finished = Arc::new(Mutex::new(0,),);
for _ in 0..THREADS_HALF {
let finished1 = finished.clone();
thread::spawn(move || {
for _ in 0..CYCLES {
while !STACK.park::<Thread>() {};
for _ in 0..CHAOS { STACK.pop(); }
}
*finished1.lock().unwrap() += 1;
});
let finished1 = finished.clone();
thread::spawn(move || {
for _ in 0..CYCLES {
for _ in 0..CHAOS { STACK.pop(); }
while !STACK.park::<Thread>() {};
}
*finished1.lock().unwrap() += 1;
});
}
thread::sleep(Duration::from_millis(SLEEP,),);
loop {
let mut old_finished = 0;
while {
let finished = *finished.lock().unwrap();
let sleep = finished != THREADS
&& finished != old_finished;
old_finished = finished;
sleep
} {
thread::sleep(Duration::from_millis(SLEEP,),);
}
if !STACK.pop() { break }
}
assert_eq!(*finished.lock().unwrap(), THREADS,);
}
}