1#![allow(unsafe_code)]
7
8use super::{QueueError, QueueResult, QueueStats};
9use std::ptr;
10use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
11
12struct Node<T> {
14 value: Option<T>,
15 next: AtomicPtr<Node<T>>,
16}
17
18impl<T> Node<T> {
19 fn new(value: T) -> *mut Node<T> {
20 Box::into_raw(Box::new(Node {
21 value: Some(value),
22 next: AtomicPtr::new(ptr::null_mut()),
23 }))
24 }
25
26 fn stub() -> *mut Node<T> {
27 Box::into_raw(Box::new(Node {
28 value: None,
29 next: AtomicPtr::new(ptr::null_mut()),
30 }))
31 }
32}
33
34pub struct MpscQueue<T> {
40 head: AtomicPtr<Node<T>>,
42 tail: AtomicPtr<Node<T>>,
44 stats: QueueStats,
46 max_capacity: usize,
48 size: AtomicUsize,
50}
51
52impl<T> MpscQueue<T> {
53 pub fn new() -> Self {
55 let stub = Node::<T>::stub();
56 Self {
57 head: AtomicPtr::new(stub),
58 tail: AtomicPtr::new(stub),
59 stats: QueueStats::new(),
60 max_capacity: 0,
61 size: AtomicUsize::new(0),
62 }
63 }
64
65 pub fn with_capacity(capacity: usize) -> Self {
67 let mut queue = Self::new();
68 queue.max_capacity = capacity;
69 queue
70 }
71
72 pub fn push(&self, value: T) -> QueueResult<()> {
74 if self.max_capacity > 0 {
76 let size = self.size.load(Ordering::Relaxed);
77 if size >= self.max_capacity {
78 self.stats.record_overflow();
79 return Err(QueueError::QueueFull);
80 }
81 }
82
83 let node = Node::new(value);
84 let mut tail = self.tail.load(Ordering::Acquire);
85
86 loop {
87 let next = unsafe { (*tail).next.load(Ordering::Acquire) };
88
89 if next.is_null() {
90 match unsafe {
92 (*tail).next.compare_exchange_weak(
93 ptr::null_mut(),
94 node,
95 Ordering::Release,
96 Ordering::Relaxed,
97 )
98 } {
99 Ok(_) => {
100 let _ = self.tail.compare_exchange(
102 tail,
103 node,
104 Ordering::Release,
105 Ordering::Relaxed,
106 );
107 self.size.fetch_add(1, Ordering::Relaxed);
108 self.stats.record_push(self.size());
109 return Ok(());
110 }
111 Err(new_next) => {
112 let _ = self.tail.compare_exchange(
114 tail,
115 new_next,
116 Ordering::Release,
117 Ordering::Relaxed,
118 );
119 tail = new_next;
120 }
121 }
122 } else {
123 let _ =
125 self.tail
126 .compare_exchange(tail, next, Ordering::Release, Ordering::Relaxed);
127 tail = next;
128 }
129 }
130 }
131
132 pub fn pop(&self) -> Option<T> {
134 loop {
135 let head = self.head.load(Ordering::Acquire);
136 let tail = self.tail.load(Ordering::Acquire);
137 let next = unsafe { (*head).next.load(Ordering::Acquire) };
138
139 if head == tail {
140 if next.is_null() {
141 return None;
142 }
143 let _ =
144 self.tail
145 .compare_exchange(tail, next, Ordering::Release, Ordering::Relaxed);
146 } else {
147 if next.is_null() {
148 continue;
149 }
150
151 if self
152 .head
153 .compare_exchange(head, next, Ordering::Release, Ordering::Relaxed)
154 .is_ok()
155 {
156 let value = unsafe { (*next).value.take() };
157 unsafe {
158 drop(Box::from_raw(head));
159 }
160 self.size.fetch_sub(1, Ordering::Relaxed);
161 self.stats.record_pop();
162 return value;
163 }
164 }
165 }
166 }
167
168 pub fn size(&self) -> usize {
170 self.size.load(Ordering::Relaxed)
171 }
172
173 pub fn capacity(&self) -> usize {
175 self.max_capacity
176 }
177
178 pub fn is_empty(&self) -> bool {
180 let head = self.head.load(Ordering::Acquire);
181 let tail = self.tail.load(Ordering::Acquire);
182 let next = unsafe { (*head).next.load(Ordering::Acquire) };
183
184 head == tail && next.is_null()
185 }
186}
187
188impl<T> Drop for MpscQueue<T> {
189 fn drop(&mut self) {
190 while let Some(_) = self.pop() {}
191
192 let head = self.head.load(Ordering::Relaxed);
193 if !head.is_null() {
194 unsafe {
195 drop(Box::from_raw(head));
196 }
197 }
198 }
199}
200
201unsafe impl<T: Send> Send for MpscQueue<T> {}
202unsafe impl<T: Send> Sync for MpscQueue<T> {}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use std::thread;
208
209 #[test]
210 fn test_mpsc_basic() {
211 let queue = MpscQueue::new();
212
213 queue.push(1).unwrap();
214 queue.push(2).unwrap();
215 queue.push(3).unwrap();
216
217 assert_eq!(queue.pop(), Some(1));
218 assert_eq!(queue.pop(), Some(2));
219 assert_eq!(queue.pop(), Some(3));
220 assert_eq!(queue.pop(), None);
221 }
222
223 #[test]
224 fn test_mpsc_multiple_producers() {
225 let queue = std::sync::Arc::new(MpscQueue::new());
226 let mut handles = vec![];
227
228 for i in 0..4 {
229 let queue = queue.clone();
230 handles.push(thread::spawn(move || {
231 for j in 0..250 {
232 queue.push(i * 1000 + j).unwrap();
233 }
234 }));
235 }
236
237 for handle in handles {
238 handle.join().unwrap();
239 }
240
241 let mut count = 0;
242 while queue.pop().is_some() {
243 count += 1;
244 }
245
246 assert_eq!(count, 1000);
247 }
248}