Skip to main content

ferrolearn_preprocess/
kbins_discretizer.rs

1//! K-bins discretizer: bin continuous features into discrete intervals.
2//!
3//! [`KBinsDiscretizer`] transforms continuous features into discrete bins.
4//! Each feature is independently binned according to one of the following
5//! strategies:
6//!
7//! - [`BinStrategy::Uniform`] — equal-width bins.
8//! - [`BinStrategy::Quantile`] — bins with equal numbers of samples.
9//! - [`BinStrategy::KMeans`] — bins based on 1D k-means clustering.
10//!
11//! The output can be ordinal-encoded (integers 0..k-1) or one-hot encoded.
12
13use ferrolearn_core::error::FerroError;
14use ferrolearn_core::traits::{Fit, FitTransform, Transform};
15use ndarray::Array2;
16use num_traits::Float;
17
18// ---------------------------------------------------------------------------
19// BinStrategy
20// ---------------------------------------------------------------------------
21
22/// Strategy for computing bin edges.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum BinStrategy {
25    /// Equal-width bins.
26    Uniform,
27    /// Equal-frequency bins (quantile-based).
28    Quantile,
29    /// Bins based on 1D k-means clustering.
30    KMeans,
31}
32
33/// Encoding method for the output.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum BinEncoding {
36    /// Ordinal encoding: each value is replaced by its bin index (0..n_bins-1).
37    Ordinal,
38    /// One-hot encoding: each bin becomes a separate binary column.
39    OneHot,
40}
41
42// ---------------------------------------------------------------------------
43// KBinsDiscretizer (unfitted)
44// ---------------------------------------------------------------------------
45
46/// An unfitted K-bins discretizer.
47///
48/// Calling [`Fit::fit`] computes the bin edges for each feature and returns a
49/// [`FittedKBinsDiscretizer`].
50///
51/// # Examples
52///
53/// ```
54/// use ferrolearn_preprocess::kbins_discretizer::{KBinsDiscretizer, BinStrategy, BinEncoding};
55/// use ferrolearn_core::traits::{Fit, Transform};
56/// use ndarray::array;
57///
58/// let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
59/// let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]];
60/// let fitted = disc.fit(&x, &()).unwrap();
61/// let out = fitted.transform(&x).unwrap();
62/// // Values should be in {0.0, 1.0, 2.0}
63/// for v in out.iter() {
64///     assert!(*v >= 0.0 && *v < 3.0);
65/// }
66/// ```
67#[must_use]
68#[derive(Debug, Clone)]
69pub struct KBinsDiscretizer<F> {
70    /// Number of bins.
71    n_bins: usize,
72    /// Encoding method.
73    encode: BinEncoding,
74    /// Binning strategy.
75    strategy: BinStrategy,
76    _marker: std::marker::PhantomData<F>,
77}
78
79impl<F: Float + Send + Sync + 'static> KBinsDiscretizer<F> {
80    /// Create a new `KBinsDiscretizer`.
81    pub fn new(n_bins: usize, encode: BinEncoding, strategy: BinStrategy) -> Self {
82        Self {
83            n_bins,
84            encode,
85            strategy,
86            _marker: std::marker::PhantomData,
87        }
88    }
89
90    /// Return the number of bins.
91    #[must_use]
92    pub fn n_bins(&self) -> usize {
93        self.n_bins
94    }
95
96    /// Return the encoding method.
97    #[must_use]
98    pub fn encode(&self) -> BinEncoding {
99        self.encode
100    }
101
102    /// Return the binning strategy.
103    #[must_use]
104    pub fn strategy(&self) -> BinStrategy {
105        self.strategy
106    }
107}
108
109impl<F: Float + Send + Sync + 'static> Default for KBinsDiscretizer<F> {
110    fn default() -> Self {
111        Self::new(5, BinEncoding::Ordinal, BinStrategy::Uniform)
112    }
113}
114
115// ---------------------------------------------------------------------------
116// FittedKBinsDiscretizer
117// ---------------------------------------------------------------------------
118
119/// A fitted K-bins discretizer holding per-feature bin edges.
120///
121/// Created by calling [`Fit::fit`] on a [`KBinsDiscretizer`].
122#[derive(Debug, Clone)]
123pub struct FittedKBinsDiscretizer<F> {
124    /// Bin edges per feature. `bin_edges[j]` has `n_bins + 1` edges.
125    bin_edges: Vec<Vec<F>>,
126    /// Number of bins.
127    n_bins: usize,
128    /// Encoding method.
129    encode: BinEncoding,
130}
131
132impl<F: Float + Send + Sync + 'static> FittedKBinsDiscretizer<F> {
133    /// Return the bin edges per feature.
134    #[must_use]
135    pub fn bin_edges(&self) -> &[Vec<F>] {
136        &self.bin_edges
137    }
138
139    /// Return the number of bins.
140    #[must_use]
141    pub fn n_bins(&self) -> usize {
142        self.n_bins
143    }
144
145    /// Return the encoding method.
146    #[must_use]
147    pub fn encode(&self) -> BinEncoding {
148        self.encode
149    }
150}
151
152// ---------------------------------------------------------------------------
153// Helpers
154// ---------------------------------------------------------------------------
155
156/// Assign a value to a bin index given sorted bin edges.
157fn assign_bin<F: Float>(value: F, edges: &[F]) -> usize {
158    let n_bins = edges.len() - 1;
159    if n_bins == 0 {
160        return 0;
161    }
162    // Binary search for the bin
163    for (i, edge) in edges.iter().enumerate().skip(1) {
164        if value < *edge {
165            return i - 1;
166        }
167    }
168    // Last bin for values >= last edge
169    n_bins - 1
170}
171
172/// Simple 1D k-means to find bin edges.
173fn kmeans_1d<F: Float>(values: &[F], n_bins: usize, max_iter: usize) -> Vec<F> {
174    let n = values.len();
175    if n <= n_bins || n_bins == 0 {
176        // Fallback to uniform
177        let min_v = values.iter().copied().fold(F::infinity(), |a, b| a.min(b));
178        let max_v = values
179            .iter()
180            .copied()
181            .fold(F::neg_infinity(), |a, b| a.max(b));
182        return (0..=n_bins)
183            .map(|i| min_v + (max_v - min_v) * F::from(i).unwrap() / F::from(n_bins).unwrap())
184            .collect();
185    }
186
187    // Initialize centroids using uniform spacing
188    let min_v = values.iter().copied().fold(F::infinity(), |a, b| a.min(b));
189    let max_v = values
190        .iter()
191        .copied()
192        .fold(F::neg_infinity(), |a, b| a.max(b));
193
194    let mut centroids: Vec<F> = (0..n_bins)
195        .map(|i| {
196            min_v
197                + (max_v - min_v) * (F::from(i).unwrap() + F::from(0.5).unwrap())
198                    / F::from(n_bins).unwrap()
199        })
200        .collect();
201
202    for _ in 0..max_iter {
203        // Assign each value to nearest centroid
204        let mut sums = vec![F::zero(); n_bins];
205        let mut counts = vec![0usize; n_bins];
206
207        for &v in values {
208            let mut best_c = 0;
209            let mut best_dist = F::infinity();
210            for (c, &centroid) in centroids.iter().enumerate() {
211                let d = (v - centroid).abs();
212                if d < best_dist {
213                    best_dist = d;
214                    best_c = c;
215                }
216            }
217            sums[best_c] = sums[best_c] + v;
218            counts[best_c] += 1;
219        }
220
221        // Update centroids
222        let mut converged = true;
223        for c in 0..n_bins {
224            if counts[c] > 0 {
225                let new_centroid = sums[c] / F::from(counts[c]).unwrap();
226                if (new_centroid - centroids[c]).abs() > F::from(1e-10).unwrap_or(F::epsilon()) {
227                    converged = false;
228                }
229                centroids[c] = new_centroid;
230            }
231        }
232        if converged {
233            break;
234        }
235    }
236
237    // Sort centroids and compute edges as midpoints
238    centroids.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
239
240    let mut edges = Vec::with_capacity(n_bins + 1);
241    edges.push(min_v);
242    for i in 0..n_bins - 1 {
243        let mid = (centroids[i] + centroids[i + 1]) / (F::one() + F::one());
244        edges.push(mid);
245    }
246    edges.push(max_v);
247
248    edges
249}
250
251// ---------------------------------------------------------------------------
252// Trait implementations
253// ---------------------------------------------------------------------------
254
255impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for KBinsDiscretizer<F> {
256    type Fitted = FittedKBinsDiscretizer<F>;
257    type Error = FerroError;
258
259    /// Fit by computing bin edges for each feature.
260    ///
261    /// # Errors
262    ///
263    /// - [`FerroError::InsufficientSamples`] if the input has fewer than 2 rows.
264    /// - [`FerroError::InvalidParameter`] if `n_bins` < 2.
265    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedKBinsDiscretizer<F>, FerroError> {
266        let n_samples = x.nrows();
267        if n_samples < 2 {
268            return Err(FerroError::InsufficientSamples {
269                required: 2,
270                actual: n_samples,
271                context: "KBinsDiscretizer::fit".into(),
272            });
273        }
274        if self.n_bins < 2 {
275            return Err(FerroError::InvalidParameter {
276                name: "n_bins".into(),
277                reason: "n_bins must be at least 2".into(),
278            });
279        }
280
281        let n_features = x.ncols();
282        let mut bin_edges = Vec::with_capacity(n_features);
283
284        for j in 0..n_features {
285            let mut col_vals: Vec<F> = x.column(j).iter().copied().collect();
286            col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
287
288            let min_val = col_vals[0];
289            let max_val = col_vals[col_vals.len() - 1];
290
291            let edges = match self.strategy {
292                BinStrategy::Uniform => (0..=self.n_bins)
293                    .map(|i| {
294                        min_val
295                            + (max_val - min_val) * F::from(i).unwrap()
296                                / F::from(self.n_bins).unwrap()
297                    })
298                    .collect(),
299                BinStrategy::Quantile => {
300                    let n = col_vals.len();
301                    (0..=self.n_bins)
302                        .map(|i| {
303                            let frac = F::from(i).unwrap() / F::from(self.n_bins).unwrap();
304                            let pos = frac * F::from(n.saturating_sub(1)).unwrap();
305                            let lo = pos.floor().to_usize().unwrap_or(0).min(n - 1);
306                            let hi = pos.ceil().to_usize().unwrap_or(0).min(n - 1);
307                            let f = pos - F::from(lo).unwrap();
308                            col_vals[lo] * (F::one() - f) + col_vals[hi] * f
309                        })
310                        .collect()
311                }
312                BinStrategy::KMeans => kmeans_1d(&col_vals, self.n_bins, 100),
313            };
314
315            bin_edges.push(edges);
316        }
317
318        Ok(FittedKBinsDiscretizer {
319            bin_edges,
320            n_bins: self.n_bins,
321            encode: self.encode,
322        })
323    }
324}
325
326impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedKBinsDiscretizer<F> {
327    type Output = Array2<F>;
328    type Error = FerroError;
329
330    /// Discretize features into bin indices or one-hot vectors.
331    ///
332    /// # Errors
333    ///
334    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
335    /// from the number of features seen during fitting.
336    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
337        let n_features = self.bin_edges.len();
338        if x.ncols() != n_features {
339            return Err(FerroError::ShapeMismatch {
340                expected: vec![x.nrows(), n_features],
341                actual: vec![x.nrows(), x.ncols()],
342                context: "FittedKBinsDiscretizer::transform".into(),
343            });
344        }
345
346        let n_samples = x.nrows();
347
348        match self.encode {
349            BinEncoding::Ordinal => {
350                let mut out = Array2::zeros((n_samples, n_features));
351                for j in 0..n_features {
352                    let edges = &self.bin_edges[j];
353                    for i in 0..n_samples {
354                        let bin = assign_bin(x[[i, j]], edges);
355                        out[[i, j]] = F::from(bin).unwrap_or(F::zero());
356                    }
357                }
358                Ok(out)
359            }
360            BinEncoding::OneHot => {
361                let n_out = n_features * self.n_bins;
362                let mut out = Array2::zeros((n_samples, n_out));
363                for j in 0..n_features {
364                    let edges = &self.bin_edges[j];
365                    let col_offset = j * self.n_bins;
366                    for i in 0..n_samples {
367                        let bin = assign_bin(x[[i, j]], edges);
368                        out[[i, col_offset + bin]] = F::one();
369                    }
370                }
371                Ok(out)
372            }
373        }
374    }
375}
376
377/// Implement `Transform` on the unfitted discretizer.
378impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for KBinsDiscretizer<F> {
379    type Output = Array2<F>;
380    type Error = FerroError;
381
382    /// Always returns an error — must be fitted first.
383    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
384        Err(FerroError::InvalidParameter {
385            name: "KBinsDiscretizer".into(),
386            reason: "discretizer must be fitted before calling transform; use fit() first".into(),
387        })
388    }
389}
390
391impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for KBinsDiscretizer<F> {
392    type FitError = FerroError;
393
394    /// Fit and transform in one step.
395    ///
396    /// # Errors
397    ///
398    /// Returns an error if fitting fails.
399    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
400        let fitted = self.fit(x, &())?;
401        fitted.transform(x)
402    }
403}
404
405// ---------------------------------------------------------------------------
406// Tests
407// ---------------------------------------------------------------------------
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use approx::assert_abs_diff_eq;
413    use ndarray::array;
414
415    #[test]
416    fn test_kbins_ordinal_uniform() {
417        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
418        let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]];
419        let fitted = disc.fit(&x, &()).unwrap();
420        let out = fitted.transform(&x).unwrap();
421        assert_eq!(out.ncols(), 1);
422        // Check bin assignments
423        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); // 0.0 → bin 0
424        assert_abs_diff_eq!(out[[5, 0]], 2.0, epsilon = 1e-10); // 5.0 → bin 2 (last)
425    }
426
427    #[test]
428    fn test_kbins_onehot_uniform() {
429        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::OneHot, BinStrategy::Uniform);
430        let x = array![[0.0], [2.5], [5.0]];
431        let fitted = disc.fit(&x, &()).unwrap();
432        let out = fitted.transform(&x).unwrap();
433        // 3 bins → 3 columns per feature
434        assert_eq!(out.ncols(), 3);
435        // Each row should have exactly one 1.0
436        for i in 0..out.nrows() {
437            let row_sum: f64 = out.row(i).iter().sum();
438            assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
439        }
440    }
441
442    #[test]
443    fn test_kbins_quantile_strategy() {
444        let disc = KBinsDiscretizer::<f64>::new(4, BinEncoding::Ordinal, BinStrategy::Quantile);
445        let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0]];
446        let fitted = disc.fit(&x, &()).unwrap();
447        let out = fitted.transform(&x).unwrap();
448        // All values should be valid bin indices
449        for v in out.iter() {
450            assert!(*v >= 0.0 && *v < 4.0);
451        }
452    }
453
454    #[test]
455    fn test_kbins_kmeans_strategy() {
456        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::KMeans);
457        let x = array![
458            [0.0],
459            [0.1],
460            [0.2],
461            [5.0],
462            [5.1],
463            [5.2],
464            [10.0],
465            [10.1],
466            [10.2]
467        ];
468        let fitted = disc.fit(&x, &()).unwrap();
469        let out = fitted.transform(&x).unwrap();
470        // Values should be valid bin indices
471        for v in out.iter() {
472            assert!(*v >= 0.0 && *v < 3.0);
473        }
474    }
475
476    #[test]
477    fn test_kbins_multi_feature() {
478        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
479        let x = array![[0.0, 10.0], [2.5, 15.0], [5.0, 20.0]];
480        let fitted = disc.fit(&x, &()).unwrap();
481        let out = fitted.transform(&x).unwrap();
482        assert_eq!(out.ncols(), 2);
483    }
484
485    #[test]
486    fn test_kbins_bin_edges() {
487        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
488        let x = array![[0.0], [3.0], [6.0]];
489        let fitted = disc.fit(&x, &()).unwrap();
490        let edges = &fitted.bin_edges()[0];
491        // 4 edges for 3 bins: [0, 2, 4, 6]
492        assert_eq!(edges.len(), 4);
493        assert_abs_diff_eq!(edges[0], 0.0, epsilon = 1e-10);
494        assert_abs_diff_eq!(edges[3], 6.0, epsilon = 1e-10);
495    }
496
497    #[test]
498    fn test_kbins_fit_transform() {
499        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
500        let x = array![[0.0], [2.5], [5.0]];
501        let out = disc.fit_transform(&x).unwrap();
502        assert_eq!(out.ncols(), 1);
503    }
504
505    #[test]
506    fn test_kbins_insufficient_samples_error() {
507        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
508        let x = array![[1.0]];
509        assert!(disc.fit(&x, &()).is_err());
510    }
511
512    #[test]
513    fn test_kbins_too_few_bins_error() {
514        let disc = KBinsDiscretizer::<f64>::new(1, BinEncoding::Ordinal, BinStrategy::Uniform);
515        let x = array![[0.0], [1.0]];
516        assert!(disc.fit(&x, &()).is_err());
517    }
518
519    #[test]
520    fn test_kbins_shape_mismatch_error() {
521        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
522        let x_train = array![[0.0, 1.0], [2.0, 3.0]];
523        let fitted = disc.fit(&x_train, &()).unwrap();
524        let x_bad = array![[1.0, 2.0, 3.0]];
525        assert!(fitted.transform(&x_bad).is_err());
526    }
527
528    #[test]
529    fn test_kbins_unfitted_error() {
530        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
531        let x = array![[0.0]];
532        assert!(disc.transform(&x).is_err());
533    }
534
535    #[test]
536    fn test_kbins_default() {
537        let disc = KBinsDiscretizer::<f64>::default();
538        assert_eq!(disc.n_bins(), 5);
539        assert_eq!(disc.encode(), BinEncoding::Ordinal);
540        assert_eq!(disc.strategy(), BinStrategy::Uniform);
541    }
542
543    #[test]
544    fn test_kbins_ordinal_values_in_range() {
545        let disc = KBinsDiscretizer::<f64>::new(5, BinEncoding::Ordinal, BinStrategy::Uniform);
546        let x = array![[0.0], [2.5], [5.0], [7.5], [10.0]];
547        let fitted = disc.fit(&x, &()).unwrap();
548        let out = fitted.transform(&x).unwrap();
549        for v in out.iter() {
550            assert!(*v >= 0.0 && *v < 5.0, "Bin index {} out of range", v);
551        }
552    }
553}