1use std::ptr;
11use std::cell::UnsafeCell;
12use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
13use std::ops::{Deref, DerefMut};
14
15#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
16#[repr(align(64))]
17struct Aligner;
18
19#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
20struct CacheAligned<T>(pub T, pub Aligner);
21
22impl<T> Deref for CacheAligned<T> {
23 type Target = T;
24 fn deref(&self) -> &Self::Target {
25 &self.0
26 }
27}
28
29impl<T> DerefMut for CacheAligned<T> {
30 fn deref_mut(&mut self) -> &mut Self::Target {
31 &mut self.0
32 }
33}
34
35impl<T> CacheAligned<T> {
36 fn new(t: T) -> Self {
37 CacheAligned(t, Aligner)
38 }
39}
40
41struct Node<T> {
43 value: Option<T>, cached: bool, next: AtomicPtr<Node<T>>, }
50
51pub struct Queue<T, ProducerAddition=(), ConsumerAddition=()> {
56 consumer: CacheAligned<Consumer<T, ConsumerAddition>>,
58
59 producer: CacheAligned<Producer<T, ProducerAddition>>,
61}
62
63struct Consumer<T, Addition> {
64 tail: UnsafeCell<*mut Node<T>>, tail_prev: AtomicPtr<Node<T>>, cache_bound: usize, cached_nodes: AtomicUsize, addition: Addition,
69}
70
71struct Producer<T, Addition> {
72 head: UnsafeCell<*mut Node<T>>, first: UnsafeCell<*mut Node<T>>, tail_copy: UnsafeCell<*mut Node<T>>, addition: Addition,
76}
77
78unsafe impl<T: Send, P: Send + Sync, C: Send + Sync> Send for Queue<T, P, C> { }
79
80unsafe impl<T: Send, P: Send + Sync, C: Send + Sync> Sync for Queue<T, P, C> { }
81
82impl<T> Node<T> {
83 fn new() -> *mut Node<T> {
84 Box::into_raw(Box::new(Node {
85 value: None,
86 cached: false,
87 next: AtomicPtr::new(ptr::null_mut::<Node<T>>()),
88 }))
89 }
90}
91
92impl<T, ProducerAddition, ConsumerAddition> Queue<T, ProducerAddition, ConsumerAddition> {
93
94 pub unsafe fn with_additions(
121 bound: usize,
122 producer_addition: ProducerAddition,
123 consumer_addition: ConsumerAddition,
124 ) -> Self {
125 let n1 = Node::new();
126 let n2 = Node::new();
127 (*n1).next.store(n2, Ordering::Relaxed);
128 Queue {
129 consumer: CacheAligned::new(Consumer {
130 tail: UnsafeCell::new(n2),
131 tail_prev: AtomicPtr::new(n1),
132 cache_bound: bound,
133 cached_nodes: AtomicUsize::new(0),
134 addition: consumer_addition
135 }),
136 producer: CacheAligned::new(Producer {
137 head: UnsafeCell::new(n2),
138 first: UnsafeCell::new(n1),
139 tail_copy: UnsafeCell::new(n1),
140 addition: producer_addition
141 }),
142 }
143 }
144
145 pub fn push(&self, t: T) {
148 unsafe {
149 let n = self.alloc();
152 assert!((*n).value.is_none());
153 (*n).value = Some(t);
154 (*n).next.store(ptr::null_mut(), Ordering::Relaxed);
155 (**self.producer.head.get()).next.store(n, Ordering::Release);
156 *(&self.producer.head).get() = n;
157 }
158 }
159
160 unsafe fn alloc(&self) -> *mut Node<T> {
161 if *self.producer.first.get() != *self.producer.tail_copy.get() {
163 let ret = *self.producer.first.get();
164 *self.producer.0.first.get() = (*ret).next.load(Ordering::Relaxed);
165 return ret;
166 }
167 *self.producer.0.tail_copy.get() =
170 self.consumer.tail_prev.load(Ordering::Acquire);
171 if *self.producer.first.get() != *self.producer.tail_copy.get() {
172 let ret = *self.producer.first.get();
173 *self.producer.0.first.get() = (*ret).next.load(Ordering::Relaxed);
174 return ret;
175 }
176 Node::new()
179 }
180
181 pub fn pop(&self) -> Option<T> {
184 unsafe {
185 let tail = *self.consumer.tail.get();
190 let next = (*tail).next.load(Ordering::Acquire);
191 if next.is_null() { return None }
192 assert!((*next).value.is_some());
193 let ret = (*next).value.take();
194
195 *self.consumer.0.tail.get() = next;
196 if self.consumer.cache_bound == 0 {
197 self.consumer.tail_prev.store(tail, Ordering::Release);
198 } else {
199 let cached_nodes = self.consumer.cached_nodes.load(Ordering::Relaxed);
200 if cached_nodes < self.consumer.cache_bound && !(*tail).cached {
201 self.consumer.cached_nodes.store(cached_nodes, Ordering::Relaxed);
202 (*tail).cached = true;
203 }
204
205 if (*tail).cached {
206 self.consumer.tail_prev.store(tail, Ordering::Release);
207 } else {
208 (*self.consumer.tail_prev.load(Ordering::Relaxed))
209 .next.store(next, Ordering::Relaxed);
210 let _: Box<Node<T>> = Box::from_raw(tail);
213 }
214 }
215 ret
216 }
217 }
218
219 pub fn peek(&self) -> Option<&mut T> {
227 unsafe {
230 let tail = *self.consumer.tail.get();
231 let next = (*tail).next.load(Ordering::Acquire);
232 if next.is_null() { None } else { (*next).value.as_mut() }
233 }
234 }
235
236 pub fn producer_addition(&self) -> &ProducerAddition {
237 &self.producer.addition
238 }
239
240 pub fn consumer_addition(&self) -> &ConsumerAddition {
241 &self.consumer.addition
242 }
243}
244
245impl<T, ProducerAddition, ConsumerAddition> Drop for Queue<T, ProducerAddition, ConsumerAddition> {
246 fn drop(&mut self) {
247 unsafe {
248 let mut cur = *self.producer.first.get();
249 while !cur.is_null() {
250 let next = (*cur).next.load(Ordering::Relaxed);
251 let _n: Box<Node<T>> = Box::from_raw(cur);
252 cur = next;
253 }
254 }
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::Queue;
261 use std::sync::Arc;
262 use std::thread;
263 use std::sync::mpsc::channel;
264
265 #[test]
266 fn smoke() {
267 unsafe {
268 let queue = Queue::with_additions(0, (), ());
269 queue.push(1);
270 queue.push(2);
271 assert_eq!(queue.pop(), Some(1));
272 assert_eq!(queue.pop(), Some(2));
273 assert_eq!(queue.pop(), None);
274 queue.push(3);
275 queue.push(4);
276 assert_eq!(queue.pop(), Some(3));
277 assert_eq!(queue.pop(), Some(4));
278 assert_eq!(queue.pop(), None);
279 }
280 }
281
282 #[test]
283 fn peek() {
284 unsafe {
285 let queue = Queue::with_additions(0, (), ());
286 queue.push(vec![1]);
287
288 match queue.peek() {
290 Some(vec) => {
291 assert_eq!(&*vec, &[1]);
292 },
293 None => unreachable!()
294 }
295
296 match queue.pop() {
297 Some(vec) => {
298 assert_eq!(vec, &[1]);
299 },
300 None => unreachable!()
301 }
302 }
303 }
304
305 #[test]
306 fn drop_full() {
307 unsafe {
308 let q: Queue<Box<_>> = Queue::with_additions(0, (), ());
309 q.push(Box::new(1));
310 q.push(Box::new(2));
311 }
312 }
313
314 #[test]
315 fn smoke_bound() {
316 unsafe {
317 let q = Queue::with_additions(0, (), ());
318 q.push(1);
319 q.push(2);
320 assert_eq!(q.pop(), Some(1));
321 assert_eq!(q.pop(), Some(2));
322 assert_eq!(q.pop(), None);
323 q.push(3);
324 q.push(4);
325 assert_eq!(q.pop(), Some(3));
326 assert_eq!(q.pop(), Some(4));
327 assert_eq!(q.pop(), None);
328 }
329 }
330
331 #[test]
332 fn stress() {
333 unsafe {
334 stress_bound(0);
335 stress_bound(1);
336 }
337
338 unsafe fn stress_bound(bound: usize) {
339 let q = Arc::new(Queue::with_additions(bound, (), ()));
340
341 let (tx, rx) = channel();
342 let q2 = q.clone();
343 let _t = thread::spawn(move|| {
344 for _ in 0..100000 {
345 loop {
346 match q2.pop() {
347 Some(1) => break,
348 Some(_) => panic!(),
349 None => {}
350 }
351 }
352 }
353 tx.send(()).unwrap();
354 });
355 for _ in 0..100000 {
356 q.push(1);
357 }
358 rx.recv().unwrap();
359 }
360 }
361}