Skip to main content

ferrolearn_preprocess/
feature_selection.rs

1//! Feature selection transformers.
2//!
3//! This module provides three feature selection strategies:
4//!
5//! - [`VarianceThreshold`] — remove features whose variance falls below a
6//!   configurable threshold (default 0.0 removes zero-variance features).
7//! - [`SelectKBest`] — keep the *K* features with the highest ANOVA F-scores
8//!   computed against a class label vector.
9//! - [`SelectFromModel`] — keep features whose importance weight (provided by
10//!   a previously fitted model) exceeds a configurable threshold.
11//!
12//! All three implement the standard ferrolearn `Fit` / `Transform` pattern
13//! and integrate with the dynamic [`ferrolearn_core::pipeline::Pipeline`].
14//!
15//! ## REQ status
16//!
17//! Translation target: scikit-learn 1.5.2 `VarianceThreshold`
18//! (`sklearn/feature_selection/_variance_threshold.py`) + `SelectKBest` +
19//! `GenericUnivariateSelect` (`sklearn/feature_selection/_univariate_selection.py`).
20//! Tracking: #1424. Each REQ is BINARY — SHIPPED (impl + non-test consumer +
21//! tests + green verification) or NOT-STARTED (with a concrete open blocker).
22//! The basic [`SelectFromModel`] here duplicates the richer
23//! `select_from_model.rs::SelectFromModelExt` — its parity is covered by
24//! `.design/preprocess/select_from_model.md` (REQ-9).
25//!
26//! | REQ | Scope | Status | Evidence / Blocker |
27//! |-----|-------|--------|--------------------|
28//! | REQ-1 | [`VarianceThreshold`] mask (`var > threshold`, strict) + population variances | SHIPPED | `fit` Welford population variance matches `np.nanvar` + `_variance_threshold.py:133` on the common case; oracle tests in `tests/divergence_feature_selection.rs`. Consumer: re-export `lib.rs:110` + `PipelineTransformer` |
29//! | REQ-2 | [`SelectKBest`] ANOVA F-score VALUES (finite / non-constant) | SHIPPED | `anova_f_scores` matches `f_oneway` `_univariate_selection.py:43-117`; oracle score tests (tol 1e-9) |
30//! | REQ-3 | [`SelectKBest`] top-k SELECTION (tie-break + constant-feature + k-boundary) | SHIPPED | matches sklearn `mask[argsort(scores, mergesort)[-k:]]` `:794` (ties → higher index) + `_clean_nans` `:24-33` (constant feature → NaN → finfo.min → ranks last) + `k>n_features` clamp-keep-all `:774-779`; constant `anova_f_scores` now NaN (was +inf), verified across multi-tie/multi-constant/k∈{0,all,>n}/mixed/f32 (21 oracle tests — was DIV-A/B #1425 + DIV-C #1426, fixed) |
31//! | REQ-4 | Error/parameter contracts (VarianceThreshold threshold<0/zero-rows; SelectKBest zero-rows/y-mismatch) | SHIPPED (scoped) | per-fn guards; divergence error tests |
32//! | REQ-5a | VarianceThreshold `np.nanvar` NaN-handling + "no feature meets threshold" ValueError | SHIPPED | `fit` skips NaN in the Welford pass (population var over FINITE values, ddof=0, all-NaN col → NaN), matching `np.nanvar` (`_variance_threshold.py:112`, `force_all_finite="allow-nan"` `:103`); raises `InvalidParameter("No feature in X meets the variance threshold {:.5}" [+ " (X contains only one sample)"])` when no `var > threshold`, matching `:121-126`. Oracle: `X=[[1,7],[2,7],[NaN,7]]` → `variances_=[0.25,0.0]`, support `[0]`; all-constant / single-sample → ValueError. Consumer: re-export `lib.rs` + `PipelineTransformer`. Tests: `tests/divergence_variance_threshold_2349.rs` (#2350 #2351). |
33//! | REQ-5b | VarianceThreshold `threshold==0` peak-to-peak guard (`min(var, ptp)`) | NOT-STARTED | sklearn `_variance_threshold.py:113-120` — blocker #1427 (ptp-guard only; nanvar+ValueError shipped as REQ-5a) |
34//! | REQ-6 | SelectKBest `k='all'` string + pluggable `score_func` (chi2/f_regression/mutual_info) + general `_clean_nans` | NOT-STARTED | `usize` k + `FClassif` only; sklearn `_univariate_selection.py:770-795` — blocker #1428 |
35//! | REQ-7 | `GenericUnivariateSelect` (mode meta-selector) | NOT-STARTED | absent (route parity_op); sklearn `_univariate_selection.py:1054` — blocker #1429 |
36//! | REQ-8 | `SelectorMixin` surface (`get_support`/`inverse_transform`/`get_feature_names_out`) + `scores_`/`pvalues_`/`n_features_in_` attrs | NOT-STARTED | sklearn `_univariate_selection.py:526` — blocker #1430 |
37//! | REQ-9 | Basic `SelectFromModel` here duplicates `SelectFromModelExt` | NOT-STARTED | tech-debt; parity in `.design/preprocess/select_from_model.md` — blocker #1431 |
38//! | REQ-10 | PyO3 binding | NOT-STARTED | no `ferrolearn-python` registration — blocker #1432 |
39//! | REQ-11 | ferray substrate | NOT-STARTED | dense `Array1`/`Array2` + `num_traits::Float` only — blocker #1433 |
40
41use ferrolearn_core::error::FerroError;
42use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
43use ferrolearn_core::traits::{Fit, Transform};
44use ndarray::{Array1, Array2};
45use num_traits::Float;
46
47// ---------------------------------------------------------------------------
48// Shared helper: collect selected columns
49// ---------------------------------------------------------------------------
50
51/// Build a new `Array2<F>` containing only the columns listed in `indices`.
52///
53/// Columns are emitted in the order they appear in `indices`.
54fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
55    let nrows = x.nrows();
56    let ncols = indices.len();
57    if ncols == 0 {
58        // Return empty matrix with correct row count
59        return Array2::zeros((nrows, 0));
60    }
61    let mut out = Array2::zeros((nrows, ncols));
62    for (new_j, &old_j) in indices.iter().enumerate() {
63        for i in 0..nrows {
64            out[[i, new_j]] = x[[i, old_j]];
65        }
66    }
67    out
68}
69
70// ===========================================================================
71// VarianceThreshold
72// ===========================================================================
73
74/// An unfitted variance-threshold feature selector.
75///
76/// During fitting the population variance of every column is computed (NaN
77/// values are treated as zero — use an imputer upstream if needed).  Columns
78/// whose variance is *less than or equal to* the configured threshold are
79/// discarded during transformation.
80///
81/// The default threshold is `0.0`, which removes features with exactly zero
82/// variance (i.e. constant columns).
83///
84/// # Examples
85///
86/// ```
87/// use ferrolearn_preprocess::feature_selection::VarianceThreshold;
88/// use ferrolearn_core::traits::{Fit, Transform};
89/// use ndarray::array;
90///
91/// let sel = VarianceThreshold::<f64>::new(0.0);
92/// // Column 1 is constant — will be removed
93/// let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
94/// let fitted = sel.fit(&x, &()).unwrap();
95/// let out = fitted.transform(&x).unwrap();
96/// assert_eq!(out.ncols(), 1);
97/// ```
98#[derive(Debug, Clone)]
99pub struct VarianceThreshold<F> {
100    /// Features with variance strictly less than this threshold are removed.
101    threshold: F,
102}
103
104impl<F: Float + Send + Sync + 'static> VarianceThreshold<F> {
105    /// Create a new `VarianceThreshold` with the given threshold.
106    ///
107    /// Pass `F::zero()` (the default) to remove only constant features.
108    ///
109    /// # Errors
110    ///
111    /// Returns [`FerroError::InvalidParameter`] if `threshold` is negative.
112    pub fn new(threshold: F) -> Self {
113        Self { threshold }
114    }
115
116    /// Return the variance threshold.
117    #[must_use]
118    pub fn threshold(&self) -> F {
119        self.threshold
120    }
121}
122
123impl<F: Float + Send + Sync + 'static> Default for VarianceThreshold<F> {
124    fn default() -> Self {
125        Self::new(F::zero())
126    }
127}
128
129// ---------------------------------------------------------------------------
130// FittedVarianceThreshold
131// ---------------------------------------------------------------------------
132
133/// A fitted variance-threshold selector holding the selected column indices
134/// and the per-column variances observed during fitting.
135///
136/// Created by calling [`Fit::fit`] on a [`VarianceThreshold`].
137#[derive(Debug, Clone)]
138pub struct FittedVarianceThreshold<F> {
139    /// Column indices (into the *original* feature matrix) that were selected.
140    selected_indices: Vec<usize>,
141    /// Per-column population variances computed during fitting.
142    variances: Array1<F>,
143}
144
145impl<F: Float + Send + Sync + 'static> FittedVarianceThreshold<F> {
146    /// Return the indices of the selected columns.
147    #[must_use]
148    pub fn selected_indices(&self) -> &[usize] {
149        &self.selected_indices
150    }
151
152    /// Return the per-column variances computed during fitting.
153    #[must_use]
154    pub fn variances(&self) -> &Array1<F> {
155        &self.variances
156    }
157}
158
159// ---------------------------------------------------------------------------
160// Trait implementations — VarianceThreshold
161// ---------------------------------------------------------------------------
162
163impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for VarianceThreshold<F> {
164    type Fitted = FittedVarianceThreshold<F>;
165    type Error = FerroError;
166
167    /// Fit by computing per-column population variances.
168    ///
169    /// # Errors
170    ///
171    /// Returns [`FerroError::InsufficientSamples`] if the input has zero rows.
172    /// Returns [`FerroError::InvalidParameter`] if the threshold is negative.
173    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedVarianceThreshold<F>, FerroError> {
174        if self.threshold < F::zero() {
175            return Err(FerroError::InvalidParameter {
176                name: "threshold".into(),
177                reason: "variance threshold must be non-negative".into(),
178            });
179        }
180        let n_samples = x.nrows();
181        if n_samples == 0 {
182            return Err(FerroError::InsufficientSamples {
183                required: 1,
184                actual: 0,
185                context: "VarianceThreshold::fit".into(),
186            });
187        }
188
189        let n_features = x.ncols();
190        let mut variances = Array1::zeros(n_features);
191        let mut selected_indices = Vec::new();
192
193        for j in 0..n_features {
194            let col = x.column(j);
195            // sklearn computes `variances_ = np.nanvar(X, axis=0)`
196            // (`_variance_threshold.py:112`, `force_all_finite="allow-nan"`
197            // `:103`): NaN entries are IGNORED and the population variance
198            // (ddof=0) is taken over the FINITE values only. We replicate
199            // np.nanvar with a NaN-skipping Welford pass — numerically stable,
200            // yields *exactly* zero for constant columns (defeating the
201            // ~1e-34 FP noise the naive `Σ(v-mean)²/n` accumulates), and a
202            // column whose finite values are e.g. {1, 2} has nanvar 0.25 even
203            // though it contains a NaN. An ALL-NaN column has zero finite
204            // values → nanvar is NaN (np.nanvar of an all-NaN slice → NaN with
205            // a "Degrees of freedom <= 0" warning), and `NaN > threshold` is
206            // false so the column is dropped (matching `_get_support_mask`
207            // `variances_ > threshold` `:133`).
208            let mut mean = F::zero();
209            let mut m2 = F::zero();
210            let mut count = F::zero();
211            for &v in col.iter() {
212                if v.is_nan() {
213                    continue;
214                }
215                count = count + F::one();
216                let delta = v - mean;
217                mean = mean + delta / count;
218                let delta2 = v - mean;
219                m2 = m2 + delta * delta2;
220            }
221            // np.nanvar over zero finite values → NaN (ddof=0, empty slice).
222            let var = if count == F::zero() {
223                F::nan()
224            } else {
225                m2 / count
226            };
227            variances[j] = var;
228            if var > self.threshold {
229                selected_indices.push(j);
230            }
231        }
232
233        // sklearn raises when NO feature meets the threshold
234        // (`_variance_threshold.py:121-126`):
235        //   if np.all(~np.isfinite(variances_) | (variances_ <= threshold)):
236        //       msg = "No feature in X meets the variance threshold {0:.5f}"
237        //       if X.shape[0] == 1: msg += " (X contains only one sample)"
238        //       raise ValueError(msg.format(threshold))
239        // A column counts as "not meeting" when its variance is non-finite
240        // (NaN/inf) OR `<= threshold`. The single-sample case falls in here
241        // naturally (every finite variance is 0 ≤ threshold). `selected_indices`
242        // already holds exactly the columns with `variance > threshold`, so an
243        // empty selection is precisely the raise condition.
244        if selected_indices.is_empty() {
245            let threshold = self.threshold.to_f64().unwrap_or(0.0);
246            let mut reason = format!("No feature in X meets the variance threshold {threshold:.5}");
247            if n_samples == 1 {
248                reason.push_str(" (X contains only one sample)");
249            }
250            return Err(FerroError::InvalidParameter {
251                name: "threshold".into(),
252                reason,
253            });
254        }
255
256        Ok(FittedVarianceThreshold {
257            selected_indices,
258            variances,
259        })
260    }
261}
262
263impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedVarianceThreshold<F> {
264    type Output = Array2<F>;
265    type Error = FerroError;
266
267    /// Return a matrix containing only the selected (high-variance) columns.
268    ///
269    /// # Errors
270    ///
271    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
272    /// from the number of features seen during fitting.
273    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
274        let n_original = self.variances.len();
275        if x.ncols() != n_original {
276            return Err(FerroError::ShapeMismatch {
277                expected: vec![x.nrows(), n_original],
278                actual: vec![x.nrows(), x.ncols()],
279                context: "FittedVarianceThreshold::transform".into(),
280            });
281        }
282        Ok(select_columns(x, &self.selected_indices))
283    }
284}
285
286// ---------------------------------------------------------------------------
287// Pipeline integration — VarianceThreshold (generic)
288// ---------------------------------------------------------------------------
289
290impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for VarianceThreshold<F> {
291    /// Fit using the pipeline interface; `y` is ignored.
292    ///
293    /// # Errors
294    ///
295    /// Propagates errors from [`Fit::fit`].
296    fn fit_pipeline(
297        &self,
298        x: &Array2<F>,
299        _y: &Array1<F>,
300    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
301        let fitted = self.fit(x, &())?;
302        Ok(Box::new(fitted))
303    }
304}
305
306impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedVarianceThreshold<F> {
307    /// Transform using the pipeline interface.
308    ///
309    /// # Errors
310    ///
311    /// Propagates errors from [`Transform::transform`].
312    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
313        self.transform(x)
314    }
315}
316
317// ===========================================================================
318// SelectKBest
319// ===========================================================================
320
321/// Scoring function variants for [`SelectKBest`].
322///
323/// Currently only ANOVA F-value scoring is supported.
324#[derive(Debug, Clone, Copy, PartialEq, Eq)]
325pub enum ScoreFunc {
326    /// ANOVA F-value: ratio of between-class variance to within-class variance.
327    ///
328    /// This is analogous to scikit-learn's `f_classif`.
329    FClassif,
330}
331
332/// An unfitted K-best feature selector.
333///
334/// Requires class labels (`Array1<usize>`) at fit time to compute per-feature
335/// ANOVA F-scores.  The top *K* features (by score) are retained.
336///
337/// # Examples
338///
339/// ```
340/// use ferrolearn_preprocess::feature_selection::{SelectKBest, ScoreFunc};
341/// use ferrolearn_core::traits::{Fit, Transform};
342/// use ndarray::{array, Array1};
343///
344/// let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
345/// let x = array![[1.0, 10.0], [1.0, 20.0], [2.0, 10.0], [2.0, 20.0]];
346/// let y: Array1<usize> = array![0, 0, 1, 1];
347/// let fitted = sel.fit(&x, &y).unwrap();
348/// let out = fitted.transform(&x).unwrap();
349/// assert_eq!(out.ncols(), 1);
350/// ```
351#[derive(Debug, Clone)]
352pub struct SelectKBest<F> {
353    /// Number of top-scoring features to keep.
354    k: usize,
355    /// The scoring function to use.
356    score_func: ScoreFunc,
357    _marker: std::marker::PhantomData<F>,
358}
359
360impl<F: Float + Send + Sync + 'static> SelectKBest<F> {
361    /// Create a new `SelectKBest` selector.
362    ///
363    /// # Parameters
364    ///
365    /// - `k` — the number of features to retain.
366    /// - `score_func` — the scoring function; currently only
367    ///   [`ScoreFunc::FClassif`] is available.
368    #[must_use]
369    pub fn new(k: usize, score_func: ScoreFunc) -> Self {
370        Self {
371            k,
372            score_func,
373            _marker: std::marker::PhantomData,
374        }
375    }
376
377    /// Return *k*.
378    #[must_use]
379    pub fn k(&self) -> usize {
380        self.k
381    }
382
383    /// Return the score function.
384    #[must_use]
385    pub fn score_func(&self) -> ScoreFunc {
386        self.score_func
387    }
388}
389
390// ---------------------------------------------------------------------------
391// FittedSelectKBest
392// ---------------------------------------------------------------------------
393
394/// A fitted K-best selector holding per-feature scores and selected indices.
395///
396/// Created by calling [`Fit::fit`] on a [`SelectKBest`].
397#[derive(Debug, Clone)]
398pub struct FittedSelectKBest<F> {
399    /// The original number of features (used for shape checking on transform).
400    n_features_in: usize,
401    /// Per-feature ANOVA F-scores computed during fitting.
402    scores: Array1<F>,
403    /// Indices of the selected columns, sorted in original column order.
404    selected_indices: Vec<usize>,
405}
406
407impl<F: Float + Send + Sync + 'static> FittedSelectKBest<F> {
408    /// Return the per-feature F-scores computed during fitting.
409    #[must_use]
410    pub fn scores(&self) -> &Array1<F> {
411        &self.scores
412    }
413
414    /// Return the indices of the selected columns.
415    #[must_use]
416    pub fn selected_indices(&self) -> &[usize] {
417        &self.selected_indices
418    }
419}
420
421// ---------------------------------------------------------------------------
422// ANOVA F-value helper
423// ---------------------------------------------------------------------------
424
425/// Compute per-feature ANOVA F-scores given a feature matrix `x` and integer
426/// class labels `y`.
427///
428/// For each feature column the F-statistic is:
429///
430/// ```text
431/// F = (between-class variance / (n_classes - 1))
432///   / (within-class variance  / (n_samples - n_classes))
433/// ```
434///
435/// Features for which the within-class variance is zero (perfectly separable)
436/// get an F-score of `F::infinity()`.  Features that have zero between-class
437/// variance get a score of `F::zero()`.
438fn anova_f_scores<F: Float>(x: &Array2<F>, y: &Array1<usize>) -> Vec<F> {
439    let n_samples = x.nrows();
440    let n_features = x.ncols();
441
442    // Collect unique classes and build per-class row-index lists.
443    let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
444        std::collections::HashMap::new();
445    for (i, &label) in y.iter().enumerate() {
446        class_indices.entry(label).or_default().push(i);
447    }
448    let n_classes = class_indices.len();
449
450    let mut scores = Vec::with_capacity(n_features);
451
452    for j in 0..n_features {
453        let col = x.column(j);
454
455        // Overall mean of this feature
456        let grand_mean =
457            col.iter().copied().fold(F::zero(), |acc, v| acc + v) / F::from(n_samples).unwrap();
458
459        // Between-class sum of squares: sum_k n_k * (mean_k - grand_mean)^2
460        let mut ss_between = F::zero();
461        // Within-class sum of squares: sum_k sum_{i in k} (x_i - mean_k)^2
462        let mut ss_within = F::zero();
463
464        for rows in class_indices.values() {
465            let n_k = F::from(rows.len()).unwrap();
466            let class_mean = rows
467                .iter()
468                .map(|&i| col[i])
469                .fold(F::zero(), |acc, v| acc + v)
470                / n_k;
471            let diff = class_mean - grand_mean;
472            ss_between = ss_between + n_k * diff * diff;
473            for &i in rows {
474                let d = col[i] - class_mean;
475                ss_within = ss_within + d * d;
476            }
477        }
478
479        let df_between = F::from(n_classes.saturating_sub(1)).unwrap();
480        let df_within = F::from(n_samples.saturating_sub(n_classes)).unwrap();
481
482        let f = if df_between == F::zero() || df_within == F::zero() {
483            F::zero()
484        } else {
485            let ms_between = ss_between / df_between;
486            let ms_within = ss_within / df_within;
487            if ms_within == F::zero() {
488                // sklearn `f_oneway` computes msb/msw. When msw == 0 the result
489                // is +inf for a genuine perfect separator (msb > 0) but 0/0 =
490                // NaN for a CONSTANT feature (msb == 0). NaN later flows through
491                // `_clean_nans` -> `finfo.min`, ranking the constant feature
492                // LAST; a bare +inf would wrongly rank it FIRST.
493                if ms_between > F::zero() {
494                    F::infinity()
495                } else {
496                    F::nan()
497                }
498            } else {
499                ms_between / ms_within
500            }
501        };
502
503        scores.push(f);
504    }
505
506    scores
507}
508
509// ---------------------------------------------------------------------------
510// Trait implementations — SelectKBest
511// ---------------------------------------------------------------------------
512
513impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for SelectKBest<F> {
514    type Fitted = FittedSelectKBest<F>;
515    type Error = FerroError;
516
517    /// Fit by computing per-feature ANOVA F-scores against the class labels.
518    ///
519    /// # Errors
520    ///
521    /// - [`FerroError::InsufficientSamples`] if the input has zero rows.
522    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different numbers
523    ///   of rows.
524    ///
525    /// When `k` exceeds the number of features, all features are selected
526    /// (matching scikit-learn `_check_params`, `_univariate_selection.py:774-779`,
527    /// which warns rather than raising).
528    fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedSelectKBest<F>, FerroError> {
529        let n_samples = x.nrows();
530        if n_samples == 0 {
531            return Err(FerroError::InsufficientSamples {
532                required: 1,
533                actual: 0,
534                context: "SelectKBest::fit".into(),
535            });
536        }
537        if y.len() != n_samples {
538            return Err(FerroError::ShapeMismatch {
539                expected: vec![n_samples],
540                actual: vec![y.len()],
541                context: "SelectKBest::fit — y must have the same length as x has rows".into(),
542            });
543        }
544        let n_features = x.ncols();
545        // sklearn `_check_params` (`_univariate_selection.py:774-779`) only WARNS
546        // (does NOT raise) when k > n_features and keeps ALL features. Clamp the
547        // effective k to the feature count so `idx[n_features - k_eff..]` selects
548        // everything in that case (matching warn+keep-all), and is unchanged when
549        // k <= n_features.
550        let k_eff = self.k.min(n_features);
551
552        let raw_scores = match self.score_func {
553            ScoreFunc::FClassif => anova_f_scores(x, y),
554        };
555
556        let scores = Array1::from_vec(raw_scores.clone());
557
558        // Replicate sklearn `_get_support_mask`:
559        //   scores = _clean_nans(scores_); argsort(kind="mergesort")[-k:]
560        // `_clean_nans` maps NaN -> finfo.min (most-negative finite, NOT -inf),
561        // so constant features rank LAST; a genuine +inf perfect separator stays
562        // highest. The ASCENDING + STABLE sort means a k-boundary tie keeps the
563        // HIGHER column index (it appears later, so it lands in the last-k slice).
564        let cleaned: Vec<F> = raw_scores
565            .iter()
566            .map(|&v| if v.is_nan() { F::min_value() } else { v })
567            .collect();
568        let mut idx: Vec<usize> = (0..n_features).collect();
569        idx.sort_by(|&a, &b| {
570            cleaned[a]
571                .partial_cmp(&cleaned[b])
572                .unwrap_or(core::cmp::Ordering::Equal)
573        });
574
575        let mut selected_indices: Vec<usize> = idx[n_features - k_eff..].to_vec();
576        // Return in original column order for a stable output layout
577        selected_indices.sort_unstable();
578
579        Ok(FittedSelectKBest {
580            n_features_in: n_features,
581            scores,
582            selected_indices,
583        })
584    }
585}
586
587impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectKBest<F> {
588    type Output = Array2<F>;
589    type Error = FerroError;
590
591    /// Return a matrix containing only the K selected columns.
592    ///
593    /// # Errors
594    ///
595    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
596    /// from the number of features seen during fitting.
597    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
598        if x.ncols() != self.n_features_in {
599            return Err(FerroError::ShapeMismatch {
600                expected: vec![x.nrows(), self.n_features_in],
601                actual: vec![x.nrows(), x.ncols()],
602                context: "FittedSelectKBest::transform".into(),
603            });
604        }
605        Ok(select_columns(x, &self.selected_indices))
606    }
607}
608
609// ---------------------------------------------------------------------------
610// Pipeline integration — SelectKBest (generic)
611//
612// NOTE: The pipeline interface uses a *fixed* `y = Array1<f64>`, so we cannot
613// use the actual class-label vector from the pipeline.  We therefore refit
614// using the pipeline `y` converted to `usize` labels by rounding.
615// ---------------------------------------------------------------------------
616
617impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for SelectKBest<F> {
618    /// Fit using the pipeline interface.
619    ///
620    /// The continuous `y` labels are rounded to `usize` class indices.
621    ///
622    /// # Errors
623    ///
624    /// Propagates errors from [`Fit::fit`].
625    fn fit_pipeline(
626        &self,
627        x: &Array2<F>,
628        y: &Array1<F>,
629    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
630        let y_usize: Array1<usize> = y.mapv(|v| v.round().to_usize().unwrap_or(0));
631        let fitted = self.fit(x, &y_usize)?;
632        Ok(Box::new(fitted))
633    }
634}
635
636impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedSelectKBest<F> {
637    /// Transform using the pipeline interface.
638    ///
639    /// # Errors
640    ///
641    /// Propagates errors from [`Transform::transform`].
642    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
643        self.transform(x)
644    }
645}
646
647// ===========================================================================
648// SelectFromModel
649// ===========================================================================
650
651/// A feature selector driven by external feature-importance weights.
652///
653/// The importance vector is typically obtained from a fitted model (e.g. a
654/// decision-tree model's `feature_importances_` field).  Features whose
655/// importance is *strictly greater than or equal to* the threshold are kept.
656///
657/// The default threshold is the **mean importance** of all features.
658///
659/// # Examples
660///
661/// ```
662/// use ferrolearn_preprocess::feature_selection::SelectFromModel;
663/// use ferrolearn_core::traits::Transform;
664/// use ndarray::array;
665///
666/// let importances = array![0.1, 0.5, 0.4];
667/// let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
668/// let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
669/// let out = sel.transform(&x).unwrap();
670/// // Mean importance = (0.1+0.5+0.4)/3 ≈ 0.333; columns 1 and 2 are kept
671/// assert_eq!(out.ncols(), 2);
672/// ```
673#[derive(Debug, Clone)]
674pub struct SelectFromModel<F> {
675    /// Importance weight for each feature.
676    importances: Array1<F>,
677    /// The threshold: features with importance >= threshold are kept.
678    threshold: F,
679    /// Indices of selected features (original column order).
680    selected_indices: Vec<usize>,
681}
682
683impl<F: Float + Send + Sync + 'static> SelectFromModel<F> {
684    /// Create a `SelectFromModel` from a pre-computed importance vector.
685    ///
686    /// # Parameters
687    ///
688    /// - `importances` — one importance weight per feature.
689    /// - `threshold` — optional explicit threshold; if `None` the mean
690    ///   importance is used.
691    ///
692    /// # Errors
693    ///
694    /// Returns [`FerroError::InvalidParameter`] if `importances` is empty.
695    pub fn new_from_importances(
696        importances: &Array1<F>,
697        threshold: Option<F>,
698    ) -> Result<Self, FerroError> {
699        let n = importances.len();
700        if n == 0 {
701            return Err(FerroError::InvalidParameter {
702                name: "importances".into(),
703                reason: "importance vector must not be empty".into(),
704            });
705        }
706
707        let thr = threshold.unwrap_or_else(|| {
708            importances
709                .iter()
710                .copied()
711                .fold(F::zero(), |acc, v| acc + v)
712                / F::from(n).unwrap_or_else(F::one)
713        });
714
715        let selected_indices: Vec<usize> = importances
716            .iter()
717            .enumerate()
718            .filter(|&(_, &imp)| imp >= thr)
719            .map(|(j, _)| j)
720            .collect();
721
722        Ok(Self {
723            importances: importances.clone(),
724            threshold: thr,
725            selected_indices,
726        })
727    }
728
729    /// Return the threshold used to select features.
730    #[must_use]
731    pub fn threshold(&self) -> F {
732        self.threshold
733    }
734
735    /// Return the importance vector supplied at construction time.
736    #[must_use]
737    pub fn importances(&self) -> &Array1<F> {
738        &self.importances
739    }
740
741    /// Return the indices of the selected columns.
742    #[must_use]
743    pub fn selected_indices(&self) -> &[usize] {
744        &self.selected_indices
745    }
746}
747
748// ---------------------------------------------------------------------------
749// Trait implementation — SelectFromModel
750// ---------------------------------------------------------------------------
751
752impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SelectFromModel<F> {
753    type Output = Array2<F>;
754    type Error = FerroError;
755
756    /// Return a matrix containing only the columns whose importance exceeds
757    /// the threshold.
758    ///
759    /// # Errors
760    ///
761    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
762    /// from the length of the importance vector supplied at construction.
763    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
764        let n_features = self.importances.len();
765        if x.ncols() != n_features {
766            return Err(FerroError::ShapeMismatch {
767                expected: vec![x.nrows(), n_features],
768                actual: vec![x.nrows(), x.ncols()],
769                context: "SelectFromModel::transform".into(),
770            });
771        }
772        Ok(select_columns(x, &self.selected_indices))
773    }
774}
775
776// ---------------------------------------------------------------------------
777// Pipeline integration — SelectFromModel (generic)
778//
779// `SelectFromModel` is already "fitted" (importance weights are provided at
780// construction time), so `fit_pipeline` merely boxes `self.clone()`.
781// ---------------------------------------------------------------------------
782
783impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for SelectFromModel<F> {
784    /// Clone the selector and box it as a fitted pipeline transformer.
785    ///
786    /// # Errors
787    ///
788    /// This implementation never fails.
789    fn fit_pipeline(
790        &self,
791        _x: &Array2<F>,
792        _y: &Array1<F>,
793    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
794        Ok(Box::new(self.clone()))
795    }
796}
797
798impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for SelectFromModel<F> {
799    /// Transform using the pipeline interface.
800    ///
801    /// # Errors
802    ///
803    /// Propagates errors from [`Transform::transform`].
804    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
805        self.transform(x)
806    }
807}
808
809// ===========================================================================
810// Tests
811// ===========================================================================
812
813#[cfg(test)]
814mod tests {
815    use super::*;
816    use approx::assert_abs_diff_eq;
817    use ndarray::array;
818
819    // ========================================================================
820    // VarianceThreshold tests
821    // ========================================================================
822
823    #[test]
824    fn test_variance_threshold_removes_constant_column() {
825        let sel = VarianceThreshold::<f64>::new(0.0);
826        // Column 1 is constant (all 7.0)
827        let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
828        let fitted = sel.fit(&x, &()).unwrap();
829        assert_eq!(fitted.selected_indices(), &[0usize]);
830        let out = fitted.transform(&x).unwrap();
831        assert_eq!(out.ncols(), 1);
832        // Column 0 values preserved
833        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-15);
834        assert_abs_diff_eq!(out[[1, 0]], 2.0, epsilon = 1e-15);
835    }
836
837    #[test]
838    fn test_variance_threshold_keeps_all_when_above() {
839        let sel = VarianceThreshold::<f64>::new(0.0);
840        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
841        let fitted = sel.fit(&x, &()).unwrap();
842        assert_eq!(fitted.selected_indices().len(), 2);
843        let out = fitted.transform(&x).unwrap();
844        assert_eq!(out.ncols(), 2);
845    }
846
847    #[test]
848    fn test_variance_threshold_custom_threshold() {
849        let sel = VarianceThreshold::<f64>::new(1.5);
850        // Column 0: values [1,2,3], variance = 2/3 ≈ 0.667 → removed
851        // Column 1: values [10,20,30], variance = 200/3 ≈ 66.7 → kept
852        let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
853        let fitted = sel.fit(&x, &()).unwrap();
854        assert_eq!(fitted.selected_indices(), &[1usize]);
855        let out = fitted.transform(&x).unwrap();
856        assert_eq!(out.ncols(), 1);
857    }
858
859    #[test]
860    fn test_variance_threshold_stores_variances() {
861        // A single constant column has variance 0, which does not exceed the
862        // default threshold 0.0 — sklearn raises ValueError("No feature in X
863        // meets the variance threshold 0.00000")
864        // (`_variance_threshold.py:121-126`). Use a NON-constant column to
865        // exercise the variance-storage path on a fit that succeeds.
866        let sel = VarianceThreshold::<f64>::default();
867        let x = array![[0.0], [2.0], [4.0]]; // var = 8/3 ~ 2.667 > 0 -> kept
868        let fitted = sel.fit(&x, &());
869        assert!(fitted.is_ok(), "non-constant column should fit");
870        if let Ok(f) = fitted {
871            assert_abs_diff_eq!(f.variances()[0], 8.0 / 3.0, epsilon = 1e-15);
872        }
873    }
874
875    #[test]
876    fn test_variance_threshold_zero_rows_error() {
877        let sel = VarianceThreshold::<f64>::new(0.0);
878        let x: Array2<f64> = Array2::zeros((0, 2));
879        assert!(sel.fit(&x, &()).is_err());
880    }
881
882    #[test]
883    fn test_variance_threshold_negative_threshold_error() {
884        let sel = VarianceThreshold::<f64>::new(-0.1);
885        let x = array![[1.0], [2.0]];
886        assert!(sel.fit(&x, &()).is_err());
887    }
888
889    #[test]
890    fn test_variance_threshold_shape_mismatch_on_transform() {
891        let sel = VarianceThreshold::<f64>::new(0.0);
892        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
893        let fitted = sel.fit(&x_train, &()).unwrap();
894        let x_bad = array![[1.0, 2.0, 3.0]];
895        assert!(fitted.transform(&x_bad).is_err());
896    }
897
898    #[test]
899    fn test_variance_threshold_all_constant_columns() {
900        // Both columns are constant → no feature meets the threshold → sklearn
901        // raises ValueError("No feature in X meets the variance threshold
902        // 0.00000") (`_variance_threshold.py:121-126`). ferrolearn surfaces it
903        // as InvalidParameter (maps to ValueError at the Python boundary).
904        let sel = VarianceThreshold::<f64>::new(0.0);
905        let x = array![[5.0, 3.0], [5.0, 3.0], [5.0, 3.0]];
906        let err = sel.fit(&x, &());
907        let reason = match &err {
908            Err(FerroError::InvalidParameter { reason, .. }) => Some(reason.clone()),
909            _ => None,
910        };
911        assert_eq!(
912            reason.as_deref(),
913            Some("No feature in X meets the variance threshold 0.00000"),
914            "all-constant X must raise sklearn's ValueError, got {err:?}"
915        );
916    }
917
918    #[test]
919    fn test_variance_threshold_pipeline_integration() {
920        use ferrolearn_core::pipeline::PipelineTransformer;
921        let sel = VarianceThreshold::<f64>::new(0.0);
922        let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
923        let y = ndarray::array![0.0, 1.0, 0.0];
924        let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
925        let out = fitted_box.transform_pipeline(&x).unwrap();
926        assert_eq!(out.ncols(), 1);
927    }
928
929    #[test]
930    fn test_variance_threshold_f32() {
931        let sel = VarianceThreshold::<f32>::new(0.0f32);
932        let x: Array2<f32> = array![[1.0f32, 5.0], [2.0, 5.0], [3.0, 5.0]];
933        let fitted = sel.fit(&x, &()).unwrap();
934        assert_eq!(fitted.selected_indices(), &[0usize]);
935    }
936
937    // ========================================================================
938    // SelectKBest tests
939    // ========================================================================
940
941    #[test]
942    fn test_select_k_best_selects_highest_scoring_feature() {
943        // Feature 0 separates classes perfectly; feature 1 does not.
944        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
945        let x = array![[1.0, 5.0], [1.0, 6.0], [10.0, 5.0], [10.0, 6.0]];
946        let y: Array1<usize> = array![0, 0, 1, 1];
947        let fitted = sel.fit(&x, &y).unwrap();
948        // Column 0 should be selected (high F-score)
949        assert_eq!(fitted.selected_indices(), &[0usize]);
950        let out = fitted.transform(&x).unwrap();
951        assert_eq!(out.ncols(), 1);
952    }
953
954    #[test]
955    fn test_select_k_best_k_equals_n_features_keeps_all() {
956        let sel = SelectKBest::<f64>::new(2, ScoreFunc::FClassif);
957        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
958        let y: Array1<usize> = array![0, 1, 0];
959        let fitted = sel.fit(&x, &y).unwrap();
960        assert_eq!(fitted.selected_indices().len(), 2);
961        let out = fitted.transform(&x).unwrap();
962        assert_eq!(out.ncols(), 2);
963    }
964
965    #[test]
966    fn test_select_k_best_scores_stored() {
967        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
968        let x = array![[1.0, 2.0], [1.0, 4.0]];
969        let y: Array1<usize> = array![0, 1];
970        let fitted = sel.fit(&x, &y).unwrap();
971        assert_eq!(fitted.scores().len(), 2);
972    }
973
974    #[test]
975    fn test_select_k_best_zero_rows_error() {
976        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
977        let x: Array2<f64> = Array2::zeros((0, 3));
978        let y: Array1<usize> = Array1::zeros(0);
979        assert!(sel.fit(&x, &y).is_err());
980    }
981
982    #[test]
983    fn test_select_k_best_k_exceeds_n_features_keeps_all() {
984        // sklearn `_check_params` (`_univariate_selection.py:774-779`) only WARNS
985        // (does NOT raise) when k > n_features and keeps ALL features.
986        let sel = SelectKBest::<f64>::new(5, ScoreFunc::FClassif);
987        let x = array![[1.0, 2.0], [3.0, 4.0]];
988        let y: Array1<usize> = array![0, 1];
989        let fitted = sel.fit(&x, &y);
990        assert!(
991            fitted.is_ok(),
992            "k>n_features must keep all features, not error: {fitted:?}"
993        );
994        if let Ok(f) = fitted {
995            assert_eq!(f.selected_indices().len(), x.ncols());
996            assert_eq!(f.selected_indices(), &[0usize, 1][..]);
997        }
998    }
999
1000    #[test]
1001    fn test_select_k_best_y_length_mismatch_error() {
1002        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
1003        let x = array![[1.0, 2.0], [3.0, 4.0]];
1004        let y: Array1<usize> = array![0]; // wrong length
1005        assert!(sel.fit(&x, &y).is_err());
1006    }
1007
1008    #[test]
1009    fn test_select_k_best_shape_mismatch_on_transform() {
1010        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
1011        let x = array![[1.0, 2.0], [3.0, 4.0]];
1012        let y: Array1<usize> = array![0, 1];
1013        let fitted = sel.fit(&x, &y).unwrap();
1014        let x_bad = array![[1.0, 2.0, 3.0]];
1015        assert!(fitted.transform(&x_bad).is_err());
1016    }
1017
1018    #[test]
1019    fn test_select_k_best_selected_indices_in_column_order() {
1020        // Both features selected — indices should be [0, 1] not reversed
1021        let sel = SelectKBest::<f64>::new(2, ScoreFunc::FClassif);
1022        let x = array![[1.0, 100.0], [2.0, 200.0]];
1023        let y: Array1<usize> = array![0, 1];
1024        let fitted = sel.fit(&x, &y).unwrap();
1025        let indices = fitted.selected_indices();
1026        assert!(indices.windows(2).all(|w| w[0] < w[1]));
1027    }
1028
1029    #[test]
1030    fn test_select_k_best_pipeline_integration() {
1031        use ferrolearn_core::pipeline::PipelineTransformer;
1032        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
1033        let x = array![[1.0, 5.0], [1.0, 6.0], [10.0, 5.0], [10.0, 6.0]];
1034        let y = ndarray::array![0.0, 0.0, 1.0, 1.0];
1035        let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
1036        let out = fitted_box.transform_pipeline(&x).unwrap();
1037        assert_eq!(out.ncols(), 1);
1038    }
1039
1040    #[test]
1041    fn test_select_k_best_f_score_zero_within_class_variance() {
1042        // Perfectly separating feature → F should be infinity
1043        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
1044        let x = array![[0.0], [0.0], [10.0], [10.0]];
1045        let y: Array1<usize> = array![0, 0, 1, 1];
1046        let fitted = sel.fit(&x, &y).unwrap();
1047        assert!(fitted.scores()[0].is_infinite());
1048    }
1049
1050    // ========================================================================
1051    // SelectFromModel tests
1052    // ========================================================================
1053
1054    #[test]
1055    fn test_select_from_model_mean_threshold() {
1056        // Mean importance = (0.1 + 0.5 + 0.4) / 3 ≈ 0.333
1057        // Features 1 (0.5) and 2 (0.4) are >= threshold; feature 0 (0.1) is not
1058        let importances = array![0.1, 0.5, 0.4];
1059        let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1060        assert_eq!(sel.selected_indices(), &[1usize, 2]);
1061        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1062        let out = sel.transform(&x).unwrap();
1063        assert_eq!(out.ncols(), 2);
1064        assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
1065        assert_abs_diff_eq!(out[[0, 1]], 3.0, epsilon = 1e-15);
1066    }
1067
1068    #[test]
1069    fn test_select_from_model_explicit_threshold() {
1070        let importances = array![0.1, 0.5, 0.4];
1071        // Only feature 1 (0.5 >= 0.45) is selected
1072        let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.45)).unwrap();
1073        assert_eq!(sel.selected_indices(), &[1usize]);
1074        let x = array![[1.0, 2.0, 3.0]];
1075        let out = sel.transform(&x).unwrap();
1076        assert_eq!(out.ncols(), 1);
1077        assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
1078    }
1079
1080    #[test]
1081    fn test_select_from_model_all_selected_when_threshold_zero() {
1082        let importances = array![0.1, 0.2, 0.3];
1083        let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.0)).unwrap();
1084        assert_eq!(sel.selected_indices().len(), 3);
1085    }
1086
1087    #[test]
1088    fn test_select_from_model_none_selected_when_threshold_high() {
1089        let importances = array![0.1, 0.2, 0.3];
1090        let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(1.0)).unwrap();
1091        assert_eq!(sel.selected_indices().len(), 0);
1092        let x = array![[1.0, 2.0, 3.0]];
1093        let out = sel.transform(&x).unwrap();
1094        assert_eq!(out.ncols(), 0);
1095    }
1096
1097    #[test]
1098    fn test_select_from_model_empty_importances_error() {
1099        let importances: Array1<f64> = Array1::zeros(0);
1100        assert!(SelectFromModel::<f64>::new_from_importances(&importances, None).is_err());
1101    }
1102
1103    #[test]
1104    fn test_select_from_model_shape_mismatch_on_transform() {
1105        let importances = array![0.3, 0.7];
1106        let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1107        let x_bad = array![[1.0, 2.0, 3.0]]; // 3 cols, but 2 features expected
1108        assert!(sel.transform(&x_bad).is_err());
1109    }
1110
1111    #[test]
1112    fn test_select_from_model_threshold_accessor() {
1113        let importances = array![0.3, 0.7];
1114        let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.5)).unwrap();
1115        assert_abs_diff_eq!(sel.threshold(), 0.5, epsilon = 1e-15);
1116    }
1117
1118    #[test]
1119    fn test_select_from_model_pipeline_integration() {
1120        use ferrolearn_core::pipeline::PipelineTransformer;
1121        let importances = array![0.1, 0.9];
1122        let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1123        let x = array![[1.0, 2.0], [3.0, 4.0]];
1124        let y = ndarray::array![0.0, 1.0];
1125        let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
1126        let out = fitted_box.transform_pipeline(&x).unwrap();
1127        // Mean importance = 0.5; only feature 1 (0.9 >= 0.5) kept
1128        assert_eq!(out.ncols(), 1);
1129        assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
1130    }
1131
1132    #[test]
1133    fn test_select_from_model_importances_accessor() {
1134        let importances = array![0.2, 0.8];
1135        let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1136        assert_abs_diff_eq!(sel.importances()[0], 0.2, epsilon = 1e-15);
1137        assert_abs_diff_eq!(sel.importances()[1], 0.8, epsilon = 1e-15);
1138    }
1139}