1use core::{ptr, sync::atomic};
32
33#[derive(Debug)]
49pub struct MPSCQueue<T> {
50 head: atomic::AtomicPtr<Node<T>>,
51}
52
53impl<T> MPSCQueue<T> {
54 #[inline(always)]
56 pub fn is_empty(&self) -> bool {
57 self.head.load(atomic::Ordering::Acquire).is_null()
58 }
59
60 #[inline(always)]
78 pub fn push(&self, value: T) {
79 let mut head = self.head.load(atomic::Ordering::Relaxed);
80 let node = Node::new(value);
81
82 loop {
83 unsafe { (*node).next = head };
84
85 match self
86 .head
87 .compare_exchange_weak(head, node, atomic::Ordering::AcqRel, atomic::Ordering::Relaxed)
88 {
89 Ok(_) => return,
90 Err(h) => head = h,
91 }
92 }
93 }
94
95 #[inline(always)]
119 pub fn drain(&self) -> Vec<T> {
120 let mut out = Vec::new();
121 let mut node = self.head.swap(ptr::null_mut(), atomic::Ordering::AcqRel);
122
123 while !node.is_null() {
124 let boxed = unsafe { Box::from_raw(node) };
125 node = boxed.next;
126 out.push(boxed.value);
127 }
128
129 out
130 }
131}
132
133impl<T> Default for MPSCQueue<T> {
134 fn default() -> Self {
135 Self {
136 head: atomic::AtomicPtr::new(ptr::null_mut()),
137 }
138 }
139}
140
141impl<T> Drop for MPSCQueue<T> {
142 fn drop(&mut self) {
143 let mut node = self.head.swap(ptr::null_mut(), atomic::Ordering::Relaxed);
144 while !node.is_null() {
145 unsafe {
146 let boxed = Box::from_raw(node);
147 node = boxed.next;
148 }
149 }
150 }
151}
152
153struct Node<T> {
154 next: *mut Node<T>,
155 value: T,
156}
157
158impl<T> Node<T> {
159 fn new(value: T) -> *mut Self {
160 Box::into_raw(Box::new(Self {
161 next: ptr::null_mut(),
162 value,
163 }))
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use std::sync::{Arc, Barrier};
171 use std::thread;
172
173 mod basics {
174 use super::*;
175
176 #[test]
177 fn ok_push_drain_single() {
178 let q = MPSCQueue::default();
179 q.push(1usize);
180
181 let batch = q.drain();
182 assert_eq!(batch, vec![1]);
183 }
184
185 #[test]
186 fn ok_push_drain_multiple() {
187 let q = MPSCQueue::default();
188
189 q.push(1);
190 q.push(2);
191 q.push(3);
192
193 let batch = q.drain();
194 assert_eq!(batch.len(), 3);
195 assert_eq!(batch, vec![3, 2, 1]);
196 }
197
198 #[test]
199 fn ok_drain_empty_when_queue_empty() {
200 let q: MPSCQueue<usize> = MPSCQueue::default();
201 let batch = q.drain();
202 assert!(batch.is_empty());
203 }
204 }
205
206 mod empty {
207 use super::*;
208
209 #[test]
210 fn ok_is_empty_true_on_init() {
211 let q: MPSCQueue<usize> = MPSCQueue::default();
212 assert!(q.is_empty());
213 }
214
215 #[test]
216 fn ok_is_empty_false_after_push() {
217 let q = MPSCQueue::default();
218 q.push(1);
219 assert!(!q.is_empty());
220 }
221
222 #[test]
223 fn ok_is_empty_true_after_drain() {
224 let q = MPSCQueue::default();
225
226 q.push(1);
227 q.push(2);
228
229 let _ = q.drain();
230 assert!(q.is_empty());
231 }
232 }
233
234 mod cycles {
235 use super::*;
236
237 #[test]
238 fn ok_single_push_drain_cycles() {
239 let q = MPSCQueue::default();
240 for i in 0..0x400 {
241 q.push(i);
242 let batch = q.drain();
243
244 assert_eq!(batch.len(), 1);
245 assert_eq!(batch[0], i);
246 }
247 }
248
249 #[test]
250 fn ok_multi_push_drain_cycles() {
251 let q = MPSCQueue::default();
252 for _ in 0..0x200 {
253 for i in 0..0x0A {
254 q.push(i);
255 }
256
257 let batch = q.drain();
258 assert_eq!(batch.len(), 0x0A);
259 }
260 }
261 }
262
263 mod concurrency {
264 use super::*;
265
266 const THREADS: usize = 0x0A;
267 const ITERS: usize = 0x2000;
268
269 #[test]
270 fn ok_multi_tx_push() {
271 let q = Arc::new(MPSCQueue::default());
272
273 let mut handles = Vec::new();
274 for _ in 0..THREADS {
275 let q = q.clone();
276 handles.push(thread::spawn(move || {
277 for i in 0..ITERS {
278 q.push(i);
279 }
280 }));
281 }
282
283 for h in handles {
284 h.join().unwrap();
285 }
286
287 let batch = q.drain();
288 assert_eq!(batch.len(), THREADS * ITERS);
289 }
290
291 #[test]
292 fn ok_multi_tx_push_high_contention() {
293 let q = Arc::new(MPSCQueue::default());
294 let barrier = Arc::new(Barrier::new(THREADS * 2));
295
296 let mut handles = Vec::new();
297
298 for _ in 0..(THREADS * 2) {
299 let q = q.clone();
300 let barrier = barrier.clone();
301
302 handles.push(thread::spawn(move || {
303 barrier.wait();
304
305 for i in 0..(ITERS * 2) {
306 q.push(i);
307 }
308 }));
309 }
310
311 for h in handles {
312 h.join().unwrap();
313 }
314
315 let batch = q.drain();
316 assert_eq!(batch.len(), (THREADS * 2) * (ITERS * 2));
317 }
318
319 #[test]
320 fn ok_multi_tx_push_drain() {
321 let q = Arc::new(MPSCQueue::default());
322 let producer = {
323 let q = q.clone();
324 thread::spawn(move || {
325 for i in 0..0x8000 {
326 q.push(i);
327 }
328 })
329 };
330
331 let consumer = {
332 let q = q.clone();
333 thread::spawn(move || {
334 let mut total = 0usize;
335 while total < 0x8000 {
336 let batch = q.drain();
337 total += batch.len();
338 }
339
340 total
341 })
342 };
343
344 producer.join().unwrap();
345 let total = consumer.join().unwrap();
346
347 assert_eq!(total, 0x8000);
348 }
349 }
350}