queen_io/plus/
mpmc_queue.rs1use std::sync::Arc;
5use std::cell::UnsafeCell;
6
7use std::sync::atomic::AtomicUsize;
8use std::sync::atomic::Ordering::{Relaxed, Release, Acquire};
9
10struct Node<T> {
11 sequence: AtomicUsize,
12 value: Option<T>,
13}
14
15unsafe impl<T: Send> Send for Node<T> {}
16unsafe impl<T: Sync> Sync for Node<T> {}
17
18struct State<T> {
19 _pad0: [u8; 64],
20 buffer: Vec<UnsafeCell<Node<T>>>,
21 mask: usize,
22 _pad1: [u8; 64],
23 enqueue_pos: AtomicUsize,
24 _pad2: [u8; 64],
25 dequeue_pos: AtomicUsize,
26 _pad3: [u8; 64],
27}
28
29unsafe impl<T: Send> Send for State<T> {}
30unsafe impl<T: Sync> Sync for State<T> {}
31
32pub struct Queue<T> {
33 state: Arc<State<T>>,
34}
35
36impl<T: Send> State<T> {
37 fn with_capacity(capacity: usize) -> State<T> {
38 let capacity = if capacity < 2 || (capacity & (capacity - 1)) != 0 {
39 if capacity < 2 {
40 2
41 } else {
42 capacity.next_power_of_two()
44 }
45 } else {
46 capacity
47 };
48 let buffer = (0..capacity)
49 .map(|i| {
50 UnsafeCell::new(Node {
51 sequence: AtomicUsize::new(i),
52 value: None,
53 })
54 })
55 .collect::<Vec<_>>();
56 State {
57 _pad0: [0; 64],
58 buffer,
59 mask: capacity - 1,
60 _pad1: [0; 64],
61 enqueue_pos: AtomicUsize::new(0),
62 _pad2: [0; 64],
63 dequeue_pos: AtomicUsize::new(0),
64 _pad3: [0; 64],
65 }
66 }
67
68 fn push(&self, value: T) -> Result<(), T> {
69 let mask = self.mask;
70 let mut pos = self.enqueue_pos.load(Relaxed);
71 loop {
72 let node = &self.buffer[pos & mask];
73 let seq = unsafe { (*node.get()).sequence.load(Acquire) };
74 let diff: isize = seq as isize - pos as isize;
75
76 if diff == 0 {
77 let enqueue_pos = self.enqueue_pos.compare_and_swap(pos, pos + 1, Relaxed);
78 if enqueue_pos == pos {
79 unsafe {
80 (*node.get()).value = Some(value);
81 (*node.get()).sequence.store(pos + 1, Release);
82 }
83 break;
84 } else {
85 pos = enqueue_pos;
86 }
87 } else if diff < 0 {
88 return Err(value);
89 } else {
90 pos = self.enqueue_pos.load(Relaxed);
91 }
92 }
93 Ok(())
94 }
95
96 fn pop(&self) -> Option<T> {
97 let mask = self.mask;
98 let mut pos = self.dequeue_pos.load(Relaxed);
99 loop {
100 let node = &self.buffer[pos & mask];
101 let seq = unsafe { (*node.get()).sequence.load(Acquire) };
102 let diff: isize = seq as isize - (pos + 1) as isize;
103 if diff == 0 {
104 let dequeue_pos = self.dequeue_pos.compare_and_swap(pos, pos + 1, Relaxed);
105 if dequeue_pos == pos {
106 unsafe {
107 let value = (*node.get()).value.take();
108 (*node.get()).sequence.store(pos + mask + 1, Release);
109 return value;
110 }
111 } else {
112 pos = dequeue_pos;
113 }
114 } else if diff < 0 {
115 return None;
116 } else {
117 pos = self.dequeue_pos.load(Relaxed);
118 }
119 }
120 }
121}
122
123impl<T: Send> Queue<T> {
124 pub fn with_capacity(capacity: usize) -> Queue<T> {
125 Queue { state: Arc::new(State::with_capacity(capacity)) }
126 }
127
128 pub fn push(&self, value: T) -> Result<(), T> {
129 self.state.push(value)
130 }
131
132 pub fn pop(&self) -> Option<T> {
133 self.state.pop()
134 }
135}
136
137impl<T: Send> Clone for Queue<T> {
138 fn clone(&self) -> Queue<T> {
139 Queue { state: self.state.clone() }
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use std::thread;
146 use std::sync::mpsc::channel;
147 use super::Queue;
148
149 #[test]
150 fn test() {
151 let nthreads = 8;
152 let nmsgs = 1000;
153 let q = Queue::with_capacity(nthreads * nmsgs);
154 assert_eq!(None, q.pop());
155 let (tx, rx) = channel();
156
157 for _ in 0..nthreads {
158 let q = q.clone();
159 let tx = tx.clone();
160 thread::spawn(move || {
161 let q = q;
162 for i in 0..nmsgs {
163 assert!(q.push(i).is_ok());
164 }
165 tx.send(()).unwrap();
166 });
167 }
168
169 let mut completion_rxs = vec![];
170 for _ in 0..nthreads {
171 let (tx, rx) = channel();
172 completion_rxs.push(rx);
173 let q = q.clone();
174 thread::spawn(move || {
175 let q = q;
176 let mut i = 0;
177 loop {
178 match q.pop() {
179 None => {}
180 Some(_) => {
181 i += 1;
182 if i == nmsgs {
183 break;
184 }
185 }
186 }
187 }
188 tx.send(i).unwrap();
189 });
190 }
191
192 for rx in completion_rxs.iter_mut() {
193 assert_eq!(nmsgs, rx.recv().unwrap());
194 }
195 for _ in 0..nthreads {
196 rx.recv().unwrap();
197 }
198 }
199}