para_graph/algorithms/rust/
radix_sort.rs

1use crate::algorithms::prefix_sum::pref_sum_par_cpu;
2use itertools::Itertools;
3use rayon::prelude::*;
4use std::cell::Cell;
5
6pub trait RadixSortable: Clone {
7    fn digit_of(&self, base: usize, radix: usize) -> usize;
8}
9
10impl RadixSortable for usize {
11    fn digit_of(&self, base: usize, radix: usize) -> usize {
12        self / base % radix
13    }
14}
15
16impl<T: Clone> RadixSortable for (usize, T) {
17    fn digit_of(&self, base: usize, radix: usize) -> usize {
18        self.0 / base % radix
19    }
20}
21
22pub fn radix_sort_serial<T: RadixSortable>(arr: &mut [T]) {
23    let radix = arr.len().next_power_of_two();
24    let mut base = 1;
25    loop {
26        let mut counter = vec![0; radix];
27        for x in arr.iter() {
28            counter[x.digit_of(base, radix)] += 1;
29        }
30        for i in 1..radix {
31            counter[i] += counter[i - 1];
32        }
33        if counter[0] == arr.len() {
34            break;
35        }
36        for x in arr.to_owned().iter().rev() {
37            counter[x.digit_of(base, radix)] -= 1;
38            arr[counter[x.digit_of(base, radix)]] = x.clone();
39        }
40        base *= radix;
41    }
42}
43
44pub fn radix_sort_par_cpu<T>(arr: &mut [T])
45where
46    T: RadixSortable + Send + Sync,
47{
48    let chunks = rayon::current_num_threads();
49    let chunk_size = arr.len().div_ceil(chunks);
50    let radix = chunks;
51    let mut base = 1;
52
53    let mut digits = Vec::new();
54    let mut counters = vec![0usize; radix * chunks]
55        .chunks_exact(radix)
56        .map(|x| x.to_vec())
57        .collect_vec();
58    loop {
59        counters.iter_mut().for_each(|x| x.fill(0));
60
61        arr.par_iter()
62            .map(|x| x.digit_of(base, radix))
63            .collect_into_vec(&mut digits);
64
65        digits
66            .par_chunks(chunk_size)
67            .zip(counters.par_iter_mut())
68            .for_each(|(chunk, counter)| {
69                chunk.iter().for_each(|&x| {
70                    counter[x] += 1;
71                });
72            });
73
74        let slice = &mut counters[..];
75        let slice_of_cells: &[Cell<_>] = Cell::from_mut(slice).as_slice_of_cells();
76        slice_of_cells.windows(2).for_each(|window| {
77            let prev = window[0].take();
78            let mut curr = window[1].take();
79            curr.par_iter_mut()
80                .zip(prev.par_iter())
81                .for_each(|(c, p)| *c += *p);
82            window[1].set(curr);
83            window[0].set(prev);
84        });
85        pref_sum_par_cpu(counters.last_mut().unwrap());
86        let (front, end) = counters.split_at_mut(chunks - 1);
87        let end = &mut end[0];
88        front.par_iter_mut().for_each(|counter| {
89            counter
90                .par_iter_mut()
91                .skip(1)
92                .zip(end.par_iter())
93                .for_each(|(c, e)| *c += *e);
94        });
95        if counters.last().unwrap()[0] == arr.len() {
96            break;
97        }
98
99        let idxs = digits
100            .par_chunks(chunk_size)
101            .zip(counters.par_iter_mut())
102            .flat_map(|(chunk, counter)| {
103                let aux = chunk
104                    .iter()
105                    .rev()
106                    .map(|&x| {
107                        counter[x] -= 1;
108                        counter[x]
109                    })
110                    .collect_vec();
111                aux.into_iter().rev().collect_vec()
112            })
113            .collect::<Vec<_>>();
114        arr.to_owned().iter().enumerate().for_each(|(i, x)| {
115            arr[idxs[i]] = x.clone();
116        });
117        base *= radix;
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    const LARGE_ARR: [usize; 40] = [
126        963, 482, 145, 973, 281, 856, 724, 329, 920, 198, 29, 735, 503, 920, 74, 621, 415, 877,
127        266, 253, 499, 782, 720, 481, 444, 96, 762, 901, 864, 679, 503, 3, 650, 718, 644, 380, 66,
128        368, 192, 370,
129    ];
130
131    #[test]
132    fn ascending_serial() {
133        let mut v = vec![1, 4, 24, 37, 64, 127, 201];
134        let expected = v.iter().cloned().sorted().collect_vec();
135        radix_sort_serial(&mut v);
136        assert_eq!(v, expected);
137    }
138
139    #[test]
140    fn descending_serial() {
141        let mut v = vec![201, 127, 64, 37, 24, 4, 1];
142        let expected = v.iter().cloned().sorted().collect_vec();
143        radix_sort_serial(&mut v);
144        assert_eq!(v, expected);
145    }
146
147    #[test]
148    fn large_random_serial() {
149        let mut v = LARGE_ARR.to_vec();
150        let expected = v.iter().cloned().sorted().collect_vec();
151        radix_sort_serial(&mut v);
152        assert_eq!(v, expected);
153    }
154
155    #[test]
156    fn ascending_cpu() {
157        let mut v = vec![1, 4, 24, 37, 64, 127, 201];
158        let expected = v.iter().cloned().sorted().collect_vec();
159        radix_sort_par_cpu(&mut v);
160        assert_eq!(v, expected);
161    }
162
163    #[test]
164    fn descending_cpu() {
165        let mut v = vec![201, 127, 64, 37, 24, 4, 1];
166        let expected = v.iter().cloned().sorted().collect_vec();
167        radix_sort_par_cpu(&mut v);
168        assert_eq!(v, expected);
169    }
170
171    #[test]
172    fn large_random_cpu() {
173        let mut v = LARGE_ARR.to_vec();
174        let expected = v.iter().cloned().sorted().collect_vec();
175        radix_sort_par_cpu(&mut v);
176        assert_eq!(v, expected);
177    }
178}