use std::{
sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering::*},
sync::{Arc, Mutex},
thread,
};
pub(crate) struct RecvWaiter {
pub(crate) case_id: usize,
pub(crate) selected: Arc<AtomicUsize>,
pub(crate) thread: thread::Thread,
}
impl RecvWaiter {
pub(crate) fn new(case_id: usize, selected: Arc<AtomicUsize>) -> Arc<Self> {
Arc::new(RecvWaiter {
case_id,
selected,
thread: thread::current(),
})
}
}
pub(crate) type RecvWaiterList = Arc<Mutex<Vec<Arc<RecvWaiter>>>>;
pub(crate) fn new_recv_waiter_list() -> RecvWaiterList {
Arc::new(Mutex::new(Vec::new()))
}
pub(crate) struct RecvWaiterGuard {
waiter: Arc<RecvWaiter>,
list: RecvWaiterList,
}
impl RecvWaiterGuard {
pub(crate) fn register(waiter: Arc<RecvWaiter>, list: &RecvWaiterList) -> Self {
list.lock().unwrap().push(Arc::clone(&waiter));
RecvWaiterGuard {
waiter,
list: Arc::clone(list),
}
}
}
impl Drop for RecvWaiterGuard {
fn drop(&mut self) {
let ptr = Arc::as_ptr(&self.waiter);
let mut guard = self.list.lock().unwrap();
guard.retain(|w| Arc::as_ptr(w) != ptr);
}
}
pub(crate) fn wake_one_recv_waiter(list: &RecvWaiterList, unselected: usize) -> bool {
let waiters: Vec<Arc<RecvWaiter>> = list.lock().unwrap().clone();
for waiter in waiters {
if waiter
.selected
.compare_exchange(unselected, waiter.case_id, SeqCst, SeqCst)
.is_ok()
{
waiter.thread.unpark();
return true;
}
}
false
}
pub(crate) fn wake_all_recv_waiters(list: &RecvWaiterList, unselected: usize) {
let waiters: Vec<Arc<RecvWaiter>> = list.lock().unwrap().clone();
for waiter in waiters {
waiter
.selected
.compare_exchange(unselected, waiter.case_id, SeqCst, SeqCst)
.ok();
waiter.thread.unpark();
}
}
pub(crate) fn wake_all_unselected_recv_waiters(list: &RecvWaiterList) {
let waiters: Vec<Arc<RecvWaiter>> = list.lock().unwrap().clone();
for waiter in waiters {
if waiter.selected.load(Acquire) == UNSELECTED {
waiter.thread.unpark();
}
}
}
pub(crate) const UNSELECTED: usize = usize::MAX;
pub(crate) struct SelectWaiter {
pub(crate) case_id: usize,
pub(crate) selected: Arc<AtomicUsize>,
pub(crate) thread: thread::Thread,
pub(crate) next: AtomicPtr<SelectWaiter>,
pub(crate) aborted: AtomicBool,
}
impl SelectWaiter {
pub(crate) fn alloc(case_id: usize, selected: Arc<AtomicUsize>) -> *mut SelectWaiter {
Box::into_raw(Box::new(SelectWaiter {
case_id,
selected,
thread: thread::current(),
next: AtomicPtr::new(std::ptr::null_mut()),
aborted: AtomicBool::new(false),
}))
}
}
pub(crate) fn push_select_waiter(ptr: *mut SelectWaiter, stack: &Arc<AtomicPtr<SelectWaiter>>) {
loop {
let head = stack.load(Acquire);
unsafe { (*ptr).next.store(head, Relaxed) };
if stack.compare_exchange(head, ptr, AcqRel, Acquire).is_ok() {
return;
}
}
}
pub(crate) fn abort_select_waiters(
stack: &Arc<AtomicPtr<SelectWaiter>>,
selected: &Arc<AtomicUsize>,
) {
let mut current = stack.load(Acquire);
while !current.is_null() {
let node = unsafe { &*current };
if Arc::ptr_eq(&node.selected, selected) {
node.aborted.store(true, Release);
}
current = node.next.load(Acquire);
}
}
pub(crate) fn wake_select_one(stack: &Arc<AtomicPtr<SelectWaiter>>) -> bool {
let head = stack.swap(std::ptr::null_mut(), AcqRel);
let mut current = head;
let mut winner_found = false;
while !current.is_null() {
let node_box = unsafe { Box::from_raw(current) };
current = node_box.next.load(Acquire);
if node_box.aborted.load(Acquire) {
continue;
}
if !winner_found {
if node_box
.selected
.compare_exchange(UNSELECTED, node_box.case_id, SeqCst, SeqCst)
.is_ok()
{
node_box.thread.unpark();
winner_found = true;
continue;
}
}
}
winner_found
}
pub(crate) fn wake_select_all(stack: &Arc<AtomicPtr<SelectWaiter>>) {
let head = stack.swap(std::ptr::null_mut(), AcqRel);
let mut current = head;
while !current.is_null() {
let node_box = unsafe { Box::from_raw(current) };
current = node_box.next.load(Acquire);
if node_box.aborted.load(Acquire) {
continue;
}
node_box
.selected
.compare_exchange(UNSELECTED, node_box.case_id, SeqCst, SeqCst)
.ok();
node_box.thread.unpark();
}
}
pub(crate) fn drain_select_waiters(stack: &Arc<AtomicPtr<SelectWaiter>>) {
let head = stack.swap(std::ptr::null_mut(), AcqRel);
let mut current = head;
while !current.is_null() {
let node_box = unsafe { Box::from_raw(current) };
current = node_box.next.load(Acquire);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
const UNSELECTED: usize = usize::MAX;
#[test]
fn test_select_waiter_push_and_wake_one() {
let stack = Arc::new(AtomicPtr::new(std::ptr::null_mut()));
let selected = Arc::new(AtomicUsize::new(UNSELECTED));
let ptr = SelectWaiter::alloc(3, Arc::clone(&selected));
push_select_waiter(ptr, &stack);
assert!(wake_select_one(&stack));
assert!(stack.load(Acquire).is_null());
assert_eq!(selected.load(Acquire), 3);
}
#[test]
fn test_select_waiter_abort_skips() {
let stack = Arc::new(AtomicPtr::new(std::ptr::null_mut()));
let selected = Arc::new(AtomicUsize::new(UNSELECTED));
let ptr = SelectWaiter::alloc(7, Arc::clone(&selected));
push_select_waiter(ptr, &stack);
abort_select_waiters(&stack, &selected);
assert!(!wake_select_one(&stack));
assert_eq!(selected.load(Acquire), UNSELECTED);
}
#[test]
fn test_select_wake_all_frees_nodes() {
let stack = Arc::new(AtomicPtr::new(std::ptr::null_mut()));
let sel1 = Arc::new(AtomicUsize::new(UNSELECTED));
let sel2 = Arc::new(AtomicUsize::new(UNSELECTED));
push_select_waiter(SelectWaiter::alloc(1, Arc::clone(&sel1)), &stack);
push_select_waiter(SelectWaiter::alloc(2, Arc::clone(&sel2)), &stack);
wake_select_all(&stack);
assert!(stack.load(Acquire).is_null());
}
#[test]
fn test_drain_select_waiters_no_leak() {
let stack = Arc::new(AtomicPtr::new(std::ptr::null_mut()));
let selected = Arc::new(AtomicUsize::new(UNSELECTED));
push_select_waiter(SelectWaiter::alloc(0, Arc::clone(&selected)), &stack);
push_select_waiter(SelectWaiter::alloc(1, Arc::clone(&selected)), &stack);
drain_select_waiters(&stack);
assert!(stack.load(Acquire).is_null());
}
#[test]
fn test_abort_only_matching_selected() {
let stack = Arc::new(AtomicPtr::new(std::ptr::null_mut()));
let sel_a = Arc::new(AtomicUsize::new(UNSELECTED));
let sel_b = Arc::new(AtomicUsize::new(UNSELECTED));
push_select_waiter(SelectWaiter::alloc(0, Arc::clone(&sel_a)), &stack);
push_select_waiter(SelectWaiter::alloc(1, Arc::clone(&sel_b)), &stack);
abort_select_waiters(&stack, &sel_a);
assert!(wake_select_one(&stack));
assert_eq!(sel_b.load(Acquire), 1);
assert_eq!(sel_a.load(Acquire), UNSELECTED); }
}