Skip to main content

ferrolearn_preprocess/
kbins_discretizer.rs

1//! K-bins discretizer: bin continuous features into discrete intervals.
2//!
3//! [`KBinsDiscretizer`] transforms continuous features into discrete bins.
4//! Each feature is independently binned according to one of the following
5//! strategies:
6//!
7//! - [`BinStrategy::Uniform`] — equal-width bins.
8//! - [`BinStrategy::Quantile`] — bins with equal numbers of samples.
9//! - [`BinStrategy::KMeans`] — bins based on 1D k-means clustering.
10//!
11//! The output can be ordinal-encoded (integers 0..k-1) or one-hot encoded.
12
13use ferrolearn_core::error::FerroError;
14use ferrolearn_core::traits::{Fit, FitTransform, Transform};
15use ndarray::Array2;
16use num_traits::Float;
17
18// ---------------------------------------------------------------------------
19// BinStrategy
20// ---------------------------------------------------------------------------
21
22/// Strategy for computing bin edges.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum BinStrategy {
25    /// Equal-width bins.
26    Uniform,
27    /// Equal-frequency bins (quantile-based).
28    Quantile,
29    /// Bins based on 1D k-means clustering.
30    KMeans,
31}
32
33/// Encoding method for the output.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum BinEncoding {
36    /// Ordinal encoding: each value is replaced by its bin index (0..n_bins-1).
37    Ordinal,
38    /// One-hot encoding: each bin becomes a separate binary column.
39    OneHot,
40}
41
42// ---------------------------------------------------------------------------
43// KBinsDiscretizer (unfitted)
44// ---------------------------------------------------------------------------
45
46/// An unfitted K-bins discretizer.
47///
48/// Calling [`Fit::fit`] computes the bin edges for each feature and returns a
49/// [`FittedKBinsDiscretizer`].
50///
51/// # Examples
52///
53/// ```
54/// use ferrolearn_preprocess::kbins_discretizer::{KBinsDiscretizer, BinStrategy, BinEncoding};
55/// use ferrolearn_core::traits::{Fit, Transform};
56/// use ndarray::array;
57///
58/// let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
59/// let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]];
60/// let fitted = disc.fit(&x, &()).unwrap();
61/// let out = fitted.transform(&x).unwrap();
62/// // Values should be in {0.0, 1.0, 2.0}
63/// for v in out.iter() {
64///     assert!(*v >= 0.0 && *v < 3.0);
65/// }
66/// ```
67#[must_use]
68#[derive(Debug, Clone)]
69pub struct KBinsDiscretizer<F> {
70    /// Number of bins.
71    n_bins: usize,
72    /// Encoding method.
73    encode: BinEncoding,
74    /// Binning strategy.
75    strategy: BinStrategy,
76    _marker: std::marker::PhantomData<F>,
77}
78
79impl<F: Float + Send + Sync + 'static> KBinsDiscretizer<F> {
80    /// Create a new `KBinsDiscretizer`.
81    pub fn new(n_bins: usize, encode: BinEncoding, strategy: BinStrategy) -> Self {
82        Self {
83            n_bins,
84            encode,
85            strategy,
86            _marker: std::marker::PhantomData,
87        }
88    }
89
90    /// Return the number of bins.
91    #[must_use]
92    pub fn n_bins(&self) -> usize {
93        self.n_bins
94    }
95
96    /// Return the encoding method.
97    #[must_use]
98    pub fn encode(&self) -> BinEncoding {
99        self.encode
100    }
101
102    /// Return the binning strategy.
103    #[must_use]
104    pub fn strategy(&self) -> BinStrategy {
105        self.strategy
106    }
107}
108
109impl<F: Float + Send + Sync + 'static> Default for KBinsDiscretizer<F> {
110    fn default() -> Self {
111        Self::new(5, BinEncoding::Ordinal, BinStrategy::Uniform)
112    }
113}
114
115// ---------------------------------------------------------------------------
116// FittedKBinsDiscretizer
117// ---------------------------------------------------------------------------
118
119/// A fitted K-bins discretizer holding per-feature bin edges.
120///
121/// Created by calling [`Fit::fit`] on a [`KBinsDiscretizer`].
122#[derive(Debug, Clone)]
123pub struct FittedKBinsDiscretizer<F> {
124    /// Bin edges per feature. `bin_edges[j]` has `n_bins + 1` edges.
125    bin_edges: Vec<Vec<F>>,
126    /// Number of bins.
127    n_bins: usize,
128    /// Encoding method.
129    encode: BinEncoding,
130}
131
132impl<F: Float + Send + Sync + 'static> FittedKBinsDiscretizer<F> {
133    /// Return the bin edges per feature.
134    #[must_use]
135    pub fn bin_edges(&self) -> &[Vec<F>] {
136        &self.bin_edges
137    }
138
139    /// Return the number of bins.
140    #[must_use]
141    pub fn n_bins(&self) -> usize {
142        self.n_bins
143    }
144
145    /// Return the encoding method.
146    #[must_use]
147    pub fn encode(&self) -> BinEncoding {
148        self.encode
149    }
150}
151
152// ---------------------------------------------------------------------------
153// Helpers
154// ---------------------------------------------------------------------------
155
156/// Assign a value to a bin index given sorted bin edges.
157fn assign_bin<F: Float>(value: F, edges: &[F]) -> usize {
158    let n_bins = edges.len() - 1;
159    if n_bins == 0 {
160        return 0;
161    }
162    // Binary search for the bin
163    for (i, edge) in edges.iter().enumerate().skip(1) {
164        if value < *edge {
165            return i - 1;
166        }
167    }
168    // Last bin for values >= last edge
169    n_bins - 1
170}
171
172/// Simple 1D k-means to find bin edges.
173fn kmeans_1d<F: Float>(values: &[F], n_bins: usize, max_iter: usize) -> Vec<F> {
174    let n = values.len();
175    if n <= n_bins || n_bins == 0 {
176        // Fallback to uniform
177        let min_v = values
178            .iter()
179            .copied()
180            .fold(F::infinity(), num_traits::Float::min);
181        let max_v = values
182            .iter()
183            .copied()
184            .fold(F::neg_infinity(), num_traits::Float::max);
185        return (0..=n_bins)
186            .map(|i| min_v + (max_v - min_v) * F::from(i).unwrap() / F::from(n_bins).unwrap())
187            .collect();
188    }
189
190    // Initialize centroids using uniform spacing
191    let min_v = values
192        .iter()
193        .copied()
194        .fold(F::infinity(), num_traits::Float::min);
195    let max_v = values
196        .iter()
197        .copied()
198        .fold(F::neg_infinity(), num_traits::Float::max);
199
200    let mut centroids: Vec<F> = (0..n_bins)
201        .map(|i| {
202            min_v
203                + (max_v - min_v) * (F::from(i).unwrap() + F::from(0.5).unwrap())
204                    / F::from(n_bins).unwrap()
205        })
206        .collect();
207
208    for _ in 0..max_iter {
209        // Assign each value to nearest centroid
210        let mut sums = vec![F::zero(); n_bins];
211        let mut counts = vec![0usize; n_bins];
212
213        for &v in values {
214            let mut best_c = 0;
215            let mut best_dist = F::infinity();
216            for (c, &centroid) in centroids.iter().enumerate() {
217                let d = (v - centroid).abs();
218                if d < best_dist {
219                    best_dist = d;
220                    best_c = c;
221                }
222            }
223            sums[best_c] = sums[best_c] + v;
224            counts[best_c] += 1;
225        }
226
227        // Update centroids
228        let mut converged = true;
229        for c in 0..n_bins {
230            if counts[c] > 0 {
231                let new_centroid = sums[c] / F::from(counts[c]).unwrap();
232                if (new_centroid - centroids[c]).abs() > F::from(1e-10).unwrap_or_else(F::epsilon) {
233                    converged = false;
234                }
235                centroids[c] = new_centroid;
236            }
237        }
238        if converged {
239            break;
240        }
241    }
242
243    // Sort centroids and compute edges as midpoints
244    centroids.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
245
246    let mut edges = Vec::with_capacity(n_bins + 1);
247    edges.push(min_v);
248    for i in 0..n_bins - 1 {
249        let mid = (centroids[i] + centroids[i + 1]) / (F::one() + F::one());
250        edges.push(mid);
251    }
252    edges.push(max_v);
253
254    edges
255}
256
257// ---------------------------------------------------------------------------
258// Trait implementations
259// ---------------------------------------------------------------------------
260
261impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for KBinsDiscretizer<F> {
262    type Fitted = FittedKBinsDiscretizer<F>;
263    type Error = FerroError;
264
265    /// Fit by computing bin edges for each feature.
266    ///
267    /// # Errors
268    ///
269    /// - [`FerroError::InsufficientSamples`] if the input has fewer than 2 rows.
270    /// - [`FerroError::InvalidParameter`] if `n_bins` < 2.
271    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedKBinsDiscretizer<F>, FerroError> {
272        let n_samples = x.nrows();
273        if n_samples < 2 {
274            return Err(FerroError::InsufficientSamples {
275                required: 2,
276                actual: n_samples,
277                context: "KBinsDiscretizer::fit".into(),
278            });
279        }
280        if self.n_bins < 2 {
281            return Err(FerroError::InvalidParameter {
282                name: "n_bins".into(),
283                reason: "n_bins must be at least 2".into(),
284            });
285        }
286
287        let n_features = x.ncols();
288        let mut bin_edges = Vec::with_capacity(n_features);
289
290        for j in 0..n_features {
291            let mut col_vals: Vec<F> = x.column(j).iter().copied().collect();
292            col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
293
294            let min_val = col_vals[0];
295            let max_val = col_vals[col_vals.len() - 1];
296
297            let edges = match self.strategy {
298                BinStrategy::Uniform => (0..=self.n_bins)
299                    .map(|i| {
300                        min_val
301                            + (max_val - min_val) * F::from(i).unwrap()
302                                / F::from(self.n_bins).unwrap()
303                    })
304                    .collect(),
305                BinStrategy::Quantile => {
306                    let n = col_vals.len();
307                    (0..=self.n_bins)
308                        .map(|i| {
309                            let frac = F::from(i).unwrap() / F::from(self.n_bins).unwrap();
310                            let pos = frac * F::from(n.saturating_sub(1)).unwrap();
311                            let lo = pos.floor().to_usize().unwrap_or(0).min(n - 1);
312                            let hi = pos.ceil().to_usize().unwrap_or(0).min(n - 1);
313                            let f = pos - F::from(lo).unwrap();
314                            col_vals[lo] * (F::one() - f) + col_vals[hi] * f
315                        })
316                        .collect()
317                }
318                BinStrategy::KMeans => kmeans_1d(&col_vals, self.n_bins, 100),
319            };
320
321            bin_edges.push(edges);
322        }
323
324        Ok(FittedKBinsDiscretizer {
325            bin_edges,
326            n_bins: self.n_bins,
327            encode: self.encode,
328        })
329    }
330}
331
332impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedKBinsDiscretizer<F> {
333    type Output = Array2<F>;
334    type Error = FerroError;
335
336    /// Discretize features into bin indices or one-hot vectors.
337    ///
338    /// # Errors
339    ///
340    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
341    /// from the number of features seen during fitting.
342    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
343        let n_features = self.bin_edges.len();
344        if x.ncols() != n_features {
345            return Err(FerroError::ShapeMismatch {
346                expected: vec![x.nrows(), n_features],
347                actual: vec![x.nrows(), x.ncols()],
348                context: "FittedKBinsDiscretizer::transform".into(),
349            });
350        }
351
352        let n_samples = x.nrows();
353
354        match self.encode {
355            BinEncoding::Ordinal => {
356                let mut out = Array2::zeros((n_samples, n_features));
357                for j in 0..n_features {
358                    let edges = &self.bin_edges[j];
359                    for i in 0..n_samples {
360                        let bin = assign_bin(x[[i, j]], edges);
361                        out[[i, j]] = F::from(bin).unwrap_or_else(F::zero);
362                    }
363                }
364                Ok(out)
365            }
366            BinEncoding::OneHot => {
367                let n_out = n_features * self.n_bins;
368                let mut out = Array2::zeros((n_samples, n_out));
369                for j in 0..n_features {
370                    let edges = &self.bin_edges[j];
371                    let col_offset = j * self.n_bins;
372                    for i in 0..n_samples {
373                        let bin = assign_bin(x[[i, j]], edges);
374                        out[[i, col_offset + bin]] = F::one();
375                    }
376                }
377                Ok(out)
378            }
379        }
380    }
381}
382
383/// Implement `Transform` on the unfitted discretizer.
384impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for KBinsDiscretizer<F> {
385    type Output = Array2<F>;
386    type Error = FerroError;
387
388    /// Always returns an error — must be fitted first.
389    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
390        Err(FerroError::InvalidParameter {
391            name: "KBinsDiscretizer".into(),
392            reason: "discretizer must be fitted before calling transform; use fit() first".into(),
393        })
394    }
395}
396
397impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for KBinsDiscretizer<F> {
398    type FitError = FerroError;
399
400    /// Fit and transform in one step.
401    ///
402    /// # Errors
403    ///
404    /// Returns an error if fitting fails.
405    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
406        let fitted = self.fit(x, &())?;
407        fitted.transform(x)
408    }
409}
410
411// ---------------------------------------------------------------------------
412// Tests
413// ---------------------------------------------------------------------------
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use approx::assert_abs_diff_eq;
419    use ndarray::array;
420
421    #[test]
422    fn test_kbins_ordinal_uniform() {
423        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
424        let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]];
425        let fitted = disc.fit(&x, &()).unwrap();
426        let out = fitted.transform(&x).unwrap();
427        assert_eq!(out.ncols(), 1);
428        // Check bin assignments
429        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); // 0.0 → bin 0
430        assert_abs_diff_eq!(out[[5, 0]], 2.0, epsilon = 1e-10); // 5.0 → bin 2 (last)
431    }
432
433    #[test]
434    fn test_kbins_onehot_uniform() {
435        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::OneHot, BinStrategy::Uniform);
436        let x = array![[0.0], [2.5], [5.0]];
437        let fitted = disc.fit(&x, &()).unwrap();
438        let out = fitted.transform(&x).unwrap();
439        // 3 bins → 3 columns per feature
440        assert_eq!(out.ncols(), 3);
441        // Each row should have exactly one 1.0
442        for i in 0..out.nrows() {
443            let row_sum: f64 = out.row(i).iter().sum();
444            assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
445        }
446    }
447
448    #[test]
449    fn test_kbins_quantile_strategy() {
450        let disc = KBinsDiscretizer::<f64>::new(4, BinEncoding::Ordinal, BinStrategy::Quantile);
451        let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0]];
452        let fitted = disc.fit(&x, &()).unwrap();
453        let out = fitted.transform(&x).unwrap();
454        // All values should be valid bin indices
455        for v in &out {
456            assert!(*v >= 0.0 && *v < 4.0);
457        }
458    }
459
460    #[test]
461    fn test_kbins_kmeans_strategy() {
462        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::KMeans);
463        let x = array![
464            [0.0],
465            [0.1],
466            [0.2],
467            [5.0],
468            [5.1],
469            [5.2],
470            [10.0],
471            [10.1],
472            [10.2]
473        ];
474        let fitted = disc.fit(&x, &()).unwrap();
475        let out = fitted.transform(&x).unwrap();
476        // Values should be valid bin indices
477        for v in &out {
478            assert!(*v >= 0.0 && *v < 3.0);
479        }
480    }
481
482    #[test]
483    fn test_kbins_multi_feature() {
484        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
485        let x = array![[0.0, 10.0], [2.5, 15.0], [5.0, 20.0]];
486        let fitted = disc.fit(&x, &()).unwrap();
487        let out = fitted.transform(&x).unwrap();
488        assert_eq!(out.ncols(), 2);
489    }
490
491    #[test]
492    fn test_kbins_bin_edges() {
493        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
494        let x = array![[0.0], [3.0], [6.0]];
495        let fitted = disc.fit(&x, &()).unwrap();
496        let edges = &fitted.bin_edges()[0];
497        // 4 edges for 3 bins: [0, 2, 4, 6]
498        assert_eq!(edges.len(), 4);
499        assert_abs_diff_eq!(edges[0], 0.0, epsilon = 1e-10);
500        assert_abs_diff_eq!(edges[3], 6.0, epsilon = 1e-10);
501    }
502
503    #[test]
504    fn test_kbins_fit_transform() {
505        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
506        let x = array![[0.0], [2.5], [5.0]];
507        let out = disc.fit_transform(&x).unwrap();
508        assert_eq!(out.ncols(), 1);
509    }
510
511    #[test]
512    fn test_kbins_insufficient_samples_error() {
513        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
514        let x = array![[1.0]];
515        assert!(disc.fit(&x, &()).is_err());
516    }
517
518    #[test]
519    fn test_kbins_too_few_bins_error() {
520        let disc = KBinsDiscretizer::<f64>::new(1, BinEncoding::Ordinal, BinStrategy::Uniform);
521        let x = array![[0.0], [1.0]];
522        assert!(disc.fit(&x, &()).is_err());
523    }
524
525    #[test]
526    fn test_kbins_shape_mismatch_error() {
527        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
528        let x_train = array![[0.0, 1.0], [2.0, 3.0]];
529        let fitted = disc.fit(&x_train, &()).unwrap();
530        let x_bad = array![[1.0, 2.0, 3.0]];
531        assert!(fitted.transform(&x_bad).is_err());
532    }
533
534    #[test]
535    fn test_kbins_unfitted_error() {
536        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
537        let x = array![[0.0]];
538        assert!(disc.transform(&x).is_err());
539    }
540
541    #[test]
542    fn test_kbins_default() {
543        let disc = KBinsDiscretizer::<f64>::default();
544        assert_eq!(disc.n_bins(), 5);
545        assert_eq!(disc.encode(), BinEncoding::Ordinal);
546        assert_eq!(disc.strategy(), BinStrategy::Uniform);
547    }
548
549    #[test]
550    fn test_kbins_ordinal_values_in_range() {
551        let disc = KBinsDiscretizer::<f64>::new(5, BinEncoding::Ordinal, BinStrategy::Uniform);
552        let x = array![[0.0], [2.5], [5.0], [7.5], [10.0]];
553        let fitted = disc.fit(&x, &()).unwrap();
554        let out = fitted.transform(&x).unwrap();
555        for v in &out {
556            assert!(*v >= 0.0 && *v < 5.0, "Bin index {v} out of range");
557        }
558    }
559}