para_graph/algorithms/rust/
radix_sort.rs1use crate::algorithms::prefix_sum::pref_sum_par_cpu;
2use itertools::Itertools;
3use rayon::prelude::*;
4use std::cell::Cell;
5
6pub trait RadixSortable: Clone {
7 fn digit_of(&self, base: usize, radix: usize) -> usize;
8}
9
10impl RadixSortable for usize {
11 fn digit_of(&self, base: usize, radix: usize) -> usize {
12 self / base % radix
13 }
14}
15
16impl<T: Clone> RadixSortable for (usize, T) {
17 fn digit_of(&self, base: usize, radix: usize) -> usize {
18 self.0 / base % radix
19 }
20}
21
22pub fn radix_sort_serial<T: RadixSortable>(arr: &mut [T]) {
23 let radix = arr.len().next_power_of_two();
24 let mut base = 1;
25 loop {
26 let mut counter = vec![0; radix];
27 for x in arr.iter() {
28 counter[x.digit_of(base, radix)] += 1;
29 }
30 for i in 1..radix {
31 counter[i] += counter[i - 1];
32 }
33 if counter[0] == arr.len() {
34 break;
35 }
36 for x in arr.to_owned().iter().rev() {
37 counter[x.digit_of(base, radix)] -= 1;
38 arr[counter[x.digit_of(base, radix)]] = x.clone();
39 }
40 base *= radix;
41 }
42}
43
44pub fn radix_sort_par_cpu<T>(arr: &mut [T])
45where
46 T: RadixSortable + Send + Sync,
47{
48 let chunks = rayon::current_num_threads();
49 let chunk_size = arr.len().div_ceil(chunks);
50 let radix = chunks;
51 let mut base = 1;
52
53 let mut digits = Vec::new();
54 let mut counters = vec![0usize; radix * chunks]
55 .chunks_exact(radix)
56 .map(|x| x.to_vec())
57 .collect_vec();
58 loop {
59 counters.iter_mut().for_each(|x| x.fill(0));
60
61 arr.par_iter()
62 .map(|x| x.digit_of(base, radix))
63 .collect_into_vec(&mut digits);
64
65 digits
66 .par_chunks(chunk_size)
67 .zip(counters.par_iter_mut())
68 .for_each(|(chunk, counter)| {
69 chunk.iter().for_each(|&x| {
70 counter[x] += 1;
71 });
72 });
73
74 let slice = &mut counters[..];
75 let slice_of_cells: &[Cell<_>] = Cell::from_mut(slice).as_slice_of_cells();
76 slice_of_cells.windows(2).for_each(|window| {
77 let prev = window[0].take();
78 let mut curr = window[1].take();
79 curr.par_iter_mut()
80 .zip(prev.par_iter())
81 .for_each(|(c, p)| *c += *p);
82 window[1].set(curr);
83 window[0].set(prev);
84 });
85 pref_sum_par_cpu(counters.last_mut().unwrap());
86 let (front, end) = counters.split_at_mut(chunks - 1);
87 let end = &mut end[0];
88 front.par_iter_mut().for_each(|counter| {
89 counter
90 .par_iter_mut()
91 .skip(1)
92 .zip(end.par_iter())
93 .for_each(|(c, e)| *c += *e);
94 });
95 if counters.last().unwrap()[0] == arr.len() {
96 break;
97 }
98
99 let idxs = digits
100 .par_chunks(chunk_size)
101 .zip(counters.par_iter_mut())
102 .flat_map(|(chunk, counter)| {
103 let aux = chunk
104 .iter()
105 .rev()
106 .map(|&x| {
107 counter[x] -= 1;
108 counter[x]
109 })
110 .collect_vec();
111 aux.into_iter().rev().collect_vec()
112 })
113 .collect::<Vec<_>>();
114 arr.to_owned().iter().enumerate().for_each(|(i, x)| {
115 arr[idxs[i]] = x.clone();
116 });
117 base *= radix;
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 const LARGE_ARR: [usize; 40] = [
126 963, 482, 145, 973, 281, 856, 724, 329, 920, 198, 29, 735, 503, 920, 74, 621, 415, 877,
127 266, 253, 499, 782, 720, 481, 444, 96, 762, 901, 864, 679, 503, 3, 650, 718, 644, 380, 66,
128 368, 192, 370,
129 ];
130
131 #[test]
132 fn ascending_serial() {
133 let mut v = vec![1, 4, 24, 37, 64, 127, 201];
134 let expected = v.iter().cloned().sorted().collect_vec();
135 radix_sort_serial(&mut v);
136 assert_eq!(v, expected);
137 }
138
139 #[test]
140 fn descending_serial() {
141 let mut v = vec![201, 127, 64, 37, 24, 4, 1];
142 let expected = v.iter().cloned().sorted().collect_vec();
143 radix_sort_serial(&mut v);
144 assert_eq!(v, expected);
145 }
146
147 #[test]
148 fn large_random_serial() {
149 let mut v = LARGE_ARR.to_vec();
150 let expected = v.iter().cloned().sorted().collect_vec();
151 radix_sort_serial(&mut v);
152 assert_eq!(v, expected);
153 }
154
155 #[test]
156 fn ascending_cpu() {
157 let mut v = vec![1, 4, 24, 37, 64, 127, 201];
158 let expected = v.iter().cloned().sorted().collect_vec();
159 radix_sort_par_cpu(&mut v);
160 assert_eq!(v, expected);
161 }
162
163 #[test]
164 fn descending_cpu() {
165 let mut v = vec![201, 127, 64, 37, 24, 4, 1];
166 let expected = v.iter().cloned().sorted().collect_vec();
167 radix_sort_par_cpu(&mut v);
168 assert_eq!(v, expected);
169 }
170
171 #[test]
172 fn large_random_cpu() {
173 let mut v = LARGE_ARR.to_vec();
174 let expected = v.iter().cloned().sorted().collect_vec();
175 radix_sort_par_cpu(&mut v);
176 assert_eq!(v, expected);
177 }
178}