queen_io/plus/
mpmc_queue.rs

1// http://www.1024cores.net/home/lock-free-algorithms/queues/bounded-mpmc-queue
2// This queue is copy pasted from old rust stdlib.
3
4use std::sync::Arc;
5use std::cell::UnsafeCell;
6
7use std::sync::atomic::AtomicUsize;
8use std::sync::atomic::Ordering::{Relaxed, Release, Acquire};
9
10struct Node<T> {
11    sequence: AtomicUsize,
12    value: Option<T>,
13}
14
15unsafe impl<T: Send> Send for Node<T> {}
16unsafe impl<T: Sync> Sync for Node<T> {}
17
18struct State<T> {
19    _pad0: [u8; 64],
20    buffer: Vec<UnsafeCell<Node<T>>>,
21    mask: usize,
22    _pad1: [u8; 64],
23    enqueue_pos: AtomicUsize,
24    _pad2: [u8; 64],
25    dequeue_pos: AtomicUsize,
26    _pad3: [u8; 64],
27}
28
29unsafe impl<T: Send> Send for State<T> {}
30unsafe impl<T: Sync> Sync for State<T> {}
31
32pub struct Queue<T> {
33    state: Arc<State<T>>,
34}
35
36impl<T: Send> State<T> {
37    fn with_capacity(capacity: usize) -> State<T> {
38        let capacity = if capacity < 2 || (capacity & (capacity - 1)) != 0 {
39            if capacity < 2 {
40                2
41            } else {
42                // use next power of 2 as capacity
43                capacity.next_power_of_two()
44            }
45        } else {
46            capacity
47        };
48        let buffer = (0..capacity)
49            .map(|i| {
50                UnsafeCell::new(Node {
51                    sequence: AtomicUsize::new(i),
52                    value: None,
53                })
54            })
55            .collect::<Vec<_>>();
56        State {
57            _pad0: [0; 64],
58            buffer,
59            mask: capacity - 1,
60            _pad1: [0; 64],
61            enqueue_pos: AtomicUsize::new(0),
62            _pad2: [0; 64],
63            dequeue_pos: AtomicUsize::new(0),
64            _pad3: [0; 64],
65        }
66    }
67
68    fn push(&self, value: T) -> Result<(), T> {
69        let mask = self.mask;
70        let mut pos = self.enqueue_pos.load(Relaxed);
71        loop {
72            let node = &self.buffer[pos & mask];
73            let seq = unsafe { (*node.get()).sequence.load(Acquire) };
74            let diff: isize = seq as isize - pos as isize;
75
76            if diff == 0 {
77                let enqueue_pos = self.enqueue_pos.compare_and_swap(pos, pos + 1, Relaxed);
78                if enqueue_pos == pos {
79                    unsafe {
80                        (*node.get()).value = Some(value);
81                        (*node.get()).sequence.store(pos + 1, Release);
82                    }
83                    break;
84                } else {
85                    pos = enqueue_pos;
86                }
87            } else if diff < 0 {
88                return Err(value);
89            } else {
90                pos = self.enqueue_pos.load(Relaxed);
91            }
92        }
93        Ok(())
94    }
95
96    fn pop(&self) -> Option<T> {
97        let mask = self.mask;
98        let mut pos = self.dequeue_pos.load(Relaxed);
99        loop {
100            let node = &self.buffer[pos & mask];
101            let seq = unsafe { (*node.get()).sequence.load(Acquire) };
102            let diff: isize = seq as isize - (pos + 1) as isize;
103            if diff == 0 {
104                let dequeue_pos = self.dequeue_pos.compare_and_swap(pos, pos + 1, Relaxed);
105                if dequeue_pos == pos {
106                    unsafe {
107                        let value = (*node.get()).value.take();
108                        (*node.get()).sequence.store(pos + mask + 1, Release);
109                        return value;
110                    }
111                } else {
112                    pos = dequeue_pos;
113                }
114            } else if diff < 0 {
115                return None;
116            } else {
117                pos = self.dequeue_pos.load(Relaxed);
118            }
119        }
120    }
121}
122
123impl<T: Send> Queue<T> {
124    pub fn with_capacity(capacity: usize) -> Queue<T> {
125        Queue { state: Arc::new(State::with_capacity(capacity)) }
126    }
127
128    pub fn push(&self, value: T) -> Result<(), T> {
129        self.state.push(value)
130    }
131
132    pub fn pop(&self) -> Option<T> {
133        self.state.pop()
134    }
135}
136
137impl<T: Send> Clone for Queue<T> {
138    fn clone(&self) -> Queue<T> {
139        Queue { state: self.state.clone() }
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::thread;
146    use std::sync::mpsc::channel;
147    use super::Queue;
148
149    #[test]
150    fn test() {
151        let nthreads = 8;
152        let nmsgs = 1000;
153        let q = Queue::with_capacity(nthreads * nmsgs);
154        assert_eq!(None, q.pop());
155        let (tx, rx) = channel();
156
157        for _ in 0..nthreads {
158            let q = q.clone();
159            let tx = tx.clone();
160            thread::spawn(move || {
161                let q = q;
162                for i in 0..nmsgs {
163                    assert!(q.push(i).is_ok());
164                }
165                tx.send(()).unwrap();
166            });
167        }
168
169        let mut completion_rxs = vec![];
170        for _ in 0..nthreads {
171            let (tx, rx) = channel();
172            completion_rxs.push(rx);
173            let q = q.clone();
174            thread::spawn(move || {
175                let q = q;
176                let mut i = 0;
177                loop {
178                    match q.pop() {
179                        None => {}
180                        Some(_) => {
181                            i += 1;
182                            if i == nmsgs {
183                                break;
184                            }
185                        }
186                    }
187                }
188                tx.send(i).unwrap();
189            });
190        }
191
192        for rx in completion_rxs.iter_mut() {
193            assert_eq!(nmsgs, rx.recv().unwrap());
194        }
195        for _ in 0..nthreads {
196            rx.recv().unwrap();
197        }
198    }
199}