1use crate::simd_compare::SIMDCompare;
2use crate::zero_copy::Line;
3use rayon::prelude::*;
4use std::cmp::Ordering;
7
8pub struct RadixSort {
10 parallel: bool,
12}
13
14impl RadixSort {
15 pub fn new(parallel: bool) -> Self {
16 Self { parallel }
17 }
18
19 pub fn sort_numeric_lines(&self, lines: &mut [Line]) {
21 if lines.len() < 1000 {
22 self.insertion_sort(lines);
24 return;
25 }
26
27 const VERY_LARGE_THRESHOLD: usize = 20_000_000; if lines.len() > VERY_LARGE_THRESHOLD {
31 self.sort_very_large_dataset(lines);
32 return;
33 }
34
35 if self.are_all_simple_integers(lines) {
37 if self.parallel && lines.len() > 10000 {
38 self.parallel_radix_sort_integers(lines);
39 } else {
40 self.sequential_radix_sort_integers(lines);
41 }
42 } else {
43 if self.parallel {
45 lines.par_sort_unstable_by(|a, b| a.compare_numeric(b));
46 } else {
47 lines.sort_unstable_by(|a, b| a.compare_numeric(b));
48 }
49 }
50 }
51
52 fn sort_very_large_dataset(&self, lines: &mut [Line]) {
54 if !self.parallel {
55 lines.sort_unstable_by(|a, b| a.compare_numeric(b));
57 return;
58 }
59
60 const CHUNK_SIZE: usize = 2_000_000; let num_chunks = (lines.len() + CHUNK_SIZE - 1) / CHUNK_SIZE;
62
63 lines.par_chunks_mut(CHUNK_SIZE).for_each(|chunk| {
65 if self.are_all_simple_integers(chunk) {
67 self.sequential_radix_sort_integers(chunk);
68 } else {
69 chunk.sort_unstable_by(|a, b| a.compare_numeric(b));
70 }
71 });
72
73 self.parallel_merge_chunks(lines, CHUNK_SIZE, num_chunks);
75 }
76
77 fn parallel_merge_chunks(&self, lines: &mut [Line], chunk_size: usize, num_chunks: usize) {
79 if num_chunks <= 1 {
80 return;
81 }
82
83 let mut current_chunk_size = chunk_size;
85 let mut remaining_chunks = num_chunks;
86
87 while remaining_chunks > 1 {
88 let pairs = remaining_chunks / 2;
90
91 for pair_idx in 0..pairs {
94 let chunk1_start = pair_idx * 2 * current_chunk_size;
95 let chunk2_start = chunk1_start + current_chunk_size;
96 let merge_end = ((pair_idx + 1) * 2 * current_chunk_size).min(lines.len());
97
98 if chunk2_start < lines.len() {
99 self.merge_two_sorted_ranges(
100 &mut lines[chunk1_start..merge_end],
101 current_chunk_size.min(merge_end - chunk1_start),
102 );
103 }
104 }
105
106 current_chunk_size *= 2;
108 remaining_chunks = (remaining_chunks + 1) / 2;
109 }
110 }
111
112 fn merge_two_sorted_ranges(&self, slice: &mut [Line], mid: usize) {
114 if mid >= slice.len() {
115 return;
116 }
117
118 let mut temp = Vec::with_capacity(slice.len());
120 let (left, right) = slice.split_at(mid);
121
122 let mut i = 0;
123 let mut j = 0;
124
125 while i < left.len() && j < right.len() {
127 if left[i].compare_numeric(&right[j]) != Ordering::Greater {
128 temp.push(left[i]);
129 i += 1;
130 } else {
131 temp.push(right[j]);
132 j += 1;
133 }
134 }
135
136 while i < left.len() {
138 temp.push(left[i]);
139 i += 1;
140 }
141 while j < right.len() {
142 temp.push(right[j]);
143 j += 1;
144 }
145
146 slice.copy_from_slice(&temp);
148 }
149
150 fn are_all_simple_integers(&self, lines: &[Line]) -> bool {
152 let sample_size = lines.len().min(100);
154 lines[..sample_size].iter().all(|line| unsafe {
155 let bytes = line.as_bytes();
156 self.is_simple_integer(bytes)
157 })
158 }
159
160 fn is_simple_integer(&self, bytes: &[u8]) -> bool {
162 if bytes.is_empty() {
163 return true;
164 }
165
166 let mut start = 0;
167 if bytes[0] == b'-' || bytes[0] == b'+' {
169 start = 1;
170 }
171
172 if start >= bytes.len() {
173 return false;
174 }
175
176 SIMDCompare::is_all_digits_simd(&bytes[start..])
178 }
179
180 fn parallel_radix_sort_integers(&self, lines: &mut [Line]) {
182 let mut values: Vec<(i64, usize)> = lines
184 .par_iter()
185 .enumerate()
186 .map(|(idx, line)| {
187 let value = unsafe {
188 let bytes = line.as_bytes();
189 self.parse_integer_fast(bytes)
190 };
191 (value, idx)
192 })
193 .collect();
194
195 self.parallel_radix_sort_pairs(&mut values);
197
198 let original_lines: Vec<Line> = lines.to_vec();
200 for (i, &(_, original_idx)) in values.iter().enumerate() {
201 lines[i] = original_lines[original_idx];
202 }
203 }
204
205 fn sequential_radix_sort_integers(&self, lines: &mut [Line]) {
207 let mut values: Vec<(i64, usize)> = lines
209 .iter()
210 .enumerate()
211 .map(|(idx, line)| {
212 let value = unsafe {
213 let bytes = line.as_bytes();
214 self.parse_integer_fast(bytes)
215 };
216 (value, idx)
217 })
218 .collect();
219
220 self.sequential_radix_sort_pairs(&mut values);
222
223 let original_lines: Vec<Line> = lines.to_vec();
225 for (i, &(_, original_idx)) in values.iter().enumerate() {
226 lines[i] = original_lines[original_idx];
227 }
228 }
229
230 fn parse_integer_fast(&self, bytes: &[u8]) -> i64 {
232 if bytes.is_empty() {
233 return 0;
234 }
235
236 let mut result: i64 = 0;
237 let mut start = 0;
238 let negative = if bytes[0] == b'-' {
239 start = 1;
240 true
241 } else if bytes[0] == b'+' {
242 start = 1;
243 false
244 } else {
245 false
246 };
247
248 for &byte in &bytes[start..] {
250 result = result * 10 + (byte - b'0') as i64;
251 }
252
253 if negative {
254 -result
255 } else {
256 result
257 }
258 }
259
260 fn parallel_radix_sort_pairs(&self, values: &mut [(i64, usize)]) {
262 #[allow(dead_code)]
263 const RADIX: usize = 256;
264 #[allow(dead_code)]
265 const MAX_BITS: usize = 64;
266
267 let (mut negatives, mut positives): (Vec<_>, Vec<_>) = values
269 .par_iter()
270 .cloned()
271 .partition(|(value, _)| *value < 0);
272
273 if !positives.is_empty() {
275 self.radix_sort_positive_parallel(&mut positives);
276 }
277
278 if !negatives.is_empty() {
280 negatives
282 .par_iter_mut()
283 .for_each(|(value, _)| *value = -*value);
284 self.radix_sort_positive_parallel(&mut negatives);
285 negatives.reverse();
287 negatives
288 .par_iter_mut()
289 .for_each(|(value, _)| *value = -*value);
290 }
291
292 for (idx, item) in negatives
294 .into_iter()
295 .chain(positives.into_iter())
296 .enumerate()
297 {
298 values[idx] = item;
299 }
300 }
301
302 fn sequential_radix_sort_pairs(&self, values: &mut [(i64, usize)]) {
304 values.sort_unstable_by_key(|(value, _)| *value);
306 }
307
308 fn radix_sort_positive_parallel(&self, values: &mut [(i64, usize)]) {
310 if values.is_empty() {
311 return;
312 }
313
314 const RADIX: usize = 256;
315 let mut temp = vec![(0i64, 0usize); values.len()];
316
317 let max_val = values.par_iter().map(|(v, _)| *v).max().unwrap_or(0);
319 let max_bits = if max_val == 0 {
320 1
321 } else {
322 64 - max_val.leading_zeros() as usize
323 };
324 let passes = (max_bits + 7) / 8; for pass in 0..passes {
327 let shift = pass * 8;
328 let mask = ((1u64 << 8) - 1) as i64;
329
330 let mut counts = vec![0usize; RADIX];
332 for (value, _) in values.iter() {
333 let digit = ((value >> shift) & mask) as usize;
334 counts[digit] += 1;
335 }
336
337 let mut positions = vec![0usize; RADIX];
339 for (i, _) in counts.iter().enumerate().skip(1) {
340 positions[i] = positions[i - 1] + counts[i - 1];
341 }
342
343 for &(value, idx) in values.iter() {
345 let digit = ((value >> shift) & mask) as usize;
346 temp[positions[digit]] = (value, idx);
347 positions[digit] += 1;
348 }
349
350 values.copy_from_slice(&temp);
352 }
353 }
354
355 fn insertion_sort(&self, lines: &mut [Line]) {
357 for i in 1..lines.len() {
358 let mut j = i;
359 while j > 0 && lines[j].compare_numeric(&lines[j - 1]) == Ordering::Less {
360 lines.swap(j, j - 1);
361 j -= 1;
362 }
363 }
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::zero_copy::Line;
371
372 #[test]
373 fn test_radix_sort_simple_integers() {
374 let data1 = b"123";
375 let data2 = b"456";
376 let data3 = b"789";
377 let data4 = b"1";
378
379 let mut lines = vec![
380 Line::new(data2), Line::new(data1), Line::new(data4), Line::new(data3), ];
385
386 let sorter = RadixSort::new(false);
387 sorter.sort_numeric_lines(&mut lines);
388
389 unsafe {
391 assert_eq!(lines[0].as_bytes(), b"1");
392 assert_eq!(lines[1].as_bytes(), b"123");
393 assert_eq!(lines[2].as_bytes(), b"456");
394 assert_eq!(lines[3].as_bytes(), b"789");
395 }
396 }
397
398 #[test]
399 fn test_negative_numbers() {
400 let data1 = b"-123";
401 let data2 = b"456";
402 let data3 = b"-789";
403 let data4 = b"1";
404
405 let mut lines = vec![
406 Line::new(data2), Line::new(data1), Line::new(data4), Line::new(data3), ];
411
412 let sorter = RadixSort::new(false);
413 sorter.sort_numeric_lines(&mut lines);
414
415 unsafe {
417 assert_eq!(lines[0].as_bytes(), b"-789");
418 assert_eq!(lines[1].as_bytes(), b"-123");
419 assert_eq!(lines[2].as_bytes(), b"1");
420 assert_eq!(lines[3].as_bytes(), b"456");
421 }
422 }
423}