use num_traits::ToPrimitive;
use rand::prelude::*;
use rand::rngs::StdRng;
use std::collections::HashSet;
use crate::utilities::{
breaks_to_classification, create_unique_val_mapping, to_vec_f64, unique_to_normal_breaks,
};
use crate::utilities::{Classification, UniqueVal};
pub fn get_jenks_classification<T: ToPrimitive>(num_bins: usize, data: &[T]) -> Classification {
let breaks: Vec<f64> = get_jenks_breaks(num_bins, data);
breaks_to_classification(&breaks, data)
}
pub fn get_jenks_breaks<T: ToPrimitive>(num_bins: usize, data: &[T]) -> Vec<f64> {
let data = to_vec_f64(data);
let num_vals = data.len();
let mut sorted_data: Vec<f64> = vec![];
for item in data.iter().take(num_vals) {
sorted_data.push(*item);
}
sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mut unique_val_map: Vec<UniqueVal> = vec![];
create_unique_val_mapping(&mut unique_val_map, &sorted_data);
let num_unique_vals = unique_val_map.len();
let true_num_bins = std::cmp::min(num_unique_vals, num_bins);
let gssd = calc_gssd(&sorted_data);
let mut rand_breaks: Vec<usize> = vec![0_usize; true_num_bins - 1];
let mut best_breaks: Vec<usize> = vec![0_usize; true_num_bins - 1];
let mut unique_rand_breaks: Vec<usize> = vec![0_usize; true_num_bins - 1];
let mut max_gvf: f64 = 0.0;
let c = 5000 * 2200 * 4;
let mut permutations = c / num_vals;
if permutations < 10 {
permutations = 10
}
if permutations > 10000 {
permutations = 10000
}
println!("permutations: {}", permutations);
let mut pseudo_rng = StdRng::seed_from_u64(123456789);
for _ in 0..permutations {
pick_rand_breaks(&mut unique_rand_breaks, &num_unique_vals, &mut pseudo_rng);
unique_to_normal_breaks(&unique_rand_breaks, &unique_val_map, &mut rand_breaks);
let new_gvf: f64 = calc_gvf(&rand_breaks, &sorted_data, &gssd);
if new_gvf > max_gvf {
max_gvf = new_gvf;
best_breaks[..rand_breaks.len()].copy_from_slice(&rand_breaks[..]);
}
}
let mut nat_breaks: Vec<f64> = vec![];
nat_breaks.resize(best_breaks.len(), 0.0);
for i in 0..best_breaks.len() {
nat_breaks[i] = sorted_data[best_breaks[i]];
}
println!("Breaks: {:#?}", nat_breaks);
nat_breaks
}
pub fn pick_rand_breaks(breaks: &mut Vec<usize>, num_vals: &usize, rng: &mut StdRng) {
let num_breaks = breaks.len();
if num_breaks > num_vals - 1 {
return;
}
let mut set = HashSet::new();
while set.len() < num_breaks {
set.insert(rng.gen_range(1..*num_vals));
}
let mut set_iter = set.iter();
for item in breaks.iter_mut().take(set_iter.len()) {
*item = *set_iter.next().unwrap();
}
breaks.sort_unstable();
}
pub fn calc_gvf(breaks: &Vec<usize>, vals: &Vec<f64>, gssd: &f64) -> f64 {
let num_vals = vals.len();
let num_bins = breaks.len() + 1;
let mut tssd: f64 = 0.0;
for i in 0..num_bins {
let lower = if i == 0 { 0 } else { breaks[i - 1] };
let upper = if i == num_bins - 1 {
num_vals
} else {
breaks[i]
};
let mut mean: f64 = 0.0;
let mut ssd: f64 = 0.0;
for item in vals.iter().take(upper).skip(lower) {
mean += item;
}
mean /= (upper - lower) as f64;
for item in vals.iter().take(upper).skip(lower) {
ssd += (item - mean) * (item - mean)
}
tssd += ssd;
}
1.0 - (tssd / gssd)
}
pub fn calc_gssd(data: &Vec<f64>) -> f64 {
let num_vals = data.len();
let mut mean = 0.0;
let mut max_val: f64 = data[0];
for item in data.iter().take(num_vals) {
let val = *item;
if val > max_val {
max_val = val
}
mean += val;
}
mean /= num_vals as f64;
let mut gssd: f64 = 0.0;
for item in data.iter().take(num_vals) {
let val = *item;
gssd += (val - mean) * (val - mean);
}
gssd
}