Skip to main content

ferrolearn_preprocess/
feature_selection.rs

1//! Feature selection transformers.
2//!
3//! This module provides three feature selection strategies:
4//!
5//! - [`VarianceThreshold`] — remove features whose variance falls below a
6//!   configurable threshold (default 0.0 removes zero-variance features).
7//! - [`SelectKBest`] — keep the *K* features with the highest ANOVA F-scores
8//!   computed against a class label vector.
9//! - [`SelectFromModel`] — keep features whose importance weight (provided by
10//!   a previously fitted model) exceeds a configurable threshold.
11//!
12//! All three implement the standard ferrolearn `Fit` / `Transform` pattern
13//! and integrate with the dynamic [`ferrolearn_core::pipeline::Pipeline`].
14
15use ferrolearn_core::error::FerroError;
16use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
17use ferrolearn_core::traits::{Fit, Transform};
18use ndarray::{Array1, Array2};
19use num_traits::Float;
20
21// ---------------------------------------------------------------------------
22// Shared helper: collect selected columns
23// ---------------------------------------------------------------------------
24
25/// Build a new `Array2<F>` containing only the columns listed in `indices`.
26///
27/// Columns are emitted in the order they appear in `indices`.
28fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
29    let nrows = x.nrows();
30    let ncols = indices.len();
31    if ncols == 0 {
32        // Return empty matrix with correct row count
33        return Array2::zeros((nrows, 0));
34    }
35    let mut out = Array2::zeros((nrows, ncols));
36    for (new_j, &old_j) in indices.iter().enumerate() {
37        for i in 0..nrows {
38            out[[i, new_j]] = x[[i, old_j]];
39        }
40    }
41    out
42}
43
44// ===========================================================================
45// VarianceThreshold
46// ===========================================================================
47
48/// An unfitted variance-threshold feature selector.
49///
50/// During fitting the population variance of every column is computed (NaN
51/// values are treated as zero — use an imputer upstream if needed).  Columns
52/// whose variance is *less than or equal to* the configured threshold are
53/// discarded during transformation.
54///
55/// The default threshold is `0.0`, which removes features with exactly zero
56/// variance (i.e. constant columns).
57///
58/// # Examples
59///
60/// ```
61/// use ferrolearn_preprocess::feature_selection::VarianceThreshold;
62/// use ferrolearn_core::traits::{Fit, Transform};
63/// use ndarray::array;
64///
65/// let sel = VarianceThreshold::<f64>::new(0.0);
66/// // Column 1 is constant — will be removed
67/// let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
68/// let fitted = sel.fit(&x, &()).unwrap();
69/// let out = fitted.transform(&x).unwrap();
70/// assert_eq!(out.ncols(), 1);
71/// ```
72#[derive(Debug, Clone)]
73pub struct VarianceThreshold<F> {
74    /// Features with variance strictly less than this threshold are removed.
75    threshold: F,
76}
77
78impl<F: Float + Send + Sync + 'static> VarianceThreshold<F> {
79    /// Create a new `VarianceThreshold` with the given threshold.
80    ///
81    /// Pass `F::zero()` (the default) to remove only constant features.
82    ///
83    /// # Errors
84    ///
85    /// Returns [`FerroError::InvalidParameter`] if `threshold` is negative.
86    pub fn new(threshold: F) -> Self {
87        Self { threshold }
88    }
89
90    /// Return the variance threshold.
91    #[must_use]
92    pub fn threshold(&self) -> F {
93        self.threshold
94    }
95}
96
97impl<F: Float + Send + Sync + 'static> Default for VarianceThreshold<F> {
98    fn default() -> Self {
99        Self::new(F::zero())
100    }
101}
102
103// ---------------------------------------------------------------------------
104// FittedVarianceThreshold
105// ---------------------------------------------------------------------------
106
107/// A fitted variance-threshold selector holding the selected column indices
108/// and the per-column variances observed during fitting.
109///
110/// Created by calling [`Fit::fit`] on a [`VarianceThreshold`].
111#[derive(Debug, Clone)]
112pub struct FittedVarianceThreshold<F> {
113    /// Column indices (into the *original* feature matrix) that were selected.
114    selected_indices: Vec<usize>,
115    /// Per-column population variances computed during fitting.
116    variances: Array1<F>,
117}
118
119impl<F: Float + Send + Sync + 'static> FittedVarianceThreshold<F> {
120    /// Return the indices of the selected columns.
121    #[must_use]
122    pub fn selected_indices(&self) -> &[usize] {
123        &self.selected_indices
124    }
125
126    /// Return the per-column variances computed during fitting.
127    #[must_use]
128    pub fn variances(&self) -> &Array1<F> {
129        &self.variances
130    }
131}
132
133// ---------------------------------------------------------------------------
134// Trait implementations — VarianceThreshold
135// ---------------------------------------------------------------------------
136
137impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for VarianceThreshold<F> {
138    type Fitted = FittedVarianceThreshold<F>;
139    type Error = FerroError;
140
141    /// Fit by computing per-column population variances.
142    ///
143    /// # Errors
144    ///
145    /// Returns [`FerroError::InsufficientSamples`] if the input has zero rows.
146    /// Returns [`FerroError::InvalidParameter`] if the threshold is negative.
147    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedVarianceThreshold<F>, FerroError> {
148        if self.threshold < F::zero() {
149            return Err(FerroError::InvalidParameter {
150                name: "threshold".into(),
151                reason: "variance threshold must be non-negative".into(),
152            });
153        }
154        let n_samples = x.nrows();
155        if n_samples == 0 {
156            return Err(FerroError::InsufficientSamples {
157                required: 1,
158                actual: 0,
159                context: "VarianceThreshold::fit".into(),
160            });
161        }
162
163        let n = F::from(n_samples).unwrap_or_else(F::one);
164        let n_features = x.ncols();
165        let mut variances = Array1::zeros(n_features);
166        let mut selected_indices = Vec::new();
167
168        for j in 0..n_features {
169            let col = x.column(j);
170            // Welford's online algorithm — numerically stable and yields
171            // *exactly* zero for constant columns, matching numpy/sklearn.
172            // The naive `sum((v-mean)^2) / n` accumulates FP error during
173            // the `sum / n` step and produces ~1e-34 noise on constant
174            // columns, defeating the `threshold=0.0 ⇒ drop constants`
175            // contract that sklearn ships.
176            let mut mean = F::zero();
177            let mut m2 = F::zero();
178            let mut count = F::zero();
179            for &v in col.iter() {
180                count = count + F::one();
181                let delta = v - mean;
182                mean = mean + delta / count;
183                let delta2 = v - mean;
184                m2 = m2 + delta * delta2;
185            }
186            let var = m2 / n;
187            variances[j] = var;
188            if var > self.threshold {
189                selected_indices.push(j);
190            }
191        }
192
193        Ok(FittedVarianceThreshold {
194            selected_indices,
195            variances,
196        })
197    }
198}
199
200impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedVarianceThreshold<F> {
201    type Output = Array2<F>;
202    type Error = FerroError;
203
204    /// Return a matrix containing only the selected (high-variance) columns.
205    ///
206    /// # Errors
207    ///
208    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
209    /// from the number of features seen during fitting.
210    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
211        let n_original = self.variances.len();
212        if x.ncols() != n_original {
213            return Err(FerroError::ShapeMismatch {
214                expected: vec![x.nrows(), n_original],
215                actual: vec![x.nrows(), x.ncols()],
216                context: "FittedVarianceThreshold::transform".into(),
217            });
218        }
219        Ok(select_columns(x, &self.selected_indices))
220    }
221}
222
223// ---------------------------------------------------------------------------
224// Pipeline integration — VarianceThreshold (generic)
225// ---------------------------------------------------------------------------
226
227impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for VarianceThreshold<F> {
228    /// Fit using the pipeline interface; `y` is ignored.
229    ///
230    /// # Errors
231    ///
232    /// Propagates errors from [`Fit::fit`].
233    fn fit_pipeline(
234        &self,
235        x: &Array2<F>,
236        _y: &Array1<F>,
237    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
238        let fitted = self.fit(x, &())?;
239        Ok(Box::new(fitted))
240    }
241}
242
243impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedVarianceThreshold<F> {
244    /// Transform using the pipeline interface.
245    ///
246    /// # Errors
247    ///
248    /// Propagates errors from [`Transform::transform`].
249    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
250        self.transform(x)
251    }
252}
253
254// ===========================================================================
255// SelectKBest
256// ===========================================================================
257
258/// Scoring function variants for [`SelectKBest`].
259///
260/// Currently only ANOVA F-value scoring is supported.
261#[derive(Debug, Clone, Copy, PartialEq, Eq)]
262pub enum ScoreFunc {
263    /// ANOVA F-value: ratio of between-class variance to within-class variance.
264    ///
265    /// This is analogous to scikit-learn's `f_classif`.
266    FClassif,
267}
268
269/// An unfitted K-best feature selector.
270///
271/// Requires class labels (`Array1<usize>`) at fit time to compute per-feature
272/// ANOVA F-scores.  The top *K* features (by score) are retained.
273///
274/// # Examples
275///
276/// ```
277/// use ferrolearn_preprocess::feature_selection::{SelectKBest, ScoreFunc};
278/// use ferrolearn_core::traits::{Fit, Transform};
279/// use ndarray::{array, Array1};
280///
281/// let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
282/// let x = array![[1.0, 10.0], [1.0, 20.0], [2.0, 10.0], [2.0, 20.0]];
283/// let y: Array1<usize> = array![0, 0, 1, 1];
284/// let fitted = sel.fit(&x, &y).unwrap();
285/// let out = fitted.transform(&x).unwrap();
286/// assert_eq!(out.ncols(), 1);
287/// ```
288#[derive(Debug, Clone)]
289pub struct SelectKBest<F> {
290    /// Number of top-scoring features to keep.
291    k: usize,
292    /// The scoring function to use.
293    score_func: ScoreFunc,
294    _marker: std::marker::PhantomData<F>,
295}
296
297impl<F: Float + Send + Sync + 'static> SelectKBest<F> {
298    /// Create a new `SelectKBest` selector.
299    ///
300    /// # Parameters
301    ///
302    /// - `k` — the number of features to retain.
303    /// - `score_func` — the scoring function; currently only
304    ///   [`ScoreFunc::FClassif`] is available.
305    #[must_use]
306    pub fn new(k: usize, score_func: ScoreFunc) -> Self {
307        Self {
308            k,
309            score_func,
310            _marker: std::marker::PhantomData,
311        }
312    }
313
314    /// Return *k*.
315    #[must_use]
316    pub fn k(&self) -> usize {
317        self.k
318    }
319
320    /// Return the score function.
321    #[must_use]
322    pub fn score_func(&self) -> ScoreFunc {
323        self.score_func
324    }
325}
326
327// ---------------------------------------------------------------------------
328// FittedSelectKBest
329// ---------------------------------------------------------------------------
330
331/// A fitted K-best selector holding per-feature scores and selected indices.
332///
333/// Created by calling [`Fit::fit`] on a [`SelectKBest`].
334#[derive(Debug, Clone)]
335pub struct FittedSelectKBest<F> {
336    /// The original number of features (used for shape checking on transform).
337    n_features_in: usize,
338    /// Per-feature ANOVA F-scores computed during fitting.
339    scores: Array1<F>,
340    /// Indices of the selected columns, sorted in original column order.
341    selected_indices: Vec<usize>,
342}
343
344impl<F: Float + Send + Sync + 'static> FittedSelectKBest<F> {
345    /// Return the per-feature F-scores computed during fitting.
346    #[must_use]
347    pub fn scores(&self) -> &Array1<F> {
348        &self.scores
349    }
350
351    /// Return the indices of the selected columns.
352    #[must_use]
353    pub fn selected_indices(&self) -> &[usize] {
354        &self.selected_indices
355    }
356}
357
358// ---------------------------------------------------------------------------
359// ANOVA F-value helper
360// ---------------------------------------------------------------------------
361
362/// Compute per-feature ANOVA F-scores given a feature matrix `x` and integer
363/// class labels `y`.
364///
365/// For each feature column the F-statistic is:
366///
367/// ```text
368/// F = (between-class variance / (n_classes - 1))
369///   / (within-class variance  / (n_samples - n_classes))
370/// ```
371///
372/// Features for which the within-class variance is zero (perfectly separable)
373/// get an F-score of `F::infinity()`.  Features that have zero between-class
374/// variance get a score of `F::zero()`.
375fn anova_f_scores<F: Float>(x: &Array2<F>, y: &Array1<usize>) -> Vec<F> {
376    let n_samples = x.nrows();
377    let n_features = x.ncols();
378
379    // Collect unique classes and build per-class row-index lists.
380    let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
381        std::collections::HashMap::new();
382    for (i, &label) in y.iter().enumerate() {
383        class_indices.entry(label).or_default().push(i);
384    }
385    let n_classes = class_indices.len();
386
387    let mut scores = Vec::with_capacity(n_features);
388
389    for j in 0..n_features {
390        let col = x.column(j);
391
392        // Overall mean of this feature
393        let grand_mean =
394            col.iter().copied().fold(F::zero(), |acc, v| acc + v) / F::from(n_samples).unwrap();
395
396        // Between-class sum of squares: sum_k n_k * (mean_k - grand_mean)^2
397        let mut ss_between = F::zero();
398        // Within-class sum of squares: sum_k sum_{i in k} (x_i - mean_k)^2
399        let mut ss_within = F::zero();
400
401        for rows in class_indices.values() {
402            let n_k = F::from(rows.len()).unwrap();
403            let class_mean = rows
404                .iter()
405                .map(|&i| col[i])
406                .fold(F::zero(), |acc, v| acc + v)
407                / n_k;
408            let diff = class_mean - grand_mean;
409            ss_between = ss_between + n_k * diff * diff;
410            for &i in rows {
411                let d = col[i] - class_mean;
412                ss_within = ss_within + d * d;
413            }
414        }
415
416        let df_between = F::from(n_classes.saturating_sub(1)).unwrap();
417        let df_within = F::from(n_samples.saturating_sub(n_classes)).unwrap();
418
419        let f = if df_between == F::zero() || df_within == F::zero() {
420            F::zero()
421        } else {
422            let ms_between = ss_between / df_between;
423            let ms_within = ss_within / df_within;
424            if ms_within == F::zero() {
425                F::infinity()
426            } else {
427                ms_between / ms_within
428            }
429        };
430
431        scores.push(f);
432    }
433
434    scores
435}
436
437// ---------------------------------------------------------------------------
438// Trait implementations — SelectKBest
439// ---------------------------------------------------------------------------
440
441impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for SelectKBest<F> {
442    type Fitted = FittedSelectKBest<F>;
443    type Error = FerroError;
444
445    /// Fit by computing per-feature ANOVA F-scores against the class labels.
446    ///
447    /// # Errors
448    ///
449    /// - [`FerroError::InsufficientSamples`] if the input has zero rows.
450    /// - [`FerroError::InvalidParameter`] if `k` exceeds the number of features.
451    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different numbers
452    ///   of rows.
453    fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedSelectKBest<F>, FerroError> {
454        let n_samples = x.nrows();
455        if n_samples == 0 {
456            return Err(FerroError::InsufficientSamples {
457                required: 1,
458                actual: 0,
459                context: "SelectKBest::fit".into(),
460            });
461        }
462        if y.len() != n_samples {
463            return Err(FerroError::ShapeMismatch {
464                expected: vec![n_samples],
465                actual: vec![y.len()],
466                context: "SelectKBest::fit — y must have the same length as x has rows".into(),
467            });
468        }
469        let n_features = x.ncols();
470        if self.k > n_features {
471            return Err(FerroError::InvalidParameter {
472                name: "k".into(),
473                reason: format!(
474                    "k ({}) cannot exceed the number of features ({})",
475                    self.k, n_features
476                ),
477            });
478        }
479
480        let raw_scores = match self.score_func {
481            ScoreFunc::FClassif => anova_f_scores(x, y),
482        };
483
484        let scores = Array1::from_vec(raw_scores.clone());
485
486        // Determine the top-k indices (stable: break ties by preferring the
487        // lower column index so results are deterministic).
488        let mut ranked: Vec<usize> = (0..n_features).collect();
489        ranked.sort_by(|&a, &b| {
490            raw_scores[b]
491                .partial_cmp(&raw_scores[a])
492                .unwrap_or(std::cmp::Ordering::Equal)
493                // Tie: keep lower column index
494                .then(a.cmp(&b))
495        });
496
497        let mut selected_indices: Vec<usize> = ranked[..self.k].to_vec();
498        // Return in original column order for a stable output layout
499        selected_indices.sort_unstable();
500
501        Ok(FittedSelectKBest {
502            n_features_in: n_features,
503            scores,
504            selected_indices,
505        })
506    }
507}
508
509impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectKBest<F> {
510    type Output = Array2<F>;
511    type Error = FerroError;
512
513    /// Return a matrix containing only the K selected columns.
514    ///
515    /// # Errors
516    ///
517    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
518    /// from the number of features seen during fitting.
519    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
520        if x.ncols() != self.n_features_in {
521            return Err(FerroError::ShapeMismatch {
522                expected: vec![x.nrows(), self.n_features_in],
523                actual: vec![x.nrows(), x.ncols()],
524                context: "FittedSelectKBest::transform".into(),
525            });
526        }
527        Ok(select_columns(x, &self.selected_indices))
528    }
529}
530
531// ---------------------------------------------------------------------------
532// Pipeline integration — SelectKBest (generic)
533//
534// NOTE: The pipeline interface uses a *fixed* `y = Array1<f64>`, so we cannot
535// use the actual class-label vector from the pipeline.  We therefore refit
536// using the pipeline `y` converted to `usize` labels by rounding.
537// ---------------------------------------------------------------------------
538
539impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for SelectKBest<F> {
540    /// Fit using the pipeline interface.
541    ///
542    /// The continuous `y` labels are rounded to `usize` class indices.
543    ///
544    /// # Errors
545    ///
546    /// Propagates errors from [`Fit::fit`].
547    fn fit_pipeline(
548        &self,
549        x: &Array2<F>,
550        y: &Array1<F>,
551    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
552        let y_usize: Array1<usize> = y.mapv(|v| v.round().to_usize().unwrap_or(0));
553        let fitted = self.fit(x, &y_usize)?;
554        Ok(Box::new(fitted))
555    }
556}
557
558impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedSelectKBest<F> {
559    /// Transform using the pipeline interface.
560    ///
561    /// # Errors
562    ///
563    /// Propagates errors from [`Transform::transform`].
564    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
565        self.transform(x)
566    }
567}
568
569// ===========================================================================
570// SelectFromModel
571// ===========================================================================
572
573/// A feature selector driven by external feature-importance weights.
574///
575/// The importance vector is typically obtained from a fitted model (e.g. a
576/// decision-tree model's `feature_importances_` field).  Features whose
577/// importance is *strictly greater than or equal to* the threshold are kept.
578///
579/// The default threshold is the **mean importance** of all features.
580///
581/// # Examples
582///
583/// ```
584/// use ferrolearn_preprocess::feature_selection::SelectFromModel;
585/// use ferrolearn_core::traits::Transform;
586/// use ndarray::array;
587///
588/// let importances = array![0.1, 0.5, 0.4];
589/// let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
590/// let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
591/// let out = sel.transform(&x).unwrap();
592/// // Mean importance = (0.1+0.5+0.4)/3 ≈ 0.333; columns 1 and 2 are kept
593/// assert_eq!(out.ncols(), 2);
594/// ```
595#[derive(Debug, Clone)]
596pub struct SelectFromModel<F> {
597    /// Importance weight for each feature.
598    importances: Array1<F>,
599    /// The threshold: features with importance >= threshold are kept.
600    threshold: F,
601    /// Indices of selected features (original column order).
602    selected_indices: Vec<usize>,
603}
604
605impl<F: Float + Send + Sync + 'static> SelectFromModel<F> {
606    /// Create a `SelectFromModel` from a pre-computed importance vector.
607    ///
608    /// # Parameters
609    ///
610    /// - `importances` — one importance weight per feature.
611    /// - `threshold` — optional explicit threshold; if `None` the mean
612    ///   importance is used.
613    ///
614    /// # Errors
615    ///
616    /// Returns [`FerroError::InvalidParameter`] if `importances` is empty.
617    pub fn new_from_importances(
618        importances: &Array1<F>,
619        threshold: Option<F>,
620    ) -> Result<Self, FerroError> {
621        let n = importances.len();
622        if n == 0 {
623            return Err(FerroError::InvalidParameter {
624                name: "importances".into(),
625                reason: "importance vector must not be empty".into(),
626            });
627        }
628
629        let thr = threshold.unwrap_or_else(|| {
630            importances
631                .iter()
632                .copied()
633                .fold(F::zero(), |acc, v| acc + v)
634                / F::from(n).unwrap_or_else(F::one)
635        });
636
637        let selected_indices: Vec<usize> = importances
638            .iter()
639            .enumerate()
640            .filter(|&(_, &imp)| imp >= thr)
641            .map(|(j, _)| j)
642            .collect();
643
644        Ok(Self {
645            importances: importances.clone(),
646            threshold: thr,
647            selected_indices,
648        })
649    }
650
651    /// Return the threshold used to select features.
652    #[must_use]
653    pub fn threshold(&self) -> F {
654        self.threshold
655    }
656
657    /// Return the importance vector supplied at construction time.
658    #[must_use]
659    pub fn importances(&self) -> &Array1<F> {
660        &self.importances
661    }
662
663    /// Return the indices of the selected columns.
664    #[must_use]
665    pub fn selected_indices(&self) -> &[usize] {
666        &self.selected_indices
667    }
668}
669
670// ---------------------------------------------------------------------------
671// Trait implementation — SelectFromModel
672// ---------------------------------------------------------------------------
673
674impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SelectFromModel<F> {
675    type Output = Array2<F>;
676    type Error = FerroError;
677
678    /// Return a matrix containing only the columns whose importance exceeds
679    /// the threshold.
680    ///
681    /// # Errors
682    ///
683    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
684    /// from the length of the importance vector supplied at construction.
685    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
686        let n_features = self.importances.len();
687        if x.ncols() != n_features {
688            return Err(FerroError::ShapeMismatch {
689                expected: vec![x.nrows(), n_features],
690                actual: vec![x.nrows(), x.ncols()],
691                context: "SelectFromModel::transform".into(),
692            });
693        }
694        Ok(select_columns(x, &self.selected_indices))
695    }
696}
697
698// ---------------------------------------------------------------------------
699// Pipeline integration — SelectFromModel (generic)
700//
701// `SelectFromModel` is already "fitted" (importance weights are provided at
702// construction time), so `fit_pipeline` merely boxes `self.clone()`.
703// ---------------------------------------------------------------------------
704
705impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for SelectFromModel<F> {
706    /// Clone the selector and box it as a fitted pipeline transformer.
707    ///
708    /// # Errors
709    ///
710    /// This implementation never fails.
711    fn fit_pipeline(
712        &self,
713        _x: &Array2<F>,
714        _y: &Array1<F>,
715    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
716        Ok(Box::new(self.clone()))
717    }
718}
719
720impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for SelectFromModel<F> {
721    /// Transform using the pipeline interface.
722    ///
723    /// # Errors
724    ///
725    /// Propagates errors from [`Transform::transform`].
726    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
727        self.transform(x)
728    }
729}
730
731// ===========================================================================
732// Tests
733// ===========================================================================
734
735#[cfg(test)]
736mod tests {
737    use super::*;
738    use approx::assert_abs_diff_eq;
739    use ndarray::array;
740
741    // ========================================================================
742    // VarianceThreshold tests
743    // ========================================================================
744
745    #[test]
746    fn test_variance_threshold_removes_constant_column() {
747        let sel = VarianceThreshold::<f64>::new(0.0);
748        // Column 1 is constant (all 7.0)
749        let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
750        let fitted = sel.fit(&x, &()).unwrap();
751        assert_eq!(fitted.selected_indices(), &[0usize]);
752        let out = fitted.transform(&x).unwrap();
753        assert_eq!(out.ncols(), 1);
754        // Column 0 values preserved
755        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-15);
756        assert_abs_diff_eq!(out[[1, 0]], 2.0, epsilon = 1e-15);
757    }
758
759    #[test]
760    fn test_variance_threshold_keeps_all_when_above() {
761        let sel = VarianceThreshold::<f64>::new(0.0);
762        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
763        let fitted = sel.fit(&x, &()).unwrap();
764        assert_eq!(fitted.selected_indices().len(), 2);
765        let out = fitted.transform(&x).unwrap();
766        assert_eq!(out.ncols(), 2);
767    }
768
769    #[test]
770    fn test_variance_threshold_custom_threshold() {
771        let sel = VarianceThreshold::<f64>::new(1.5);
772        // Column 0: values [1,2,3], variance = 2/3 ≈ 0.667 → removed
773        // Column 1: values [10,20,30], variance = 200/3 ≈ 66.7 → kept
774        let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
775        let fitted = sel.fit(&x, &()).unwrap();
776        assert_eq!(fitted.selected_indices(), &[1usize]);
777        let out = fitted.transform(&x).unwrap();
778        assert_eq!(out.ncols(), 1);
779    }
780
781    #[test]
782    fn test_variance_threshold_stores_variances() {
783        let sel = VarianceThreshold::<f64>::default();
784        let x = array![[0.0], [0.0], [0.0]]; // constant → var = 0
785        let fitted = sel.fit(&x, &()).unwrap();
786        assert_abs_diff_eq!(fitted.variances()[0], 0.0, epsilon = 1e-15);
787    }
788
789    #[test]
790    fn test_variance_threshold_zero_rows_error() {
791        let sel = VarianceThreshold::<f64>::new(0.0);
792        let x: Array2<f64> = Array2::zeros((0, 2));
793        assert!(sel.fit(&x, &()).is_err());
794    }
795
796    #[test]
797    fn test_variance_threshold_negative_threshold_error() {
798        let sel = VarianceThreshold::<f64>::new(-0.1);
799        let x = array![[1.0], [2.0]];
800        assert!(sel.fit(&x, &()).is_err());
801    }
802
803    #[test]
804    fn test_variance_threshold_shape_mismatch_on_transform() {
805        let sel = VarianceThreshold::<f64>::new(0.0);
806        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
807        let fitted = sel.fit(&x_train, &()).unwrap();
808        let x_bad = array![[1.0, 2.0, 3.0]];
809        assert!(fitted.transform(&x_bad).is_err());
810    }
811
812    #[test]
813    fn test_variance_threshold_all_constant_columns() {
814        let sel = VarianceThreshold::<f64>::new(0.0);
815        let x = array![[5.0, 3.0], [5.0, 3.0], [5.0, 3.0]];
816        let fitted = sel.fit(&x, &()).unwrap();
817        // Both columns are constant: both removed
818        assert_eq!(fitted.selected_indices().len(), 0);
819        let out = fitted.transform(&x).unwrap();
820        assert_eq!(out.ncols(), 0);
821        assert_eq!(out.nrows(), 3);
822    }
823
824    #[test]
825    fn test_variance_threshold_pipeline_integration() {
826        use ferrolearn_core::pipeline::PipelineTransformer;
827        let sel = VarianceThreshold::<f64>::new(0.0);
828        let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
829        let y = ndarray::array![0.0, 1.0, 0.0];
830        let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
831        let out = fitted_box.transform_pipeline(&x).unwrap();
832        assert_eq!(out.ncols(), 1);
833    }
834
835    #[test]
836    fn test_variance_threshold_f32() {
837        let sel = VarianceThreshold::<f32>::new(0.0f32);
838        let x: Array2<f32> = array![[1.0f32, 5.0], [2.0, 5.0], [3.0, 5.0]];
839        let fitted = sel.fit(&x, &()).unwrap();
840        assert_eq!(fitted.selected_indices(), &[0usize]);
841    }
842
843    // ========================================================================
844    // SelectKBest tests
845    // ========================================================================
846
847    #[test]
848    fn test_select_k_best_selects_highest_scoring_feature() {
849        // Feature 0 separates classes perfectly; feature 1 does not.
850        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
851        let x = array![[1.0, 5.0], [1.0, 6.0], [10.0, 5.0], [10.0, 6.0]];
852        let y: Array1<usize> = array![0, 0, 1, 1];
853        let fitted = sel.fit(&x, &y).unwrap();
854        // Column 0 should be selected (high F-score)
855        assert_eq!(fitted.selected_indices(), &[0usize]);
856        let out = fitted.transform(&x).unwrap();
857        assert_eq!(out.ncols(), 1);
858    }
859
860    #[test]
861    fn test_select_k_best_k_equals_n_features_keeps_all() {
862        let sel = SelectKBest::<f64>::new(2, ScoreFunc::FClassif);
863        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
864        let y: Array1<usize> = array![0, 1, 0];
865        let fitted = sel.fit(&x, &y).unwrap();
866        assert_eq!(fitted.selected_indices().len(), 2);
867        let out = fitted.transform(&x).unwrap();
868        assert_eq!(out.ncols(), 2);
869    }
870
871    #[test]
872    fn test_select_k_best_scores_stored() {
873        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
874        let x = array![[1.0, 2.0], [1.0, 4.0]];
875        let y: Array1<usize> = array![0, 1];
876        let fitted = sel.fit(&x, &y).unwrap();
877        assert_eq!(fitted.scores().len(), 2);
878    }
879
880    #[test]
881    fn test_select_k_best_zero_rows_error() {
882        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
883        let x: Array2<f64> = Array2::zeros((0, 3));
884        let y: Array1<usize> = Array1::zeros(0);
885        assert!(sel.fit(&x, &y).is_err());
886    }
887
888    #[test]
889    fn test_select_k_best_k_exceeds_n_features_error() {
890        let sel = SelectKBest::<f64>::new(5, ScoreFunc::FClassif);
891        let x = array![[1.0, 2.0], [3.0, 4.0]];
892        let y: Array1<usize> = array![0, 1];
893        assert!(sel.fit(&x, &y).is_err());
894    }
895
896    #[test]
897    fn test_select_k_best_y_length_mismatch_error() {
898        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
899        let x = array![[1.0, 2.0], [3.0, 4.0]];
900        let y: Array1<usize> = array![0]; // wrong length
901        assert!(sel.fit(&x, &y).is_err());
902    }
903
904    #[test]
905    fn test_select_k_best_shape_mismatch_on_transform() {
906        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
907        let x = array![[1.0, 2.0], [3.0, 4.0]];
908        let y: Array1<usize> = array![0, 1];
909        let fitted = sel.fit(&x, &y).unwrap();
910        let x_bad = array![[1.0, 2.0, 3.0]];
911        assert!(fitted.transform(&x_bad).is_err());
912    }
913
914    #[test]
915    fn test_select_k_best_selected_indices_in_column_order() {
916        // Both features selected — indices should be [0, 1] not reversed
917        let sel = SelectKBest::<f64>::new(2, ScoreFunc::FClassif);
918        let x = array![[1.0, 100.0], [2.0, 200.0]];
919        let y: Array1<usize> = array![0, 1];
920        let fitted = sel.fit(&x, &y).unwrap();
921        let indices = fitted.selected_indices();
922        assert!(indices.windows(2).all(|w| w[0] < w[1]));
923    }
924
925    #[test]
926    fn test_select_k_best_pipeline_integration() {
927        use ferrolearn_core::pipeline::PipelineTransformer;
928        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
929        let x = array![[1.0, 5.0], [1.0, 6.0], [10.0, 5.0], [10.0, 6.0]];
930        let y = ndarray::array![0.0, 0.0, 1.0, 1.0];
931        let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
932        let out = fitted_box.transform_pipeline(&x).unwrap();
933        assert_eq!(out.ncols(), 1);
934    }
935
936    #[test]
937    fn test_select_k_best_f_score_zero_within_class_variance() {
938        // Perfectly separating feature → F should be infinity
939        let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
940        let x = array![[0.0], [0.0], [10.0], [10.0]];
941        let y: Array1<usize> = array![0, 0, 1, 1];
942        let fitted = sel.fit(&x, &y).unwrap();
943        assert!(fitted.scores()[0].is_infinite());
944    }
945
946    // ========================================================================
947    // SelectFromModel tests
948    // ========================================================================
949
950    #[test]
951    fn test_select_from_model_mean_threshold() {
952        // Mean importance = (0.1 + 0.5 + 0.4) / 3 ≈ 0.333
953        // Features 1 (0.5) and 2 (0.4) are >= threshold; feature 0 (0.1) is not
954        let importances = array![0.1, 0.5, 0.4];
955        let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
956        assert_eq!(sel.selected_indices(), &[1usize, 2]);
957        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
958        let out = sel.transform(&x).unwrap();
959        assert_eq!(out.ncols(), 2);
960        assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
961        assert_abs_diff_eq!(out[[0, 1]], 3.0, epsilon = 1e-15);
962    }
963
964    #[test]
965    fn test_select_from_model_explicit_threshold() {
966        let importances = array![0.1, 0.5, 0.4];
967        // Only feature 1 (0.5 >= 0.45) is selected
968        let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.45)).unwrap();
969        assert_eq!(sel.selected_indices(), &[1usize]);
970        let x = array![[1.0, 2.0, 3.0]];
971        let out = sel.transform(&x).unwrap();
972        assert_eq!(out.ncols(), 1);
973        assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
974    }
975
976    #[test]
977    fn test_select_from_model_all_selected_when_threshold_zero() {
978        let importances = array![0.1, 0.2, 0.3];
979        let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.0)).unwrap();
980        assert_eq!(sel.selected_indices().len(), 3);
981    }
982
983    #[test]
984    fn test_select_from_model_none_selected_when_threshold_high() {
985        let importances = array![0.1, 0.2, 0.3];
986        let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(1.0)).unwrap();
987        assert_eq!(sel.selected_indices().len(), 0);
988        let x = array![[1.0, 2.0, 3.0]];
989        let out = sel.transform(&x).unwrap();
990        assert_eq!(out.ncols(), 0);
991    }
992
993    #[test]
994    fn test_select_from_model_empty_importances_error() {
995        let importances: Array1<f64> = Array1::zeros(0);
996        assert!(SelectFromModel::<f64>::new_from_importances(&importances, None).is_err());
997    }
998
999    #[test]
1000    fn test_select_from_model_shape_mismatch_on_transform() {
1001        let importances = array![0.3, 0.7];
1002        let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1003        let x_bad = array![[1.0, 2.0, 3.0]]; // 3 cols, but 2 features expected
1004        assert!(sel.transform(&x_bad).is_err());
1005    }
1006
1007    #[test]
1008    fn test_select_from_model_threshold_accessor() {
1009        let importances = array![0.3, 0.7];
1010        let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.5)).unwrap();
1011        assert_abs_diff_eq!(sel.threshold(), 0.5, epsilon = 1e-15);
1012    }
1013
1014    #[test]
1015    fn test_select_from_model_pipeline_integration() {
1016        use ferrolearn_core::pipeline::PipelineTransformer;
1017        let importances = array![0.1, 0.9];
1018        let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1019        let x = array![[1.0, 2.0], [3.0, 4.0]];
1020        let y = ndarray::array![0.0, 1.0];
1021        let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
1022        let out = fitted_box.transform_pipeline(&x).unwrap();
1023        // Mean importance = 0.5; only feature 1 (0.9 >= 0.5) kept
1024        assert_eq!(out.ncols(), 1);
1025        assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
1026    }
1027
1028    #[test]
1029    fn test_select_from_model_importances_accessor() {
1030        let importances = array![0.2, 0.8];
1031        let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1032        assert_abs_diff_eq!(sel.importances()[0], 0.2, epsilon = 1e-15);
1033        assert_abs_diff_eq!(sel.importances()[1], 0.8, epsilon = 1e-15);
1034    }
1035}