lockfree/
stack.rs

1use owned_alloc::OwnedAlloc;
2use std::{
3    fmt,
4    iter::FromIterator,
5    mem::ManuallyDrop,
6    ptr::{null_mut, NonNull},
7    sync::atomic::{AtomicPtr, Ordering::*},
8};
9
10/// A lock-free stack. LIFO/FILO semanthics are fully respected.
11pub struct Stack<T> {
12    top: AtomicPtr<Node<T>>,
13    incin: SharedIncin<T>,
14}
15
16impl<T> Stack<T> {
17    /// Creates a new empty stack.
18    pub fn new() -> Self {
19        Self::with_incin(SharedIncin::new())
20    }
21
22    /// Creates an empty queue using the passed shared incinerator.
23    pub fn with_incin(incin: SharedIncin<T>) -> Self {
24        Self { top: AtomicPtr::new(null_mut()), incin }
25    }
26
27    /// Returns the shared incinerator used by this [`Stack`].
28    pub fn incin(&self) -> SharedIncin<T> {
29        self.incin.clone()
30    }
31
32    /// Creates an iterator over `T`s, based on [`pop`](Stack::pop) operation of
33    /// the [`Stack`].
34    pub fn pop_iter<'stack>(&'stack self) -> PopIter<'stack, T> {
35        PopIter { stack: self }
36    }
37
38    /// Pushes a new value onto the top of the stack.
39    pub fn push(&self, val: T) {
40        // Let's first create a node.
41        let mut target =
42            OwnedAlloc::new(Node::new(val, self.top.load(Acquire)));
43
44        loop {
45            // Let's try to publish our changes.
46            let new_top = target.raw().as_ptr();
47            match self.top.compare_exchange(
48                target.next,
49                new_top,
50                Release,
51                Relaxed,
52            ) {
53                Ok(_) => {
54                    // Let's be sure we do not deallocate the pointer.
55                    target.into_raw();
56                    break;
57                },
58
59                Err(ptr) => target.next = ptr,
60            }
61        }
62    }
63
64    /// Pops a single element from the top of the stack.
65    pub fn pop(&self) -> Option<T> {
66        // We need this because of ABA problem and use-after-free.
67        let pause = self.incin.inner.pause();
68        // First, let's load our top.
69        let mut top = self.top.load(Acquire);
70
71        loop {
72            // If top is null, we have nothing. Try operator (?) handles it.
73            let mut nnptr = NonNull::new(top)?;
74            // The replacement for top is its "next". This is only possible
75            // because of incinerator. Otherwise, we would face the "ABA
76            // problem".
77            //
78            // Note this dereferral is safe because we only delete nodes via
79            // incinerator and we have a pause now.
80            match self.top.compare_exchange(
81                top,
82                unsafe { nnptr.as_ref().next },
83                AcqRel,
84                Acquire,
85            ) {
86                Ok(_) => {
87                    // Done with an element. Let's first get the "val" to be
88                    // returned.
89                    //
90                    // This derreferal and read are safe since we drop the
91                    // node via incinerator and we never drop the inner value
92                    // when dropping the node in the incinerator.
93                    let val =
94                        unsafe { (&mut *nnptr.as_mut().val as *mut T).read() };
95                    // Safe because we already removed the node and we are
96                    // adding to the incinerator rather than
97                    // dropping it directly.
98                    pause.add_to_incin(unsafe { OwnedAlloc::from_raw(nnptr) });
99                    break Some(val);
100                },
101
102                Err(new_top) => top = new_top,
103            }
104        }
105    }
106
107    /// Pushes elements from the given iterable. Acts just like
108    /// [`Extend::extend`] but does not require mutability.
109    pub fn extend<I>(&self, iterable: I)
110    where
111        I: IntoIterator<Item = T>,
112    {
113        for elem in iterable {
114            self.push(elem);
115        }
116    }
117}
118
119impl<T> Default for Stack<T> {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl<T> Drop for Stack<T> {
126    fn drop(&mut self) {
127        while let Some(_) = self.next() {}
128    }
129}
130
131impl<T> Iterator for Stack<T> {
132    type Item = T;
133
134    fn next(&mut self) -> Option<T> {
135        let top = self.top.get_mut();
136
137        NonNull::new(*top).map(|nnptr| {
138            // This is safe because we only store pointers allocated via
139            // `OwnedAlloc`. Also, we have exclusive access to this pointer.
140            let mut node = unsafe { OwnedAlloc::from_raw(nnptr) };
141            *top = node.next;
142            // This read is we never drop the inner value when dropping the
143            // node.
144            unsafe { (&mut *node.val as *mut T).read() }
145        })
146    }
147}
148
149impl<T> Extend<T> for Stack<T> {
150    fn extend<I>(&mut self, iterable: I)
151    where
152        I: IntoIterator<Item = T>,
153    {
154        (&*self).extend(iterable)
155    }
156}
157
158impl<T> FromIterator<T> for Stack<T> {
159    fn from_iter<I>(iterable: I) -> Self
160    where
161        I: IntoIterator<Item = T>,
162    {
163        let this = Self::new();
164        this.extend(iterable);
165        this
166    }
167}
168
169impl<T> fmt::Debug for Stack<T> {
170    fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
171        write!(
172            fmtr,
173            "Stack {} top: {:?}, incin: {:?} {}",
174            '{', self.top, self.incin, '}'
175        )
176    }
177}
178
179unsafe impl<T> Send for Stack<T> where T: Send {}
180unsafe impl<T> Sync for Stack<T> where T: Send {}
181
182/// An iterator based on [`pop`](Stack::pop) operation of the [`Stack`].
183pub struct PopIter<'stack, T>
184where
185    T: 'stack,
186{
187    stack: &'stack Stack<T>,
188}
189
190impl<'stack, T> Iterator for PopIter<'stack, T> {
191    type Item = T;
192
193    fn next(&mut self) -> Option<Self::Item> {
194        self.stack.pop()
195    }
196}
197
198impl<'stack, T> fmt::Debug for PopIter<'stack, T> {
199    fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
200        write!(fmtr, "PopIter {} stack: {:?} {}", '{', self.stack, '}')
201    }
202}
203
204make_shared_incin! {
205    { "[`Stack`]" }
206    pub SharedIncin<T> of OwnedAlloc<Node<T>>
207}
208
209impl<T> fmt::Debug for SharedIncin<T> {
210    fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
211        write!(fmtr, "SharedIncin {} inner: {:?} {}", '{', self.inner, '}')
212    }
213}
214
215#[derive(Debug)]
216struct Node<T> {
217    val: ManuallyDrop<T>,
218    next: *mut Node<T>,
219}
220
221impl<T> Node<T> {
222    fn new(val: T, next: *mut Node<T>) -> Self {
223        Self { val: ManuallyDrop::new(val), next }
224    }
225}
226
227// Testing the safety of `unsafe` in this module is done with random operations
228// via fuzzing
229#[cfg(test)]
230mod test {
231    use super::*;
232    use std::{sync::Arc, thread};
233
234    #[test]
235    fn on_empty_first_pop_is_none() {
236        let stack = Stack::<usize>::new();
237        assert!(stack.pop().is_none());
238    }
239
240    #[test]
241    fn on_empty_last_pop_is_none() {
242        let stack = Stack::new();
243        stack.push(3);
244        stack.push(1234);
245        stack.pop();
246        stack.pop();
247        assert!(stack.pop().is_none());
248    }
249
250    #[test]
251    fn order() {
252        let stack = Stack::new();
253        stack.push(4);
254        stack.push(3);
255        stack.push(5);
256        stack.push(6);
257        assert_eq!(stack.pop(), Some(6));
258        assert_eq!(stack.pop(), Some(5));
259        assert_eq!(stack.pop(), Some(3));
260    }
261
262    #[test]
263    fn no_data_corruption() {
264        const NTHREAD: usize = 20;
265        const NITER: usize = 800;
266        const NMOD: usize = 55;
267
268        let stack = Arc::new(Stack::new());
269        let mut handles = Vec::with_capacity(NTHREAD);
270
271        for i in 0 .. NTHREAD {
272            let stack = stack.clone();
273            handles.push(thread::spawn(move || {
274                for j in 0 .. NITER {
275                    let val = (i * NITER) + j;
276                    stack.push(val);
277                    if (val + 1) % NMOD == 0 {
278                        if let Some(val) = stack.pop() {
279                            assert!(val < NITER * NTHREAD);
280                        }
281                    }
282                }
283            }));
284        }
285
286        for handle in handles {
287            handle.join().expect("thread failed");
288        }
289
290        let expected = NITER * NTHREAD - NITER * NTHREAD / NMOD;
291        let mut res = 0;
292        while let Some(val) = stack.pop() {
293            assert!(val < NITER * NTHREAD);
294            res += 1;
295        }
296
297        assert_eq!(res, expected);
298    }
299}