parallel_processor/
fast_smart_bucket_sort.rs

1use crate::buckets::bucket_writer::BucketItemSerializer;
2use rand::rng;
3use rand::RngCore;
4use rayon::prelude::*;
5use std::cell::UnsafeCell;
6use std::cmp::min;
7use std::cmp::Ordering;
8use std::fmt::Debug;
9use std::io::{Read, Write};
10use std::slice::from_raw_parts_mut;
11use std::sync::atomic::AtomicUsize;
12use unchecked_index::{unchecked_index, UncheckedIndex};
13
14type IndexType = usize;
15
16// #[repr(packed)]
17#[derive(Eq, PartialOrd, PartialEq, Ord, Copy, Clone, Debug)]
18pub struct SortedData<const LEN: usize> {
19    pub data: [u8; LEN],
20}
21
22impl<const LEN: usize> SortedData<LEN> {
23    #[inline(always)]
24    pub fn new(data: [u8; LEN]) -> Self {
25        Self { data }
26    }
27}
28
29impl<const LEN: usize> Default for SortedData<LEN> {
30    fn default() -> Self {
31        Self { data: [0; LEN] }
32    }
33}
34
35pub struct SortedDataSerializer<const LEN: usize>;
36impl<const LEN: usize> BucketItemSerializer for SortedDataSerializer<LEN> {
37    type InputElementType<'a> = SortedData<LEN>;
38    type ExtraData = ();
39    type ExtraDataBuffer = ();
40    type ReadBuffer = SortedData<LEN>;
41    type ReadType<'a> = &'a SortedData<LEN>;
42    type InitData = ();
43
44    type CheckpointData = ();
45
46    #[inline(always)]
47    fn new(_: ()) -> Self {
48        Self
49    }
50
51    #[inline(always)]
52    fn reset(&mut self) {}
53
54    #[inline(always)]
55    fn write_to(
56        &mut self,
57        element: &Self::InputElementType<'_>,
58        bucket: &mut Vec<u8>,
59        _: &Self::ExtraData,
60        _: &Self::ExtraDataBuffer,
61    ) {
62        bucket.write(element.data.as_slice()).unwrap();
63    }
64
65    #[inline(always)]
66    fn read_from<'a, S: Read>(
67        &mut self,
68        mut stream: S,
69        read_buffer: &'a mut Self::ReadBuffer,
70        _: &mut Self::ExtraDataBuffer,
71    ) -> Option<Self::ReadType<'a>> {
72        stream.read(read_buffer.data.as_mut_slice()).ok()?;
73        Some(read_buffer)
74    }
75
76    #[inline(always)]
77    fn get_size(&self, _: &Self::InputElementType<'_>, _: &()) -> usize {
78        LEN
79    }
80}
81
82pub trait FastSortable: Ord {
83    fn get_shifted(&self, rhs: u8) -> u8;
84}
85
86macro_rules! fast_sortable_impl {
87    ($int_type:ty) => {
88        impl FastSortable for $int_type {
89            #[inline(always)]
90            fn get_shifted(&self, rhs: u8) -> u8 {
91                (*self >> rhs) as u8
92            }
93        }
94    };
95}
96
97fast_sortable_impl!(u8);
98fast_sortable_impl!(u16);
99fast_sortable_impl!(u32);
100fast_sortable_impl!(u64);
101fast_sortable_impl!(u128);
102
103pub trait SortKey<T> {
104    type KeyType: Ord;
105    const KEY_BITS: usize;
106    fn compare(left: &T, right: &T) -> Ordering;
107    fn get_shifted(value: &T, rhs: u8) -> u8;
108}
109
110#[macro_export]
111macro_rules! make_comparer {
112    ($Name:ident, $type_name:ty, $key:ident: $key_type:ty) => {
113        struct $Name;
114        impl SortKey<$type_name> for $Name {
115            type KeyType = $key_type;
116            const KEY_BITS: usize = std::mem::size_of::<$key_type>() * 8;
117
118            fn compare(left: &$type_name, right: &$type_name) -> std::cmp::Ordering {
119                left.$key.cmp(&right.$key)
120            }
121
122            fn get_shifted(value: &$type_name, rhs: u8) -> u8 {
123                (value.$key >> rhs) as u8
124            }
125        }
126    };
127}
128
129const RADIX_SIZE_LOG: u8 = 8;
130const RADIX_SIZE: usize = 1 << 8;
131
132// pub fn striped_parallel_smart_radix_sort_memfile<
133//     T: Ord + Send + Sync + Debug + 'static,
134//     F: SortKey<T>,
135// >(
136//     mem_file: FileReader,
137//     dest_buffer: &mut Vec<T>,
138// ) -> usize {
139//     let chunks: Vec<_> = unsafe { mem_file.get_typed_chunks_mut::<T>().collect() };
140//     let tot_entries = chunks.iter().map(|x| x.len()).sum();
141//
142//     dest_buffer.clear();
143//     dest_buffer.reserve(tot_entries);
144//     unsafe { dest_buffer.set_len(tot_entries) };
145//
146//     striped_parallel_smart_radix_sort::<T, F>(chunks.as_slice(), dest_buffer.as_mut_slice());
147//
148//     assert_eq!(dest_buffer.len(), chunks.iter().map(|x| x.len()).sum());
149//     assert_eq!(dest_buffer.len(), mem_file.len() / size_of::<T>());
150//
151//     drop(mem_file);
152//     tot_entries
153// }
154
155pub fn striped_parallel_smart_radix_sort<T: Ord + Send + Sync + Debug, F: SortKey<T>>(
156    striped_file: &[&mut [T]],
157    dest_buffer: &mut [T],
158) {
159    let num_threads = rayon::current_num_threads();
160    let queue = crossbeam::queue::ArrayQueue::new(num_threads);
161
162    let first_shift = F::KEY_BITS as u8 - RADIX_SIZE_LOG;
163
164    for _ in 0..num_threads {
165        queue.push([0; RADIX_SIZE + 1]).unwrap();
166    }
167
168    striped_file.par_iter().for_each(|chunk| {
169        let mut counts = queue.pop().unwrap();
170        for el in chunk.iter() {
171            counts[(F::get_shifted(el, first_shift)) as usize + 1] += 1usize;
172        }
173        queue.push(counts).unwrap();
174    });
175
176    let mut counters = [0; RADIX_SIZE + 1];
177    while let Some(counts) = queue.pop() {
178        for i in 1..(RADIX_SIZE + 1) {
179            counters[i] += counts[i];
180        }
181    }
182    const ATOMIC_USIZE_ZERO: AtomicUsize = AtomicUsize::new(0);
183    let offsets = [ATOMIC_USIZE_ZERO; RADIX_SIZE + 1];
184    let mut offsets_reference = [0; RADIX_SIZE + 1];
185
186    use std::sync::atomic::Ordering;
187    for i in 1..(RADIX_SIZE + 1) {
188        offsets_reference[i] = offsets[i - 1].load(Ordering::Relaxed) + counters[i];
189        offsets[i].store(offsets_reference[i], Ordering::Relaxed);
190    }
191
192    let dest_buffer_addr = dest_buffer.as_mut_ptr() as usize;
193    striped_file.par_iter().for_each(|chunk| {
194        let dest_buffer_ptr = dest_buffer_addr as *mut T;
195
196        let chunk_addr = chunk.as_ptr() as usize;
197        let chunk_data_mut = unsafe { from_raw_parts_mut(chunk_addr as *mut T, chunk.len()) };
198
199        let choffs = smart_radix_sort_::<T, F, false, true>(
200            chunk_data_mut,
201            F::KEY_BITS as u8 - RADIX_SIZE_LOG,
202        );
203        let mut offset = 0;
204        for idx in 1..(RADIX_SIZE + 1) {
205            let count = choffs[idx] - choffs[idx - 1];
206            let dest_position = offsets[idx - 1].fetch_add(count, Ordering::Relaxed);
207
208            unsafe {
209                std::ptr::copy_nonoverlapping(
210                    chunk.as_ptr().add(offset),
211                    dest_buffer_ptr.add(dest_position),
212                    count,
213                );
214            }
215
216            offset += count;
217        }
218    });
219
220    if F::KEY_BITS >= 16 {
221        let offsets_reference = offsets_reference;
222        (0..256usize).into_par_iter().for_each(|idx| {
223            let dest_buffer_ptr = dest_buffer_addr as *mut T;
224
225            let bucket_start = offsets_reference[idx];
226            let bucket_len = offsets_reference[idx + 1] - bucket_start;
227
228            let crt_slice =
229                unsafe { from_raw_parts_mut(dest_buffer_ptr.add(bucket_start), bucket_len) };
230            smart_radix_sort_::<T, F, false, false>(crt_slice, F::KEY_BITS as u8 - 16);
231        });
232    }
233}
234
235pub fn fast_smart_radix_sort<T: Sync + Send, F: SortKey<T>, const PARALLEL: bool>(data: &mut [T]) {
236    smart_radix_sort_::<T, F, PARALLEL, false>(data, F::KEY_BITS as u8 - RADIX_SIZE_LOG);
237}
238
239pub fn fast_smart_radix_sort_by_value<T: Sync + Send, F: SortKey<T>, const PARALLEL: bool>(
240    data: &mut [T],
241) {
242    smart_radix_sort_::<T, F, PARALLEL, false>(data, F::KEY_BITS as u8 - RADIX_SIZE_LOG);
243}
244
245fn smart_radix_sort_<
246    T: Sync + Send,
247    F: SortKey<T>,
248    const PARALLEL: bool,
249    const SINGLE_STEP: bool,
250>(
251    data: &mut [T],
252    shift: u8,
253) -> [IndexType; RADIX_SIZE + 1] {
254    let mut stack = unsafe { unchecked_index(vec![(0..0, 0); shift as usize * RADIX_SIZE]) };
255
256    let mut stack_index = 1;
257    stack[0] = (0..data.len(), shift);
258
259    let mut ret_counts = [0; RADIX_SIZE + 1];
260
261    let mut first = true;
262
263    while stack_index > 0 {
264        stack_index -= 1;
265        let (range, shift) = stack[stack_index].clone();
266
267        let mut data = unsafe { unchecked_index(&mut data[range.clone()]) };
268
269        let mut counts: UncheckedIndex<[IndexType; RADIX_SIZE + 1]> =
270            unsafe { unchecked_index([0; RADIX_SIZE + 1]) };
271        let mut sums: UncheckedIndex<[IndexType; RADIX_SIZE + 1]>;
272
273        {
274            if PARALLEL {
275                const ATOMIC_ZERO: AtomicUsize = AtomicUsize::new(0);
276                let par_counts: UncheckedIndex<[AtomicUsize; RADIX_SIZE + 1]> =
277                    unsafe { unchecked_index([ATOMIC_ZERO; RADIX_SIZE + 1]) };
278                let num_threads = rayon::current_num_threads();
279                let chunk_size = (data.len() + num_threads - 1) / num_threads;
280                data.chunks(chunk_size).par_bridge().for_each(|chunk| {
281                    let mut thread_counts = unsafe { unchecked_index([0; RADIX_SIZE + 1]) };
282
283                    for el in chunk {
284                        thread_counts[(F::get_shifted(el, shift)) as usize + 1] += 1;
285                    }
286
287                    for (p, t) in par_counts.iter().zip(thread_counts.iter()) {
288                        p.fetch_add(*t, std::sync::atomic::Ordering::Relaxed);
289                    }
290                });
291
292                for i in 1..(RADIX_SIZE + 1) {
293                    counts[i] =
294                        counts[i - 1] + par_counts[i].load(std::sync::atomic::Ordering::Relaxed);
295                }
296                sums = counts;
297
298                let mut bucket_queues = Vec::with_capacity(RADIX_SIZE);
299                for i in 0..RADIX_SIZE {
300                    bucket_queues.push(crossbeam::channel::unbounded());
301
302                    let range = sums[i]..counts[i + 1];
303                    let range_steps = num_threads * 2;
304                    let tot_range_len = range.len();
305                    let subrange_len = (tot_range_len + range_steps - 1) / range_steps;
306
307                    let mut start = range.start;
308                    while start < range.end {
309                        let end = min(start + subrange_len, range.end);
310                        if start < end {
311                            bucket_queues[i].0.send(start..end).unwrap();
312                        }
313                        start += subrange_len;
314                    }
315                }
316
317                let data_ptr = data.as_mut_ptr() as usize;
318                (0..num_threads).into_par_iter().for_each(|thread_index| {
319                    let mut start_buckets = unsafe { unchecked_index([0; RADIX_SIZE]) };
320                    let mut end_buckets = unsafe { unchecked_index([0; RADIX_SIZE]) };
321
322                    let data = unsafe { from_raw_parts_mut(data_ptr as *mut T, data.len()) };
323
324                    let get_bpart = || {
325                        let start = rng().next_u32() as usize % RADIX_SIZE;
326                        let mut res = None;
327                        for i in 0..RADIX_SIZE {
328                            let bucket_num = (i + start) % RADIX_SIZE;
329                            if let Ok(val) = bucket_queues[bucket_num].1.try_recv() {
330                                res = Some((bucket_num, val));
331                                break;
332                            }
333                        }
334                        res
335                    };
336
337                    let mut buckets_stack: Vec<_> = vec![];
338
339                    while let Some((bidx, bpart)) = get_bpart() {
340                        start_buckets[bidx] = bpart.start;
341                        end_buckets[bidx] = bpart.end;
342                        buckets_stack.push(bidx);
343
344                        while let Some(bucket) = buckets_stack.pop() {
345                            while start_buckets[bucket] < end_buckets[bucket] {
346                                let val =
347                                    (F::get_shifted(&data[start_buckets[bucket]], shift)) as usize;
348
349                                while start_buckets[val] == end_buckets[val] {
350                                    let next_bucket = match bucket_queues[val].1.try_recv() {
351                                        Ok(val) => val,
352                                        Err(_) => {
353                                            // Final thread
354                                            if thread_index == num_threads - 1 {
355                                                bucket_queues[val].1.recv().unwrap()
356                                            } else {
357                                                // Non final thread, exit and let the final thread finish the computation
358                                                for i in 0..RADIX_SIZE {
359                                                    if start_buckets[i] < end_buckets[i] {
360                                                        bucket_queues[i]
361                                                            .0
362                                                            .send(start_buckets[i]..end_buckets[i])
363                                                            .unwrap();
364                                                    }
365                                                }
366                                                return;
367                                            }
368                                        }
369                                    };
370                                    start_buckets[val] = next_bucket.start;
371                                    end_buckets[val] = next_bucket.end;
372                                    buckets_stack.push(val);
373                                }
374
375                                data.swap(start_buckets[bucket], start_buckets[val]);
376                                start_buckets[val] += 1;
377                            }
378                        }
379                    }
380                });
381            } else {
382                for el in data.iter() {
383                    counts[(F::get_shifted(el, shift)) as usize + 1] += 1;
384                }
385
386                for i in 1..(RADIX_SIZE + 1) {
387                    counts[i] += counts[i - 1];
388                }
389                sums = counts;
390
391                for bucket in 0..RADIX_SIZE {
392                    let end = counts[bucket + 1];
393                    while sums[bucket] < end {
394                        let val = (F::get_shifted(&data[sums[bucket]], shift)) as usize;
395                        data.swap(sums[bucket], sums[val]);
396                        sums[val] += 1;
397                    }
398                }
399            }
400        }
401
402        if first {
403            ret_counts = *counts;
404            first = false;
405        }
406
407        struct UCWrapper<T> {
408            uc: UnsafeCell<T>,
409        }
410        unsafe impl<T> Sync for UCWrapper<T> {}
411        let data_ptr = UCWrapper {
412            uc: UnsafeCell::new(data),
413        };
414
415        if !SINGLE_STEP && shift >= RADIX_SIZE_LOG {
416            if PARALLEL && shift as usize == (F::KEY_BITS - RADIX_SIZE_LOG as usize) {
417                let data_ptr = &data_ptr;
418                (0..256usize)
419                    .into_par_iter()
420                    .filter(|x| (counts[(*x as usize) + 1] - counts[*x as usize]) > 1)
421                    .for_each(|i| {
422                        let mut data_ptr = unsafe { std::ptr::read(data_ptr.uc.get()) };
423                        let slice = &mut data_ptr[counts[i] as usize..counts[i + 1] as usize];
424                        smart_radix_sort_::<T, F, false, false>(slice, shift - RADIX_SIZE_LOG);
425                    });
426            } else {
427                (0..RADIX_SIZE).into_iter().for_each(|i| {
428                    let slice_len = counts[i + 1] - counts[i];
429                    let mut data_ptr = unsafe { std::ptr::read(data_ptr.uc.get()) };
430
431                    match slice_len {
432                        2 => {
433                            if F::compare(&data_ptr[counts[i]], &data_ptr[counts[i] + 1])
434                                == Ordering::Greater
435                            {
436                                data_ptr.swap(counts[i], counts[i] + 1);
437                            }
438                        }
439                        0 | 1 => return,
440
441                        _ => {}
442                    }
443
444                    if slice_len < 192 {
445                        let slice = &mut data_ptr[counts[i] as usize..counts[i + 1] as usize];
446                        slice.sort_unstable_by(F::compare);
447                        return;
448                    }
449
450                    stack[stack_index] = (
451                        range.start + counts[i] as usize..range.start + counts[i + 1] as usize,
452                        shift - RADIX_SIZE_LOG,
453                    );
454                    stack_index += 1;
455                });
456            }
457        }
458    }
459    ret_counts
460}
461
462#[cfg(test)]
463mod tests {
464    use crate::fast_smart_bucket_sort::{fast_smart_radix_sort, SortKey};
465    use rand::{rng, RngCore};
466    use std::time::Instant;
467    use voracious_radix_sort::RadixSort;
468
469    #[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
470    struct DataTypeStruct(u128, [u8; 32 - 16]);
471
472    struct U64SortKey;
473    impl SortKey<DataTypeStruct> for U64SortKey {
474        type KeyType = u128;
475        const KEY_BITS: usize = std::mem::size_of::<u128>() * 8;
476
477        #[inline(always)]
478        fn compare(left: &DataTypeStruct, right: &DataTypeStruct) -> std::cmp::Ordering {
479            left.0.cmp(&right.0)
480        }
481
482        #[inline(always)]
483        fn get_shifted(value: &DataTypeStruct, rhs: u8) -> u8 {
484            (value.0 >> rhs) as u8
485        }
486    }
487
488    #[test]
489    #[ignore]
490    fn parallel_sorting() {
491        const ARRAY_SIZE: usize = 5000000000;
492
493        let mut vec = Vec::with_capacity(ARRAY_SIZE);
494
495        let mut rng = rng();
496
497        for _ in 0..ARRAY_SIZE {
498            vec.push((rng.next_u32()) as u32);
499        }
500        let mut vec2 = vec.clone();
501
502        crate::log_info!("Starting...");
503        let start = Instant::now();
504
505        struct U16SortKey;
506        impl SortKey<u32> for U16SortKey {
507            type KeyType = u32;
508            const KEY_BITS: usize = std::mem::size_of::<u32>() * 8;
509
510            #[inline(always)]
511            fn compare(left: &u32, right: &u32) -> std::cmp::Ordering {
512                left.cmp(&right)
513            }
514
515            #[inline(always)]
516            fn get_shifted(value: &u32, rhs: u8) -> u8 {
517                (value >> rhs) as u8
518            }
519        }
520
521        fast_smart_radix_sort::<_, U16SortKey, true>(vec.as_mut_slice());
522
523        let end = start.elapsed();
524        crate::log_info!("Total time: {:.2?}", end);
525
526        crate::log_info!("Starting2...");
527        let start = Instant::now();
528
529        vec2.voracious_mt_sort(16);
530        let end = start.elapsed();
531        crate::log_info!("Total time 2: {:.2?}", end);
532    }
533
534    // #[test]
535    // fn sorting_test() {
536    //     let mut data = vec![DataTypeStruct(0, [0; 32 - 16]); VEC_SIZE];
537    //
538    //     data.par_iter_mut()
539    //         .enumerate()
540    //         .for_each(|(i, x)| *x = DataTypeStruct(rng().gen(), [2; 32 - 16]));
541    //
542    //     crate::log_info!("Started sorting...");
543    //     let start = Instant::now();
544    //     fast_smart_radix_sort::<_, U64SortKey, true>(data.as_mut_slice());
545    //     crate::log_info!("Done sorting => {:.2?}!", start.elapsed());
546    //     assert!(data.is_sorted_by(|a, b| {
547    //         Some(match a.cmp(b) {
548    //             Ordering::Less => Ordering::Less,
549    //             Ordering::Equal => Ordering::Equal,
550    //             Ordering::Greater => {
551    //                 panic!("{:?} > {:?}!", a, b);
552    //             }
553    //         })
554    //     }));
555    // }
556}