coco/
stack.rs

1//! A lock-free stack.
2//!
3//! This is an implementation of the Treiber stack, one of the simplest lock-free data structures.
4
5use std::ptr;
6use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed};
7
8use epoch::{self, Atomic, Owned};
9
10/// A single node in a stack.
11struct Node<T> {
12    /// The payload.
13    value: T,
14    /// The next node in the stack.
15    next: Atomic<Node<T>>,
16}
17
18/// A lock-free stack.
19///
20/// It can be used with multiple producers and multiple consumers at the same time.
21pub struct Stack<T> {
22    head: Atomic<Node<T>>,
23}
24
25unsafe impl<T: Send> Send for Stack<T> {}
26unsafe impl<T: Send> Sync for Stack<T> {}
27
28impl<T> Stack<T> {
29    /// Returns a new, empty stack.
30    ///
31    /// # Examples
32    ///
33    /// ```
34    /// use coco::Stack;
35    ///
36    /// let s = Stack::<i32>::new();
37    /// ```
38    pub fn new() -> Self {
39        Stack { head: Atomic::null() }
40    }
41
42    /// Returns `true` if the stack is empty.
43    ///
44    /// # Examples
45    ///
46    /// ```
47    /// use coco::Stack;
48    ///
49    /// let s = Stack::new();
50    /// assert!(s.is_empty());
51    /// s.push("hello");
52    /// assert!(!s.is_empty());
53    /// ```
54    pub fn is_empty(&self) -> bool {
55        epoch::pin(|scope| self.head.load(Acquire, scope).is_null())
56    }
57
58    /// Pushes a new value onto the stack.
59    ///
60    /// # Examples
61    ///
62    /// ```
63    /// use coco::Stack;
64    ///
65    /// let s = Stack::new();
66    /// s.push(1);
67    /// s.push(2);
68    /// ```
69    pub fn push(&self, value: T) {
70        let mut node = Owned::new(Node {
71            value: value,
72            next: Atomic::null(),
73        });
74
75        epoch::pin(|scope| {
76            let mut head = self.head.load(Acquire, scope);
77            loop {
78                node.next.store(head, Relaxed);
79                match self.head.compare_and_swap_weak_owned(head, node, AcqRel, scope) {
80                    Ok(_) => break,
81                    Err((h, n)) => {
82                        head = h;
83                        node = n;
84                    }
85                }
86            }
87        })
88    }
89
90    /// Attempts to pop an value from the stack.
91    ///
92    /// Returns `None` if the stack is empty.
93    ///
94    /// # Examples
95    ///
96    /// ```
97    /// use coco::Stack;
98    ///
99    /// let s = Stack::new();
100    /// s.push(1);
101    /// s.push(2);
102    /// assert_eq!(s.pop(), Some(2));
103    /// assert_eq!(s.pop(), Some(1));
104    /// assert_eq!(s.pop(), None);
105    /// ```
106    pub fn pop(&self) -> Option<T> {
107        epoch::pin(|scope| {
108            let mut head = self.head.load(Acquire, scope);
109            loop {
110                match unsafe { head.as_ref() } {
111                    Some(h) => {
112                        let next = h.next.load(Acquire, scope);
113                        match self.head.compare_and_swap_weak(head, next, AcqRel, scope) {
114                            Ok(()) => unsafe {
115                                scope.defer_free(head);
116                                return Some(ptr::read(&h.value));
117                            },
118                            Err(h) => head = h,
119                        }
120                    }
121                    None => return None,
122                }
123            }
124        })
125    }
126}
127
128impl<T> Drop for Stack<T> {
129    fn drop(&mut self) {
130        // Destruct all nodes in the stack.
131        unsafe {
132            epoch::unprotected(|scope| {
133                let mut curr = self.head.load(Relaxed, scope).as_raw();
134                while !curr.is_null() {
135                    let next = (*curr).next.load(Relaxed, scope).as_raw();
136                    drop(Box::from_raw(curr as *mut Node<T>));
137                    curr = next;
138                }
139            })
140        }
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    extern crate rand;
147
148    use std::sync::Arc;
149    use std::sync::atomic::AtomicUsize;
150    use std::sync::atomic::Ordering::SeqCst;
151    use std::thread;
152
153    use super::Stack;
154    use self::rand::Rng;
155
156    #[test]
157    fn smoke() {
158        let s = Stack::new();
159        s.push(1);
160        assert_eq!(s.pop(), Some(1));
161        assert_eq!(s.pop(), None);
162    }
163
164    #[test]
165    fn push_pop() {
166        let s = Stack::new();
167        s.push(1);
168        s.push(2);
169        s.push(3);
170        assert_eq!(s.pop(), Some(3));
171        s.push(4);
172        assert_eq!(s.pop(), Some(4));
173        assert_eq!(s.pop(), Some(2));
174        assert_eq!(s.pop(), Some(1));
175        assert_eq!(s.pop(), None);
176        s.push(5);
177        assert_eq!(s.pop(), Some(5));
178        assert_eq!(s.pop(), None);
179    }
180
181    #[test]
182    fn is_empty() {
183        let s = Stack::new();
184        assert!(s.is_empty());
185
186        for i in 0..3 {
187            s.push(i);
188            assert!(!s.is_empty());
189        }
190
191        for _ in 0..3 {
192            assert!(!s.is_empty());
193            s.pop();
194        }
195
196        assert!(s.is_empty());
197        s.push(3);
198        assert!(!s.is_empty());
199        s.pop();
200        assert!(s.is_empty());
201    }
202
203    #[test]
204    fn stress() {
205        const THREADS: usize = 8;
206
207        let s = Arc::new(Stack::new());
208        let len = Arc::new(AtomicUsize::new(0));
209
210        let threads = (0..THREADS).map(|t| {
211            let s = s.clone();
212            let len = len.clone();
213
214            thread::spawn(move || {
215                let mut rng = rand::thread_rng();
216                for i in 0..100_000 {
217                    if rng.gen_range(0, t + 1) == 0 {
218                        if s.pop().is_some() {
219                            len.fetch_sub(1, SeqCst);
220                        }
221                    } else {
222                        s.push(t + THREADS * i);
223                        len.fetch_add(1, SeqCst);
224                    }
225                }
226            })
227        }).collect::<Vec<_>>();
228
229        for t in threads {
230            t.join().unwrap();
231        }
232
233        let mut last = [::std::usize::MAX; THREADS];
234
235        while !s.is_empty() {
236            let x = s.pop().unwrap();
237            let t = x % THREADS;
238
239            assert!(last[t] > x);
240            last[t] = x;
241
242            len.fetch_sub(1, SeqCst);
243        }
244        assert_eq!(len.load(SeqCst), 0);
245    }
246
247    #[test]
248    fn destructors() {
249        struct Elem((), Arc<AtomicUsize>);
250
251        impl Drop for Elem {
252            fn drop(&mut self) {
253                self.1.fetch_add(1, SeqCst);
254            }
255        }
256
257        const THREADS: usize = 8;
258
259        let s = Arc::new(Stack::new());
260        let len = Arc::new(AtomicUsize::new(0));
261        let popped = Arc::new(AtomicUsize::new(0));
262        let dropped = Arc::new(AtomicUsize::new(0));
263
264        let threads = (0..THREADS).map(|t| {
265            let s = s.clone();
266            let len = len.clone();
267            let popped = popped.clone();
268            let dropped = dropped.clone();
269
270            thread::spawn(move || {
271                let mut rng = rand::thread_rng();
272                for _ in 0..100_000 {
273                    if rng.gen_range(0, t + 1) == 0 {
274                        if s.pop().is_some() {
275                            len.fetch_sub(1, SeqCst);
276                            popped.fetch_add(1, SeqCst);
277                        }
278                    } else {
279                        s.push(Elem((), dropped.clone()));
280                        len.fetch_add(1, SeqCst);
281                    }
282                }
283            })
284        }).collect::<Vec<_>>();
285
286        for t in threads {
287            t.join().unwrap();
288        }
289
290        assert_eq!(dropped.load(SeqCst), popped.load(SeqCst));
291        drop(s);
292        assert_eq!(dropped.load(SeqCst), popped.load(SeqCst) + len.load(SeqCst));
293    }
294}