1use crate::buckets::bucket_writer::BucketItemSerializer;
2use rand::{thread_rng, RngCore};
3use rayon::prelude::*;
4use std::cell::UnsafeCell;
5use std::cmp::min;
6use std::cmp::Ordering;
7use std::fmt::Debug;
8use std::io::{Read, Write};
9use std::slice::from_raw_parts_mut;
10use std::sync::atomic::AtomicUsize;
11use unchecked_index::{unchecked_index, UncheckedIndex};
12
13type IndexType = usize;
14
15#[derive(Eq, PartialOrd, PartialEq, Ord, Copy, Clone, Debug)]
17pub struct SortedData<const LEN: usize> {
18 pub data: [u8; LEN],
19}
20
21impl<const LEN: usize> SortedData<LEN> {
22 #[inline(always)]
23 pub fn new(data: [u8; LEN]) -> Self {
24 Self { data }
25 }
26}
27
28pub struct SortedDataSerializer<const LEN: usize>;
29impl<const LEN: usize> BucketItemSerializer for SortedDataSerializer<LEN> {
30 type InputElementType<'a> = SortedData<LEN>;
31 type ExtraData = ();
32 type ExtraDataBuffer = ();
33 type ReadBuffer = SortedData<LEN>;
34 type ReadType<'a> = &'a SortedData<LEN>;
35 type InitData = ();
36
37 type CheckpointData = ();
38
39 #[inline(always)]
40 fn new(_: ()) -> Self {
41 Self
42 }
43
44 #[inline(always)]
45 fn reset(&mut self) {}
46
47 #[inline(always)]
48 fn write_to(
49 &mut self,
50 element: &Self::InputElementType<'_>,
51 bucket: &mut Vec<u8>,
52 _: &Self::ExtraData,
53 _: &Self::ExtraDataBuffer,
54 ) {
55 bucket.write(element.data.as_slice()).unwrap();
56 }
57
58 #[inline(always)]
59 fn read_from<'a, S: Read>(
60 &mut self,
61 mut stream: S,
62 read_buffer: &'a mut Self::ReadBuffer,
63 _: &mut Self::ExtraDataBuffer,
64 ) -> Option<Self::ReadType<'a>> {
65 stream.read(read_buffer.data.as_mut_slice()).ok()?;
66 Some(read_buffer)
67 }
68
69 #[inline(always)]
70 fn get_size(&self, _: &Self::InputElementType<'_>, _: &()) -> usize {
71 LEN
72 }
73}
74
75pub trait FastSortable: Ord {
76 fn get_shifted(&self, rhs: u8) -> u8;
77}
78
79macro_rules! fast_sortable_impl {
80 ($int_type:ty) => {
81 impl FastSortable for $int_type {
82 #[inline(always)]
83 fn get_shifted(&self, rhs: u8) -> u8 {
84 (*self >> rhs) as u8
85 }
86 }
87 };
88}
89
90fast_sortable_impl!(u8);
91fast_sortable_impl!(u16);
92fast_sortable_impl!(u32);
93fast_sortable_impl!(u64);
94fast_sortable_impl!(u128);
95
96pub trait SortKey<T> {
97 type KeyType: Ord;
98 const KEY_BITS: usize;
99 fn compare(left: &T, right: &T) -> Ordering;
100 fn get_shifted(value: &T, rhs: u8) -> u8;
101}
102
103#[macro_export]
104macro_rules! make_comparer {
105 ($Name:ident, $type_name:ty, $key:ident: $key_type:ty) => {
106 struct $Name;
107 impl SortKey<$type_name> for $Name {
108 type KeyType = $key_type;
109 const KEY_BITS: usize = std::mem::size_of::<$key_type>() * 8;
110
111 fn compare(left: &$type_name, right: &$type_name) -> std::cmp::Ordering {
112 left.$key.cmp(&right.$key)
113 }
114
115 fn get_shifted(value: &$type_name, rhs: u8) -> u8 {
116 (value.$key >> rhs) as u8
117 }
118 }
119 };
120}
121
122const RADIX_SIZE_LOG: u8 = 8;
123const RADIX_SIZE: usize = 1 << 8;
124
125pub fn striped_parallel_smart_radix_sort<T: Ord + Send + Sync + Debug, F: SortKey<T>>(
149 striped_file: &[&mut [T]],
150 dest_buffer: &mut [T],
151) {
152 let num_threads = rayon::current_num_threads();
153 let queue = crossbeam::queue::ArrayQueue::new(num_threads);
154
155 let first_shift = F::KEY_BITS as u8 - RADIX_SIZE_LOG;
156
157 for _ in 0..num_threads {
158 queue.push([0; RADIX_SIZE + 1]).unwrap();
159 }
160
161 striped_file.par_iter().for_each(|chunk| {
162 let mut counts = queue.pop().unwrap();
163 for el in chunk.iter() {
164 counts[(F::get_shifted(el, first_shift)) as usize + 1] += 1usize;
165 }
166 queue.push(counts).unwrap();
167 });
168
169 let mut counters = [0; RADIX_SIZE + 1];
170 while let Some(counts) = queue.pop() {
171 for i in 1..(RADIX_SIZE + 1) {
172 counters[i] += counts[i];
173 }
174 }
175 const ATOMIC_USIZE_ZERO: AtomicUsize = AtomicUsize::new(0);
176 let offsets = [ATOMIC_USIZE_ZERO; RADIX_SIZE + 1];
177 let mut offsets_reference = [0; RADIX_SIZE + 1];
178
179 use std::sync::atomic::Ordering;
180 for i in 1..(RADIX_SIZE + 1) {
181 offsets_reference[i] = offsets[i - 1].load(Ordering::Relaxed) + counters[i];
182 offsets[i].store(offsets_reference[i], Ordering::Relaxed);
183 }
184
185 let dest_buffer_addr = dest_buffer.as_mut_ptr() as usize;
186 striped_file.par_iter().for_each(|chunk| {
187 let dest_buffer_ptr = dest_buffer_addr as *mut T;
188
189 let chunk_addr = chunk.as_ptr() as usize;
190 let chunk_data_mut = unsafe { from_raw_parts_mut(chunk_addr as *mut T, chunk.len()) };
191
192 let choffs = smart_radix_sort_::<T, F, false, true>(
193 chunk_data_mut,
194 F::KEY_BITS as u8 - RADIX_SIZE_LOG,
195 );
196 let mut offset = 0;
197 for idx in 1..(RADIX_SIZE + 1) {
198 let count = choffs[idx] - choffs[idx - 1];
199 let dest_position = offsets[idx - 1].fetch_add(count, Ordering::Relaxed);
200
201 unsafe {
202 std::ptr::copy_nonoverlapping(
203 chunk.as_ptr().add(offset),
204 dest_buffer_ptr.add(dest_position),
205 count,
206 );
207 }
208
209 offset += count;
210 }
211 });
212
213 if F::KEY_BITS >= 16 {
214 let offsets_reference = offsets_reference;
215 (0..256usize).into_par_iter().for_each(|idx| {
216 let dest_buffer_ptr = dest_buffer_addr as *mut T;
217
218 let bucket_start = offsets_reference[idx];
219 let bucket_len = offsets_reference[idx + 1] - bucket_start;
220
221 let crt_slice =
222 unsafe { from_raw_parts_mut(dest_buffer_ptr.add(bucket_start), bucket_len) };
223 smart_radix_sort_::<T, F, false, false>(crt_slice, F::KEY_BITS as u8 - 16);
224 });
225 }
226}
227
228pub fn fast_smart_radix_sort<T: Sync + Send, F: SortKey<T>, const PARALLEL: bool>(data: &mut [T]) {
229 smart_radix_sort_::<T, F, PARALLEL, false>(data, F::KEY_BITS as u8 - RADIX_SIZE_LOG);
230}
231
232pub fn fast_smart_radix_sort_by_value<T: Sync + Send, F: SortKey<T>, const PARALLEL: bool>(
233 data: &mut [T],
234) {
235 smart_radix_sort_::<T, F, PARALLEL, false>(data, F::KEY_BITS as u8 - RADIX_SIZE_LOG);
236}
237
238fn smart_radix_sort_<
239 T: Sync + Send,
240 F: SortKey<T>,
241 const PARALLEL: bool,
242 const SINGLE_STEP: bool,
243>(
244 data: &mut [T],
245 shift: u8,
246) -> [IndexType; RADIX_SIZE + 1] {
247 let mut stack = unsafe { unchecked_index(vec![(0..0, 0); shift as usize * RADIX_SIZE]) };
248
249 let mut stack_index = 1;
250 stack[0] = (0..data.len(), shift);
251
252 let mut ret_counts = [0; RADIX_SIZE + 1];
253
254 let mut first = true;
255
256 while stack_index > 0 {
257 stack_index -= 1;
258 let (range, shift) = stack[stack_index].clone();
259
260 let mut data = unsafe { unchecked_index(&mut data[range.clone()]) };
261
262 let mut counts: UncheckedIndex<[IndexType; RADIX_SIZE + 1]> =
263 unsafe { unchecked_index([0; RADIX_SIZE + 1]) };
264 let mut sums: UncheckedIndex<[IndexType; RADIX_SIZE + 1]>;
265
266 {
267 if PARALLEL {
268 const ATOMIC_ZERO: AtomicUsize = AtomicUsize::new(0);
269 let par_counts: UncheckedIndex<[AtomicUsize; RADIX_SIZE + 1]> =
270 unsafe { unchecked_index([ATOMIC_ZERO; RADIX_SIZE + 1]) };
271 let num_threads = rayon::current_num_threads();
272 let chunk_size = (data.len() + num_threads - 1) / num_threads;
273 data.chunks(chunk_size).par_bridge().for_each(|chunk| {
274 let mut thread_counts = unsafe { unchecked_index([0; RADIX_SIZE + 1]) };
275
276 for el in chunk {
277 thread_counts[(F::get_shifted(el, shift)) as usize + 1] += 1;
278 }
279
280 for (p, t) in par_counts.iter().zip(thread_counts.iter()) {
281 p.fetch_add(*t, std::sync::atomic::Ordering::Relaxed);
282 }
283 });
284
285 for i in 1..(RADIX_SIZE + 1) {
286 counts[i] =
287 counts[i - 1] + par_counts[i].load(std::sync::atomic::Ordering::Relaxed);
288 }
289 sums = counts;
290
291 let mut bucket_queues = Vec::with_capacity(RADIX_SIZE);
292 for i in 0..RADIX_SIZE {
293 bucket_queues.push(crossbeam::channel::unbounded());
294
295 let range = sums[i]..counts[i + 1];
296 let range_steps = num_threads * 2;
297 let tot_range_len = range.len();
298 let subrange_len = (tot_range_len + range_steps - 1) / range_steps;
299
300 let mut start = range.start;
301 while start < range.end {
302 let end = min(start + subrange_len, range.end);
303 if start < end {
304 bucket_queues[i].0.send(start..end).unwrap();
305 }
306 start += subrange_len;
307 }
308 }
309
310 let data_ptr = data.as_mut_ptr() as usize;
311 (0..num_threads).into_par_iter().for_each(|thread_index| {
312 let mut start_buckets = unsafe { unchecked_index([0; RADIX_SIZE]) };
313 let mut end_buckets = unsafe { unchecked_index([0; RADIX_SIZE]) };
314
315 let data = unsafe { from_raw_parts_mut(data_ptr as *mut T, data.len()) };
316
317 let get_bpart = || {
318 let start = thread_rng().next_u32() as usize % RADIX_SIZE;
319 let mut res = None;
320 for i in 0..RADIX_SIZE {
321 let bucket_num = (i + start) % RADIX_SIZE;
322 if let Ok(val) = bucket_queues[bucket_num].1.try_recv() {
323 res = Some((bucket_num, val));
324 break;
325 }
326 }
327 res
328 };
329
330 let mut buckets_stack: Vec<_> = vec![];
331
332 while let Some((bidx, bpart)) = get_bpart() {
333 start_buckets[bidx] = bpart.start;
334 end_buckets[bidx] = bpart.end;
335 buckets_stack.push(bidx);
336
337 while let Some(bucket) = buckets_stack.pop() {
338 while start_buckets[bucket] < end_buckets[bucket] {
339 let val =
340 (F::get_shifted(&data[start_buckets[bucket]], shift)) as usize;
341
342 while start_buckets[val] == end_buckets[val] {
343 let next_bucket = match bucket_queues[val].1.try_recv() {
344 Ok(val) => val,
345 Err(_) => {
346 if thread_index == num_threads - 1 {
348 bucket_queues[val].1.recv().unwrap()
349 } else {
350 for i in 0..RADIX_SIZE {
352 if start_buckets[i] < end_buckets[i] {
353 bucket_queues[i]
354 .0
355 .send(start_buckets[i]..end_buckets[i])
356 .unwrap();
357 }
358 }
359 return;
360 }
361 }
362 };
363 start_buckets[val] = next_bucket.start;
364 end_buckets[val] = next_bucket.end;
365 buckets_stack.push(val);
366 }
367
368 data.swap(start_buckets[bucket], start_buckets[val]);
369 start_buckets[val] += 1;
370 }
371 }
372 }
373 });
374 } else {
375 for el in data.iter() {
376 counts[(F::get_shifted(el, shift)) as usize + 1] += 1;
377 }
378
379 for i in 1..(RADIX_SIZE + 1) {
380 counts[i] += counts[i - 1];
381 }
382 sums = counts;
383
384 for bucket in 0..RADIX_SIZE {
385 let end = counts[bucket + 1];
386 while sums[bucket] < end {
387 let val = (F::get_shifted(&data[sums[bucket]], shift)) as usize;
388 data.swap(sums[bucket], sums[val]);
389 sums[val] += 1;
390 }
391 }
392 }
393 }
394
395 if first {
396 ret_counts = *counts;
397 first = false;
398 }
399
400 struct UCWrapper<T> {
401 uc: UnsafeCell<T>,
402 }
403 unsafe impl<T> Sync for UCWrapper<T> {}
404 let data_ptr = UCWrapper {
405 uc: UnsafeCell::new(data),
406 };
407
408 if !SINGLE_STEP && shift >= RADIX_SIZE_LOG {
409 if PARALLEL && shift as usize == (F::KEY_BITS - RADIX_SIZE_LOG as usize) {
410 let data_ptr = &data_ptr;
411 (0..256usize)
412 .into_par_iter()
413 .filter(|x| (counts[(*x as usize) + 1] - counts[*x as usize]) > 1)
414 .for_each(|i| {
415 let mut data_ptr = unsafe { std::ptr::read(data_ptr.uc.get()) };
416 let slice = &mut data_ptr[counts[i] as usize..counts[i + 1] as usize];
417 smart_radix_sort_::<T, F, false, false>(slice, shift - RADIX_SIZE_LOG);
418 });
419 } else {
420 (0..RADIX_SIZE).into_iter().for_each(|i| {
421 let slice_len = counts[i + 1] - counts[i];
422 let mut data_ptr = unsafe { std::ptr::read(data_ptr.uc.get()) };
423
424 match slice_len {
425 2 => {
426 if F::compare(&data_ptr[counts[i]], &data_ptr[counts[i] + 1])
427 == Ordering::Greater
428 {
429 data_ptr.swap(counts[i], counts[i] + 1);
430 }
431 }
432 0 | 1 => return,
433
434 _ => {}
435 }
436
437 if slice_len < 192 {
438 let slice = &mut data_ptr[counts[i] as usize..counts[i + 1] as usize];
439 slice.sort_unstable_by(F::compare);
440 return;
441 }
442
443 stack[stack_index] = (
444 range.start + counts[i] as usize..range.start + counts[i + 1] as usize,
445 shift - RADIX_SIZE_LOG,
446 );
447 stack_index += 1;
448 });
449 }
450 }
451 }
452 ret_counts
453}
454
455#[cfg(test)]
456mod tests {
457 use crate::fast_smart_bucket_sort::{fast_smart_radix_sort, SortKey};
458 use rand::{thread_rng, RngCore};
459 use std::time::Instant;
460 use voracious_radix_sort::RadixSort;
461
462 #[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
463 struct DataTypeStruct(u128, [u8; (32 - 16)]);
464
465 struct U64SortKey;
466 impl SortKey<DataTypeStruct> for U64SortKey {
467 type KeyType = u128;
468 const KEY_BITS: usize = std::mem::size_of::<u128>() * 8;
469
470 #[inline(always)]
471 fn compare(left: &DataTypeStruct, right: &DataTypeStruct) -> std::cmp::Ordering {
472 left.0.cmp(&right.0)
473 }
474
475 #[inline(always)]
476 fn get_shifted(value: &DataTypeStruct, rhs: u8) -> u8 {
477 (value.0 >> rhs) as u8
478 }
479 }
480
481 #[test]
482 #[ignore]
483 fn parallel_sorting() {
484 const ARRAY_SIZE: usize = 5000000000;
485
486 let mut vec = Vec::with_capacity(ARRAY_SIZE);
487
488 let mut rng = thread_rng();
489
490 for _ in 0..ARRAY_SIZE {
491 vec.push((rng.next_u32()) as u32);
492 }
493 let mut vec2 = vec.clone();
494
495 crate::log_info!("Starting...");
496 let start = Instant::now();
497
498 struct U16SortKey;
499 impl SortKey<u32> for U16SortKey {
500 type KeyType = u32;
501 const KEY_BITS: usize = std::mem::size_of::<u32>() * 8;
502
503 #[inline(always)]
504 fn compare(left: &u32, right: &u32) -> std::cmp::Ordering {
505 left.cmp(&right)
506 }
507
508 #[inline(always)]
509 fn get_shifted(value: &u32, rhs: u8) -> u8 {
510 (value >> rhs) as u8
511 }
512 }
513
514 fast_smart_radix_sort::<_, U16SortKey, true>(vec.as_mut_slice());
515
516 let end = start.elapsed();
517 crate::log_info!("Total time: {:.2?}", end);
518
519 crate::log_info!("Starting2...");
520 let start = Instant::now();
521
522 vec2.voracious_mt_sort(16);
523 let end = start.elapsed();
524 crate::log_info!("Total time 2: {:.2?}", end);
525 }
526
527 }