rdst/utils/
sort_utils.rs

1use crate::RadixKey;
2#[cfg(feature = "multi-threaded")]
3use rayon::prelude::*;
4#[cfg(feature = "multi-threaded")]
5use std::sync::mpsc::channel;
6
7#[inline]
8pub fn get_prefix_sums(counts: &[usize; 256]) -> [usize; 256] {
9    let mut sums = [0usize; 256];
10
11    let mut running_total = 0;
12    for (i, c) in counts.iter().enumerate() {
13        sums[i] = running_total;
14        running_total += c;
15    }
16
17    sums
18}
19
20#[inline]
21pub fn get_end_offsets(counts: &[usize; 256], prefix_sums: &[usize; 256]) -> [usize; 256] {
22    let mut end_offsets = [0usize; 256];
23
24    end_offsets[0..255].copy_from_slice(&prefix_sums[1..256]);
25    end_offsets[255] = counts[255] + prefix_sums[255];
26
27    end_offsets
28}
29
30#[inline]
31#[cfg(any(test, bench, tuning))]
32pub fn par_get_counts<T>(bucket: &[T], level: usize) -> ([usize; 256], bool)
33where
34    T: RadixKey + Sized + Send + Sync,
35{
36    if bucket.len() == 0 {
37        return ([0usize; 256], true);
38    }
39
40    let (counts, sorted, _, _) = par_get_counts_with_ends(bucket, level);
41    (counts, sorted)
42}
43
44#[inline]
45#[cfg(feature = "multi-threaded")]
46pub fn par_get_counts_with_ends<T>(bucket: &[T], level: usize) -> ([usize; 256], bool, u8, u8)
47where
48    T: RadixKey + Sized + Send + Sync,
49{
50    #[cfg(feature = "work_profiles")]
51    println!("({}) PAR_COUNT", level);
52
53    if bucket.len() < 400_000 {
54        return get_counts_with_ends(bucket, level);
55    }
56
57    let threads = rayon::current_num_threads();
58    let chunk_divisor = 8;
59    let chunk_size = (bucket.len() / threads / chunk_divisor) + 1;
60    let chunks = bucket.par_chunks(chunk_size);
61    let len = chunks.len();
62    let (tx, rx) = channel();
63
64    chunks.enumerate().for_each_with(tx, |tx, (i, chunk)| {
65        let counts = get_counts_with_ends(chunk, level);
66        tx.send((i, counts.0, counts.1, counts.2, counts.3))
67            .unwrap();
68    });
69
70    let mut msb_counts = [0usize; 256];
71    let mut already_sorted = true;
72    let mut boundaries = vec![(0u8, 0u8); len];
73
74    for _ in 0..len {
75        let (i, counts, chunk_sorted, start, end) = rx.recv().unwrap();
76
77        if !chunk_sorted {
78            already_sorted = false;
79        }
80
81        boundaries[i].0 = start;
82        boundaries[i].1 = end;
83
84        for (i, c) in counts.iter().enumerate() {
85            msb_counts[i] += *c;
86        }
87    }
88
89    // Check the boundaries of each counted chunk, to see if the full bucket
90    // is already sorted
91    if already_sorted {
92        for w in boundaries.windows(2) {
93            if w[1].0 < w[0].1 {
94                already_sorted = false;
95                break;
96            }
97        }
98    }
99
100    (
101        msb_counts,
102        already_sorted,
103        boundaries[0].0,
104        boundaries[boundaries.len() - 1].1,
105    )
106}
107
108#[inline]
109pub fn get_counts_with_ends<T>(bucket: &[T], level: usize) -> ([usize; 256], bool, u8, u8)
110where
111    T: RadixKey,
112{
113    #[cfg(feature = "work_profiles")]
114    println!("({}) COUNT", level);
115
116    let mut already_sorted = true;
117    let mut continue_from = bucket.len();
118    let mut counts_1 = [0usize; 256];
119    let mut last = 0usize;
120
121    for (i, item) in bucket.iter().enumerate() {
122        let b = item.get_level(level) as usize;
123        counts_1[b] += 1;
124
125        if b < last {
126            continue_from = i + 1;
127            already_sorted = false;
128            break;
129        }
130
131        last = b;
132    }
133
134    if continue_from == bucket.len() {
135        return (
136            counts_1,
137            already_sorted,
138            bucket[0].get_level(level),
139            last as u8,
140        );
141    }
142
143    let mut counts_2 = [0usize; 256];
144    let mut counts_3 = [0usize; 256];
145    let mut counts_4 = [0usize; 256];
146    let chunks = bucket[continue_from..].chunks_exact(4);
147    let rem = chunks.remainder();
148
149    chunks.into_iter().for_each(|chunk| {
150        let a = chunk[0].get_level(level) as usize;
151        let b = chunk[1].get_level(level) as usize;
152        let c = chunk[2].get_level(level) as usize;
153        let d = chunk[3].get_level(level) as usize;
154
155        counts_1[a] += 1;
156        counts_2[b] += 1;
157        counts_3[c] += 1;
158        counts_4[d] += 1;
159    });
160
161    rem.iter().for_each(|v| {
162        let b = v.get_level(level) as usize;
163        counts_1[b] += 1;
164    });
165
166    for i in 0..256 {
167        counts_1[i] += counts_2[i];
168        counts_1[i] += counts_3[i];
169        counts_1[i] += counts_4[i];
170    }
171
172    let b_first = bucket.first().unwrap().get_level(level);
173    let b_last = bucket.last().unwrap().get_level(level);
174
175    (counts_1, already_sorted, b_first, b_last)
176}
177
178#[inline]
179pub fn get_counts<T>(bucket: &[T], level: usize) -> ([usize; 256], bool)
180where
181    T: RadixKey,
182{
183    if bucket.is_empty() {
184        return ([0usize; 256], true);
185    }
186
187    let (counts, sorted, _, _) = get_counts_with_ends(bucket, level);
188
189    (counts, sorted)
190}
191
192#[allow(clippy::uninit_vec)]
193#[inline]
194pub fn get_tmp_bucket<T>(len: usize) -> Vec<T> {
195    let mut tmp_bucket = Vec::with_capacity(len);
196    unsafe {
197        // Safety: This will leave the vec with potentially uninitialized data
198        // however as we account for every value when placing things
199        // into tmp_bucket, this is "safe". This is used because it provides a
200        // very significant speed improvement over resize, to_vec etc.
201        tmp_bucket.set_len(len);
202    }
203
204    tmp_bucket
205}
206
207#[inline]
208pub const fn cdiv(a: usize, b: usize) -> usize {
209    (a + b - 1) / b
210}
211
212#[inline]
213pub fn get_tile_counts<T>(bucket: &[T], tile_size: usize, level: usize) -> (Vec<[usize; 256]>, bool)
214where
215    T: RadixKey + Copy + Sized + Send + Sync,
216{
217    #[cfg(feature = "work_profiles")]
218    println!("({}) TILE_COUNT", level);
219
220    #[cfg(feature = "multi-threaded")]
221    let tiles: Vec<([usize; 256], bool, u8, u8)> = bucket
222        .par_chunks(tile_size)
223        .map(|chunk| par_get_counts_with_ends(chunk, level))
224        .collect();
225
226    #[cfg(not(feature = "multi-threaded"))]
227    let tiles: Vec<([usize; 256], bool, u8, u8)> = bucket
228        .chunks(tile_size)
229        .map(|chunk| get_counts_with_ends(chunk, level))
230        .collect();
231
232    let mut all_sorted = true;
233
234    if tiles.len() == 1 {
235        // If there is only one tile, we already have a flag for if it is sorted
236        all_sorted = tiles[0].1;
237    } else {
238        // Check if any of the tiles, or any of the tile boundaries are unsorted
239        for tile in tiles.windows(2) {
240            if !tile[0].1 || !tile[1].1 || tile[1].2 < tile[0].3 {
241                all_sorted = false;
242                break;
243            }
244        }
245    }
246
247    (tiles.into_iter().map(|v| v.0).collect(), all_sorted)
248}
249
250#[inline]
251pub fn aggregate_tile_counts(tile_counts: &[[usize; 256]]) -> [usize; 256] {
252    let mut out = tile_counts[0];
253    for tile in tile_counts.iter().skip(1) {
254        for i in 0..256 {
255            out[i] += tile[i];
256        }
257    }
258
259    out
260}
261
262#[inline]
263pub fn is_homogenous_bucket(counts: &[usize; 256]) -> bool {
264    let mut seen = false;
265    for c in counts {
266        if *c > 0 {
267            if seen {
268                return false;
269            } else {
270                seen = true;
271            }
272        }
273    }
274
275    true
276}
277
278#[cfg(test)]
279mod tests {
280    use crate::utils::get_tile_counts;
281
282    #[test]
283    pub fn test_get_tile_counts_correctly_marks_already_sorted_single_tile() {
284        let mut data: Vec<u8> = vec![0, 5, 2, 3, 1];
285
286        let (_counts, already_sorted) = get_tile_counts(&mut data, 5, 0);
287        assert_eq!(already_sorted, false);
288
289        let mut data: Vec<u8> = vec![0, 0, 1, 1, 2];
290
291        let (_counts, already_sorted) = get_tile_counts(&mut data, 5, 0);
292        assert_eq!(already_sorted, true);
293    }
294
295    #[test]
296    pub fn test_get_tile_counts_correctly_marks_already_sorted_multiple_tiles() {
297        let mut data: Vec<u8> = vec![0, 5, 2, 3, 1];
298
299        let (_counts, already_sorted) = get_tile_counts(&mut data, 2, 0);
300        assert_eq!(already_sorted, false);
301
302        let mut data: Vec<u8> = vec![0, 0, 1, 1, 2];
303
304        let (_counts, already_sorted) = get_tile_counts(&mut data, 2, 0);
305        assert_eq!(already_sorted, true);
306    }
307}