1#![allow(unsafe_code)]
6
7use super::{QueueError, QueueResult, QueueStats};
8use std::ptr;
9use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
10
11struct Node<T> {
13 value: Option<T>,
14 next: AtomicPtr<Node<T>>,
15}
16
17impl<T> Node<T> {
18 fn new(value: T) -> *mut Node<T> {
19 Box::into_raw(Box::new(Node {
20 value: Some(value),
21 next: AtomicPtr::new(ptr::null_mut()),
22 }))
23 }
24
25 fn stub() -> *mut Node<T> {
26 Box::into_raw(Box::new(Node {
27 value: None,
28 next: AtomicPtr::new(ptr::null_mut()),
29 }))
30 }
31}
32
33pub struct MpscQueue<T> {
38 head: AtomicPtr<Node<T>>,
40 tail: AtomicPtr<Node<T>>,
42 stats: QueueStats,
44 max_capacity: usize,
46 size: AtomicUsize,
48}
49
50impl<T> Default for MpscQueue<T> {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl<T> MpscQueue<T> {
57 pub fn new() -> Self {
59 let stub = Node::<T>::stub();
60 Self {
61 head: AtomicPtr::new(stub),
62 tail: AtomicPtr::new(stub),
63 stats: QueueStats::new(),
64 max_capacity: 0,
65 size: AtomicUsize::new(0),
66 }
67 }
68
69 pub fn with_capacity(capacity: usize) -> Self {
71 let mut queue = Self::new();
72 queue.max_capacity = capacity;
73 queue
74 }
75
76 pub fn push(&self, value: T) -> QueueResult<()> {
78 if self.max_capacity > 0 {
80 let size = self.size.load(Ordering::Relaxed);
81 if size >= self.max_capacity {
82 self.stats.record_overflow();
83 return Err(QueueError::QueueFull);
84 }
85 }
86
87 let node = Node::new(value);
88 let mut tail = self.tail.load(Ordering::Acquire);
89
90 loop {
91 let next = unsafe { (*tail).next.load(Ordering::Acquire) };
92
93 if next.is_null() {
94 match unsafe {
96 (*tail).next.compare_exchange_weak(
97 ptr::null_mut(),
98 node,
99 Ordering::Release,
100 Ordering::Relaxed,
101 )
102 } {
103 Ok(_) => {
104 let _ = self.tail.compare_exchange(
106 tail,
107 node,
108 Ordering::Release,
109 Ordering::Relaxed,
110 );
111 self.size.fetch_add(1, Ordering::Relaxed);
112 self.stats.record_push(self.size());
113 return Ok(());
114 }
115 Err(new_next) => {
116 let _ = self.tail.compare_exchange(
118 tail,
119 new_next,
120 Ordering::Release,
121 Ordering::Relaxed,
122 );
123 tail = new_next;
124 }
125 }
126 } else {
127 let _ =
129 self.tail
130 .compare_exchange(tail, next, Ordering::Release, Ordering::Relaxed);
131 tail = next;
132 }
133 }
134 }
135
136 pub fn pop(&self) -> Option<T> {
138 loop {
139 let head = self.head.load(Ordering::Acquire);
140 let tail = self.tail.load(Ordering::Acquire);
141 let next = unsafe { (*head).next.load(Ordering::Acquire) };
142
143 if head == tail {
144 if next.is_null() {
145 return None;
146 }
147 let _ =
148 self.tail
149 .compare_exchange(tail, next, Ordering::Release, Ordering::Relaxed);
150 } else {
151 if next.is_null() {
152 continue;
153 }
154
155 if self
156 .head
157 .compare_exchange(head, next, Ordering::Release, Ordering::Relaxed)
158 .is_ok()
159 {
160 let value = unsafe { (*next).value.take() };
161 unsafe {
162 drop(Box::from_raw(head));
163 }
164 self.size.fetch_sub(1, Ordering::Relaxed);
165 self.stats.record_pop();
166 return value;
167 }
168 }
169 }
170 }
171
172 pub fn size(&self) -> usize {
174 self.size.load(Ordering::Relaxed)
175 }
176
177 pub fn capacity(&self) -> usize {
179 self.max_capacity
180 }
181
182 pub fn is_empty(&self) -> bool {
184 let head = self.head.load(Ordering::Acquire);
185 let tail = self.tail.load(Ordering::Acquire);
186 let next = unsafe { (*head).next.load(Ordering::Acquire) };
187
188 head == tail && next.is_null()
189 }
190}
191
192impl<T> Drop for MpscQueue<T> {
193 fn drop(&mut self) {
194 while self.pop().is_some() {}
195
196 let head = self.head.load(Ordering::Relaxed);
197 if !head.is_null() {
198 unsafe {
199 drop(Box::from_raw(head));
200 }
201 }
202 }
203}
204
205unsafe impl<T: Send> Send for MpscQueue<T> {}
206unsafe impl<T: Send> Sync for MpscQueue<T> {}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use std::thread;
212
213 #[test]
214 fn test_mpsc_basic() {
215 let queue = MpscQueue::new();
216
217 queue.push(1).unwrap();
218 queue.push(2).unwrap();
219 queue.push(3).unwrap();
220
221 assert_eq!(queue.pop(), Some(1));
222 assert_eq!(queue.pop(), Some(2));
223 assert_eq!(queue.pop(), Some(3));
224 assert_eq!(queue.pop(), None);
225 }
226
227 #[test]
228 fn test_mpsc_multiple_producers() {
229 let queue = std::sync::Arc::new(MpscQueue::new());
230 let mut handles = vec![];
231
232 for i in 0..4 {
233 let queue = queue.clone();
234 handles.push(thread::spawn(move || {
235 for j in 0..250 {
236 queue.push(i * 1000 + j).unwrap();
237 }
238 }));
239 }
240
241 for handle in handles {
242 handle.join().unwrap();
243 }
244
245 let mut count = 0;
246 while queue.pop().is_some() {
247 count += 1;
248 }
249
250 assert_eq!(count, 1000);
251 }
252}