1use crate::data::{FloatData, JaggedMatrix, Matrix};
2use crate::errors::ForustError;
3use crate::utils::{is_missing, map_bin, percentiles};
4
5fn percentiles_or_value<T>(v: &[T], sample_weight: &[T], pcts: &[T]) -> Vec<T>
12where
13 T: FloatData<T>,
14{
15 let mut v_u = v.to_owned();
16 v_u.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
17 v_u.dedup();
18 if v_u.len() <= pcts.len() + 1 {
19 v_u
20 } else {
21 percentiles(v, sample_weight, pcts)
22 }
23}
24
25pub struct BinnedData<T> {
38 pub binned_data: Vec<u16>,
39 pub cuts: JaggedMatrix<T>,
40 pub nunique: Vec<usize>,
41}
42
43fn bin_matrix_from_cuts<T: FloatData<T>>(
49 data: &Matrix<T>,
50 cuts: &JaggedMatrix<T>,
51 missing: &T,
52) -> Vec<u16> {
53 data.data
57 .iter()
58 .enumerate()
59 .map(|(i, v)| {
60 let col = i / data.rows;
61 map_bin(cuts.get_col(col), v, missing).unwrap()
64 })
65 .collect()
66}
67
68pub fn bin_matrix(
75 data: &Matrix<f64>,
76 sample_weight: &[f64],
77 nbins: u16,
78 missing: f64,
79) -> Result<BinnedData<f64>, ForustError> {
80 let mut pcts = Vec::new();
81 let nbins_ = f64::from_u16(nbins);
82 for i in 0..nbins {
83 let v = f64::from_u16(i) / nbins_;
84 pcts.push(v);
85 }
86
87 let mut cuts = JaggedMatrix::new();
90 let mut nunique = Vec::new();
91 for i in 0..data.cols {
92 let (no_miss, w): (Vec<f64>, Vec<f64>) = data
93 .get_col(i)
94 .iter()
95 .zip(sample_weight.iter())
96 .filter(|(v, _)| !is_missing(v, &missing))
99 .unzip();
100 assert_eq!(no_miss.len(), w.len());
101 let mut col_cuts = percentiles_or_value(&no_miss, &w, &pcts);
102 col_cuts.push(f64::MAX);
103 col_cuts.dedup();
104 nunique.push(col_cuts.len());
110 let l = col_cuts.len();
111 cuts.data.extend(col_cuts);
112 let e = match cuts.ends.last() {
113 Some(v) => v + l,
114 None => l,
115 };
116 cuts.ends.push(e);
117 cuts.cols = cuts.ends.len();
118 cuts.n_records = cuts.ends.iter().sum();
119 }
120
121 let binned_data = bin_matrix_from_cuts(data, &cuts, &missing);
122
123 Ok(BinnedData {
124 binned_data,
125 cuts,
126 nunique,
127 })
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use std::fs;
134 #[test]
135 fn test_bin_data() {
136 let file = fs::read_to_string("resources/contiguous_no_missing.csv")
137 .expect("Something went wrong reading the file");
138 let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
139 let data = Matrix::new(&data_vec, 891, 5);
140 let sample_weight = vec![1.; data.rows];
141 let b = bin_matrix(&data, &sample_weight, 50, f64::NAN).unwrap();
142 let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
143 for column in 0..data.cols {
144 let mut b_compare = 1;
145 for cuts in b.cuts.get_col(column).windows(2) {
146 let c1 = cuts[0];
147 let c2 = cuts[1];
148 let mut n_v = 0;
149 let mut n_b = 0;
150 for (bin, value) in bdata.get_col(column).iter().zip(data.get_col(column)) {
151 if *bin == b_compare {
152 n_b += 1;
153 }
154 if (c1 <= *value) && (*value < c2) {
155 n_v += 1;
156 }
157 }
158 assert_eq!(n_v, n_b);
159 b_compare += 1;
160 }
161 }
162 }
163}