1use crate::RadixKey;
2#[cfg(feature = "multi-threaded")]
3use rayon::prelude::*;
4#[cfg(feature = "multi-threaded")]
5use std::sync::mpsc::channel;
6
7#[inline]
8pub fn get_prefix_sums(counts: &[usize; 256]) -> [usize; 256] {
9 let mut sums = [0usize; 256];
10
11 let mut running_total = 0;
12 for (i, c) in counts.iter().enumerate() {
13 sums[i] = running_total;
14 running_total += c;
15 }
16
17 sums
18}
19
20#[inline]
21pub fn get_end_offsets(counts: &[usize; 256], prefix_sums: &[usize; 256]) -> [usize; 256] {
22 let mut end_offsets = [0usize; 256];
23
24 end_offsets[0..255].copy_from_slice(&prefix_sums[1..256]);
25 end_offsets[255] = counts[255] + prefix_sums[255];
26
27 end_offsets
28}
29
30#[inline]
31#[cfg(any(test, bench, tuning))]
32pub fn par_get_counts<T>(bucket: &[T], level: usize) -> ([usize; 256], bool)
33where
34 T: RadixKey + Sized + Send + Sync,
35{
36 if bucket.len() == 0 {
37 return ([0usize; 256], true);
38 }
39
40 let (counts, sorted, _, _) = par_get_counts_with_ends(bucket, level);
41 (counts, sorted)
42}
43
44#[inline]
45#[cfg(feature = "multi-threaded")]
46pub fn par_get_counts_with_ends<T>(bucket: &[T], level: usize) -> ([usize; 256], bool, u8, u8)
47where
48 T: RadixKey + Sized + Send + Sync,
49{
50 #[cfg(feature = "work_profiles")]
51 println!("({}) PAR_COUNT", level);
52
53 if bucket.len() < 400_000 {
54 return get_counts_with_ends(bucket, level);
55 }
56
57 let threads = rayon::current_num_threads();
58 let chunk_divisor = 8;
59 let chunk_size = (bucket.len() / threads / chunk_divisor) + 1;
60 let chunks = bucket.par_chunks(chunk_size);
61 let len = chunks.len();
62 let (tx, rx) = channel();
63
64 chunks.enumerate().for_each_with(tx, |tx, (i, chunk)| {
65 let counts = get_counts_with_ends(chunk, level);
66 tx.send((i, counts.0, counts.1, counts.2, counts.3))
67 .unwrap();
68 });
69
70 let mut msb_counts = [0usize; 256];
71 let mut already_sorted = true;
72 let mut boundaries = vec![(0u8, 0u8); len];
73
74 for _ in 0..len {
75 let (i, counts, chunk_sorted, start, end) = rx.recv().unwrap();
76
77 if !chunk_sorted {
78 already_sorted = false;
79 }
80
81 boundaries[i].0 = start;
82 boundaries[i].1 = end;
83
84 for (i, c) in counts.iter().enumerate() {
85 msb_counts[i] += *c;
86 }
87 }
88
89 if already_sorted {
92 for w in boundaries.windows(2) {
93 if w[1].0 < w[0].1 {
94 already_sorted = false;
95 break;
96 }
97 }
98 }
99
100 (
101 msb_counts,
102 already_sorted,
103 boundaries[0].0,
104 boundaries[boundaries.len() - 1].1,
105 )
106}
107
108#[inline]
109pub fn get_counts_with_ends<T>(bucket: &[T], level: usize) -> ([usize; 256], bool, u8, u8)
110where
111 T: RadixKey,
112{
113 #[cfg(feature = "work_profiles")]
114 println!("({}) COUNT", level);
115
116 let mut already_sorted = true;
117 let mut continue_from = bucket.len();
118 let mut counts_1 = [0usize; 256];
119 let mut last = 0usize;
120
121 for (i, item) in bucket.iter().enumerate() {
122 let b = item.get_level(level) as usize;
123 counts_1[b] += 1;
124
125 if b < last {
126 continue_from = i + 1;
127 already_sorted = false;
128 break;
129 }
130
131 last = b;
132 }
133
134 if continue_from == bucket.len() {
135 return (
136 counts_1,
137 already_sorted,
138 bucket[0].get_level(level),
139 last as u8,
140 );
141 }
142
143 let mut counts_2 = [0usize; 256];
144 let mut counts_3 = [0usize; 256];
145 let mut counts_4 = [0usize; 256];
146 let chunks = bucket[continue_from..].chunks_exact(4);
147 let rem = chunks.remainder();
148
149 chunks.into_iter().for_each(|chunk| {
150 let a = chunk[0].get_level(level) as usize;
151 let b = chunk[1].get_level(level) as usize;
152 let c = chunk[2].get_level(level) as usize;
153 let d = chunk[3].get_level(level) as usize;
154
155 counts_1[a] += 1;
156 counts_2[b] += 1;
157 counts_3[c] += 1;
158 counts_4[d] += 1;
159 });
160
161 rem.iter().for_each(|v| {
162 let b = v.get_level(level) as usize;
163 counts_1[b] += 1;
164 });
165
166 for i in 0..256 {
167 counts_1[i] += counts_2[i];
168 counts_1[i] += counts_3[i];
169 counts_1[i] += counts_4[i];
170 }
171
172 let b_first = bucket.first().unwrap().get_level(level);
173 let b_last = bucket.last().unwrap().get_level(level);
174
175 (counts_1, already_sorted, b_first, b_last)
176}
177
178#[inline]
179pub fn get_counts<T>(bucket: &[T], level: usize) -> ([usize; 256], bool)
180where
181 T: RadixKey,
182{
183 if bucket.is_empty() {
184 return ([0usize; 256], true);
185 }
186
187 let (counts, sorted, _, _) = get_counts_with_ends(bucket, level);
188
189 (counts, sorted)
190}
191
192#[allow(clippy::uninit_vec)]
193#[inline]
194pub fn get_tmp_bucket<T>(len: usize) -> Vec<T> {
195 let mut tmp_bucket = Vec::with_capacity(len);
196 unsafe {
197 tmp_bucket.set_len(len);
202 }
203
204 tmp_bucket
205}
206
207#[inline]
208pub const fn cdiv(a: usize, b: usize) -> usize {
209 (a + b - 1) / b
210}
211
212#[inline]
213pub fn get_tile_counts<T>(bucket: &[T], tile_size: usize, level: usize) -> (Vec<[usize; 256]>, bool)
214where
215 T: RadixKey + Copy + Sized + Send + Sync,
216{
217 #[cfg(feature = "work_profiles")]
218 println!("({}) TILE_COUNT", level);
219
220 #[cfg(feature = "multi-threaded")]
221 let tiles: Vec<([usize; 256], bool, u8, u8)> = bucket
222 .par_chunks(tile_size)
223 .map(|chunk| par_get_counts_with_ends(chunk, level))
224 .collect();
225
226 #[cfg(not(feature = "multi-threaded"))]
227 let tiles: Vec<([usize; 256], bool, u8, u8)> = bucket
228 .chunks(tile_size)
229 .map(|chunk| get_counts_with_ends(chunk, level))
230 .collect();
231
232 let mut all_sorted = true;
233
234 if tiles.len() == 1 {
235 all_sorted = tiles[0].1;
237 } else {
238 for tile in tiles.windows(2) {
240 if !tile[0].1 || !tile[1].1 || tile[1].2 < tile[0].3 {
241 all_sorted = false;
242 break;
243 }
244 }
245 }
246
247 (tiles.into_iter().map(|v| v.0).collect(), all_sorted)
248}
249
250#[inline]
251pub fn aggregate_tile_counts(tile_counts: &[[usize; 256]]) -> [usize; 256] {
252 let mut out = tile_counts[0];
253 for tile in tile_counts.iter().skip(1) {
254 for i in 0..256 {
255 out[i] += tile[i];
256 }
257 }
258
259 out
260}
261
262#[inline]
263pub fn is_homogenous_bucket(counts: &[usize; 256]) -> bool {
264 let mut seen = false;
265 for c in counts {
266 if *c > 0 {
267 if seen {
268 return false;
269 } else {
270 seen = true;
271 }
272 }
273 }
274
275 true
276}
277
278#[cfg(test)]
279mod tests {
280 use crate::utils::get_tile_counts;
281
282 #[test]
283 pub fn test_get_tile_counts_correctly_marks_already_sorted_single_tile() {
284 let mut data: Vec<u8> = vec![0, 5, 2, 3, 1];
285
286 let (_counts, already_sorted) = get_tile_counts(&mut data, 5, 0);
287 assert_eq!(already_sorted, false);
288
289 let mut data: Vec<u8> = vec![0, 0, 1, 1, 2];
290
291 let (_counts, already_sorted) = get_tile_counts(&mut data, 5, 0);
292 assert_eq!(already_sorted, true);
293 }
294
295 #[test]
296 pub fn test_get_tile_counts_correctly_marks_already_sorted_multiple_tiles() {
297 let mut data: Vec<u8> = vec![0, 5, 2, 3, 1];
298
299 let (_counts, already_sorted) = get_tile_counts(&mut data, 2, 0);
300 assert_eq!(already_sorted, false);
301
302 let mut data: Vec<u8> = vec![0, 0, 1, 1, 2];
303
304 let (_counts, already_sorted) = get_tile_counts(&mut data, 2, 0);
305 assert_eq!(already_sorted, true);
306 }
307}