Skip to main content

ferrolearn_preprocess/
kbins_discretizer.rs

1//! K-bins discretizer: bin continuous features into discrete intervals.
2//!
3//! [`KBinsDiscretizer`] transforms continuous features into discrete bins.
4//! Each feature is independently binned according to one of the following
5//! strategies:
6//!
7//! - [`BinStrategy::Uniform`] — equal-width bins.
8//! - [`BinStrategy::Quantile`] — bins with equal numbers of samples.
9//! - [`BinStrategy::KMeans`] — bins based on 1D k-means clustering.
10//!
11//! The output can be ordinal-encoded (integers 0..k-1) or one-hot encoded.
12//!
13//! ## REQ status
14//!
15//! Translation target: scikit-learn 1.5.2 `class KBinsDiscretizer`
16//! (`sklearn/preprocessing/_discretization.py:184`). Tracking: #1375. Each REQ
17//! is BINARY — SHIPPED (impl + non-test consumer + tests + green verification)
18//! or NOT-STARTED (with a concrete open blocker).
19//!
20//! | REQ | Scope | Status | Evidence / Blocker |
21//! |-----|-------|--------|--------------------|
22//! | REQ-1 | Uniform + Quantile bin EDGES + ordinal/onehot transform VALUES (non-degenerate features) | SHIPPED | [`KBinsDiscretizer`] `fit` — Uniform=`np.linspace` (`_discretization.py:271`), Quantile=`np.percentile` (`:276`); `assign_bin` ≡ `searchsorted(edges[1:-1], side="right")` (`:377`); oracle value tests in `tests/divergence_kbins_discretizer.rs`. Consumer: re-export `lib.rs:151` |
23//! | REQ-2 | KMeans bin edges/transform via faithful sklearn `KMeans` Lloyd | SHIPPED | `kmeans_1d` replicates sklearn `KMeans` Lloyd incl. mean-centering (`_kmeans.py:1486-1546`), `\|\|C\|\|²-2xC` assignment + lowest-index tie-break (`_k_means_lloyd.pyx:196-213`), empty-cluster RELOCATION (`_k_means_common.pyx:_relocate_empty_clusters_dense`), var-scaled `tol=mean(var)*1e-4` (`_tolerance`), strict/center-shift convergence, max_iter=300, deterministic uniform-center init (`_discretization.py:289-300`); matches sklearn bit-for-bit on well-separated + moderately-separated + empty-init-cluster data (km1 #2321, km2 #2322, 3 green oracle fixtures). RESIDUAL: ~0.1% of well-spread continuous data converges to a different valid Lloyd local optimum (BLAS-gemm vs scalar float tie-break) — honestly pinned `divergence_km3_blas_gemm_local_optimum` (#2321 follow-up) |
24//! | REQ-3 | Error/parameter contracts (n_samples<2, n_bins<2, transform ncols, unfitted) | SHIPPED (scoped) | [`KBinsDiscretizer::fit`]/[`FittedKBinsDiscretizer`] `transform`; in-module + divergence error tests |
25//! | REQ-4 | Constant feature → bin 0 + per-feature `n_bins_=1` (`col_min==col_max`) | SHIPPED | `fit` sets `bin_edges=[-inf,+inf]` + `n_bins_per_feature[j]=1`; `assign_bin` → bin 0 (mirrors `_discretization.py:262-268`); 3 oracle tests — was DIV-1 #1376, fixed |
26//! | REQ-5 | Small-bin removal (quantile/kmeans near-duplicate edge collapse → per-feature `n_bins_`) + onehot variable width | SHIPPED | `fit` collapse `gap > 1e-8` (mirrors `ediff1d > 1e-8` `:302-312`); `transform` onehot width = `sum(n_bins_per_feature)` cumulative offsets; oracle tests (quantile collapse, onehot variable width, threshold boundary) — was DIV-2 #1377, fixed |
27//! | REQ-6 | `subsample` (default 200000) + `random_state` resample for quantile/kmeans | NOT-STARTED | sklearn `_discretization.py:242-249` — blocker #1379 |
28//! | REQ-7 | `n_bins` as per-feature array + `_validate_n_bins` | NOT-STARTED | scalar only; sklearn `_discretization.py:329-352` — blocker #1380 |
29//! | REQ-8 | encode='onehot' SPARSE default + sklearn ctor defaults (encode=onehot, strategy=quantile) | NOT-STARTED | dense only, defaults Ordinal/Uniform; sklearn `_discretization.py:185,320` — blocker #1381 |
30//! | REQ-9 | `dtype` param + `sample_weight` (weighted percentile/kmeans) | NOT-STARTED | sklearn `_discretization.py:228,235,295` — blocker #1382 |
31//! | REQ-10 | `inverse_transform` | NOT-STARTED | sklearn `_discretization.py:393` — blocker #1383 |
32//! | REQ-11 | `get_feature_names_out` + `bin_edges_`/`n_bins_` attr names + `PipelineTransformer` impl | NOT-STARTED | absent — blocker #1384 |
33//! | REQ-12 | PyO3 binding | NOT-STARTED | no `ferrolearn-python` registration — blocker #1385 |
34//! | REQ-13 | ferray substrate | NOT-STARTED | dense `Array2` + `num_traits::Float` only — blocker #1386 |
35//! | REQ-14 | KMeans empty-cluster relocation (degenerate/duplicate-heavy data) | SHIPPED | `kmeans_1d` now RELOCATES empty clusters (translates `_relocate_empty_clusters_dense`, `_k_means_common.pyx:170-214`): farthest-point reassignment + heaviest-cluster fallback in `_average_centers`; km2 (#2322) + `green_kmeans_empty_cluster_relocation_k3`. Residual degenerate carve-out (more clusters than distinct samples / multi-empty argpartition tie order) folded into REQ-2's pinned BLAS-gemm residual — was carve-out #1378 |
36
37use ferrolearn_core::error::FerroError;
38use ferrolearn_core::traits::{Fit, FitTransform, Transform};
39use ndarray::Array2;
40use num_traits::Float;
41
42// ---------------------------------------------------------------------------
43// BinStrategy
44// ---------------------------------------------------------------------------
45
46/// Strategy for computing bin edges.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum BinStrategy {
49    /// Equal-width bins.
50    Uniform,
51    /// Equal-frequency bins (quantile-based).
52    Quantile,
53    /// Bins based on 1D k-means clustering.
54    KMeans,
55}
56
57/// Encoding method for the output.
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum BinEncoding {
60    /// Ordinal encoding: each value is replaced by its bin index (0..n_bins-1).
61    Ordinal,
62    /// One-hot encoding: each bin becomes a separate binary column.
63    OneHot,
64}
65
66// ---------------------------------------------------------------------------
67// KBinsDiscretizer (unfitted)
68// ---------------------------------------------------------------------------
69
70/// An unfitted K-bins discretizer.
71///
72/// Calling [`Fit::fit`] computes the bin edges for each feature and returns a
73/// [`FittedKBinsDiscretizer`].
74///
75/// # Examples
76///
77/// ```
78/// use ferrolearn_preprocess::kbins_discretizer::{KBinsDiscretizer, BinStrategy, BinEncoding};
79/// use ferrolearn_core::traits::{Fit, Transform};
80/// use ndarray::array;
81///
82/// let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
83/// let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]];
84/// let fitted = disc.fit(&x, &()).unwrap();
85/// let out = fitted.transform(&x).unwrap();
86/// // Values should be in {0.0, 1.0, 2.0}
87/// for v in out.iter() {
88///     assert!(*v >= 0.0 && *v < 3.0);
89/// }
90/// ```
91#[must_use]
92#[derive(Debug, Clone)]
93pub struct KBinsDiscretizer<F> {
94    /// Number of bins.
95    n_bins: usize,
96    /// Encoding method.
97    encode: BinEncoding,
98    /// Binning strategy.
99    strategy: BinStrategy,
100    _marker: std::marker::PhantomData<F>,
101}
102
103impl<F: Float + Send + Sync + 'static> KBinsDiscretizer<F> {
104    /// Create a new `KBinsDiscretizer`.
105    pub fn new(n_bins: usize, encode: BinEncoding, strategy: BinStrategy) -> Self {
106        Self {
107            n_bins,
108            encode,
109            strategy,
110            _marker: std::marker::PhantomData,
111        }
112    }
113
114    /// Return the number of bins.
115    #[must_use]
116    pub fn n_bins(&self) -> usize {
117        self.n_bins
118    }
119
120    /// Return the encoding method.
121    #[must_use]
122    pub fn encode(&self) -> BinEncoding {
123        self.encode
124    }
125
126    /// Return the binning strategy.
127    #[must_use]
128    pub fn strategy(&self) -> BinStrategy {
129        self.strategy
130    }
131}
132
133impl<F: Float + Send + Sync + 'static> Default for KBinsDiscretizer<F> {
134    fn default() -> Self {
135        Self::new(5, BinEncoding::Ordinal, BinStrategy::Uniform)
136    }
137}
138
139// ---------------------------------------------------------------------------
140// FittedKBinsDiscretizer
141// ---------------------------------------------------------------------------
142
143/// A fitted K-bins discretizer holding per-feature bin edges.
144///
145/// Created by calling [`Fit::fit`] on a [`KBinsDiscretizer`].
146#[derive(Debug, Clone)]
147pub struct FittedKBinsDiscretizer<F> {
148    /// Bin edges per feature. `bin_edges[j]` has `n_bins_per_feature[j] + 1`
149    /// edges (which may be fewer than `n_bins + 1` for constant or collapsed
150    /// features).
151    bin_edges: Vec<Vec<F>>,
152    /// Per-feature bin count (mirrors sklearn `n_bins_`). A constant feature
153    /// collapses to 1 bin; quantile/kmeans features may shrink when
154    /// near-duplicate edges are removed.
155    n_bins_per_feature: Vec<usize>,
156    /// Requested number of bins (the global `n_bins` argument).
157    n_bins: usize,
158    /// Encoding method.
159    encode: BinEncoding,
160}
161
162impl<F: Float + Send + Sync + 'static> FittedKBinsDiscretizer<F> {
163    /// Return the bin edges per feature.
164    #[must_use]
165    pub fn bin_edges(&self) -> &[Vec<F>] {
166        &self.bin_edges
167    }
168
169    /// Return the per-feature bin count (sklearn `n_bins_`).
170    #[must_use]
171    pub fn n_bins_per_feature(&self) -> &[usize] {
172        &self.n_bins_per_feature
173    }
174
175    /// Return the requested number of bins.
176    #[must_use]
177    pub fn n_bins(&self) -> usize {
178        self.n_bins
179    }
180
181    /// Return the encoding method.
182    #[must_use]
183    pub fn encode(&self) -> BinEncoding {
184        self.encode
185    }
186}
187
188// ---------------------------------------------------------------------------
189// Helpers
190// ---------------------------------------------------------------------------
191
192/// Assign a value to a bin index given sorted bin edges.
193fn assign_bin<F: Float>(value: F, edges: &[F]) -> usize {
194    let n_bins = edges.len() - 1;
195    if n_bins == 0 {
196        return 0;
197    }
198    // Binary search for the bin
199    for (i, edge) in edges.iter().enumerate().skip(1) {
200        if value < *edge {
201            return i - 1;
202        }
203    }
204    // Last bin for values >= last edge
205    n_bins - 1
206}
207
208/// 1-D k-means bin edges, faithfully replicating scikit-learn 1.5.2's
209/// `KBinsDiscretizer(strategy="kmeans")` path (`sklearn/preprocessing/_discretization.py:285-300`).
210///
211/// sklearn runs ONE `KMeans(n_clusters=n_bins, init=uniform-bin-centers, n_init=1)`
212/// Lloyd fit on the column, sorts the resulting centers, and builds
213/// `bin_edges = np.r_[col_min, (centers[1:]+centers[:-1])*0.5, col_max]`.
214///
215/// This reproduces the full Lloyd machinery used by `KMeans.fit`
216/// (`sklearn/cluster/_kmeans.py` + `_k_means_lloyd.pyx` + `_k_means_common.pyx`):
217///
218/// - **Mean-centering** (`_kmeans.py:1486-1493,1543-1546`): `KMeans.fit` subtracts
219///   `X.mean(axis=0)` from both the data and the init "for more accurate distance
220///   computations", runs Lloyd on the centered data, then adds the mean back to the
221///   centers. The distance argmin is computed via `||C||² - 2·x·C` which is NOT
222///   translation-invariant in floating point, so this shift is load-bearing for
223///   the converged local optimum (not just numerical hygiene).
224/// - **Assignment** (`_k_means_lloyd.pyx:196-213`): each point goes to the center
225///   minimizing `pairwise[j] = ||C_j||² - 2·x·C_j` (the `x²` term is dropped since it
226///   is constant per point); ties resolve to the LOWEST center index (strict `<`).
227/// - **Center update** (`_k_means_lloyd.pyx:215-218`, `_k_means_common.pyx:_average_centers`):
228///   each new center is the mean of its assigned points.
229/// - **Empty-cluster relocation** (`_k_means_common.pyx:_relocate_empty_clusters_dense`,
230///   `:170-214`): for each empty cluster, take the points FARTHEST from their own
231///   assigned center (largest squared distance, descending) and move one into the
232///   empty cluster; the donor loses it. Skipped when `max(distances) == 0` (more
233///   clusters than distinct samples). Any cluster still empty after relocation is
234///   placed at the location of the heaviest cluster (`_average_centers` else-branch).
235/// - **Convergence** (`_kmeans.py:704-755`): stop on strict convergence (no label
236///   changed) OR when `center_shift_total = Σ_j (C_new[j] - C_old[j])² <= tol`, with
237///   `tol = mean(var(column)) * 1e-4` (`_tolerance`, `_kmeans.py:286-294`,
238///   population variance), OR `max_iter = 300`.
239///
240/// All intermediate arithmetic is done in `f64` (matching numpy's float64 default)
241/// regardless of `F`, then converted back to `F` for the edges.
242fn kmeans_1d<F: Float>(values: &[F], n_bins: usize) -> Vec<F> {
243    let n = values.len();
244    // Column min/max in F (the outer edges; sklearn uses the un-centered col_min/col_max).
245    let min_v = values
246        .iter()
247        .copied()
248        .fold(F::infinity(), num_traits::Float::min);
249    let max_v = values
250        .iter()
251        .copied()
252        .fold(F::neg_infinity(), num_traits::Float::max);
253
254    if n == 0 || n_bins == 0 {
255        // Degenerate: fall back to a uniform partition over [min, max].
256        return (0..=n_bins)
257            .map(|i| {
258                min_v
259                    + (max_v - min_v) * F::from(i).unwrap_or_else(F::zero)
260                        / fdiv_or_one::<F>(n_bins)
261            })
262            .collect();
263    }
264
265    // Work entirely in f64 (numpy float64 default).
266    let col: Vec<f64> = values.iter().map(|&v| v.to_f64().unwrap_or(0.0)).collect();
267    let col_min = col.iter().copied().fold(f64::INFINITY, f64::min);
268    let col_max = col.iter().copied().fold(f64::NEG_INFINITY, f64::max);
269
270    // Variance-scaled tolerance: tol = mean(var(column)) * 1e-4 (population variance,
271    // ddof=0), == sklearn `_tolerance(X, 1e-4)` (`_kmeans.py:286-294`).
272    let mean_all: f64 = col.iter().sum::<f64>() / (n as f64);
273    let var: f64 = col
274        .iter()
275        .map(|&x| (x - mean_all) * (x - mean_all))
276        .sum::<f64>()
277        / (n as f64);
278    let tol = var * 1e-4;
279
280    // KMeans.fit mean-centers X and the init (`_kmeans.py:1486-1493`).
281    let x_mean = mean_all;
282    let xc: Vec<f64> = col.iter().map(|&x| x - x_mean).collect();
283
284    // Uniform-bin-centers init (`_discretization.py:289-290`), then shifted by -x_mean.
285    let mut centers: Vec<f64> = (0..n_bins)
286        .map(|i| {
287            let lo = col_min + (col_max - col_min) * (i as f64) / (n_bins as f64);
288            let hi = col_min + (col_max - col_min) * ((i + 1) as f64) / (n_bins as f64);
289            (lo + hi) * 0.5 - x_mean
290        })
291        .collect();
292
293    let mut labels = vec![usize::MAX; n];
294    let mut labels_old = vec![usize::MAX; n];
295    let max_iter = 300usize;
296
297    for _ in 0..max_iter {
298        let centers_old = centers.clone();
299
300        // --- Assignment: argmin_j (||C_j||² - 2·x·C_j), ties -> lowest index. ---
301        let csq: Vec<f64> = centers_old.iter().map(|&c| c * c).collect();
302        for i in 0..n {
303            let xi = xc[i];
304            let mut best_j = 0usize;
305            let mut best = csq[0] - 2.0 * xi * centers_old[0];
306            for j in 1..n_bins {
307                let d = csq[j] - 2.0 * xi * centers_old[j];
308                if d < best {
309                    best = d;
310                    best_j = j;
311                }
312            }
313            labels[i] = best_j;
314        }
315
316        // --- Accumulate per-cluster sum and weight (count). ---
317        let mut acc = vec![0.0f64; n_bins];
318        let mut wic = vec![0.0f64; n_bins];
319        for i in 0..n {
320            acc[labels[i]] += xc[i];
321            wic[labels[i]] += 1.0;
322        }
323
324        // --- Empty-cluster relocation (`_relocate_empty_clusters_dense`). ---
325        let empty: Vec<usize> = (0..n_bins).filter(|&j| wic[j] == 0.0).collect();
326        if !empty.is_empty() {
327            // distances[i] = (xc[i] - centers_old[labels[i]])²
328            let distances: Vec<f64> = (0..n)
329                .map(|i| {
330                    let d = xc[i] - centers_old[labels[i]];
331                    d * d
332                })
333                .collect();
334            let max_dist = distances.iter().copied().fold(0.0f64, f64::max);
335            if max_dist != 0.0 {
336                let n_empty = empty.len();
337                // far_from_centers: the n_empty points with the largest distance,
338                // in descending order. sklearn uses
339                // `np.argpartition(distances, -n_empty)[:-n_empty-1:-1]`
340                // (`_k_means_common.pyx:190`), whose introselect partition + reverse
341                // slice resolves equal-distance ties toward the HIGHEST original index
342                // (e.g. distances `[.04,.04,0,0,.04,.04]`, n_empty=2 -> far `[5, 4]`).
343                // Match that by breaking distance ties on DESCENDING index, so the
344                // duplicate-heavy `n_bins > n_distinct` relocation dumps the same
345                // points onto the empty clusters as sklearn (the centers then coincide
346                // and collapse under small-bin removal to sklearn's `n_bins_`).
347                let mut order: Vec<usize> = (0..n).collect();
348                order.sort_by(|&a, &b| {
349                    distances[b]
350                        .partial_cmp(&distances[a])
351                        .unwrap_or(std::cmp::Ordering::Equal)
352                        .then(b.cmp(&a))
353                });
354                for idx in 0..n_empty {
355                    let new_cluster = empty[idx];
356                    let far = order[idx];
357                    let old_cluster = labels[far];
358                    acc[old_cluster] -= xc[far];
359                    acc[new_cluster] = xc[far];
360                    wic[new_cluster] = 1.0;
361                    wic[old_cluster] -= 1.0;
362                }
363            }
364        }
365
366        // --- Average; clusters still empty -> location of the heaviest cluster
367        //     (`_average_centers` else-branch). ---
368        let mut argmax_w = 0usize;
369        for j in 1..n_bins {
370            if wic[j] > wic[argmax_w] {
371                argmax_w = j;
372            }
373        }
374        for j in 0..n_bins {
375            if wic[j] > 0.0 {
376                centers[j] = acc[j] / wic[j];
377            } else if wic[argmax_w] > 0.0 {
378                centers[j] = acc[argmax_w] / wic[argmax_w];
379            } else {
380                centers[j] = centers_old[j];
381            }
382        }
383
384        // --- Convergence (`_kmeans.py:724-739`). ---
385        if labels == labels_old {
386            // Strict convergence: no label changed.
387            break;
388        }
389        let center_shift_tot: f64 = (0..n_bins)
390            .map(|j| {
391                let d = centers[j] - centers_old[j];
392                d * d
393            })
394            .sum();
395        if center_shift_tot <= tol {
396            break;
397        }
398        labels_old.copy_from_slice(&labels);
399    }
400
401    // Shift centers back (`_kmeans.py:1546`) and sort (`_discretization.py:298`).
402    let mut centers_out: Vec<f64> = centers.iter().map(|&c| c + x_mean).collect();
403    centers_out.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
404
405    // edges = [col_min, midpoints.., col_max] (`_discretization.py:299-300`).
406    let mut edges = Vec::with_capacity(n_bins + 1);
407    edges.push(min_v);
408    for i in 0..n_bins.saturating_sub(1) {
409        let mid = (centers_out[i] + centers_out[i + 1]) * 0.5;
410        edges.push(F::from(mid).unwrap_or(min_v));
411    }
412    edges.push(max_v);
413
414    edges
415}
416
417/// `F::from(n)` as a divisor, falling back to `F::one()` when `n == 0` to avoid a
418/// division by zero in the degenerate fallback path.
419fn fdiv_or_one<F: Float>(n: usize) -> F {
420    if n == 0 {
421        F::one()
422    } else {
423        F::from(n).unwrap_or_else(F::one)
424    }
425}
426
427// ---------------------------------------------------------------------------
428// Trait implementations
429// ---------------------------------------------------------------------------
430
431impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for KBinsDiscretizer<F> {
432    type Fitted = FittedKBinsDiscretizer<F>;
433    type Error = FerroError;
434
435    /// Fit by computing bin edges for each feature.
436    ///
437    /// # Errors
438    ///
439    /// - [`FerroError::InsufficientSamples`] if the input has fewer than 2 rows.
440    /// - [`FerroError::InvalidParameter`] if `n_bins` < 2.
441    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedKBinsDiscretizer<F>, FerroError> {
442        let n_samples = x.nrows();
443        if n_samples < 2 {
444            return Err(FerroError::InsufficientSamples {
445                required: 2,
446                actual: n_samples,
447                context: "KBinsDiscretizer::fit".into(),
448            });
449        }
450        if self.n_bins < 2 {
451            return Err(FerroError::InvalidParameter {
452                name: "n_bins".into(),
453                reason: "n_bins must be at least 2".into(),
454            });
455        }
456
457        let n_features = x.ncols();
458        let mut bin_edges = Vec::with_capacity(n_features);
459        let mut n_bins_per_feature = Vec::with_capacity(n_features);
460
461        for j in 0..n_features {
462            let mut col_vals: Vec<F> = x.column(j).iter().copied().collect();
463            col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
464
465            let min_val = col_vals[0];
466            let max_val = col_vals[col_vals.len() - 1];
467
468            // Constant feature (sklearn :262-268): collapse to a single bin
469            // spanning [-inf, +inf] so transform maps every value to bin 0.
470            if min_val == max_val {
471                bin_edges.push(vec![F::neg_infinity(), F::infinity()]);
472                n_bins_per_feature.push(1);
473                continue;
474            }
475
476            let edges: Vec<F> = match self.strategy {
477                BinStrategy::Uniform => (0..=self.n_bins)
478                    .map(|i| {
479                        min_val
480                            + (max_val - min_val) * F::from(i).unwrap()
481                                / F::from(self.n_bins).unwrap()
482                    })
483                    .collect(),
484                BinStrategy::Quantile => {
485                    let n = col_vals.len();
486                    (0..=self.n_bins)
487                        .map(|i| {
488                            let frac = F::from(i).unwrap() / F::from(self.n_bins).unwrap();
489                            let pos = frac * F::from(n.saturating_sub(1)).unwrap();
490                            let lo = pos.floor().to_usize().unwrap_or(0).min(n - 1);
491                            let hi = pos.ceil().to_usize().unwrap_or(0).min(n - 1);
492                            let f = pos - F::from(lo).unwrap();
493                            col_vals[lo] * (F::one() - f) + col_vals[hi] * f
494                        })
495                        .collect()
496                }
497                BinStrategy::KMeans => kmeans_1d(&col_vals, self.n_bins),
498            };
499
500            // Small-bin removal for quantile and kmeans only (sklearn
501            // :302-312): keep the first edge, then keep each subsequent edge
502            // only if its gap to the previously kept edge exceeds 1e-8.
503            // Uniform is never collapsed.
504            match self.strategy {
505                BinStrategy::Quantile | BinStrategy::KMeans => {
506                    let tol = F::from(1e-8).unwrap_or_else(F::epsilon);
507                    let mut kept: Vec<F> = Vec::with_capacity(edges.len());
508                    for &edge in &edges {
509                        match kept.last() {
510                            None => kept.push(edge),
511                            Some(&last) if edge - last > tol => kept.push(edge),
512                            Some(_) => {}
513                        }
514                    }
515                    n_bins_per_feature.push(kept.len() - 1);
516                    bin_edges.push(kept);
517                }
518                BinStrategy::Uniform => {
519                    n_bins_per_feature.push(self.n_bins);
520                    bin_edges.push(edges);
521                }
522            }
523        }
524
525        Ok(FittedKBinsDiscretizer {
526            bin_edges,
527            n_bins_per_feature,
528            n_bins: self.n_bins,
529            encode: self.encode,
530        })
531    }
532}
533
534impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedKBinsDiscretizer<F> {
535    type Output = Array2<F>;
536    type Error = FerroError;
537
538    /// Discretize features into bin indices or one-hot vectors.
539    ///
540    /// # Errors
541    ///
542    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
543    /// from the number of features seen during fitting.
544    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
545        let n_features = self.bin_edges.len();
546        if x.ncols() != n_features {
547            return Err(FerroError::ShapeMismatch {
548                expected: vec![x.nrows(), n_features],
549                actual: vec![x.nrows(), x.ncols()],
550                context: "FittedKBinsDiscretizer::transform".into(),
551            });
552        }
553
554        let n_samples = x.nrows();
555
556        match self.encode {
557            BinEncoding::Ordinal => {
558                let mut out = Array2::zeros((n_samples, n_features));
559                for j in 0..n_features {
560                    let edges = &self.bin_edges[j];
561                    for i in 0..n_samples {
562                        let bin = assign_bin(x[[i, j]], edges);
563                        out[[i, j]] = F::from(bin).unwrap_or_else(F::zero);
564                    }
565                }
566                Ok(out)
567            }
568            BinEncoding::OneHot => {
569                // Output width is the sum of the per-feature bin counts, and
570                // feature `j`'s columns start at the cumulative sum of the
571                // preceding features' bin counts (sklearn one-hot over
572                // `n_bins_`).
573                let mut offsets = Vec::with_capacity(n_features + 1);
574                let mut acc = 0usize;
575                for &nb in &self.n_bins_per_feature {
576                    offsets.push(acc);
577                    acc += nb;
578                }
579                let n_out = acc;
580                let mut out = Array2::zeros((n_samples, n_out));
581                for j in 0..n_features {
582                    let edges = &self.bin_edges[j];
583                    let col_offset = offsets[j];
584                    for i in 0..n_samples {
585                        let bin = assign_bin(x[[i, j]], edges);
586                        out[[i, col_offset + bin]] = F::one();
587                    }
588                }
589                Ok(out)
590            }
591        }
592    }
593}
594
595/// Implement `Transform` on the unfitted discretizer.
596impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for KBinsDiscretizer<F> {
597    type Output = Array2<F>;
598    type Error = FerroError;
599
600    /// Always returns an error — must be fitted first.
601    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
602        Err(FerroError::InvalidParameter {
603            name: "KBinsDiscretizer".into(),
604            reason: "discretizer must be fitted before calling transform; use fit() first".into(),
605        })
606    }
607}
608
609impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for KBinsDiscretizer<F> {
610    type FitError = FerroError;
611
612    /// Fit and transform in one step.
613    ///
614    /// # Errors
615    ///
616    /// Returns an error if fitting fails.
617    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
618        let fitted = self.fit(x, &())?;
619        fitted.transform(x)
620    }
621}
622
623// ---------------------------------------------------------------------------
624// Tests
625// ---------------------------------------------------------------------------
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630    use approx::assert_abs_diff_eq;
631    use ndarray::array;
632
633    #[test]
634    fn test_kbins_ordinal_uniform() {
635        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
636        let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]];
637        let fitted = disc.fit(&x, &()).unwrap();
638        let out = fitted.transform(&x).unwrap();
639        assert_eq!(out.ncols(), 1);
640        // Check bin assignments
641        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); // 0.0 → bin 0
642        assert_abs_diff_eq!(out[[5, 0]], 2.0, epsilon = 1e-10); // 5.0 → bin 2 (last)
643    }
644
645    #[test]
646    fn test_kbins_onehot_uniform() {
647        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::OneHot, BinStrategy::Uniform);
648        let x = array![[0.0], [2.5], [5.0]];
649        let fitted = disc.fit(&x, &()).unwrap();
650        let out = fitted.transform(&x).unwrap();
651        // 3 bins → 3 columns per feature
652        assert_eq!(out.ncols(), 3);
653        // Each row should have exactly one 1.0
654        for i in 0..out.nrows() {
655            let row_sum: f64 = out.row(i).iter().sum();
656            assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
657        }
658    }
659
660    #[test]
661    fn test_kbins_quantile_strategy() {
662        let disc = KBinsDiscretizer::<f64>::new(4, BinEncoding::Ordinal, BinStrategy::Quantile);
663        let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0]];
664        let fitted = disc.fit(&x, &()).unwrap();
665        let out = fitted.transform(&x).unwrap();
666        // All values should be valid bin indices
667        for v in &out {
668            assert!(*v >= 0.0 && *v < 4.0);
669        }
670    }
671
672    #[test]
673    fn test_kbins_kmeans_strategy() {
674        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::KMeans);
675        let x = array![
676            [0.0],
677            [0.1],
678            [0.2],
679            [5.0],
680            [5.1],
681            [5.2],
682            [10.0],
683            [10.1],
684            [10.2]
685        ];
686        let fitted = disc.fit(&x, &()).unwrap();
687        let out = fitted.transform(&x).unwrap();
688        // Values should be valid bin indices
689        for v in &out {
690            assert!(*v >= 0.0 && *v < 3.0);
691        }
692    }
693
694    #[test]
695    fn test_kbins_multi_feature() {
696        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
697        let x = array![[0.0, 10.0], [2.5, 15.0], [5.0, 20.0]];
698        let fitted = disc.fit(&x, &()).unwrap();
699        let out = fitted.transform(&x).unwrap();
700        assert_eq!(out.ncols(), 2);
701    }
702
703    #[test]
704    fn test_kbins_bin_edges() {
705        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
706        let x = array![[0.0], [3.0], [6.0]];
707        let fitted = disc.fit(&x, &()).unwrap();
708        let edges = &fitted.bin_edges()[0];
709        // 4 edges for 3 bins: [0, 2, 4, 6]
710        assert_eq!(edges.len(), 4);
711        assert_abs_diff_eq!(edges[0], 0.0, epsilon = 1e-10);
712        assert_abs_diff_eq!(edges[3], 6.0, epsilon = 1e-10);
713    }
714
715    #[test]
716    fn test_kbins_fit_transform() {
717        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
718        let x = array![[0.0], [2.5], [5.0]];
719        let out = disc.fit_transform(&x).unwrap();
720        assert_eq!(out.ncols(), 1);
721    }
722
723    #[test]
724    fn test_kbins_insufficient_samples_error() {
725        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
726        let x = array![[1.0]];
727        assert!(disc.fit(&x, &()).is_err());
728    }
729
730    #[test]
731    fn test_kbins_too_few_bins_error() {
732        let disc = KBinsDiscretizer::<f64>::new(1, BinEncoding::Ordinal, BinStrategy::Uniform);
733        let x = array![[0.0], [1.0]];
734        assert!(disc.fit(&x, &()).is_err());
735    }
736
737    #[test]
738    fn test_kbins_shape_mismatch_error() {
739        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
740        let x_train = array![[0.0, 1.0], [2.0, 3.0]];
741        let fitted = disc.fit(&x_train, &()).unwrap();
742        let x_bad = array![[1.0, 2.0, 3.0]];
743        assert!(fitted.transform(&x_bad).is_err());
744    }
745
746    #[test]
747    fn test_kbins_unfitted_error() {
748        let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
749        let x = array![[0.0]];
750        assert!(disc.transform(&x).is_err());
751    }
752
753    #[test]
754    fn test_kbins_default() {
755        let disc = KBinsDiscretizer::<f64>::default();
756        assert_eq!(disc.n_bins(), 5);
757        assert_eq!(disc.encode(), BinEncoding::Ordinal);
758        assert_eq!(disc.strategy(), BinStrategy::Uniform);
759    }
760
761    #[test]
762    fn test_kbins_ordinal_values_in_range() {
763        let disc = KBinsDiscretizer::<f64>::new(5, BinEncoding::Ordinal, BinStrategy::Uniform);
764        let x = array![[0.0], [2.5], [5.0], [7.5], [10.0]];
765        let fitted = disc.fit(&x, &()).unwrap();
766        let out = fitted.transform(&x).unwrap();
767        for v in &out {
768            assert!(*v >= 0.0 && *v < 5.0, "Bin index {v} out of range");
769        }
770    }
771}