1use crate::buckets::bucket_writer::BucketItemSerializer;
2use rand::rng;
3use rand::RngCore;
4use rayon::prelude::*;
5use std::cell::UnsafeCell;
6use std::cmp::min;
7use std::cmp::Ordering;
8use std::fmt::Debug;
9use std::io::{Read, Write};
10use std::slice::from_raw_parts_mut;
11use std::sync::atomic::AtomicUsize;
12use unchecked_index::{unchecked_index, UncheckedIndex};
13
14type IndexType = usize;
15
16#[derive(Eq, PartialOrd, PartialEq, Ord, Copy, Clone, Debug)]
18pub struct SortedData<const LEN: usize> {
19 pub data: [u8; LEN],
20}
21
22impl<const LEN: usize> SortedData<LEN> {
23 #[inline(always)]
24 pub fn new(data: [u8; LEN]) -> Self {
25 Self { data }
26 }
27}
28
29impl<const LEN: usize> Default for SortedData<LEN> {
30 fn default() -> Self {
31 Self { data: [0; LEN] }
32 }
33}
34
35pub struct SortedDataSerializer<const LEN: usize>;
36impl<const LEN: usize> BucketItemSerializer for SortedDataSerializer<LEN> {
37 type InputElementType<'a> = SortedData<LEN>;
38 type ExtraData = ();
39 type ExtraDataBuffer = ();
40 type ReadBuffer = SortedData<LEN>;
41 type ReadType<'a> = &'a SortedData<LEN>;
42 type InitData = ();
43
44 type CheckpointData = ();
45
46 #[inline(always)]
47 fn new(_: ()) -> Self {
48 Self
49 }
50
51 #[inline(always)]
52 fn reset(&mut self) {}
53
54 #[inline(always)]
55 fn write_to(
56 &mut self,
57 element: &Self::InputElementType<'_>,
58 bucket: &mut Vec<u8>,
59 _: &Self::ExtraData,
60 _: &Self::ExtraDataBuffer,
61 ) {
62 bucket.write(element.data.as_slice()).unwrap();
63 }
64
65 #[inline(always)]
66 fn read_from<'a, S: Read>(
67 &mut self,
68 mut stream: S,
69 read_buffer: &'a mut Self::ReadBuffer,
70 _: &mut Self::ExtraDataBuffer,
71 ) -> Option<Self::ReadType<'a>> {
72 stream.read(read_buffer.data.as_mut_slice()).ok()?;
73 Some(read_buffer)
74 }
75
76 #[inline(always)]
77 fn get_size(&self, _: &Self::InputElementType<'_>, _: &()) -> usize {
78 LEN
79 }
80}
81
82pub trait FastSortable: Ord {
83 fn get_shifted(&self, rhs: u8) -> u8;
84}
85
86macro_rules! fast_sortable_impl {
87 ($int_type:ty) => {
88 impl FastSortable for $int_type {
89 #[inline(always)]
90 fn get_shifted(&self, rhs: u8) -> u8 {
91 (*self >> rhs) as u8
92 }
93 }
94 };
95}
96
97fast_sortable_impl!(u8);
98fast_sortable_impl!(u16);
99fast_sortable_impl!(u32);
100fast_sortable_impl!(u64);
101fast_sortable_impl!(u128);
102
103pub trait SortKey<T> {
104 type KeyType: Ord;
105 const KEY_BITS: usize;
106 fn compare(left: &T, right: &T) -> Ordering;
107 fn get_shifted(value: &T, rhs: u8) -> u8;
108}
109
110#[macro_export]
111macro_rules! make_comparer {
112 ($Name:ident, $type_name:ty, $key:ident: $key_type:ty) => {
113 struct $Name;
114 impl SortKey<$type_name> for $Name {
115 type KeyType = $key_type;
116 const KEY_BITS: usize = std::mem::size_of::<$key_type>() * 8;
117
118 fn compare(left: &$type_name, right: &$type_name) -> std::cmp::Ordering {
119 left.$key.cmp(&right.$key)
120 }
121
122 fn get_shifted(value: &$type_name, rhs: u8) -> u8 {
123 (value.$key >> rhs) as u8
124 }
125 }
126 };
127}
128
129const RADIX_SIZE_LOG: u8 = 8;
130const RADIX_SIZE: usize = 1 << 8;
131
132pub fn striped_parallel_smart_radix_sort<T: Ord + Send + Sync + Debug, F: SortKey<T>>(
156 striped_file: &[&mut [T]],
157 dest_buffer: &mut [T],
158) {
159 let num_threads = rayon::current_num_threads();
160 let queue = crossbeam::queue::ArrayQueue::new(num_threads);
161
162 let first_shift = F::KEY_BITS as u8 - RADIX_SIZE_LOG;
163
164 for _ in 0..num_threads {
165 queue.push([0; RADIX_SIZE + 1]).unwrap();
166 }
167
168 striped_file.par_iter().for_each(|chunk| {
169 let mut counts = queue.pop().unwrap();
170 for el in chunk.iter() {
171 counts[(F::get_shifted(el, first_shift)) as usize + 1] += 1usize;
172 }
173 queue.push(counts).unwrap();
174 });
175
176 let mut counters = [0; RADIX_SIZE + 1];
177 while let Some(counts) = queue.pop() {
178 for i in 1..(RADIX_SIZE + 1) {
179 counters[i] += counts[i];
180 }
181 }
182 const ATOMIC_USIZE_ZERO: AtomicUsize = AtomicUsize::new(0);
183 let offsets = [ATOMIC_USIZE_ZERO; RADIX_SIZE + 1];
184 let mut offsets_reference = [0; RADIX_SIZE + 1];
185
186 use std::sync::atomic::Ordering;
187 for i in 1..(RADIX_SIZE + 1) {
188 offsets_reference[i] = offsets[i - 1].load(Ordering::Relaxed) + counters[i];
189 offsets[i].store(offsets_reference[i], Ordering::Relaxed);
190 }
191
192 let dest_buffer_addr = dest_buffer.as_mut_ptr() as usize;
193 striped_file.par_iter().for_each(|chunk| {
194 let dest_buffer_ptr = dest_buffer_addr as *mut T;
195
196 let chunk_addr = chunk.as_ptr() as usize;
197 let chunk_data_mut = unsafe { from_raw_parts_mut(chunk_addr as *mut T, chunk.len()) };
198
199 let choffs = smart_radix_sort_::<T, F, false, true>(
200 chunk_data_mut,
201 F::KEY_BITS as u8 - RADIX_SIZE_LOG,
202 );
203 let mut offset = 0;
204 for idx in 1..(RADIX_SIZE + 1) {
205 let count = choffs[idx] - choffs[idx - 1];
206 let dest_position = offsets[idx - 1].fetch_add(count, Ordering::Relaxed);
207
208 unsafe {
209 std::ptr::copy_nonoverlapping(
210 chunk.as_ptr().add(offset),
211 dest_buffer_ptr.add(dest_position),
212 count,
213 );
214 }
215
216 offset += count;
217 }
218 });
219
220 if F::KEY_BITS >= 16 {
221 let offsets_reference = offsets_reference;
222 (0..256usize).into_par_iter().for_each(|idx| {
223 let dest_buffer_ptr = dest_buffer_addr as *mut T;
224
225 let bucket_start = offsets_reference[idx];
226 let bucket_len = offsets_reference[idx + 1] - bucket_start;
227
228 let crt_slice =
229 unsafe { from_raw_parts_mut(dest_buffer_ptr.add(bucket_start), bucket_len) };
230 smart_radix_sort_::<T, F, false, false>(crt_slice, F::KEY_BITS as u8 - 16);
231 });
232 }
233}
234
235pub fn fast_smart_radix_sort<T: Sync + Send, F: SortKey<T>, const PARALLEL: bool>(data: &mut [T]) {
236 smart_radix_sort_::<T, F, PARALLEL, false>(data, F::KEY_BITS as u8 - RADIX_SIZE_LOG);
237}
238
239pub fn fast_smart_radix_sort_by_value<T: Sync + Send, F: SortKey<T>, const PARALLEL: bool>(
240 data: &mut [T],
241) {
242 smart_radix_sort_::<T, F, PARALLEL, false>(data, F::KEY_BITS as u8 - RADIX_SIZE_LOG);
243}
244
245fn smart_radix_sort_<
246 T: Sync + Send,
247 F: SortKey<T>,
248 const PARALLEL: bool,
249 const SINGLE_STEP: bool,
250>(
251 data: &mut [T],
252 shift: u8,
253) -> [IndexType; RADIX_SIZE + 1] {
254 let mut stack = unsafe { unchecked_index(vec![(0..0, 0); shift as usize * RADIX_SIZE]) };
255
256 let mut stack_index = 1;
257 stack[0] = (0..data.len(), shift);
258
259 let mut ret_counts = [0; RADIX_SIZE + 1];
260
261 let mut first = true;
262
263 while stack_index > 0 {
264 stack_index -= 1;
265 let (range, shift) = stack[stack_index].clone();
266
267 let mut data = unsafe { unchecked_index(&mut data[range.clone()]) };
268
269 let mut counts: UncheckedIndex<[IndexType; RADIX_SIZE + 1]> =
270 unsafe { unchecked_index([0; RADIX_SIZE + 1]) };
271 let mut sums: UncheckedIndex<[IndexType; RADIX_SIZE + 1]>;
272
273 {
274 if PARALLEL {
275 const ATOMIC_ZERO: AtomicUsize = AtomicUsize::new(0);
276 let par_counts: UncheckedIndex<[AtomicUsize; RADIX_SIZE + 1]> =
277 unsafe { unchecked_index([ATOMIC_ZERO; RADIX_SIZE + 1]) };
278 let num_threads = rayon::current_num_threads();
279 let chunk_size = (data.len() + num_threads - 1) / num_threads;
280 data.chunks(chunk_size).par_bridge().for_each(|chunk| {
281 let mut thread_counts = unsafe { unchecked_index([0; RADIX_SIZE + 1]) };
282
283 for el in chunk {
284 thread_counts[(F::get_shifted(el, shift)) as usize + 1] += 1;
285 }
286
287 for (p, t) in par_counts.iter().zip(thread_counts.iter()) {
288 p.fetch_add(*t, std::sync::atomic::Ordering::Relaxed);
289 }
290 });
291
292 for i in 1..(RADIX_SIZE + 1) {
293 counts[i] =
294 counts[i - 1] + par_counts[i].load(std::sync::atomic::Ordering::Relaxed);
295 }
296 sums = counts;
297
298 let mut bucket_queues = Vec::with_capacity(RADIX_SIZE);
299 for i in 0..RADIX_SIZE {
300 bucket_queues.push(crossbeam::channel::unbounded());
301
302 let range = sums[i]..counts[i + 1];
303 let range_steps = num_threads * 2;
304 let tot_range_len = range.len();
305 let subrange_len = (tot_range_len + range_steps - 1) / range_steps;
306
307 let mut start = range.start;
308 while start < range.end {
309 let end = min(start + subrange_len, range.end);
310 if start < end {
311 bucket_queues[i].0.send(start..end).unwrap();
312 }
313 start += subrange_len;
314 }
315 }
316
317 let data_ptr = data.as_mut_ptr() as usize;
318 (0..num_threads).into_par_iter().for_each(|thread_index| {
319 let mut start_buckets = unsafe { unchecked_index([0; RADIX_SIZE]) };
320 let mut end_buckets = unsafe { unchecked_index([0; RADIX_SIZE]) };
321
322 let data = unsafe { from_raw_parts_mut(data_ptr as *mut T, data.len()) };
323
324 let get_bpart = || {
325 let start = rng().next_u32() as usize % RADIX_SIZE;
326 let mut res = None;
327 for i in 0..RADIX_SIZE {
328 let bucket_num = (i + start) % RADIX_SIZE;
329 if let Ok(val) = bucket_queues[bucket_num].1.try_recv() {
330 res = Some((bucket_num, val));
331 break;
332 }
333 }
334 res
335 };
336
337 let mut buckets_stack: Vec<_> = vec![];
338
339 while let Some((bidx, bpart)) = get_bpart() {
340 start_buckets[bidx] = bpart.start;
341 end_buckets[bidx] = bpart.end;
342 buckets_stack.push(bidx);
343
344 while let Some(bucket) = buckets_stack.pop() {
345 while start_buckets[bucket] < end_buckets[bucket] {
346 let val =
347 (F::get_shifted(&data[start_buckets[bucket]], shift)) as usize;
348
349 while start_buckets[val] == end_buckets[val] {
350 let next_bucket = match bucket_queues[val].1.try_recv() {
351 Ok(val) => val,
352 Err(_) => {
353 if thread_index == num_threads - 1 {
355 bucket_queues[val].1.recv().unwrap()
356 } else {
357 for i in 0..RADIX_SIZE {
359 if start_buckets[i] < end_buckets[i] {
360 bucket_queues[i]
361 .0
362 .send(start_buckets[i]..end_buckets[i])
363 .unwrap();
364 }
365 }
366 return;
367 }
368 }
369 };
370 start_buckets[val] = next_bucket.start;
371 end_buckets[val] = next_bucket.end;
372 buckets_stack.push(val);
373 }
374
375 data.swap(start_buckets[bucket], start_buckets[val]);
376 start_buckets[val] += 1;
377 }
378 }
379 }
380 });
381 } else {
382 for el in data.iter() {
383 counts[(F::get_shifted(el, shift)) as usize + 1] += 1;
384 }
385
386 for i in 1..(RADIX_SIZE + 1) {
387 counts[i] += counts[i - 1];
388 }
389 sums = counts;
390
391 for bucket in 0..RADIX_SIZE {
392 let end = counts[bucket + 1];
393 while sums[bucket] < end {
394 let val = (F::get_shifted(&data[sums[bucket]], shift)) as usize;
395 data.swap(sums[bucket], sums[val]);
396 sums[val] += 1;
397 }
398 }
399 }
400 }
401
402 if first {
403 ret_counts = *counts;
404 first = false;
405 }
406
407 struct UCWrapper<T> {
408 uc: UnsafeCell<T>,
409 }
410 unsafe impl<T> Sync for UCWrapper<T> {}
411 let data_ptr = UCWrapper {
412 uc: UnsafeCell::new(data),
413 };
414
415 if !SINGLE_STEP && shift >= RADIX_SIZE_LOG {
416 if PARALLEL && shift as usize == (F::KEY_BITS - RADIX_SIZE_LOG as usize) {
417 let data_ptr = &data_ptr;
418 (0..256usize)
419 .into_par_iter()
420 .filter(|x| (counts[(*x as usize) + 1] - counts[*x as usize]) > 1)
421 .for_each(|i| {
422 let mut data_ptr = unsafe { std::ptr::read(data_ptr.uc.get()) };
423 let slice = &mut data_ptr[counts[i] as usize..counts[i + 1] as usize];
424 smart_radix_sort_::<T, F, false, false>(slice, shift - RADIX_SIZE_LOG);
425 });
426 } else {
427 (0..RADIX_SIZE).into_iter().for_each(|i| {
428 let slice_len = counts[i + 1] - counts[i];
429 let mut data_ptr = unsafe { std::ptr::read(data_ptr.uc.get()) };
430
431 match slice_len {
432 2 => {
433 if F::compare(&data_ptr[counts[i]], &data_ptr[counts[i] + 1])
434 == Ordering::Greater
435 {
436 data_ptr.swap(counts[i], counts[i] + 1);
437 }
438 }
439 0 | 1 => return,
440
441 _ => {}
442 }
443
444 if slice_len < 192 {
445 let slice = &mut data_ptr[counts[i] as usize..counts[i + 1] as usize];
446 slice.sort_unstable_by(F::compare);
447 return;
448 }
449
450 stack[stack_index] = (
451 range.start + counts[i] as usize..range.start + counts[i + 1] as usize,
452 shift - RADIX_SIZE_LOG,
453 );
454 stack_index += 1;
455 });
456 }
457 }
458 }
459 ret_counts
460}
461
462#[cfg(test)]
463mod tests {
464 use crate::fast_smart_bucket_sort::{fast_smart_radix_sort, SortKey};
465 use rand::{rng, RngCore};
466 use std::time::Instant;
467 use voracious_radix_sort::RadixSort;
468
469 #[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
470 struct DataTypeStruct(u128, [u8; 32 - 16]);
471
472 struct U64SortKey;
473 impl SortKey<DataTypeStruct> for U64SortKey {
474 type KeyType = u128;
475 const KEY_BITS: usize = std::mem::size_of::<u128>() * 8;
476
477 #[inline(always)]
478 fn compare(left: &DataTypeStruct, right: &DataTypeStruct) -> std::cmp::Ordering {
479 left.0.cmp(&right.0)
480 }
481
482 #[inline(always)]
483 fn get_shifted(value: &DataTypeStruct, rhs: u8) -> u8 {
484 (value.0 >> rhs) as u8
485 }
486 }
487
488 #[test]
489 #[ignore]
490 fn parallel_sorting() {
491 const ARRAY_SIZE: usize = 5000000000;
492
493 let mut vec = Vec::with_capacity(ARRAY_SIZE);
494
495 let mut rng = rng();
496
497 for _ in 0..ARRAY_SIZE {
498 vec.push((rng.next_u32()) as u32);
499 }
500 let mut vec2 = vec.clone();
501
502 crate::log_info!("Starting...");
503 let start = Instant::now();
504
505 struct U16SortKey;
506 impl SortKey<u32> for U16SortKey {
507 type KeyType = u32;
508 const KEY_BITS: usize = std::mem::size_of::<u32>() * 8;
509
510 #[inline(always)]
511 fn compare(left: &u32, right: &u32) -> std::cmp::Ordering {
512 left.cmp(&right)
513 }
514
515 #[inline(always)]
516 fn get_shifted(value: &u32, rhs: u8) -> u8 {
517 (value >> rhs) as u8
518 }
519 }
520
521 fast_smart_radix_sort::<_, U16SortKey, true>(vec.as_mut_slice());
522
523 let end = start.elapsed();
524 crate::log_info!("Total time: {:.2?}", end);
525
526 crate::log_info!("Starting2...");
527 let start = Instant::now();
528
529 vec2.voracious_mt_sort(16);
530 let end = start.elapsed();
531 crate::log_info!("Total time 2: {:.2?}", end);
532 }
533
534 }