classify/
jenks.rs

1use num_traits::ToPrimitive;
2use rand::prelude::*;
3use rand::rngs::StdRng;
4
5use std::collections::HashSet;
6
7use crate::utilities::{
8    breaks_to_classification, create_unique_val_mapping, to_vec_f64, unique_to_normal_breaks,
9};
10use crate::utilities::{Classification, UniqueVal};
11
12/// Returns a Classification object following the Jenks Natural Breaks algorithm given the desired number of bins and one-dimensional data
13///
14/// # Arguments
15///
16/// * `num_bins` - An integer (usize) representing the desired number of bins
17/// * `data` - A reference to a collection of unsorted data points to generate a Classification for
18///      
19/// # Edge Cases
20///
21/// * Inputting large u64/i64 data (near their max values) will result in loss of precision because data is being cast to f64
22/// * The maximum number of bins generated by this algorithm is the number of unique values in the dataset
23///
24/// # Examples
25///
26/// ```
27/// use classify::get_jenks_classification;
28/// use classify::{Classification, Bin};
29/// use rand::prelude::*;
30/// use rand::rngs::StdRng;
31///
32/// let data: Vec<usize> = vec![1, 2, 4, 5, 7, 8];
33/// let num_bins = 3;
34///
35/// let result: Classification = get_jenks_classification(num_bins, &data);
36/// let expected: Classification = vec![
37///     Bin{bin_start: 1.0, bin_end: 4.0, count: 2},
38///     Bin{bin_start: 4.0, bin_end: 7.0, count: 2},
39///     Bin{bin_start: 7.0, bin_end: 8.0, count: 2}
40/// ];
41///
42/// assert!(result == expected);
43/// ```
44pub fn get_jenks_classification<T: ToPrimitive>(num_bins: usize, data: &[T]) -> Classification {
45    let breaks: Vec<f64> = get_jenks_breaks(num_bins, data);
46    breaks_to_classification(&breaks, data)
47}
48
49/// Returns a vector of breaks generated through the Jenks Natural Breaks algorithm given the desired number of bins and a dataset
50///
51/// # Arguments
52///
53/// * `num_bins` - The desired number of bins
54/// * `data` - A reference to a collection of unsorted data points to generate breaks for
55///
56/// # Edge Cases
57///
58/// * Inputting large u64/i64 data (near their max values) will result in loss of precision because data is being cast to f64
59/// * The maximum number of bins generated by this algorithm is the number of unique values in the dataset
60///
61/// # Examples
62///
63/// ```
64/// use classify::get_jenks_breaks;
65/// use rand::prelude::*;
66/// use rand::rngs::StdRng;
67///
68/// let data: Vec<i8> = vec![1, 2, 4, 5, 7, 8];
69/// let num_bins = 3;
70///
71/// let result: Vec<f64> = get_jenks_breaks(num_bins, &data);
72///
73/// assert_eq!(result, vec![4.0, 7.0]);
74/// ```
75pub fn get_jenks_breaks<T: ToPrimitive>(num_bins: usize, data: &[T]) -> Vec<f64> {
76    let data = to_vec_f64(data);
77
78    let num_vals = data.len();
79
80    let mut sorted_data: Vec<f64> = vec![];
81    for item in data.iter().take(num_vals) {
82        sorted_data.push(*item);
83    }
84    sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
85
86    let mut unique_val_map: Vec<UniqueVal> = vec![];
87    create_unique_val_mapping(&mut unique_val_map, &sorted_data);
88
89    let num_unique_vals = unique_val_map.len();
90    let true_num_bins = std::cmp::min(num_unique_vals, num_bins);
91
92    let gssd = calc_gssd(&sorted_data);
93
94    let mut rand_breaks: Vec<usize> = vec![0_usize; true_num_bins - 1];
95    let mut best_breaks: Vec<usize> = vec![0_usize; true_num_bins - 1];
96    let mut unique_rand_breaks: Vec<usize> = vec![0_usize; true_num_bins - 1];
97
98    let mut max_gvf: f64 = 0.0;
99
100    let c = 5000 * 2200 * 4;
101    let mut permutations = c / num_vals;
102    if permutations < 10 {
103        permutations = 10
104    }
105    if permutations > 10000 {
106        permutations = 10000
107    }
108    println!("permutations: {}", permutations);
109
110    let mut pseudo_rng = StdRng::seed_from_u64(123456789);
111
112    for _ in 0..permutations {
113        pick_rand_breaks(&mut unique_rand_breaks, &num_unique_vals, &mut pseudo_rng);
114        unique_to_normal_breaks(&unique_rand_breaks, &unique_val_map, &mut rand_breaks);
115        let new_gvf: f64 = calc_gvf(&rand_breaks, &sorted_data, &gssd);
116        if new_gvf > max_gvf {
117            max_gvf = new_gvf;
118            best_breaks[..rand_breaks.len()].copy_from_slice(&rand_breaks[..]);
119        }
120    }
121
122    let mut nat_breaks: Vec<f64> = vec![];
123    nat_breaks.resize(best_breaks.len(), 0.0);
124    for i in 0..best_breaks.len() {
125        nat_breaks[i] = sorted_data[best_breaks[i]];
126    }
127    println!("Breaks: {:#?}", nat_breaks);
128
129    nat_breaks
130}
131
132/// Populates a vector with a set of breaks as unique random integers that are valid indices within the dataset given the number of data points and an RNG
133///
134/// # Arguments
135///
136/// * `breaks` - A mutable reference to an empty vector of breaks whose length is taken to be the desired number of breaks
137/// * `num_vals` - A reference to the number of data points
138/// * `rng` - A mutable reference to a seedable random number generator (RNG) from the "rand" crate
139pub fn pick_rand_breaks(breaks: &mut Vec<usize>, num_vals: &usize, rng: &mut StdRng) {
140    let num_breaks = breaks.len();
141    if num_breaks > num_vals - 1 {
142        return;
143    }
144
145    let mut set = HashSet::new();
146    while set.len() < num_breaks {
147        set.insert(rng.gen_range(1..*num_vals));
148    }
149    let mut set_iter = set.iter();
150    for item in breaks.iter_mut().take(set_iter.len()) {
151        *item = *set_iter.next().unwrap();
152    }
153    breaks.sort_unstable();
154}
155
156/// Calculates goodness of variance fit (GVF) for a particular set of breaks on a dataset
157///
158/// # Arguments
159///
160/// * `breaks` - A reference to a vector (usize) of break indices (sorted, ascending)
161/// * `vals` - A reference to a vector (f64) of data points (sorted, ascending)
162/// * `gssd` - A reference to the global sum of squared deviations (GSSD)
163pub fn calc_gvf(breaks: &Vec<usize>, vals: &Vec<f64>, gssd: &f64) -> f64 {
164    let num_vals = vals.len();
165    let num_bins = breaks.len() + 1;
166    let mut tssd: f64 = 0.0;
167    for i in 0..num_bins {
168        let lower = if i == 0 { 0 } else { breaks[i - 1] };
169        let upper = if i == num_bins - 1 {
170            num_vals
171        } else {
172            breaks[i]
173        };
174
175        let mut mean: f64 = 0.0;
176        let mut ssd: f64 = 0.0;
177        for item in vals.iter().take(upper).skip(lower) {
178            mean += item;
179        }
180        mean /= (upper - lower) as f64;
181        for item in vals.iter().take(upper).skip(lower) {
182            ssd += (item - mean) * (item - mean)
183        }
184        tssd += ssd;
185    }
186    1.0 - (tssd / gssd)
187}
188
189/// Calculates global sum of squared deviations (GSSD) for a particular dataset
190///
191/// # Arguments
192///
193/// * `data` - A reference to a vector (f64) of data points (sorted, ascending)
194pub fn calc_gssd(data: &Vec<f64>) -> f64 {
195    let num_vals = data.len();
196    let mut mean = 0.0;
197    let mut max_val: f64 = data[0];
198    for item in data.iter().take(num_vals) {
199        let val = *item;
200        if val > max_val {
201            max_val = val
202        }
203        mean += val;
204    }
205    mean /= num_vals as f64;
206
207    let mut gssd: f64 = 0.0;
208    for item in data.iter().take(num_vals) {
209        let val = *item;
210        gssd += (val - mean) * (val - mean);
211    }
212
213    gssd
214}