1use std::sync::atomic::{AtomicPtr, Ordering};
4use std::ptr;
5
6struct Node<T> {
7 data: T,
8 next: *mut Node<T>,
9}
10
11pub struct Stack<T> {
13 head: AtomicPtr<Node<T>>,
14}
15
16impl<T> Stack<T> {
17 pub fn new() -> Self {
19 Stack {
20 head: AtomicPtr::new(ptr::null_mut()),
21 }
22 }
23
24 pub fn push(&self, data: T) {
26 let new_node = Box::into_raw(Box::new(Node {
27 data,
28 next: ptr::null_mut(),
29 }));
30
31 loop {
32 let head = self.head.load(Ordering::Acquire);
33 unsafe {
34 (*new_node).next = head;
35 }
36
37 match self.head.compare_exchange_weak(
38 head,
39 new_node,
40 Ordering::Release,
41 Ordering::Acquire,
42 ) {
43 Ok(_) => break,
44 Err(_) => continue,
45 }
46 }
47 }
48
49 pub fn pop(&self) -> Option<T> {
51 loop {
52 let head = self.head.load(Ordering::Acquire);
53 if head.is_null() {
54 return None;
55 }
56
57 let next = unsafe { (*head).next };
58
59 match self.head.compare_exchange_weak(
60 head,
61 next,
62 Ordering::Release,
63 Ordering::Acquire,
64 ) {
65 Ok(_) => {
66 let node = unsafe { Box::from_raw(head) };
67 return Some(node.data);
68 }
69 Err(_) => continue,
70 }
71 }
72 }
73
74 pub fn is_empty(&self) -> bool {
76 self.head.load(Ordering::Acquire).is_null()
77 }
78}
79
80impl<T> Default for Stack<T> {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl<T> Drop for Stack<T> {
87 fn drop(&mut self) {
88 while self.pop().is_some() {}
89 }
90}
91
92unsafe impl<T: Send> Send for Stack<T> {}
93unsafe impl<T: Send> Sync for Stack<T> {}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use std::sync::Arc;
99 use std::thread;
100
101 #[test]
102 fn test_single_thread() {
103 let stack = Stack::new();
104
105 assert!(stack.is_empty());
106 assert_eq!(stack.pop(), None);
107
108 stack.push(1);
109 stack.push(2);
110 stack.push(3);
111
112 assert!(!stack.is_empty());
113 assert_eq!(stack.pop(), Some(3));
114 assert_eq!(stack.pop(), Some(2));
115 assert_eq!(stack.pop(), Some(1));
116 assert_eq!(stack.pop(), None);
117 assert!(stack.is_empty());
118 }
119
120 #[test]
121 fn test_concurrent_push_pop() {
122 let stack = Arc::new(Stack::new());
123 let num_threads = 8;
124 let operations_per_thread = 10000;
125
126 let mut handles = vec![];
127
128 for i in 0..num_threads {
129 let stack_clone = Arc::clone(&stack);
130 let handle = thread::spawn(move || {
131 for j in 0..operations_per_thread {
132 stack_clone.push(i * operations_per_thread + j);
133 }
134 });
135 handles.push(handle);
136 }
137
138 for handle in handles {
139 handle.join().unwrap();
140 }
141
142 let mut count = 0;
143 while stack.pop().is_some() {
144 count += 1;
145 }
146
147 assert_eq!(count, num_threads * operations_per_thread);
148 }
149}