1#![warn(
4 clippy::all, clippy::pedantic, clippy::clippy::cargo_common_metadata,
5 missing_crate_level_docs, missing_debug_implementations, missing_doc_code_examples,
6 missing_docs,
7)]
8#![allow(clippy::clippy::must_use_candidate)]
9
10use std::collections::BinaryHeap;
11use std::cell::UnsafeCell;
12use std::iter::FromIterator;
13use std::fmt::{Debug, Formatter, Result as FmtResult};
14use std::sync::atomic::{spin_loop_hint, AtomicBool, Ordering::{Relaxed, Release}};
15use ref_thread_local::{ref_thread_local, RefThreadLocal};
16use rand_distr::Uniform;
17use rand::prelude::*;
18
19ref_thread_local! {
20 static managed PRNG: SmallRng = SmallRng::from_entropy();
21}
22
23pub struct MilkPQ<T: Ord> {
25 queues: Box<[Queue<T>]>,
26 dist: Uniform<usize>,
27}
28
29impl<T: Ord + Clone> Clone for MilkPQ<T> {
30 fn clone(&self) -> Self {
31 MilkPQ { queues: self.queues.clone(), dist: self.dist }
32 }
33
34 fn clone_from(&mut self, source: &Self) {
35 self.queues.clone_from(&source.queues);
36 self.dist = source.dist;
37 }
38}
39
40impl<T: Ord> FromIterator<T> for MilkPQ<T> {
41 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
42 let iter = iter.into_iter();
43 let pq = MilkPQ::with_capacity(iter.size_hint().0);
44
45 for t in iter {
46 pq.push(t);
47 }
48
49 pq
50 }
51}
52
53impl<T: Ord> From<MilkPQ<T>> for Vec<T> {
54 fn from(pq: MilkPQ<T>) -> Self {
55 let mut vec = Vec::new();
56
57 for pq in pq.queues.into_vec() {
58 vec.extend(pq);
59 }
60
61 vec
62 }
63}
64
65impl<T: Ord> IntoIterator for MilkPQ<T> {
66 type Item = T;
67 type IntoIter = std::vec::IntoIter<T>;
68
69 fn into_iter(self) -> Self::IntoIter {
70 Vec::into_iter(self.into())
71 }
72}
73
74impl<T: Ord> Extend<T> for MilkPQ<T> {
75 fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
76 self.extend_ref(iter);
77 }
78}
79
80impl<T: Ord> Default for MilkPQ<T> {
81 fn default() -> Self {
82 MilkPQ::new()
83 }
84}
85
86impl<T: Ord + Debug> Debug for MilkPQ<T> {
87 fn fmt(&self, f: &mut Formatter) -> FmtResult {
88 f.debug_list().entries(self.queues.as_ref()).finish()
89 }
90}
91
92impl<T: Ord> MilkPQ<T> {
93 pub fn new() -> Self {
95 Self::with_queues(num_cpus::get() * 4)
96 }
97
98 pub fn with_capacity(cap: usize) -> Self {
100 Self::with_capacity_and_queues(cap, num_cpus::get() * 4)
101 }
102
103 pub fn with_queues(limit: usize) -> Self {
105 let queues = std::iter::repeat_with(|| Queue::new(BinaryHeap::new()))
106 .take(limit)
107 .collect::<Vec<_>>()
108 .into_boxed_slice();
109
110 MilkPQ { queues, dist: Uniform::new(0, limit) }
111 }
112
113 pub fn with_capacity_and_queues(cap: usize, limit: usize) -> Self {
118 let queues = std::iter::repeat_with(|| Queue::new(BinaryHeap::with_capacity(cap)))
119 .take(limit)
120 .collect::<Vec<_>>()
121 .into_boxed_slice();
122
123 MilkPQ { queues, dist: Uniform::new(0, limit) }
124 }
125
126 pub fn push(&self, mut t: T) {
128 let mut i = PRNG.borrow_mut().sample(self.dist);
129
130 while let Err(t2) = self.queues[i].try_push(t) {
131 t = t2;
132 i = PRNG.borrow_mut().sample(self.dist);
133 spin_loop_hint();
134 }
135 }
136
137 pub fn pop(&self) -> Option<T> {
144 let mut i = PRNG.borrow_mut().sample(self.dist);
145 let mut t;
146
147 while {t = self.queues[i].try_pop(); t.is_err()} {
148 i = PRNG.borrow_mut().sample(self.dist);
149 spin_loop_hint();
150 }
151
152 t.unwrap()
153 }
154
155 pub fn strong_pop(&self) -> Option<T> {
162 let mut t;
163
164 for queue in self.queues.as_ref() {
165 while {t = queue.try_pop(); t.is_err()} {
166 spin_loop_hint();
167 }
168
169 let t = t.unwrap();
170 if t.is_some() {
171 return t;
172 }
173 }
174
175 None
176 }
177
178 pub fn into_sorted_vec(self) -> Vec<T> {
180 let mut vec = Vec::from(self);
181 vec.sort_unstable_by(|l, r| l.cmp(r).reverse());
182 vec
183 }
184
185 pub fn clear(&self) {
187 for queue in self.queues.as_ref() {
188 queue.clear();
189 }
190 }
191
192 pub fn drain(&mut self) -> Vec<T> {
194 let mut vec = Vec::new();
195
196 for queue in self.queues.as_mut() {
197 vec.extend(queue.take())
198 }
199
200 vec
201 }
202
203 pub fn extend_ref<I: IntoIterator<Item = T>>(&self, iter: I) {
207 for t in iter {
208 self.push(t);
209 }
210 }
211}
212
213struct Queue<T: Ord> {
214 pq: UnsafeCell<BinaryHeap<T>>,
215 cas_lock: AtomicBool,
216}
217
218unsafe impl<T: Ord + Send> Send for Queue<T> {}
219unsafe impl<T: Ord + Sync> Sync for Queue<T> {}
220
221impl<T: Ord> IntoIterator for Queue<T> {
222 type Item = T;
223 type IntoIter = std::collections::binary_heap::IntoIter<T>;
224
225 fn into_iter(self) -> Self::IntoIter {
226 self.pq.into_inner().into_iter()
227 }
228}
229
230impl<T: Ord + Clone> Clone for Queue<T> {
231 fn clone(&self) -> Self {
232 while self.cas_lock.compare_exchange_weak(false, true, Release, Relaxed).is_err() {
233 spin_loop_hint();
234 }
235
236 let pq = UnsafeCell::new(unsafe { self.pq.get().as_ref() }.unwrap().clone());
237 let cas_lock = AtomicBool::new(false);
238 self.cas_lock.store(false, Release);
239 Queue { pq, cas_lock }
240 }
241
242 fn clone_from(&mut self, source: &Self) {
243 while source.cas_lock.compare_exchange_weak(false, true, Release, Relaxed).is_err() {
244 spin_loop_hint();
245 }
246
247 unsafe { self.pq.get().as_mut() }
248 .unwrap()
249 .clone_from(unsafe { source.pq.get().as_ref() }.unwrap());
250
251 source.cas_lock.store(false, Release);
252 }
253}
254
255impl<T: Ord + Debug> Debug for Queue<T> {
256 fn fmt(&self, f: &mut Formatter) -> FmtResult {
257 while self.cas_lock.compare_exchange_weak(false, true, Release, Relaxed).is_err() {
258 spin_loop_hint();
259 }
260
261 let fmt = unsafe { self.pq.get().as_ref() }.unwrap().fmt(f);
262 self.cas_lock.store(false, Release);
263 fmt
264 }
265}
266
267impl<T: Ord> Queue<T> {
268 fn new(pq: BinaryHeap<T>) -> Self {
269 Queue {
270 pq: UnsafeCell::new(pq),
271 cas_lock: AtomicBool::new(false),
272 }
273 }
274
275 #[must_use = "must check if CAS failed"]
276 fn try_push(&self, t: T) -> Result<(), T> {
277 match self.cas_lock.compare_exchange_weak(false, true, Release, Relaxed) {
278 Ok(_) => {
279 unsafe { self.pq.get().as_mut() }.unwrap().push(t);
280 self.cas_lock.store(false, Release);
281 Ok(())
282 }
283 Err(_) => Err(t),
284 }
285 }
286
287 #[must_use = "must check if CAS failed"]
288 fn try_pop(&self) -> Result<Option<T>, ()> {
289 match self.cas_lock.compare_exchange_weak(false, true, Release, Relaxed) {
290 Ok(_) => {
291 let r = unsafe { self.pq.get().as_mut() }.unwrap().pop();
292 self.cas_lock.store(false, Release);
293 Ok(r)
294 }
295 Err(_) => Err(()),
296 }
297 }
298
299 fn clear(&self) {
300 while self.cas_lock.compare_exchange_weak(false, true, Release, Relaxed).is_err() {
301 spin_loop_hint();
302 }
303
304 unsafe { self.pq.get().as_mut() }.unwrap().clear();
305 self.cas_lock.store(false, Release);
306 }
307
308 fn take(&mut self) -> BinaryHeap<T> {
309 let pq = unsafe { self.pq.get().as_mut() }.unwrap();
310 let new = BinaryHeap::with_capacity(pq.capacity());
311 std::mem::replace(pq, new)
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use std::sync::atomic::Ordering;
318 use super::*;
319
320 #[test]
321 fn try_push() {
322 let q = Queue::new(BinaryHeap::new());
323 assert_eq!(unsafe { q.pq.get().as_ref() }.unwrap().len(), 0);
324 assert_eq!(q.try_push(1), Ok(()));
325 assert_eq!(unsafe { q.pq.get().as_ref() }.unwrap().len(), 1);
326 q.cas_lock.store(true, Ordering::Release);
327 assert_eq!(q.try_push(2), Err(2));
328 assert_eq!(unsafe { q.pq.get().as_ref() }.unwrap().len(), 1);
329 q.cas_lock.store(false, Ordering::Release);
330 assert_eq!(q.try_push(2), Ok(()));
331 assert_eq!(unsafe { q.pq.get().as_ref() }.unwrap().len(), 2);
332 }
333
334 #[test]
335 fn try_pop() {
336 let mut bheap = BinaryHeap::new();
337 bheap.push(1);
338 bheap.push(2);
339 let q = Queue::new(bheap);
340 assert_eq!(unsafe { q.pq.get().as_ref() }.unwrap().len(), 2);
341 assert_eq!(q.try_pop(), Ok(Some(2)));
342 assert_eq!(unsafe { q.pq.get().as_ref() }.unwrap().len(), 1);
343 q.cas_lock.store(true, Ordering::Release);
344 assert_eq!(q.try_pop(), Err(()));
345 assert_eq!(unsafe { q.pq.get().as_ref() }.unwrap().len(), 1);
346 q.cas_lock.store(false, Ordering::Release);
347 assert_eq!(q.try_pop(), Ok(Some(1)));
348 assert_eq!(unsafe { q.pq.get().as_ref() }.unwrap().len(), 0);
349 assert_eq!(q.try_pop(), Ok(None));
350 assert_eq!(unsafe { q.pq.get().as_ref() }.unwrap().len(), 0);
351 }
352
353 #[test]
354 fn take() {
355 let mut bheap = BinaryHeap::new();
356 bheap.push(1);
357 bheap.push(2);
358 bheap.push(0);
359 let mut q = Queue::new(bheap.clone());
360 assert_eq!(bheap.into_sorted_vec(), q.take().into_sorted_vec());
361 assert_eq!(unsafe { q.pq.get().as_ref() }.unwrap().len(), 0);
362 }
363
364 #[test]
365 fn queue_clear() {
366 let mut bheap = BinaryHeap::new();
367 bheap.push(1);
368 bheap.push(2);
369 let q = Queue::new(bheap);
370 q.clear();
371 assert_eq!(unsafe { q.pq.get().as_ref() }.unwrap().len(), 0);
372 }
373
374 #[test]
375 fn into_sorted_vec() {
376 let q = MilkPQ::new();
377 let mut vs = (0..100).collect::<Vec<_>>();
378 vs.shuffle(&mut *PRNG.borrow_mut());
379 q.extend_ref(vs);
380 vs = q.into_sorted_vec();
381 assert_eq!(vs, (0..100).rev().collect::<Vec<_>>());
382 }
383
384 #[test]
385 fn strong_pop() {
386 let q = MilkPQ::new();
387 q.push(1);
388 q.push(2);
389 assert!(q.strong_pop().is_some());
390 assert!(q.strong_pop().is_some());
391 assert!(q.strong_pop().is_none());
392 }
393}