lock_free/
stack.rs

1//! A lock-free stack implementation using Treiber's algorithm.
2
3use std::sync::atomic::{AtomicPtr, Ordering};
4use std::ptr;
5
6struct Node<T> {
7    data: T,
8    next: *mut Node<T>,
9}
10
11/// A lock-free stack implementation using Treiber's algorithm.
12pub struct Stack<T> {
13    head: AtomicPtr<Node<T>>,
14}
15
16impl<T> Stack<T> {
17    /// Creates a new empty stack.
18    pub fn new() -> Self {
19        Stack {
20            head: AtomicPtr::new(ptr::null_mut()),
21        }
22    }
23
24    /// Pushes an element onto the stack.
25    pub fn push(&self, data: T) {
26        let new_node = Box::into_raw(Box::new(Node {
27            data,
28            next: ptr::null_mut(),
29        }));
30
31        loop {
32            let head = self.head.load(Ordering::Acquire);
33            unsafe {
34                (*new_node).next = head;
35            }
36
37            match self.head.compare_exchange_weak(
38                head,
39                new_node,
40                Ordering::Release,
41                Ordering::Acquire,
42            ) {
43                Ok(_) => break,
44                Err(_) => continue,
45            }
46        }
47    }
48
49    /// Pops an element from the stack, returning None if empty.
50    pub fn pop(&self) -> Option<T> {
51        loop {
52            let head = self.head.load(Ordering::Acquire);
53            if head.is_null() {
54                return None;
55            }
56
57            let next = unsafe { (*head).next };
58
59            match self.head.compare_exchange_weak(
60                head,
61                next,
62                Ordering::Release,
63                Ordering::Acquire,
64            ) {
65                Ok(_) => {
66                    let node = unsafe { Box::from_raw(head) };
67                    return Some(node.data);
68                }
69                Err(_) => continue,
70            }
71        }
72    }
73
74    /// Returns true if the stack is empty.
75    pub fn is_empty(&self) -> bool {
76        self.head.load(Ordering::Acquire).is_null()
77    }
78}
79
80impl<T> Default for Stack<T> {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl<T> Drop for Stack<T> {
87    fn drop(&mut self) {
88        while self.pop().is_some() {}
89    }
90}
91
92unsafe impl<T: Send> Send for Stack<T> {}
93unsafe impl<T: Send> Sync for Stack<T> {}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use std::sync::Arc;
99    use std::thread;
100
101    #[test]
102    fn test_single_thread() {
103        let stack = Stack::new();
104        
105        assert!(stack.is_empty());
106        assert_eq!(stack.pop(), None);
107        
108        stack.push(1);
109        stack.push(2);
110        stack.push(3);
111        
112        assert!(!stack.is_empty());
113        assert_eq!(stack.pop(), Some(3));
114        assert_eq!(stack.pop(), Some(2));
115        assert_eq!(stack.pop(), Some(1));
116        assert_eq!(stack.pop(), None);
117        assert!(stack.is_empty());
118    }
119
120    #[test]
121    fn test_concurrent_push_pop() {
122        let stack = Arc::new(Stack::new());
123        let num_threads = 8;
124        let operations_per_thread = 10000;
125        
126        let mut handles = vec![];
127        
128        for i in 0..num_threads {
129            let stack_clone = Arc::clone(&stack);
130            let handle = thread::spawn(move || {
131                for j in 0..operations_per_thread {
132                    stack_clone.push(i * operations_per_thread + j);
133                }
134            });
135            handles.push(handle);
136        }
137        
138        for handle in handles {
139            handle.join().unwrap();
140        }
141        
142        let mut count = 0;
143        while stack.pop().is_some() {
144            count += 1;
145        }
146        
147        assert_eq!(count, num_threads * operations_per_thread);
148    }
149}