Skip to main content

ferrolearn_preprocess/
feature_selection.rs

1//! Feature selection transformers.
2//!
3//! This module provides three feature selection strategies:
4//!
5//! - [`VarianceThreshold`] — remove features whose variance falls below a
6//!   configurable threshold (default 0.0 removes zero-variance features).
7//! - [`SelectKBest`] — keep the *K* features with the highest ANOVA F-scores
8//!   computed against a class label vector.
9//! - [`SelectFromModel`] — keep features whose importance weight (provided by
10//!   a previously fitted model) exceeds a configurable threshold.
11//!
12//! All three implement the standard ferrolearn `Fit` / `Transform` pattern
13//! and integrate with the dynamic [`ferrolearn_core::pipeline::Pipeline`].
14
15use ferrolearn_core::error::FerroError;
16use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
17use ferrolearn_core::traits::{Fit, Transform};
18use ndarray::{Array1, Array2};
19use num_traits::Float;
20
21// ---------------------------------------------------------------------------
22// Shared helper: collect selected columns
23// ---------------------------------------------------------------------------
24
25/// Build a new `Array2<F>` containing only the columns listed in `indices`.
26///
27/// Columns are emitted in the order they appear in `indices`.
28fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
29    let nrows = x.nrows();
30    let ncols = indices.len();
31    if ncols == 0 {
32        // Return empty matrix with correct row count
33        return Array2::zeros((nrows, 0));
34    }
35    let mut out = Array2::zeros((nrows, ncols));
36    for (new_j, &old_j) in indices.iter().enumerate() {
37        for i in 0..nrows {
38            out[[i, new_j]] = x[[i, old_j]];
39        }
40    }
41    out
42}
43
44// ===========================================================================
45// VarianceThreshold
46// ===========================================================================
47
48/// An unfitted variance-threshold feature selector.
49///
50/// During fitting the population variance of every column is computed (NaN
51/// values are treated as zero — use an imputer upstream if needed).  Columns
52/// whose variance is *less than or equal to* the configured threshold are
53/// discarded during transformation.
54///
55/// The default threshold is `0.0`, which removes features with exactly zero
56/// variance (i.e. constant columns).
57///
58/// # Examples
59///
60/// ```
61/// use ferrolearn_preprocess::feature_selection::VarianceThreshold;
62/// use ferrolearn_core::traits::{Fit, Transform};
63/// use ndarray::array;
64///
65/// let sel = VarianceThreshold::<f64>::new(0.0);
66/// // Column 1 is constant — will be removed
67/// let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
68/// let fitted = sel.fit(&x, &()).unwrap();
69/// let out = fitted.transform(&x).unwrap();
70/// assert_eq!(out.ncols(), 1);
71/// ```
72#[derive(Debug, Clone)]
73pub struct VarianceThreshold<F> {
74    /// Features with variance strictly less than this threshold are removed.
75    threshold: F,
76}
77
78impl<F: Float + Send + Sync + 'static> VarianceThreshold<F> {
79    /// Create a new `VarianceThreshold` with the given threshold.
80    ///
81    /// Pass `F::zero()` (the default) to remove only constant features.
82    ///
83    /// # Errors
84    ///
85    /// Returns [`FerroError::InvalidParameter`] if `threshold` is negative.
86    pub fn new(threshold: F) -> Self {
87        Self { threshold }
88    }
89
90    /// Return the variance threshold.
91    #[must_use]
92    pub fn threshold(&self) -> F {
93        self.threshold
94    }
95}
96
97impl<F: Float + Send + Sync + 'static> Default for VarianceThreshold<F> {
98    fn default() -> Self {
99        Self::new(F::zero())
100    }
101}
102
103// ---------------------------------------------------------------------------
104// FittedVarianceThreshold
105// ---------------------------------------------------------------------------
106
107/// A fitted variance-threshold selector holding the selected column indices
108/// and the per-column variances observed during fitting.
109///
110/// Created by calling [`Fit::fit`] on a [`VarianceThreshold`].
111#[derive(Debug, Clone)]
112pub struct FittedVarianceThreshold<F> {
113    /// Column indices (into the *original* feature matrix) that were selected.
114    selected_indices: Vec<usize>,
115    /// Per-column population variances computed during fitting.
116    variances: Array1<F>,
117}
118
119impl<F: Float + Send + Sync + 'static> FittedVarianceThreshold<F> {
120    /// Return the indices of the selected columns.
121    #[must_use]
122    pub fn selected_indices(&self) -> &[usize] {
123        &self.selected_indices
124    }
125
126    /// Return the per-column variances computed during fitting.
127    #[must_use]
128    pub fn variances(&self) -> &Array1<F> {
129        &self.variances
130    }
131}
132
133// ---------------------------------------------------------------------------
134// Trait implementations — VarianceThreshold
135// ---------------------------------------------------------------------------
136
137impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for VarianceThreshold<F> {
138    type Fitted = FittedVarianceThreshold<F>;
139    type Error = FerroError;
140
141    /// Fit by computing per-column population variances.
142    ///
143    /// # Errors
144    ///
145    /// Returns [`FerroError::InsufficientSamples`] if the input has zero rows.
146    /// Returns [`FerroError::InvalidParameter`] if the threshold is negative.
147    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedVarianceThreshold<F>, FerroError> {
148        if self.threshold < F::zero() {
149            return Err(FerroError::InvalidParameter {
150                name: "threshold".into(),
151                reason: "variance threshold must be non-negative".into(),
152            });
153        }
154        let n_samples = x.nrows();
155        if n_samples == 0 {
156            return Err(FerroError::InsufficientSamples {
157                required: 1,
158                actual: 0,
159                context: "VarianceThreshold::fit".into(),
160            });
161        }
162
163        let n = F::from(n_samples).unwrap_or(F::one());
164        let n_features = x.ncols();
165        let mut variances = Array1::zeros(n_features);
166        let mut selected_indices = Vec::new();
167
168        for j in 0..n_features {
169            let col = x.column(j);
170            let mean = col.iter().copied().fold(F::zero(), |acc, v| acc + v) / n;
171            let var = col
172                .iter()
173                .copied()
174                .map(|v| (v - mean) * (v - mean))
175                .fold(F::zero(), |acc, v| acc + v)
176                / n;
177            variances[j] = var;
178            if var > self.threshold {
179                selected_indices.push(j);
180            }
181        }
182
183        Ok(FittedVarianceThreshold {
184            selected_indices,
185            variances,
186        })
187    }
188}
189
190impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedVarianceThreshold<F> {
191    type Output = Array2<F>;
192    type Error = FerroError;
193
194    /// Return a matrix containing only the selected (high-variance) columns.
195    ///
196    /// # Errors
197    ///
198    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
199    /// from the number of features seen during fitting.
200    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
201        let n_original = self.variances.len();
202        if x.ncols() != n_original {
203            return Err(FerroError::ShapeMismatch {
204                expected: vec![x.nrows(), n_original],
205                actual: vec![x.nrows(), x.ncols()],
206                context: "FittedVarianceThreshold::transform".into(),
207            });
208        }
209        Ok(select_columns(x, &self.selected_indices))
210    }
211}
212
213// ---------------------------------------------------------------------------
214// Pipeline integration — VarianceThreshold (f64 specialisation)
215// ---------------------------------------------------------------------------
216
217impl PipelineTransformer for VarianceThreshold<f64> {
218    /// Fit using the pipeline interface; `y` is ignored.
219    ///
220    /// # Errors
221    ///
222    /// Propagates errors from [`Fit::fit`].
223    fn fit_pipeline(
224        &self,
225        x: &Array2<f64>,
226        _y: &Array1<f64>,
227    ) -> Result<Box<dyn FittedPipelineTransformer>, FerroError> {
228        let fitted = self.fit(x, &())?;
229        Ok(Box::new(fitted))
230    }
231}
232
233impl FittedPipelineTransformer for FittedVarianceThreshold<f64> {
234    /// Transform using the pipeline interface.
235    ///
236    /// # Errors
237    ///
238    /// Propagates errors from [`Transform::transform`].
239    fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
240        self.transform(x)
241    }
242}
243
244// ===========================================================================
245// SelectKBest
246// ===========================================================================
247
248/// Scoring function variants for [`SelectKBest`].
249///
250/// Currently only ANOVA F-value scoring is supported.
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum ScoreFunc {
253    /// ANOVA F-value: ratio of between-class variance to within-class variance.
254    ///
255    /// This is analogous to scikit-learn's `f_classif`.
256    FClassif,
257}
258
259/// An unfitted K-best feature selector.
260///
261/// Requires class labels (`Array1<usize>`) at fit time to compute per-feature
262/// ANOVA F-scores.  The top *K* features (by score) are retained.
263///
264/// # Examples
265///
266/// ```
267/// use ferrolearn_preprocess::feature_selection::{SelectKBest, ScoreFunc};
268/// use ferrolearn_core::traits::{Fit, Transform};
269/// use ndarray::{array, Array1};
270///
271/// let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
272/// let x = array![[1.0, 10.0], [1.0, 20.0], [2.0, 10.0], [2.0, 20.0]];
273/// let y: Array1<usize> = array![0, 0, 1, 1];
274/// let fitted = sel.fit(&x, &y).unwrap();
275/// let out = fitted.transform(&x).unwrap();
276/// assert_eq!(out.ncols(), 1);
277/// ```
278#[derive(Debug, Clone)]
279pub struct SelectKBest<F> {
280    /// Number of top-scoring features to keep.
281    k: usize,
282    /// The scoring function to use.
283    score_func: ScoreFunc,
284    _marker: std::marker::PhantomData<F>,
285}
286
287impl<F: Float + Send + Sync + 'static> SelectKBest<F> {
288    /// Create a new `SelectKBest` selector.
289    ///
290    /// # Parameters
291    ///
292    /// - `k` — the number of features to retain.
293    /// - `score_func` — the scoring function; currently only
294    ///   [`ScoreFunc::FClassif`] is available.
295    #[must_use]
296    pub fn new(k: usize, score_func: ScoreFunc) -> Self {
297        Self {
298            k,
299            score_func,
300            _marker: std::marker::PhantomData,
301        }
302    }
303
304    /// Return *k*.
305    #[must_use]
306    pub fn k(&self) -> usize {
307        self.k
308    }
309
310    /// Return the score function.
311    #[must_use]
312    pub fn score_func(&self) -> ScoreFunc {
313        self.score_func
314    }
315}
316
317// ---------------------------------------------------------------------------
318// FittedSelectKBest
319// ---------------------------------------------------------------------------
320
321/// A fitted K-best selector holding per-feature scores and selected indices.
322///
323/// Created by calling [`Fit::fit`] on a [`SelectKBest`].
324#[derive(Debug, Clone)]
325pub struct FittedSelectKBest<F> {
326    /// The original number of features (used for shape checking on transform).
327    n_features_in: usize,
328    /// Per-feature ANOVA F-scores computed during fitting.
329    scores: Array1<F>,
330    /// Indices of the selected columns, sorted in original column order.
331    selected_indices: Vec<usize>,
332}
333
334impl<F: Float + Send + Sync + 'static> FittedSelectKBest<F> {
335    /// Return the per-feature F-scores computed during fitting.
336    #[must_use]
337    pub fn scores(&self) -> &Array1<F> {
338        &self.scores
339    }
340
341    /// Return the indices of the selected columns.
342    #[must_use]
343    pub fn selected_indices(&self) -> &[usize] {
344        &self.selected_indices
345    }
346}
347
348// ---------------------------------------------------------------------------
349// ANOVA F-value helper
350// ---------------------------------------------------------------------------
351
352/// Compute per-feature ANOVA F-scores given a feature matrix `x` and integer
353/// class labels `y`.
354///
355/// For each feature column the F-statistic is:
356///
357/// ```text
358/// F = (between-class variance / (n_classes - 1))
359///   / (within-class variance  / (n_samples - n_classes))
360/// ```
361///
362/// Features for which the within-class variance is zero (perfectly separable)
363/// get an F-score of `F::infinity()`.  Features that have zero between-class
364/// variance get a score of `F::zero()`.
365fn anova_f_scores<F: Float>(x: &Array2<F>, y: &Array1<usize>) -> Vec<F> {
366    let n_samples = x.nrows();
367    let n_features = x.ncols();
368
369    // Collect unique classes and build per-class row-index lists.
370    let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
371        std::collections::HashMap::new();
372    for (i, &label) in y.iter().enumerate() {
373        class_indices.entry(label).or_default().push(i);
374    }
375    let n_classes = class_indices.len();
376
377    let mut scores = Vec::with_capacity(n_features);
378
379    for j in 0..n_features {
380        let col = x.column(j);
381
382        // Overall mean of this feature
383        let grand_mean =
384            col.iter().copied().fold(F::zero(), |acc, v| acc + v) / F::from(n_samples).unwrap();
385
386        // Between-class sum of squares: sum_k n_k * (mean_k - grand_mean)^2
387        let mut ss_between = F::zero();
388        // Within-class sum of squares: sum_k sum_{i in k} (x_i - mean_k)^2
389        let mut ss_within = F::zero();
390
391        for rows in class_indices.values() {
392            let n_k = F::from(rows.len()).unwrap();
393            let class_mean = rows
394                .iter()
395                .map(|&i| col[i])
396                .fold(F::zero(), |acc, v| acc + v)
397                / n_k;
398            let diff = class_mean - grand_mean;
399            ss_between = ss_between + n_k * diff * diff;
400            for &i in rows {
401                let d = col[i] - class_mean;
402                ss_within = ss_within + d * d;
403            }
404        }
405
406        let df_between = F::from(n_classes.saturating_sub(1)).unwrap();
407        let df_within = F::from(n_samples.saturating_sub(n_classes)).unwrap();
408
409        let f = if df_between == F::zero() || df_within == F::zero() {
410            F::zero()
411        } else {
412            let ms_between = ss_between / df_between;
413            let ms_within = ss_within / df_within;
414            if ms_within == F::zero() {
415                F::infinity()
416            } else {
417                ms_between / ms_within
418            }
419        };
420
421        scores.push(f);
422    }
423
424    scores
425}
426
427// ---------------------------------------------------------------------------
428// Trait implementations — SelectKBest
429// ---------------------------------------------------------------------------
430
431impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for SelectKBest<F> {
432    type Fitted = FittedSelectKBest<F>;
433    type Error = FerroError;
434
435    /// Fit by computing per-feature ANOVA F-scores against the class labels.
436    ///
437    /// # Errors
438    ///
439    /// - [`FerroError::InsufficientSamples`] if the input has zero rows.
440    /// - [`FerroError::InvalidParameter`] if `k` exceeds the number of features.
441    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different numbers
442    ///   of rows.
443    fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedSelectKBest<F>, FerroError> {
444        let n_samples = x.nrows();
445        if n_samples == 0 {
446            return Err(FerroError::InsufficientSamples {
447                required: 1,
448                actual: 0,
449                context: "SelectKBest::fit".into(),
450            });
451        }
452        if y.len() != n_samples {
453            return Err(FerroError::ShapeMismatch {
454                expected: vec![n_samples],
455                actual: vec![y.len()],
456                context: "SelectKBest::fit — y must have the same length as x has rows".into(),
457            });
458        }
459        let n_features = x.ncols();
460        if self.k > n_features {
461            return Err(FerroError::InvalidParameter {
462                name: "k".into(),
463                reason: format!(
464                    "k ({}) cannot exceed the number of features ({})",
465                    self.k, n_features
466                ),
467            });
468        }
469
470        let raw_scores = match self.score_func {
471            ScoreFunc::FClassif => anova_f_scores(x, y),
472        };
473
474        let scores = Array1::from_vec(raw_scores.clone());
475
476        // Determine the top-k indices (stable: break ties by preferring the
477        // lower column index so results are deterministic).
478        let mut ranked: Vec<usize> = (0..n_features).collect();
479        ranked.sort_by(|&a, &b| {
480            raw_scores[b]
481                .partial_cmp(&raw_scores[a])
482                .unwrap_or(std::cmp::Ordering::Equal)
483                // Tie: keep lower column index
484                .then(a.cmp(&b))
485        });
486
487        let mut selected_indices: Vec<usize> = ranked[..self.k].to_vec();
488        // Return in original column order for a stable output layout
489        selected_indices.sort_unstable();
490
491        Ok(FittedSelectKBest {
492            n_features_in: n_features,
493            scores,
494            selected_indices,
495        })
496    }
497}
498
499impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectKBest<F> {
500    type Output = Array2<F>;
501    type Error = FerroError;
502
503    /// Return a matrix containing only the K selected columns.
504    ///
505    /// # Errors
506    ///
507    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
508    /// from the number of features seen during fitting.
509    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
510        if x.ncols() != self.n_features_in {
511            return Err(FerroError::ShapeMismatch {
512                expected: vec![x.nrows(), self.n_features_in],
513                actual: vec![x.nrows(), x.ncols()],
514                context: "FittedSelectKBest::transform".into(),
515            });
516        }
517        Ok(select_columns(x, &self.selected_indices))
518    }
519}
520
521// ---------------------------------------------------------------------------
522// Pipeline integration — SelectKBest (f64 specialisation)
523//
524// NOTE: The pipeline interface uses a *fixed* `y = Array1<f64>`, so we cannot
525// use the actual class-label vector from the pipeline.  We therefore refit
526// using the pipeline `y` converted to `usize` labels by rounding.
527// ---------------------------------------------------------------------------
528
529impl PipelineTransformer for SelectKBest<f64> {
530    /// Fit using the pipeline interface.
531    ///
532    /// The continuous `y` labels are rounded to `usize` class indices.
533    ///
534    /// # Errors
535    ///
536    /// Propagates errors from [`Fit::fit`].
537    fn fit_pipeline(
538        &self,
539        x: &Array2<f64>,
540        y: &Array1<f64>,
541    ) -> Result<Box<dyn FittedPipelineTransformer>, FerroError> {
542        let y_usize: Array1<usize> = y.mapv(|v| v.round() as usize);
543        let fitted = self.fit(x, &y_usize)?;
544        Ok(Box::new(fitted))
545    }
546}
547
548impl FittedPipelineTransformer for FittedSelectKBest<f64> {
549    /// Transform using the pipeline interface.
550    ///
551    /// # Errors
552    ///
553    /// Propagates errors from [`Transform::transform`].
554    fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
555        self.transform(x)
556    }
557}
558
559// ===========================================================================
560// SelectFromModel
561// ===========================================================================
562
563/// A feature selector driven by external feature-importance weights.
564///
565/// The importance vector is typically obtained from a fitted model (e.g. a
566/// decision-tree model's `feature_importances_` field).  Features whose
567/// importance is *strictly greater than or equal to* the threshold are kept.
568///
569/// The default threshold is the **mean importance** of all features.
570///
571/// # Examples
572///
573/// ```
574/// use ferrolearn_preprocess::feature_selection::SelectFromModel;
575/// use ferrolearn_core::traits::Transform;
576/// use ndarray::array;
577///
578/// let importances = array![0.1, 0.5, 0.4];
579/// let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
580/// let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
581/// let out = sel.transform(&x).unwrap();
582/// // Mean importance = (0.1+0.5+0.4)/3 ≈ 0.333; columns 1 and 2 are kept
583/// assert_eq!(out.ncols(), 2);
584/// ```
585#[derive(Debug, Clone)]
586pub struct SelectFromModel<F> {
587    /// Importance weight for each feature.
588    importances: Array1<F>,
589    /// The threshold: features with importance >= threshold are kept.
590    threshold: F,
591    /// Indices of selected features (original column order).
592    selected_indices: Vec<usize>,
593}
594
595impl<F: Float + Send + Sync + 'static> SelectFromModel<F> {
596    /// Create a `SelectFromModel` from a pre-computed importance vector.
597    ///
598    /// # Parameters
599    ///
600    /// - `importances` — one importance weight per feature.
601    /// - `threshold` — optional explicit threshold; if `None` the mean
602    ///   importance is used.
603    ///
604    /// # Errors
605    ///
606    /// Returns [`FerroError::InvalidParameter`] if `importances` is empty.
607    pub fn new_from_importances(
608        importances: &Array1<F>,
609        threshold: Option<F>,
610    ) -> Result<Self, FerroError> {
611        let n = importances.len();
612        if n == 0 {
613            return Err(FerroError::InvalidParameter {
614                name: "importances".into(),
615                reason: "importance vector must not be empty".into(),
616            });
617        }
618
619        let thr = threshold.unwrap_or_else(|| {
620            importances
621                .iter()
622                .copied()
623                .fold(F::zero(), |acc, v| acc + v)
624                / F::from(n).unwrap_or(F::one())
625        });
626
627        let selected_indices: Vec<usize> = importances
628            .iter()
629            .enumerate()
630            .filter(|&(_, &imp)| imp >= thr)
631            .map(|(j, _)| j)
632            .collect();
633
634        Ok(Self {
635            importances: importances.clone(),
636            threshold: thr,
637            selected_indices,
638        })
639    }
640
641    /// Return the threshold used to select features.
642    #[must_use]
643    pub fn threshold(&self) -> F {
644        self.threshold
645    }
646
647    /// Return the importance vector supplied at construction time.
648    #[must_use]
649    pub fn importances(&self) -> &Array1<F> {
650        &self.importances
651    }
652
653    /// Return the indices of the selected columns.
654    #[must_use]
655    pub fn selected_indices(&self) -> &[usize] {
656        &self.selected_indices
657    }
658}
659
660// ---------------------------------------------------------------------------
661// Trait implementation — SelectFromModel
662// ---------------------------------------------------------------------------
663
664impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SelectFromModel<F> {
665    type Output = Array2<F>;
666    type Error = FerroError;
667
668    /// Return a matrix containing only the columns whose importance exceeds
669    /// the threshold.
670    ///
671    /// # Errors
672    ///
673    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
674    /// from the length of the importance vector supplied at construction.
675    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
676        let n_features = self.importances.len();
677        if x.ncols() != n_features {
678            return Err(FerroError::ShapeMismatch {
679                expected: vec![x.nrows(), n_features],
680                actual: vec![x.nrows(), x.ncols()],
681                context: "SelectFromModel::transform".into(),
682            });
683        }
684        Ok(select_columns(x, &self.selected_indices))
685    }
686}
687
688// ---------------------------------------------------------------------------
689// Pipeline integration — SelectFromModel (f64 specialisation)
690//
691// `SelectFromModel` is already "fitted" (importance weights are provided at
692// construction time), so `fit_pipeline` merely boxes `self.clone()`.
693// ---------------------------------------------------------------------------
694
695impl PipelineTransformer for SelectFromModel<f64> {
696    /// Clone the selector and box it as a fitted pipeline transformer.
697    ///
698    /// # Errors
699    ///
700    /// This implementation never fails.
701    fn fit_pipeline(
702        &self,
703        _x: &Array2<f64>,
704        _y: &Array1<f64>,
705    ) -> Result<Box<dyn FittedPipelineTransformer>, FerroError> {
706        Ok(Box::new(self.clone()))
707    }
708}
709
710impl FittedPipelineTransformer for SelectFromModel<f64> {
711    /// Transform using the pipeline interface.
712    ///
713    /// # Errors
714    ///
715    /// Propagates errors from [`Transform::transform`].
716    fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
717        self.transform(x)
718    }
719}
720
721// ===========================================================================
722// Tests
723// ===========================================================================
724
725#[cfg(test)]
726mod tests {
727    use super::*;
728    use approx::assert_abs_diff_eq;
729    use ndarray::array;
730
731    // ========================================================================
732    // VarianceThreshold tests
733    // ========================================================================
734
735    #[test]
736    fn test_variance_threshold_removes_constant_column() {
737        let sel = VarianceThreshold::<f64>::new(0.0);
738        // Column 1 is constant (all 7.0)
739        let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
740        let fitted = sel.fit(&x, &()).unwrap();
741        assert_eq!(fitted.selected_indices(), &[0usize]);
742        let out = fitted.transform(&x).unwrap();
743        assert_eq!(out.ncols(), 1);
744        // Column 0 values preserved
745        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-15);
746        assert_abs_diff_eq!(out[[1, 0]], 2.0, epsilon = 1e-15);
747    }
748
749    #[test]
750    fn test_variance_threshold_keeps_all_when_above() {
751        let sel = VarianceThreshold::<f64>::new(0.0);
752        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
753        let fitted = sel.fit(&x, &()).unwrap();
754        assert_eq!(fitted.selected_indices().len(), 2);
755        let out = fitted.transform(&x).unwrap();
756        assert_eq!(out.ncols(), 2);
757    }
758
759    #[test]
760    fn test_variance_threshold_custom_threshold() {
761        let sel = VarianceThreshold::<f64>::new(1.5);
762        // Column 0: values [1,2,3], variance = 2/3 ≈ 0.667 → removed
763        // Column 1: values [10,20,30], variance = 200/3 ≈ 66.7 → kept
764        let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
765        let fitted = sel.fit(&x, &()).unwrap();
766        assert_eq!(fitted.selected_indices(), &[1usize]);
767        let out = fitted.transform(&x).unwrap();
768        assert_eq!(out.ncols(), 1);
769    }
770
771    #[test]
772    fn test_variance_threshold_stores_variances() {
773        let sel = VarianceThreshold::<f64>::default();
774        let x = array![[0.0], [0.0], [0.0]]; // constant → var = 0
775        let fitted = sel.fit(&x, &()).unwrap();
776        assert_abs_diff_eq!(fitted.variances()[0], 0.0, epsilon = 1e-15);
777    }
778
779    #[test]
780    fn test_variance_threshold_zero_rows_error() {
781        let sel = VarianceThreshold::<f64>::new(0.0);
782        let x: Array2<f64> = Array2::zeros((0, 2));
783        assert!(sel.fit(&x, &()).is_err());
784    }
785
786    #[test]
787    fn test_variance_threshold_negative_threshold_error() {
788        let sel = VarianceThreshold::<f64>::new(-0.1);
789        let x = array![[1.0], [2.0]];
790        assert!(sel.fit(&x, &()).is_err());
791    }
792
793    #[test]
794    fn test_variance_threshold_shape_mismatch_on_transform() {
795        let sel = VarianceThreshold::<f64>::new(0.0);
796        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
797        let fitted = sel.fit(&x_train, &()).unwrap();
798        let x_bad = array![[1.0, 2.0, 3.0]];
799        assert!(fitted.transform(&x_bad).is_err());
800    }
801
802    #[test]
803    fn test_variance_threshold_all_constant_columns() {
804        let sel = VarianceThreshold::<f64>::new(0.0);
805        let x = array![[5.0, 3.0], [5.0, 3.0], [5.0, 3.0]];
806        let fitted = sel.fit(&x, &()).unwrap();
807        // Both columns are constant: both removed
808        assert_eq!(fitted.selected_indices().len(), 0);
809        let out = fitted.transform(&x).unwrap();
810        assert_eq!(out.ncols(), 0);
811        assert_eq!(out.nrows(), 3);
812    }
813
814    #[test]
815    fn test_variance_threshold_pipeline_integration() {
816        use ferrolearn_core::pipeline::PipelineTransformer;
817        let sel = VarianceThreshold::<f64>::new(0.0);
818        let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
819        let y = ndarray::array![0.0, 1.0, 0.0];
820        let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
821        let out = fitted_box.transform_pipeline(&x).unwrap();
822        assert_eq!(out.ncols(), 1);
823    }
824
825    #[test]
826    fn test_variance_threshold_f32() {
827        let sel = VarianceThreshold::<f32>::new(0.0f32);
828        let x: Array2<f32> = array![[1.0f32, 5.0], [2.0, 5.0], [3.0, 5.0]];
829        let fitted = sel.fit(&x, &()).unwrap();
830        assert_eq!(fitted.selected_indices(), &[0usize]);
831    }
832
833    // ========================================================================
834    // SelectKBest tests
835    // ========================================================================
836
837    #[test]
838    fn test_select_k_best_selects_highest_scoring_feature() {
839        // Feature 0 separates classes perfectly; feature 1 does not.
840        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
841        let x = array![[1.0, 5.0], [1.0, 6.0], [10.0, 5.0], [10.0, 6.0]];
842        let y: Array1<usize> = array![0, 0, 1, 1];
843        let fitted = sel.fit(&x, &y).unwrap();
844        // Column 0 should be selected (high F-score)
845        assert_eq!(fitted.selected_indices(), &[0usize]);
846        let out = fitted.transform(&x).unwrap();
847        assert_eq!(out.ncols(), 1);
848    }
849
850    #[test]
851    fn test_select_k_best_k_equals_n_features_keeps_all() {
852        let sel = SelectKBest::<f64>::new(2, ScoreFunc::FClassif);
853        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
854        let y: Array1<usize> = array![0, 1, 0];
855        let fitted = sel.fit(&x, &y).unwrap();
856        assert_eq!(fitted.selected_indices().len(), 2);
857        let out = fitted.transform(&x).unwrap();
858        assert_eq!(out.ncols(), 2);
859    }
860
861    #[test]
862    fn test_select_k_best_scores_stored() {
863        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
864        let x = array![[1.0, 2.0], [1.0, 4.0]];
865        let y: Array1<usize> = array![0, 1];
866        let fitted = sel.fit(&x, &y).unwrap();
867        assert_eq!(fitted.scores().len(), 2);
868    }
869
870    #[test]
871    fn test_select_k_best_zero_rows_error() {
872        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
873        let x: Array2<f64> = Array2::zeros((0, 3));
874        let y: Array1<usize> = Array1::zeros(0);
875        assert!(sel.fit(&x, &y).is_err());
876    }
877
878    #[test]
879    fn test_select_k_best_k_exceeds_n_features_error() {
880        let sel = SelectKBest::<f64>::new(5, ScoreFunc::FClassif);
881        let x = array![[1.0, 2.0], [3.0, 4.0]];
882        let y: Array1<usize> = array![0, 1];
883        assert!(sel.fit(&x, &y).is_err());
884    }
885
886    #[test]
887    fn test_select_k_best_y_length_mismatch_error() {
888        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
889        let x = array![[1.0, 2.0], [3.0, 4.0]];
890        let y: Array1<usize> = array![0]; // wrong length
891        assert!(sel.fit(&x, &y).is_err());
892    }
893
894    #[test]
895    fn test_select_k_best_shape_mismatch_on_transform() {
896        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
897        let x = array![[1.0, 2.0], [3.0, 4.0]];
898        let y: Array1<usize> = array![0, 1];
899        let fitted = sel.fit(&x, &y).unwrap();
900        let x_bad = array![[1.0, 2.0, 3.0]];
901        assert!(fitted.transform(&x_bad).is_err());
902    }
903
904    #[test]
905    fn test_select_k_best_selected_indices_in_column_order() {
906        // Both features selected — indices should be [0, 1] not reversed
907        let sel = SelectKBest::<f64>::new(2, ScoreFunc::FClassif);
908        let x = array![[1.0, 100.0], [2.0, 200.0]];
909        let y: Array1<usize> = array![0, 1];
910        let fitted = sel.fit(&x, &y).unwrap();
911        let indices = fitted.selected_indices();
912        assert!(indices.windows(2).all(|w| w[0] < w[1]));
913    }
914
915    #[test]
916    fn test_select_k_best_pipeline_integration() {
917        use ferrolearn_core::pipeline::PipelineTransformer;
918        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
919        let x = array![[1.0, 5.0], [1.0, 6.0], [10.0, 5.0], [10.0, 6.0]];
920        let y = ndarray::array![0.0, 0.0, 1.0, 1.0];
921        let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
922        let out = fitted_box.transform_pipeline(&x).unwrap();
923        assert_eq!(out.ncols(), 1);
924    }
925
926    #[test]
927    fn test_select_k_best_f_score_zero_within_class_variance() {
928        // Perfectly separating feature → F should be infinity
929        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
930        let x = array![[0.0], [0.0], [10.0], [10.0]];
931        let y: Array1<usize> = array![0, 0, 1, 1];
932        let fitted = sel.fit(&x, &y).unwrap();
933        assert!(fitted.scores()[0].is_infinite());
934    }
935
936    // ========================================================================
937    // SelectFromModel tests
938    // ========================================================================
939
940    #[test]
941    fn test_select_from_model_mean_threshold() {
942        // Mean importance = (0.1 + 0.5 + 0.4) / 3 ≈ 0.333
943        // Features 1 (0.5) and 2 (0.4) are >= threshold; feature 0 (0.1) is not
944        let importances = array![0.1, 0.5, 0.4];
945        let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
946        assert_eq!(sel.selected_indices(), &[1usize, 2]);
947        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
948        let out = sel.transform(&x).unwrap();
949        assert_eq!(out.ncols(), 2);
950        assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
951        assert_abs_diff_eq!(out[[0, 1]], 3.0, epsilon = 1e-15);
952    }
953
954    #[test]
955    fn test_select_from_model_explicit_threshold() {
956        let importances = array![0.1, 0.5, 0.4];
957        // Only feature 1 (0.5 >= 0.45) is selected
958        let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.45)).unwrap();
959        assert_eq!(sel.selected_indices(), &[1usize]);
960        let x = array![[1.0, 2.0, 3.0]];
961        let out = sel.transform(&x).unwrap();
962        assert_eq!(out.ncols(), 1);
963        assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
964    }
965
966    #[test]
967    fn test_select_from_model_all_selected_when_threshold_zero() {
968        let importances = array![0.1, 0.2, 0.3];
969        let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.0)).unwrap();
970        assert_eq!(sel.selected_indices().len(), 3);
971    }
972
973    #[test]
974    fn test_select_from_model_none_selected_when_threshold_high() {
975        let importances = array![0.1, 0.2, 0.3];
976        let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(1.0)).unwrap();
977        assert_eq!(sel.selected_indices().len(), 0);
978        let x = array![[1.0, 2.0, 3.0]];
979        let out = sel.transform(&x).unwrap();
980        assert_eq!(out.ncols(), 0);
981    }
982
983    #[test]
984    fn test_select_from_model_empty_importances_error() {
985        let importances: Array1<f64> = Array1::zeros(0);
986        assert!(SelectFromModel::<f64>::new_from_importances(&importances, None).is_err());
987    }
988
989    #[test]
990    fn test_select_from_model_shape_mismatch_on_transform() {
991        let importances = array![0.3, 0.7];
992        let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
993        let x_bad = array![[1.0, 2.0, 3.0]]; // 3 cols, but 2 features expected
994        assert!(sel.transform(&x_bad).is_err());
995    }
996
997    #[test]
998    fn test_select_from_model_threshold_accessor() {
999        let importances = array![0.3, 0.7];
1000        let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.5)).unwrap();
1001        assert_abs_diff_eq!(sel.threshold(), 0.5, epsilon = 1e-15);
1002    }
1003
1004    #[test]
1005    fn test_select_from_model_pipeline_integration() {
1006        use ferrolearn_core::pipeline::PipelineTransformer;
1007        let importances = array![0.1, 0.9];
1008        let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1009        let x = array![[1.0, 2.0], [3.0, 4.0]];
1010        let y = ndarray::array![0.0, 1.0];
1011        let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
1012        let out = fitted_box.transform_pipeline(&x).unwrap();
1013        // Mean importance = 0.5; only feature 1 (0.9 >= 0.5) kept
1014        assert_eq!(out.ncols(), 1);
1015        assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
1016    }
1017
1018    #[test]
1019    fn test_select_from_model_importances_accessor() {
1020        let importances = array![0.2, 0.8];
1021        let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1022        assert_abs_diff_eq!(sel.importances()[0], 0.2, epsilon = 1e-15);
1023        assert_abs_diff_eq!(sel.importances()[1], 0.8, epsilon = 1e-15);
1024    }
1025}