Skip to main content

anofox_ml_preprocessing/
kbins_discretizer.rs

1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::Array2;
3
4/// Strategy for computing bin edges.
5#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
6pub enum BinStrategy {
7    /// All bins have equal width: `(max - min) / n_bins`.
8    Uniform,
9    /// All bins have approximately the same number of samples (quantile-based).
10    Quantile,
11}
12
13/// Encoding strategy for transformed output.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
15pub enum EncodeStrategy {
16    /// Each value is replaced by its integer bin index (0-based).
17    Ordinal,
18    /// Each feature is expanded into `n_bins` binary columns (one-hot encoding).
19    Onehot,
20}
21
22/// Parameters for KBinsDiscretizer (unfitted state).
23///
24/// Bins continuous features into discrete intervals. Two binning strategies
25/// are supported: uniform-width and quantile-based. Output can be ordinal
26/// (bin indices) or one-hot encoded.
27#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
28pub struct KBinsDiscretizer {
29    /// Number of bins per feature.
30    pub n_bins: usize,
31    /// Strategy for computing bin edges.
32    pub strategy: BinStrategy,
33    /// Encoding strategy for the output.
34    pub encode: EncodeStrategy,
35}
36
37impl KBinsDiscretizer {
38    /// Create a new `KBinsDiscretizer` with defaults (5 bins, quantile strategy, ordinal encoding).
39    pub fn new() -> Self {
40        Self {
41            n_bins: 5,
42            strategy: BinStrategy::Quantile,
43            encode: EncodeStrategy::Ordinal,
44        }
45    }
46
47    /// Set the number of bins.
48    pub fn n_bins(mut self, n_bins: usize) -> Self {
49        self.n_bins = n_bins;
50        self
51    }
52
53    /// Set the binning strategy.
54    pub fn strategy(mut self, strategy: BinStrategy) -> Self {
55        self.strategy = strategy;
56        self
57    }
58
59    /// Set the encoding strategy.
60    pub fn encode(mut self, encode: EncodeStrategy) -> Self {
61        self.encode = encode;
62        self
63    }
64}
65
66impl Default for KBinsDiscretizer {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72/// Fitted KBinsDiscretizer -- holds bin edges per feature.
73#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
74#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
75pub struct FittedKBinsDiscretizer<F: Float> {
76    /// Bin edges per feature. Each inner vec has `n_bins + 1` values.
77    bin_edges: Vec<Vec<F>>,
78    n_bins: usize,
79    encode: EncodeStrategy,
80}
81
82/// Compute a percentile from a sorted slice using linear interpolation.
83fn percentile_sorted<F: Float>(sorted: &[F], p: f64) -> F {
84    let n = sorted.len();
85    if n == 1 {
86        return sorted[0];
87    }
88    let idx = p * (n - 1) as f64;
89    let lo = idx.floor() as usize;
90    let hi = idx.ceil().min((n - 1) as f64) as usize;
91    if lo == hi {
92        sorted[lo]
93    } else {
94        let frac = F::from_f64(idx - lo as f64).unwrap();
95        sorted[lo] * (F::one() - frac) + sorted[hi] * frac
96    }
97}
98
99impl<F: Float> FitUnsupervised<F> for KBinsDiscretizer {
100    type Fitted = FittedKBinsDiscretizer<F>;
101
102    fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
103        if x.is_empty() {
104            return Err(RustMlError::EmptyInput("input array is empty".into()));
105        }
106        if self.n_bins < 2 {
107            return Err(RustMlError::InvalidParameter(
108                "n_bins must be at least 2".into(),
109            ));
110        }
111
112        let ncols = x.ncols();
113        let mut bin_edges = Vec::with_capacity(ncols);
114
115        for j in 0..ncols {
116            let mut col: Vec<F> = x.column(j).to_vec();
117            col.sort_by(|a, b| a.partial_cmp(b).unwrap());
118
119            let edges = match self.strategy {
120                BinStrategy::Uniform => {
121                    let min_val = col[0];
122                    let max_val = col[col.len() - 1];
123                    let range = max_val - min_val;
124                    let step = range / F::from_usize(self.n_bins).unwrap();
125                    let mut e = Vec::with_capacity(self.n_bins + 1);
126                    for i in 0..=self.n_bins {
127                        e.push(min_val + step * F::from_usize(i).unwrap());
128                    }
129                    e
130                }
131                BinStrategy::Quantile => {
132                    let mut e = Vec::with_capacity(self.n_bins + 1);
133                    for i in 0..=self.n_bins {
134                        let p = i as f64 / self.n_bins as f64;
135                        e.push(percentile_sorted(&col, p));
136                    }
137                    e
138                }
139            };
140
141            bin_edges.push(edges);
142        }
143
144        Ok(FittedKBinsDiscretizer {
145            bin_edges,
146            n_bins: self.n_bins,
147            encode: self.encode,
148        })
149    }
150}
151
152/// Find the bin index for a value given bin edges.
153/// Returns a 0-based bin index in [0, n_bins - 1].
154fn find_bin<F: Float>(val: F, edges: &[F], n_bins: usize) -> usize {
155    // Binary search: find the rightmost edge <= val
156    let mut lo = 0;
157    let mut hi = edges.len() - 1;
158
159    // Clamp to first/last bin for out-of-range values
160    if val <= edges[0] {
161        return 0;
162    }
163    if val >= edges[edges.len() - 1] {
164        return n_bins - 1;
165    }
166
167    while lo + 1 < hi {
168        let mid = (lo + hi) / 2;
169        if edges[mid] <= val {
170            lo = mid;
171        } else {
172            hi = mid;
173        }
174    }
175
176    // lo is the index of the left edge of the bin, bin index = lo
177    // Clamp to [0, n_bins - 1]
178    lo.min(n_bins - 1)
179}
180
181impl<F: Float> Transform<F> for FittedKBinsDiscretizer<F> {
182    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
183        let expected_cols = self.bin_edges.len();
184        if x.ncols() != expected_cols {
185            return Err(RustMlError::ShapeMismatch(format!(
186                "expected {} features, got {}",
187                expected_cols,
188                x.ncols()
189            )));
190        }
191
192        match self.encode {
193            EncodeStrategy::Ordinal => {
194                let mut result = Array2::<F>::zeros(x.raw_dim());
195                for i in 0..x.nrows() {
196                    for j in 0..x.ncols() {
197                        let bin = find_bin(x[[i, j]], &self.bin_edges[j], self.n_bins);
198                        result[[i, j]] = F::from_usize(bin).unwrap();
199                    }
200                }
201                Ok(result)
202            }
203            EncodeStrategy::Onehot => {
204                let out_cols = expected_cols * self.n_bins;
205                let mut result = Array2::<F>::zeros((x.nrows(), out_cols));
206                for i in 0..x.nrows() {
207                    for j in 0..x.ncols() {
208                        let bin = find_bin(x[[i, j]], &self.bin_edges[j], self.n_bins);
209                        let col_offset = j * self.n_bins + bin;
210                        result[[i, col_offset]] = F::one();
211                    }
212                }
213                Ok(result)
214            }
215        }
216    }
217}
218
219impl<F: Float> FittedKBinsDiscretizer<F> {
220    /// Return the bin edges per feature.
221    pub fn bin_edges(&self) -> &Vec<Vec<F>> {
222        &self.bin_edges
223    }
224
225    /// Return the number of bins.
226    pub fn n_bins(&self) -> usize {
227        self.n_bins
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use approx::assert_abs_diff_eq;
235    use ndarray::array;
236
237    #[test]
238    fn test_uniform_ordinal() {
239        let x = array![
240            [0.0, 0.0],
241            [2.5, 5.0],
242            [5.0, 10.0],
243            [7.5, 15.0],
244            [10.0, 20.0],
245        ];
246        let kbd = KBinsDiscretizer::new()
247            .n_bins(4)
248            .strategy(BinStrategy::Uniform)
249            .encode(EncodeStrategy::Ordinal);
250        let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
251        let transformed = fitted.transform(&x).unwrap();
252
253        // Uniform bins for col 0: [0, 2.5, 5, 7.5, 10]
254        // 0.0 -> bin 0, 2.5 -> bin 1, 5.0 -> bin 2, 7.5 -> bin 3, 10.0 -> bin 3
255        assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10);
256        assert_abs_diff_eq!(transformed[[1, 0]], 1.0, epsilon = 1e-10);
257        assert_abs_diff_eq!(transformed[[2, 0]], 2.0, epsilon = 1e-10);
258        assert_abs_diff_eq!(transformed[[4, 0]], 3.0, epsilon = 1e-10);
259    }
260
261    #[test]
262    fn test_quantile_ordinal() {
263        let x = array![
264            [1.0],
265            [2.0],
266            [3.0],
267            [4.0],
268            [5.0],
269            [6.0],
270            [7.0],
271            [8.0],
272            [9.0],
273            [10.0],
274        ];
275        let kbd = KBinsDiscretizer::new()
276            .n_bins(5)
277            .strategy(BinStrategy::Quantile)
278            .encode(EncodeStrategy::Ordinal);
279        let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
280        let transformed = fitted.transform(&x).unwrap();
281
282        // All bin indices should be in [0, 4]
283        for &v in transformed.iter() {
284            assert!(v >= 0.0 && v <= 4.0, "bin index out of range: {}", v);
285        }
286
287        // Values should be non-decreasing (monotonic)
288        for i in 1..x.nrows() {
289            assert!(
290                transformed[[i, 0]] >= transformed[[i - 1, 0]],
291                "monotonicity violated at row {}",
292                i
293            );
294        }
295    }
296
297    #[test]
298    fn test_onehot_encoding() {
299        let x = array![[1.0], [3.0], [5.0], [7.0], [9.0]];
300        let kbd = KBinsDiscretizer::new()
301            .n_bins(3)
302            .strategy(BinStrategy::Uniform)
303            .encode(EncodeStrategy::Onehot);
304        let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
305        let transformed = fitted.transform(&x).unwrap();
306
307        // Output should have 3 columns (1 feature * 3 bins)
308        assert_eq!(transformed.ncols(), 3);
309
310        // Each row should have exactly one 1.0 and two 0.0
311        for i in 0..transformed.nrows() {
312            let row_sum: f64 = transformed.row(i).sum();
313            assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
314        }
315    }
316
317    #[test]
318    fn test_onehot_multiple_features() {
319        let x = array![[1.0, 10.0], [5.0, 50.0], [9.0, 90.0]];
320        let kbd = KBinsDiscretizer::new()
321            .n_bins(3)
322            .strategy(BinStrategy::Uniform)
323            .encode(EncodeStrategy::Onehot);
324        let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
325        let transformed = fitted.transform(&x).unwrap();
326
327        // Output should have 6 columns (2 features * 3 bins)
328        assert_eq!(transformed.ncols(), 6);
329
330        // Each row: exactly two 1.0 values (one per feature)
331        for i in 0..transformed.nrows() {
332            let row_sum: f64 = transformed.row(i).sum();
333            assert_abs_diff_eq!(row_sum, 2.0, epsilon = 1e-10);
334        }
335    }
336
337    #[test]
338    fn test_empty_input() {
339        let x: Array2<f64> = Array2::zeros((0, 0));
340        let kbd = KBinsDiscretizer::default();
341        assert!(FitUnsupervised::<f64>::fit(&kbd, &x).is_err());
342    }
343
344    #[test]
345    fn test_invalid_n_bins() {
346        let x = array![[1.0], [2.0], [3.0]];
347        let kbd = KBinsDiscretizer::new().n_bins(1);
348        assert!(FitUnsupervised::<f64>::fit(&kbd, &x).is_err());
349    }
350
351    #[test]
352    fn test_shape_mismatch() {
353        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
354        let kbd = KBinsDiscretizer::default();
355        let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
356
357        let x_wrong = array![[1.0, 2.0, 3.0]];
358        assert!(fitted.transform(&x_wrong).is_err());
359    }
360
361    #[test]
362    fn test_out_of_range_values() {
363        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
364        let kbd = KBinsDiscretizer::new()
365            .n_bins(3)
366            .strategy(BinStrategy::Uniform)
367            .encode(EncodeStrategy::Ordinal);
368        let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
369
370        // Transform values outside the fitted range
371        let x_test = array![[-10.0], [0.0], [3.0], [6.0], [100.0]];
372        let transformed = fitted.transform(&x_test).unwrap();
373
374        // Out-of-range should clamp to first/last bin
375        assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10); // below min
376        assert_abs_diff_eq!(transformed[[4, 0]], 2.0, epsilon = 1e-10); // above max
377    }
378
379    #[test]
380    fn test_constant_feature() {
381        let x = array![[5.0], [5.0], [5.0], [5.0]];
382        let kbd = KBinsDiscretizer::new()
383            .n_bins(3)
384            .strategy(BinStrategy::Uniform)
385            .encode(EncodeStrategy::Ordinal);
386        let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
387        let transformed = fitted.transform(&x).unwrap();
388
389        // All values should map to the same bin (or at least be finite)
390        for &v in transformed.iter() {
391            assert!(v.is_finite(), "constant feature produced non-finite: {}", v);
392        }
393    }
394
395    #[test]
396    fn test_f32() {
397        let x = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
398        let kbd = KBinsDiscretizer::new()
399            .n_bins(3)
400            .strategy(BinStrategy::Quantile)
401            .encode(EncodeStrategy::Ordinal);
402        let fitted = FitUnsupervised::<f32>::fit(&kbd, &x).unwrap();
403        let transformed = fitted.transform(&x).unwrap();
404
405        for &v in transformed.iter() {
406            assert!(v.is_finite());
407            assert!(v >= 0.0 && v < 3.0);
408        }
409    }
410}