1use ndarray::{Array, ArrayBase, Axis, Data, Dimension, NdIndex, RemoveAxis};
2
3#[derive(Clone, Copy)]
6pub enum RankMethod {
7 Minimum,
8 Maximum,
9 Average,
10}
11
12pub trait RankExt<A, S, D>
13where
14 S: Data<Elem = A>,
15{
16 fn rank(&self, method: RankMethod) -> Array<usize, D>;
21
22 fn discretize(&self, method: RankMethod, buckets: usize) -> Array<usize, D>;
27}
28
29pub trait RankAxisExt<A, S, D>
30where
31 S: Data<Elem = A>,
32{
33 fn rank_axis(&self, axis: Axis, method: RankMethod) -> Array<usize, D>;
38
39 fn discretize_axis(&self, axis: Axis, method: RankMethod, buckets: usize) -> Array<usize, D>;
44}
45
46impl<A, S, D> RankExt<A, S, D> for ArrayBase<S, D>
47where
48 A: PartialOrd + Default,
49 S: Data<Elem = A>,
50 D: Dimension,
51 <D as Dimension>::Pattern: NdIndex<D>,
52{
53 fn rank(&self, method: RankMethod) -> Array<usize, D> {
54 let mut index_and_value = Vec::new();
55 for (index, element) in self.indexed_iter() {
56 if element.partial_cmp(&A::default()).is_none() {
57 continue;
58 }
59 index_and_value.push((index, element));
60 }
61 index_and_value.sort_unstable_by(|a, b| a.1.partial_cmp(b.1).unwrap());
62
63 let mut rank: usize = 1;
64 let mut index: usize = 0;
65
66 let mut ranks = Array::zeros(self.dim());
67 while index < index_and_value.len() {
68 let start_index = index;
69 let current_value = index_and_value.get(index).unwrap().1;
70 while index < index_and_value.len()
71 && index_and_value.get(index).unwrap().1 == current_value
72 {
73 index += 1;
74 }
75
76 let assign_rank = match method {
77 RankMethod::Minimum => rank,
78 RankMethod::Maximum => rank + index - start_index - 1,
79 RankMethod::Average => rank + (index - start_index - 1) / 2,
80 };
81 for (key, _) in index_and_value[start_index..index].iter() {
82 ranks[key.clone()] = assign_rank;
83 }
84 rank += index - start_index;
85 }
86
87 return ranks;
88 }
89
90 fn discretize(&self, method: RankMethod, buckets: usize) -> Array<usize, D> {
91 let mut ranks = self.rank(method);
92 if let Some(max_rank) = ranks.iter().reduce(|a, b| if *a > *b { a } else { b }) {
93 let ranks_per_bucket = *max_rank / buckets;
94
95 let (buckets, ranks_per_bucket) = if ranks_per_bucket == 0 {
97 (*max_rank, 1)
98 } else {
99 (buckets, ranks_per_bucket)
100 };
101
102 let remainder = *max_rank % buckets;
103
104 let mut rank_cut_points = Vec::new();
105 let mut low_rank: usize = 1;
106 for _ in 0..remainder {
109 let high_rank = low_rank + ranks_per_bucket;
114 rank_cut_points.push(low_rank);
115 low_rank = high_rank + 1;
116 }
117 for _ in remainder..buckets {
118 let high_rank = low_rank + ranks_per_bucket - 1;
119 rank_cut_points.push(low_rank);
120 low_rank = high_rank + 1;
121 }
122 ranks.map_inplace(|x| {
123 if *x == 0 {
124 return;
125 }
126 let mut bucket = 0;
127 for cut in rank_cut_points.iter() {
128 if *x >= *cut {
129 bucket += 1;
130 } else {
131 break;
132 }
133 }
134 *x = bucket;
135 });
136 }
137 ranks
138 }
139}
140
141impl<A, S, D> RankAxisExt<A, S, D> for ArrayBase<S, D>
142where
143 A: PartialOrd + Default,
144 S: Data<Elem = A>,
145 D: Dimension + RemoveAxis,
146 <D as Dimension>::Pattern: NdIndex<D>,
147 <D as Dimension>::Smaller: Dimension,
148 <<D as Dimension>::Smaller as Dimension>::Pattern: NdIndex<<D as Dimension>::Smaller>,
149{
150 fn rank_axis(&self, axis: Axis, method: RankMethod) -> Array<usize, D> {
151 let mut ranks = Array::zeros(self.dim());
152 for (i, subarray) in self.axis_iter(axis).enumerate() {
153 let ranked = subarray.rank(method);
154 ranked.assign_to(ranks.index_axis_mut(axis, i));
155 }
156 ranks
157 }
158
159 fn discretize_axis(&self, axis: Axis, method: RankMethod, buckets: usize) -> Array<usize, D> {
160 let mut ranks = Array::zeros(self.dim());
161 for (i, subarray) in self.axis_iter(axis).enumerate() {
162 let ranked = subarray.discretize(method, buckets);
163 ranked.assign_to(ranks.index_axis_mut(axis, i));
164 }
165 ranks
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use ndarray::array;
173 use std::f64::NAN;
174
175 #[test]
176 fn rank_vector_no_ties() {
177 let arr = array![4, 3, 2, 1];
178 let ranks = arr.rank(RankMethod::Minimum);
179 assert_eq!(ranks, array![4, 3, 2, 1]);
180 }
181
182 #[test]
183 fn rank_vector_missing_values() {
184 let arr = array![4., 3., NAN, 1.];
185 let ranks = arr.rank(RankMethod::Minimum);
186 assert_eq!(ranks, array![3, 2, 0, 1]);
187 }
188
189 #[test]
190 fn rank_vector_ties_minimum() {
191 let arr = array![4, 2, 2, 1];
192 let ranks = arr.rank(RankMethod::Minimum);
193 assert_eq!(ranks, array![4, 2, 2, 1]);
194 }
195
196 #[test]
197 fn rank_vector_ties_maximum() {
198 let arr = array![4, 2, 2, 1];
199 let ranks = arr.rank(RankMethod::Maximum);
200 assert_eq!(ranks, array![4, 3, 3, 1]);
201 }
202
203 #[test]
204 fn rank_vector_ties_average() {
205 let arr = array![4, 1, 1, 1];
206 let ranks = arr.rank(RankMethod::Average);
207 assert_eq!(ranks, array![4, 2, 2, 2]);
208 }
209
210 #[test]
211 fn rank_matrix_full() {
212 let arr = array![[6, 5, 4], [3, 2, 1]];
213 let ranks = arr.rank(RankMethod::Minimum);
214 assert_eq!(ranks, array![[6, 5, 4], [3, 2, 1]]);
215 }
216
217 #[test]
218 fn rank_matrix_rows() {
219 let arr = array![[6, 5, 4], [3, 2, 1]];
220 let ranks = arr.rank_axis(Axis(0), RankMethod::Minimum);
221 assert_eq!(ranks, array![[3, 2, 1], [3, 2, 1]]);
222 }
223
224 #[test]
225 fn rank_matrix_cols() {
226 let arr = array![[6, 5, 4], [3, 2, 1]];
227 let ranks = arr.rank_axis(Axis(1), RankMethod::Minimum);
228 assert_eq!(ranks, array![[2, 2, 2], [1, 1, 1]]);
229 }
230
231 #[test]
232 fn discretize_matrix_full() {
233 let arr = array![[6, 5, 4], [3, 2, 1]];
234 let ranks = arr.discretize(RankMethod::Minimum, 3);
235 assert_eq!(ranks, array![[3, 3, 2], [2, 1, 1]]);
236 }
237
238 #[test]
239 fn discretize_matrix_rows() {
240 let arr = array![[6, 5, 4], [3, 2, 1]];
241 let ranks = arr.discretize_axis(Axis(0), RankMethod::Minimum, 2);
242 assert_eq!(ranks, array![[2, 1, 1], [2, 1, 1]]);
243 }
244
245 #[test]
246 fn discretize_matrix_cols() {
247 let arr = array![[6, 5, 4], [3, 2, 1]];
248 let ranks = arr.discretize_axis(Axis(1), RankMethod::Minimum, 2);
249 assert_eq!(ranks, array![[2, 2, 2], [1, 1, 1]]);
250 }
251
252 #[test]
253 fn discretize_matrix_with_missing_values() {
254 let arr = array![[6., 5., NAN], [3., NAN, 1.]];
255 let ranks = arr.discretize(RankMethod::Minimum, 2);
256 assert_eq!(ranks, array![[2, 2, 0], [1, 0, 1]]);
257 }
258}