#![warn(missing_debug_implementations)]
use crate::{
loom::{
cell::UnsafeCell,
sync::atomic::{AtomicPtr, Ordering::*},
},
Linked,
};
use core::{
fmt,
marker::PhantomPinned,
ptr::{self, NonNull},
};
pub struct TransferStack<T: Linked<Links<T>>> {
head: AtomicPtr<T>,
}
pub struct Stack<T: Linked<Links<T>>> {
pub(crate) head: Option<NonNull<T>>,
}
pub struct Links<T> {
pub(crate) next: UnsafeCell<Option<NonNull<T>>>,
_unpin: PhantomPinned,
}
impl<T> TransferStack<T>
where
T: Linked<Links<T>>,
{
#[cfg(not(loom))]
#[must_use]
pub const fn new() -> Self {
Self {
head: AtomicPtr::new(ptr::null_mut()),
}
}
#[cfg(loom)]
#[must_use]
pub fn new() -> Self {
Self {
head: AtomicPtr::new(ptr::null_mut()),
}
}
#[inline]
pub fn push(&self, element: T::Handle) {
self.push_was_empty(element);
}
pub fn push_was_empty(&self, element: T::Handle) -> bool {
let ptr = T::into_ptr(element);
test_trace!(?ptr, "TransferStack::push");
let links = unsafe { T::links(ptr).as_mut() };
debug_assert!(links.next.with(|next| unsafe { (*next).is_none() }));
let mut head = self.head.load(Relaxed);
loop {
test_trace!(?ptr, ?head, "TransferStack::push");
links.next.with_mut(|next| unsafe {
*next = NonNull::new(head);
});
match self
.head
.compare_exchange_weak(head, ptr.as_ptr(), AcqRel, Acquire)
{
Ok(old) => {
let was_empty = old.is_null();
test_trace!(?ptr, ?head, was_empty, "TransferStack::push -> pushed");
return was_empty;
}
Err(actual) => head = actual,
}
}
}
#[must_use]
pub fn take_all(&self) -> Stack<T> {
let head = self.head.swap(ptr::null_mut(), AcqRel);
let head = NonNull::new(head);
Stack { head }
}
}
impl<T> Drop for TransferStack<T>
where
T: Linked<Links<T>>,
{
fn drop(&mut self) {
for entry in self.take_all() {
drop(entry);
}
}
}
impl<T> fmt::Debug for TransferStack<T>
where
T: Linked<Links<T>>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { head } = self;
f.debug_struct("TransferStack").field("head", head).finish()
}
}
impl<T> Default for TransferStack<T>
where
T: Linked<Links<T>>,
{
fn default() -> Self {
Self::new()
}
}
impl<T> Stack<T>
where
T: Linked<Links<T>>,
{
#[must_use]
pub const fn new() -> Self {
Self { head: None }
}
pub fn push(&mut self, element: T::Handle) {
let ptr = T::into_ptr(element);
test_trace!(?ptr, ?self.head, "Stack::push");
unsafe {
let links = T::links(ptr).as_mut();
links.next.with_mut(|next| {
debug_assert!((*next).is_none());
*next = self.head.replace(ptr);
})
}
}
#[must_use]
pub fn pop(&mut self) -> Option<T::Handle> {
test_trace!(?self.head, "Stack::pop");
let head = self.head.take()?;
unsafe {
self.head = T::links(head).as_mut().next.with_mut(|next| (*next).take());
test_trace!(?self.head, "Stack::pop -> popped");
Some(T::from_ptr(head))
}
}
#[must_use]
pub fn take_all(&mut self) -> Self {
Self {
head: self.head.take(),
}
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.head.is_none()
}
}
impl<T> Drop for Stack<T>
where
T: Linked<Links<T>>,
{
fn drop(&mut self) {
for entry in self {
drop(entry);
}
}
}
impl<T> fmt::Debug for Stack<T>
where
T: Linked<Links<T>>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { head } = self;
f.debug_struct("Stack").field("head", head).finish()
}
}
impl<T> Iterator for Stack<T>
where
T: Linked<Links<T>>,
{
type Item = T::Handle;
fn next(&mut self) -> Option<Self::Item> {
self.pop()
}
}
impl<T> Default for Stack<T>
where
T: Linked<Links<T>>,
{
fn default() -> Self {
Self::new()
}
}
unsafe impl<T> Send for Stack<T>
where
T: Send,
T: Linked<Links<T>>,
{
}
unsafe impl<T> Sync for Stack<T>
where
T: Sync,
T: Linked<Links<T>>,
{
}
impl<T> Links<T> {
#[cfg(not(loom))]
#[must_use]
pub const fn new() -> Self {
Self {
next: UnsafeCell::new(None),
_unpin: PhantomPinned,
}
}
#[cfg(loom)]
#[must_use]
pub fn new() -> Self {
Self {
next: UnsafeCell::new(None),
_unpin: PhantomPinned,
}
}
}
unsafe impl<T: Send> Send for Links<T> {}
unsafe impl<T: Sync> Sync for Links<T> {}
impl<T> fmt::Debug for Links<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("transfer_stack::Links { ... }")
}
}
impl<T> Default for Links<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod loom {
use super::*;
use crate::loom::{
self,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
thread,
};
use test_util::Entry;
#[test]
fn multithreaded_push() {
const PUSHES: i32 = 2;
loom::model(|| {
let stack = Arc::new(TransferStack::new());
let threads = Arc::new(AtomicUsize::new(2));
let thread1 = thread::spawn({
let stack = stack.clone();
let threads = threads.clone();
move || {
Entry::push_all(&stack, 1, PUSHES);
threads.fetch_sub(1, Ordering::Relaxed);
}
});
let thread2 = thread::spawn({
let stack = stack.clone();
let threads = threads.clone();
move || {
Entry::push_all(&stack, 2, PUSHES);
threads.fetch_sub(1, Ordering::Relaxed);
}
});
let mut seen = Vec::new();
loop {
seen.extend(stack.take_all().map(|entry| entry.val));
if threads.load(Ordering::Relaxed) == 0 {
break;
}
thread::yield_now();
}
seen.extend(stack.take_all().map(|entry| entry.val));
seen.sort();
assert_eq!(seen, vec![10, 11, 20, 21]);
thread1.join().unwrap();
thread2.join().unwrap();
})
}
#[test]
fn multithreaded_pop() {
const PUSHES: i32 = 2;
loom::model(|| {
let stack = Arc::new(TransferStack::new());
let thread1 = thread::spawn({
let stack = stack.clone();
move || Entry::push_all(&stack, 1, PUSHES)
});
let thread2 = thread::spawn({
let stack = stack.clone();
move || Entry::push_all(&stack, 2, PUSHES)
});
let thread3 = thread::spawn({
let stack = stack.clone();
move || stack.take_all().map(|entry| entry.val).collect::<Vec<_>>()
});
let seen_thread0 = stack.take_all().map(|entry| entry.val).collect::<Vec<_>>();
let seen_thread3 = thread3.join().unwrap();
thread1.join().unwrap();
thread2.join().unwrap();
let seen_thread0_final = stack.take_all().map(|entry| entry.val).collect::<Vec<_>>();
let mut all = dbg!(seen_thread0);
all.extend(dbg!(seen_thread3));
all.extend(dbg!(seen_thread0_final));
all.sort();
assert_eq!(all, vec![10, 11, 20, 21]);
})
}
#[test]
fn doesnt_leak() {
const PUSHES: i32 = 2;
loom::model(|| {
let stack = Arc::new(TransferStack::new());
let thread1 = thread::spawn({
let stack = stack.clone();
move || Entry::push_all(&stack, 1, PUSHES)
});
let thread2 = thread::spawn({
let stack = stack.clone();
move || Entry::push_all(&stack, 2, PUSHES)
});
tracing::info!("dropping stack");
drop(stack);
thread1.join().unwrap();
thread2.join().unwrap();
})
}
#[test]
fn take_all_doesnt_leak() {
const PUSHES: i32 = 2;
loom::model(|| {
let stack = Arc::new(TransferStack::new());
let thread1 = thread::spawn({
let stack = stack.clone();
move || Entry::push_all(&stack, 1, PUSHES)
});
let thread2 = thread::spawn({
let stack = stack.clone();
move || Entry::push_all(&stack, 2, PUSHES)
});
thread1.join().unwrap();
thread2.join().unwrap();
let take_all = stack.take_all();
tracing::info!("dropping stack");
drop(stack);
tracing::info!("dropping take_all");
drop(take_all);
})
}
#[test]
fn take_all_doesnt_leak_racy() {
const PUSHES: i32 = 2;
loom::model(|| {
let stack = Arc::new(TransferStack::new());
let thread1 = thread::spawn({
let stack = stack.clone();
move || Entry::push_all(&stack, 1, PUSHES)
});
let thread2 = thread::spawn({
let stack = stack.clone();
move || Entry::push_all(&stack, 2, PUSHES)
});
let take_all = stack.take_all();
thread1.join().unwrap();
thread2.join().unwrap();
tracing::info!("dropping stack");
drop(stack);
tracing::info!("dropping take_all");
drop(take_all);
})
}
#[test]
fn unsync() {
loom::model(|| {
let mut stack = Stack::<Entry>::new();
stack.push(Entry::new(1));
stack.push(Entry::new(2));
stack.push(Entry::new(3));
let mut take_all = stack.take_all();
for i in (1..=3).rev() {
assert_eq!(take_all.next().unwrap().val, i);
stack.push(Entry::new(10 + i));
}
let mut i = 11;
for entry in stack.take_all() {
assert_eq!(entry.val, i);
i += 1;
}
})
}
#[test]
fn unsync_doesnt_leak() {
loom::model(|| {
let mut stack = Stack::<Entry>::new();
stack.push(Entry::new(1));
stack.push(Entry::new(2));
stack.push(Entry::new(3));
})
}
}
#[cfg(test)]
mod test {
use super::{test_util::Entry, *};
#[test]
fn stack_is_send_sync() {
crate::util::assert_send_sync::<TransferStack<Entry>>()
}
#[test]
fn links_are_send_sync() {
crate::util::assert_send_sync::<Links<Entry>>()
}
}
#[cfg(test)]
pub(crate) mod test_util {
use super::*;
use crate::loom::alloc;
use core::pin::Pin;
#[pin_project::pin_project]
pub(crate) struct Entry {
#[pin]
links: Links<Entry>,
pub(crate) val: i32,
track: alloc::Track<()>,
}
impl PartialEq for Entry {
fn eq(&self, other: &Self) -> bool {
self.val.eq(&other.val)
}
}
impl Eq for Entry {}
impl PartialOrd for Entry {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Entry {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.val.cmp(&other.val)
}
}
unsafe impl Linked<Links<Self>> for Entry {
type Handle = Pin<Box<Entry>>;
fn into_ptr(handle: Pin<Box<Entry>>) -> NonNull<Self> {
unsafe { NonNull::from(Box::leak(Pin::into_inner_unchecked(handle))) }
}
unsafe fn from_ptr(ptr: NonNull<Self>) -> Self::Handle {
Pin::new_unchecked(Box::from_raw(ptr.as_ptr()))
}
unsafe fn links(target: NonNull<Self>) -> NonNull<Links<Self>> {
let links = ptr::addr_of_mut!((*target.as_ptr()).links);
NonNull::new_unchecked(links)
}
}
impl Entry {
pub(crate) fn new(val: i32) -> Pin<Box<Entry>> {
Box::pin(Entry {
links: Links::new(),
val,
track: alloc::Track::new(()),
})
}
pub(super) fn push_all(stack: &TransferStack<Self>, thread: i32, n: i32) {
for i in 0..n {
let entry = Self::new((thread * 10) + i);
stack.push(entry);
}
}
}
}