use owned_alloc::OwnedAlloc;
use std::{
fmt,
iter::FromIterator,
mem::ManuallyDrop,
ptr::{null_mut, NonNull},
sync::atomic::{AtomicPtr, Ordering::*},
};
pub struct Stack<T> {
top: AtomicPtr<Node<T>>,
incin: SharedIncin<T>,
}
impl<T> Stack<T> {
pub fn new() -> Self {
Self::with_incin(SharedIncin::new())
}
pub fn with_incin(incin: SharedIncin<T>) -> Self {
Self { top: AtomicPtr::new(null_mut()), incin }
}
pub fn incin(&self) -> SharedIncin<T> {
self.incin.clone()
}
pub fn pop_iter<'stack>(&'stack self) -> PopIter<'stack, T> {
PopIter { stack: self }
}
pub fn push(&self, val: T) {
let mut target =
OwnedAlloc::new(Node::new(val, self.top.load(Acquire)));
loop {
let new_top = target.raw().as_ptr();
match self.top.compare_exchange(
target.next,
new_top,
Release,
Relaxed,
) {
Ok(_) => {
target.into_raw();
break;
},
Err(ptr) => target.next = ptr,
}
}
}
pub fn pop(&self) -> Option<T> {
let pause = self.incin.inner.pause();
let mut top = self.top.load(Acquire);
loop {
let mut nnptr = NonNull::new(top)?;
match self.top.compare_exchange(
top,
unsafe { nnptr.as_ref().next },
AcqRel,
Acquire,
) {
Ok(_) => {
let val =
unsafe { (&mut *nnptr.as_mut().val as *mut T).read() };
pause.add_to_incin(unsafe { OwnedAlloc::from_raw(nnptr) });
break Some(val);
},
Err(new_top) => top = new_top,
}
}
}
pub fn extend<I>(&self, iterable: I)
where
I: IntoIterator<Item = T>,
{
for elem in iterable {
self.push(elem);
}
}
}
impl<T> Default for Stack<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Drop for Stack<T> {
fn drop(&mut self) {
while let Some(_) = self.next() {}
}
}
impl<T> Iterator for Stack<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
let top = self.top.get_mut();
NonNull::new(*top).map(|nnptr| {
let mut node = unsafe { OwnedAlloc::from_raw(nnptr) };
*top = node.next;
unsafe { (&mut *node.val as *mut T).read() }
})
}
}
impl<T> Extend<T> for Stack<T> {
fn extend<I>(&mut self, iterable: I)
where
I: IntoIterator<Item = T>,
{
(&*self).extend(iterable)
}
}
impl<T> FromIterator<T> for Stack<T> {
fn from_iter<I>(iterable: I) -> Self
where
I: IntoIterator<Item = T>,
{
let this = Self::new();
this.extend(iterable);
this
}
}
impl<T> fmt::Debug for Stack<T> {
fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
write!(
fmtr,
"Stack {} top: {:?}, incin: {:?} {}",
'{', self.top, self.incin, '}'
)
}
}
unsafe impl<T> Send for Stack<T> where T: Send {}
unsafe impl<T> Sync for Stack<T> where T: Send {}
pub struct PopIter<'stack, T>
where
T: 'stack,
{
stack: &'stack Stack<T>,
}
impl<'stack, T> Iterator for PopIter<'stack, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.stack.pop()
}
}
impl<'stack, T> fmt::Debug for PopIter<'stack, T> {
fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
write!(fmtr, "PopIter {} stack: {:?} {}", '{', self.stack, '}')
}
}
make_shared_incin! {
{ "[`Stack`]" }
pub SharedIncin<T> of OwnedAlloc<Node<T>>
}
impl<T> fmt::Debug for SharedIncin<T> {
fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
write!(fmtr, "SharedIncin {} inner: {:?} {}", '{', self.inner, '}')
}
}
#[derive(Debug)]
struct Node<T> {
val: ManuallyDrop<T>,
next: *mut Node<T>,
}
impl<T> Node<T> {
fn new(val: T, next: *mut Node<T>) -> Self {
Self { val: ManuallyDrop::new(val), next }
}
}
#[cfg(test)]
mod test {
use super::*;
use std::{sync::Arc, thread};
#[test]
fn on_empty_first_pop_is_none() {
let stack = Stack::<usize>::new();
assert!(stack.pop().is_none());
}
#[test]
fn on_empty_last_pop_is_none() {
let stack = Stack::new();
stack.push(3);
stack.push(1234);
stack.pop();
stack.pop();
assert!(stack.pop().is_none());
}
#[test]
fn order() {
let stack = Stack::new();
stack.push(4);
stack.push(3);
stack.push(5);
stack.push(6);
assert_eq!(stack.pop(), Some(6));
assert_eq!(stack.pop(), Some(5));
assert_eq!(stack.pop(), Some(3));
}
#[test]
fn no_data_corruption() {
const NTHREAD: usize = 20;
const NITER: usize = 800;
const NMOD: usize = 55;
let stack = Arc::new(Stack::new());
let mut handles = Vec::with_capacity(NTHREAD);
for i in 0 .. NTHREAD {
let stack = stack.clone();
handles.push(thread::spawn(move || {
for j in 0 .. NITER {
let val = (i * NITER) + j;
stack.push(val);
if (val + 1) % NMOD == 0 {
if let Some(val) = stack.pop() {
assert!(val < NITER * NTHREAD);
}
}
}
}));
}
for handle in handles {
handle.join().expect("thread failed");
}
let expected = NITER * NTHREAD - NITER * NTHREAD / NMOD;
let mut res = 0;
while let Some(val) = stack.pop() {
assert!(val < NITER * NTHREAD);
res += 1;
}
assert_eq!(res, expected);
}
}