parallel_processor/
fast_smart_bucket_sort.rs

1use crate::buckets::bucket_writer::BucketItemSerializer;
2use rand::{thread_rng, RngCore};
3use rayon::prelude::*;
4use std::cell::UnsafeCell;
5use std::cmp::min;
6use std::cmp::Ordering;
7use std::fmt::Debug;
8use std::io::{Read, Write};
9use std::slice::from_raw_parts_mut;
10use std::sync::atomic::AtomicUsize;
11use unchecked_index::{unchecked_index, UncheckedIndex};
12
13type IndexType = usize;
14
15// #[repr(packed)]
16#[derive(Eq, PartialOrd, PartialEq, Ord, Copy, Clone, Debug)]
17pub struct SortedData<const LEN: usize> {
18    pub data: [u8; LEN],
19}
20
21impl<const LEN: usize> SortedData<LEN> {
22    #[inline(always)]
23    pub fn new(data: [u8; LEN]) -> Self {
24        Self { data }
25    }
26}
27
28pub struct SortedDataSerializer<const LEN: usize>;
29impl<const LEN: usize> BucketItemSerializer for SortedDataSerializer<LEN> {
30    type InputElementType<'a> = SortedData<LEN>;
31    type ExtraData = ();
32    type ExtraDataBuffer = ();
33    type ReadBuffer = SortedData<LEN>;
34    type ReadType<'a> = &'a SortedData<LEN>;
35    type InitData = ();
36
37    type CheckpointData = ();
38
39    #[inline(always)]
40    fn new(_: ()) -> Self {
41        Self
42    }
43
44    #[inline(always)]
45    fn reset(&mut self) {}
46
47    #[inline(always)]
48    fn write_to(
49        &mut self,
50        element: &Self::InputElementType<'_>,
51        bucket: &mut Vec<u8>,
52        _: &Self::ExtraData,
53        _: &Self::ExtraDataBuffer,
54    ) {
55        bucket.write(element.data.as_slice()).unwrap();
56    }
57
58    #[inline(always)]
59    fn read_from<'a, S: Read>(
60        &mut self,
61        mut stream: S,
62        read_buffer: &'a mut Self::ReadBuffer,
63        _: &mut Self::ExtraDataBuffer,
64    ) -> Option<Self::ReadType<'a>> {
65        stream.read(read_buffer.data.as_mut_slice()).ok()?;
66        Some(read_buffer)
67    }
68
69    #[inline(always)]
70    fn get_size(&self, _: &Self::InputElementType<'_>, _: &()) -> usize {
71        LEN
72    }
73}
74
75pub trait FastSortable: Ord {
76    fn get_shifted(&self, rhs: u8) -> u8;
77}
78
79macro_rules! fast_sortable_impl {
80    ($int_type:ty) => {
81        impl FastSortable for $int_type {
82            #[inline(always)]
83            fn get_shifted(&self, rhs: u8) -> u8 {
84                (*self >> rhs) as u8
85            }
86        }
87    };
88}
89
90fast_sortable_impl!(u8);
91fast_sortable_impl!(u16);
92fast_sortable_impl!(u32);
93fast_sortable_impl!(u64);
94fast_sortable_impl!(u128);
95
96pub trait SortKey<T> {
97    type KeyType: Ord;
98    const KEY_BITS: usize;
99    fn compare(left: &T, right: &T) -> Ordering;
100    fn get_shifted(value: &T, rhs: u8) -> u8;
101}
102
103#[macro_export]
104macro_rules! make_comparer {
105    ($Name:ident, $type_name:ty, $key:ident: $key_type:ty) => {
106        struct $Name;
107        impl SortKey<$type_name> for $Name {
108            type KeyType = $key_type;
109            const KEY_BITS: usize = std::mem::size_of::<$key_type>() * 8;
110
111            fn compare(left: &$type_name, right: &$type_name) -> std::cmp::Ordering {
112                left.$key.cmp(&right.$key)
113            }
114
115            fn get_shifted(value: &$type_name, rhs: u8) -> u8 {
116                (value.$key >> rhs) as u8
117            }
118        }
119    };
120}
121
122const RADIX_SIZE_LOG: u8 = 8;
123const RADIX_SIZE: usize = 1 << 8;
124
125// pub fn striped_parallel_smart_radix_sort_memfile<
126//     T: Ord + Send + Sync + Debug + 'static,
127//     F: SortKey<T>,
128// >(
129//     mem_file: FileReader,
130//     dest_buffer: &mut Vec<T>,
131// ) -> usize {
132//     let chunks: Vec<_> = unsafe { mem_file.get_typed_chunks_mut::<T>().collect() };
133//     let tot_entries = chunks.iter().map(|x| x.len()).sum();
134//
135//     dest_buffer.clear();
136//     dest_buffer.reserve(tot_entries);
137//     unsafe { dest_buffer.set_len(tot_entries) };
138//
139//     striped_parallel_smart_radix_sort::<T, F>(chunks.as_slice(), dest_buffer.as_mut_slice());
140//
141//     assert_eq!(dest_buffer.len(), chunks.iter().map(|x| x.len()).sum());
142//     assert_eq!(dest_buffer.len(), mem_file.len() / size_of::<T>());
143//
144//     drop(mem_file);
145//     tot_entries
146// }
147
148pub fn striped_parallel_smart_radix_sort<T: Ord + Send + Sync + Debug, F: SortKey<T>>(
149    striped_file: &[&mut [T]],
150    dest_buffer: &mut [T],
151) {
152    let num_threads = rayon::current_num_threads();
153    let queue = crossbeam::queue::ArrayQueue::new(num_threads);
154
155    let first_shift = F::KEY_BITS as u8 - RADIX_SIZE_LOG;
156
157    for _ in 0..num_threads {
158        queue.push([0; RADIX_SIZE + 1]).unwrap();
159    }
160
161    striped_file.par_iter().for_each(|chunk| {
162        let mut counts = queue.pop().unwrap();
163        for el in chunk.iter() {
164            counts[(F::get_shifted(el, first_shift)) as usize + 1] += 1usize;
165        }
166        queue.push(counts).unwrap();
167    });
168
169    let mut counters = [0; RADIX_SIZE + 1];
170    while let Some(counts) = queue.pop() {
171        for i in 1..(RADIX_SIZE + 1) {
172            counters[i] += counts[i];
173        }
174    }
175    const ATOMIC_USIZE_ZERO: AtomicUsize = AtomicUsize::new(0);
176    let offsets = [ATOMIC_USIZE_ZERO; RADIX_SIZE + 1];
177    let mut offsets_reference = [0; RADIX_SIZE + 1];
178
179    use std::sync::atomic::Ordering;
180    for i in 1..(RADIX_SIZE + 1) {
181        offsets_reference[i] = offsets[i - 1].load(Ordering::Relaxed) + counters[i];
182        offsets[i].store(offsets_reference[i], Ordering::Relaxed);
183    }
184
185    let dest_buffer_addr = dest_buffer.as_mut_ptr() as usize;
186    striped_file.par_iter().for_each(|chunk| {
187        let dest_buffer_ptr = dest_buffer_addr as *mut T;
188
189        let chunk_addr = chunk.as_ptr() as usize;
190        let chunk_data_mut = unsafe { from_raw_parts_mut(chunk_addr as *mut T, chunk.len()) };
191
192        let choffs = smart_radix_sort_::<T, F, false, true>(
193            chunk_data_mut,
194            F::KEY_BITS as u8 - RADIX_SIZE_LOG,
195        );
196        let mut offset = 0;
197        for idx in 1..(RADIX_SIZE + 1) {
198            let count = choffs[idx] - choffs[idx - 1];
199            let dest_position = offsets[idx - 1].fetch_add(count, Ordering::Relaxed);
200
201            unsafe {
202                std::ptr::copy_nonoverlapping(
203                    chunk.as_ptr().add(offset),
204                    dest_buffer_ptr.add(dest_position),
205                    count,
206                );
207            }
208
209            offset += count;
210        }
211    });
212
213    if F::KEY_BITS >= 16 {
214        let offsets_reference = offsets_reference;
215        (0..256usize).into_par_iter().for_each(|idx| {
216            let dest_buffer_ptr = dest_buffer_addr as *mut T;
217
218            let bucket_start = offsets_reference[idx];
219            let bucket_len = offsets_reference[idx + 1] - bucket_start;
220
221            let crt_slice =
222                unsafe { from_raw_parts_mut(dest_buffer_ptr.add(bucket_start), bucket_len) };
223            smart_radix_sort_::<T, F, false, false>(crt_slice, F::KEY_BITS as u8 - 16);
224        });
225    }
226}
227
228pub fn fast_smart_radix_sort<T: Sync + Send, F: SortKey<T>, const PARALLEL: bool>(data: &mut [T]) {
229    smart_radix_sort_::<T, F, PARALLEL, false>(data, F::KEY_BITS as u8 - RADIX_SIZE_LOG);
230}
231
232pub fn fast_smart_radix_sort_by_value<T: Sync + Send, F: SortKey<T>, const PARALLEL: bool>(
233    data: &mut [T],
234) {
235    smart_radix_sort_::<T, F, PARALLEL, false>(data, F::KEY_BITS as u8 - RADIX_SIZE_LOG);
236}
237
238fn smart_radix_sort_<
239    T: Sync + Send,
240    F: SortKey<T>,
241    const PARALLEL: bool,
242    const SINGLE_STEP: bool,
243>(
244    data: &mut [T],
245    shift: u8,
246) -> [IndexType; RADIX_SIZE + 1] {
247    let mut stack = unsafe { unchecked_index(vec![(0..0, 0); shift as usize * RADIX_SIZE]) };
248
249    let mut stack_index = 1;
250    stack[0] = (0..data.len(), shift);
251
252    let mut ret_counts = [0; RADIX_SIZE + 1];
253
254    let mut first = true;
255
256    while stack_index > 0 {
257        stack_index -= 1;
258        let (range, shift) = stack[stack_index].clone();
259
260        let mut data = unsafe { unchecked_index(&mut data[range.clone()]) };
261
262        let mut counts: UncheckedIndex<[IndexType; RADIX_SIZE + 1]> =
263            unsafe { unchecked_index([0; RADIX_SIZE + 1]) };
264        let mut sums: UncheckedIndex<[IndexType; RADIX_SIZE + 1]>;
265
266        {
267            if PARALLEL {
268                const ATOMIC_ZERO: AtomicUsize = AtomicUsize::new(0);
269                let par_counts: UncheckedIndex<[AtomicUsize; RADIX_SIZE + 1]> =
270                    unsafe { unchecked_index([ATOMIC_ZERO; RADIX_SIZE + 1]) };
271                let num_threads = rayon::current_num_threads();
272                let chunk_size = (data.len() + num_threads - 1) / num_threads;
273                data.chunks(chunk_size).par_bridge().for_each(|chunk| {
274                    let mut thread_counts = unsafe { unchecked_index([0; RADIX_SIZE + 1]) };
275
276                    for el in chunk {
277                        thread_counts[(F::get_shifted(el, shift)) as usize + 1] += 1;
278                    }
279
280                    for (p, t) in par_counts.iter().zip(thread_counts.iter()) {
281                        p.fetch_add(*t, std::sync::atomic::Ordering::Relaxed);
282                    }
283                });
284
285                for i in 1..(RADIX_SIZE + 1) {
286                    counts[i] =
287                        counts[i - 1] + par_counts[i].load(std::sync::atomic::Ordering::Relaxed);
288                }
289                sums = counts;
290
291                let mut bucket_queues = Vec::with_capacity(RADIX_SIZE);
292                for i in 0..RADIX_SIZE {
293                    bucket_queues.push(crossbeam::channel::unbounded());
294
295                    let range = sums[i]..counts[i + 1];
296                    let range_steps = num_threads * 2;
297                    let tot_range_len = range.len();
298                    let subrange_len = (tot_range_len + range_steps - 1) / range_steps;
299
300                    let mut start = range.start;
301                    while start < range.end {
302                        let end = min(start + subrange_len, range.end);
303                        if start < end {
304                            bucket_queues[i].0.send(start..end).unwrap();
305                        }
306                        start += subrange_len;
307                    }
308                }
309
310                let data_ptr = data.as_mut_ptr() as usize;
311                (0..num_threads).into_par_iter().for_each(|thread_index| {
312                    let mut start_buckets = unsafe { unchecked_index([0; RADIX_SIZE]) };
313                    let mut end_buckets = unsafe { unchecked_index([0; RADIX_SIZE]) };
314
315                    let data = unsafe { from_raw_parts_mut(data_ptr as *mut T, data.len()) };
316
317                    let get_bpart = || {
318                        let start = thread_rng().next_u32() as usize % RADIX_SIZE;
319                        let mut res = None;
320                        for i in 0..RADIX_SIZE {
321                            let bucket_num = (i + start) % RADIX_SIZE;
322                            if let Ok(val) = bucket_queues[bucket_num].1.try_recv() {
323                                res = Some((bucket_num, val));
324                                break;
325                            }
326                        }
327                        res
328                    };
329
330                    let mut buckets_stack: Vec<_> = vec![];
331
332                    while let Some((bidx, bpart)) = get_bpart() {
333                        start_buckets[bidx] = bpart.start;
334                        end_buckets[bidx] = bpart.end;
335                        buckets_stack.push(bidx);
336
337                        while let Some(bucket) = buckets_stack.pop() {
338                            while start_buckets[bucket] < end_buckets[bucket] {
339                                let val =
340                                    (F::get_shifted(&data[start_buckets[bucket]], shift)) as usize;
341
342                                while start_buckets[val] == end_buckets[val] {
343                                    let next_bucket = match bucket_queues[val].1.try_recv() {
344                                        Ok(val) => val,
345                                        Err(_) => {
346                                            // Final thread
347                                            if thread_index == num_threads - 1 {
348                                                bucket_queues[val].1.recv().unwrap()
349                                            } else {
350                                                // Non final thread, exit and let the final thread finish the computation
351                                                for i in 0..RADIX_SIZE {
352                                                    if start_buckets[i] < end_buckets[i] {
353                                                        bucket_queues[i]
354                                                            .0
355                                                            .send(start_buckets[i]..end_buckets[i])
356                                                            .unwrap();
357                                                    }
358                                                }
359                                                return;
360                                            }
361                                        }
362                                    };
363                                    start_buckets[val] = next_bucket.start;
364                                    end_buckets[val] = next_bucket.end;
365                                    buckets_stack.push(val);
366                                }
367
368                                data.swap(start_buckets[bucket], start_buckets[val]);
369                                start_buckets[val] += 1;
370                            }
371                        }
372                    }
373                });
374            } else {
375                for el in data.iter() {
376                    counts[(F::get_shifted(el, shift)) as usize + 1] += 1;
377                }
378
379                for i in 1..(RADIX_SIZE + 1) {
380                    counts[i] += counts[i - 1];
381                }
382                sums = counts;
383
384                for bucket in 0..RADIX_SIZE {
385                    let end = counts[bucket + 1];
386                    while sums[bucket] < end {
387                        let val = (F::get_shifted(&data[sums[bucket]], shift)) as usize;
388                        data.swap(sums[bucket], sums[val]);
389                        sums[val] += 1;
390                    }
391                }
392            }
393        }
394
395        if first {
396            ret_counts = *counts;
397            first = false;
398        }
399
400        struct UCWrapper<T> {
401            uc: UnsafeCell<T>,
402        }
403        unsafe impl<T> Sync for UCWrapper<T> {}
404        let data_ptr = UCWrapper {
405            uc: UnsafeCell::new(data),
406        };
407
408        if !SINGLE_STEP && shift >= RADIX_SIZE_LOG {
409            if PARALLEL && shift as usize == (F::KEY_BITS - RADIX_SIZE_LOG as usize) {
410                let data_ptr = &data_ptr;
411                (0..256usize)
412                    .into_par_iter()
413                    .filter(|x| (counts[(*x as usize) + 1] - counts[*x as usize]) > 1)
414                    .for_each(|i| {
415                        let mut data_ptr = unsafe { std::ptr::read(data_ptr.uc.get()) };
416                        let slice = &mut data_ptr[counts[i] as usize..counts[i + 1] as usize];
417                        smart_radix_sort_::<T, F, false, false>(slice, shift - RADIX_SIZE_LOG);
418                    });
419            } else {
420                (0..RADIX_SIZE).into_iter().for_each(|i| {
421                    let slice_len = counts[i + 1] - counts[i];
422                    let mut data_ptr = unsafe { std::ptr::read(data_ptr.uc.get()) };
423
424                    match slice_len {
425                        2 => {
426                            if F::compare(&data_ptr[counts[i]], &data_ptr[counts[i] + 1])
427                                == Ordering::Greater
428                            {
429                                data_ptr.swap(counts[i], counts[i] + 1);
430                            }
431                        }
432                        0 | 1 => return,
433
434                        _ => {}
435                    }
436
437                    if slice_len < 192 {
438                        let slice = &mut data_ptr[counts[i] as usize..counts[i + 1] as usize];
439                        slice.sort_unstable_by(F::compare);
440                        return;
441                    }
442
443                    stack[stack_index] = (
444                        range.start + counts[i] as usize..range.start + counts[i + 1] as usize,
445                        shift - RADIX_SIZE_LOG,
446                    );
447                    stack_index += 1;
448                });
449            }
450        }
451    }
452    ret_counts
453}
454
455#[cfg(test)]
456mod tests {
457    use crate::fast_smart_bucket_sort::{fast_smart_radix_sort, SortKey};
458    use rand::{thread_rng, RngCore};
459    use std::time::Instant;
460    use voracious_radix_sort::RadixSort;
461
462    #[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
463    struct DataTypeStruct(u128, [u8; (32 - 16)]);
464
465    struct U64SortKey;
466    impl SortKey<DataTypeStruct> for U64SortKey {
467        type KeyType = u128;
468        const KEY_BITS: usize = std::mem::size_of::<u128>() * 8;
469
470        #[inline(always)]
471        fn compare(left: &DataTypeStruct, right: &DataTypeStruct) -> std::cmp::Ordering {
472            left.0.cmp(&right.0)
473        }
474
475        #[inline(always)]
476        fn get_shifted(value: &DataTypeStruct, rhs: u8) -> u8 {
477            (value.0 >> rhs) as u8
478        }
479    }
480
481    #[test]
482    #[ignore]
483    fn parallel_sorting() {
484        const ARRAY_SIZE: usize = 5000000000;
485
486        let mut vec = Vec::with_capacity(ARRAY_SIZE);
487
488        let mut rng = thread_rng();
489
490        for _ in 0..ARRAY_SIZE {
491            vec.push((rng.next_u32()) as u32);
492        }
493        let mut vec2 = vec.clone();
494
495        crate::log_info!("Starting...");
496        let start = Instant::now();
497
498        struct U16SortKey;
499        impl SortKey<u32> for U16SortKey {
500            type KeyType = u32;
501            const KEY_BITS: usize = std::mem::size_of::<u32>() * 8;
502
503            #[inline(always)]
504            fn compare(left: &u32, right: &u32) -> std::cmp::Ordering {
505                left.cmp(&right)
506            }
507
508            #[inline(always)]
509            fn get_shifted(value: &u32, rhs: u8) -> u8 {
510                (value >> rhs) as u8
511            }
512        }
513
514        fast_smart_radix_sort::<_, U16SortKey, true>(vec.as_mut_slice());
515
516        let end = start.elapsed();
517        crate::log_info!("Total time: {:.2?}", end);
518
519        crate::log_info!("Starting2...");
520        let start = Instant::now();
521
522        vec2.voracious_mt_sort(16);
523        let end = start.elapsed();
524        crate::log_info!("Total time 2: {:.2?}", end);
525    }
526
527    // #[test]
528    // fn sorting_test() {
529    //     let mut data = vec![DataTypeStruct(0, [0; 32 - 16]); VEC_SIZE];
530    //
531    //     data.par_iter_mut()
532    //         .enumerate()
533    //         .for_each(|(i, x)| *x = DataTypeStruct(thread_rng().gen(), [2; 32 - 16]));
534    //
535    //     crate::log_info!("Started sorting...");
536    //     let start = Instant::now();
537    //     fast_smart_radix_sort::<_, U64SortKey, true>(data.as_mut_slice());
538    //     crate::log_info!("Done sorting => {:.2?}!", start.elapsed());
539    //     assert!(data.is_sorted_by(|a, b| {
540    //         Some(match a.cmp(b) {
541    //             Ordering::Less => Ordering::Less,
542    //             Ordering::Equal => Ordering::Equal,
543    //             Ordering::Greater => {
544    //                 panic!("{:?} > {:?}!", a, b);
545    //             }
546    //         })
547    //     }));
548    // }
549}