1#![allow(unsafe_code)]
7
8use super::{QueueError, QueueResult, QueueStats};
9use std::ptr;
10use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
11
12struct Node<T> {
14 value: Option<T>,
15 next: AtomicPtr<Node<T>>,
16}
17
18impl<T> Node<T> {
19 fn new(value: T) -> *mut Node<T> {
20 Box::into_raw(Box::new(Node {
21 value: Some(value),
22 next: AtomicPtr::new(ptr::null_mut()),
23 }))
24 }
25
26 fn stub() -> *mut Node<T> {
27 Box::into_raw(Box::new(Node {
28 value: None,
29 next: AtomicPtr::new(ptr::null_mut()),
30 }))
31 }
32}
33
34pub struct MpscQueue<T> {
40 head: AtomicPtr<Node<T>>,
42 tail: AtomicPtr<Node<T>>,
44 stats: QueueStats,
46 max_capacity: usize,
48 size: AtomicUsize,
50}
51
52impl<T> Default for MpscQueue<T> {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl<T> MpscQueue<T> {
59 pub fn new() -> Self {
61 let stub = Node::<T>::stub();
62 Self {
63 head: AtomicPtr::new(stub),
64 tail: AtomicPtr::new(stub),
65 stats: QueueStats::new(),
66 max_capacity: 0,
67 size: AtomicUsize::new(0),
68 }
69 }
70
71 pub fn with_capacity(capacity: usize) -> Self {
73 let mut queue = Self::new();
74 queue.max_capacity = capacity;
75 queue
76 }
77
78 pub fn push(&self, value: T) -> QueueResult<()> {
80 if self.max_capacity > 0 {
82 let size = self.size.load(Ordering::Relaxed);
83 if size >= self.max_capacity {
84 self.stats.record_overflow();
85 return Err(QueueError::QueueFull);
86 }
87 }
88
89 let node = Node::new(value);
90 let mut tail = self.tail.load(Ordering::Acquire);
91
92 loop {
93 let next = unsafe { (*tail).next.load(Ordering::Acquire) };
94
95 if next.is_null() {
96 match unsafe {
98 (*tail).next.compare_exchange_weak(
99 ptr::null_mut(),
100 node,
101 Ordering::Release,
102 Ordering::Relaxed,
103 )
104 } {
105 Ok(_) => {
106 let _ = self.tail.compare_exchange(
108 tail,
109 node,
110 Ordering::Release,
111 Ordering::Relaxed,
112 );
113 self.size.fetch_add(1, Ordering::Relaxed);
114 self.stats.record_push(self.size());
115 return Ok(());
116 }
117 Err(new_next) => {
118 let _ = self.tail.compare_exchange(
120 tail,
121 new_next,
122 Ordering::Release,
123 Ordering::Relaxed,
124 );
125 tail = new_next;
126 }
127 }
128 } else {
129 let _ =
131 self.tail
132 .compare_exchange(tail, next, Ordering::Release, Ordering::Relaxed);
133 tail = next;
134 }
135 }
136 }
137
138 pub fn pop(&self) -> Option<T> {
140 loop {
141 let head = self.head.load(Ordering::Acquire);
142 let tail = self.tail.load(Ordering::Acquire);
143 let next = unsafe { (*head).next.load(Ordering::Acquire) };
144
145 if head == tail {
146 if next.is_null() {
147 return None;
148 }
149 let _ =
150 self.tail
151 .compare_exchange(tail, next, Ordering::Release, Ordering::Relaxed);
152 } else {
153 if next.is_null() {
154 continue;
155 }
156
157 if self
158 .head
159 .compare_exchange(head, next, Ordering::Release, Ordering::Relaxed)
160 .is_ok()
161 {
162 let value = unsafe { (*next).value.take() };
163 unsafe {
164 drop(Box::from_raw(head));
165 }
166 self.size.fetch_sub(1, Ordering::Relaxed);
167 self.stats.record_pop();
168 return value;
169 }
170 }
171 }
172 }
173
174 pub fn size(&self) -> usize {
176 self.size.load(Ordering::Relaxed)
177 }
178
179 pub fn capacity(&self) -> usize {
181 self.max_capacity
182 }
183
184 pub fn is_empty(&self) -> bool {
186 let head = self.head.load(Ordering::Acquire);
187 let tail = self.tail.load(Ordering::Acquire);
188 let next = unsafe { (*head).next.load(Ordering::Acquire) };
189
190 head == tail && next.is_null()
191 }
192}
193
194impl<T> Drop for MpscQueue<T> {
195 fn drop(&mut self) {
196 while self.pop().is_some() {}
197
198 let head = self.head.load(Ordering::Relaxed);
199 if !head.is_null() {
200 unsafe {
201 drop(Box::from_raw(head));
202 }
203 }
204 }
205}
206
207unsafe impl<T: Send> Send for MpscQueue<T> {}
208unsafe impl<T: Send> Sync for MpscQueue<T> {}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use std::thread;
214
215 #[test]
216 fn test_mpsc_basic() {
217 let queue = MpscQueue::new();
218
219 queue.push(1).unwrap();
220 queue.push(2).unwrap();
221 queue.push(3).unwrap();
222
223 assert_eq!(queue.pop(), Some(1));
224 assert_eq!(queue.pop(), Some(2));
225 assert_eq!(queue.pop(), Some(3));
226 assert_eq!(queue.pop(), None);
227 }
228
229 #[test]
230 fn test_mpsc_multiple_producers() {
231 let queue = std::sync::Arc::new(MpscQueue::new());
232 let mut handles = vec![];
233
234 for i in 0..4 {
235 let queue = queue.clone();
236 handles.push(thread::spawn(move || {
237 for j in 0..250 {
238 queue.push(i * 1000 + j).unwrap();
239 }
240 }));
241 }
242
243 for handle in handles {
244 handle.join().unwrap();
245 }
246
247 let mut count = 0;
248 while queue.pop().is_some() {
249 count += 1;
250 }
251
252 assert_eq!(count, 1000);
253 }
254}