Skip to main content

ferrolearn_preprocess/
rfe.rs

1//! Recursive Feature Elimination (RFE) and RFE with Cross-Validation (RFECV).
2//!
3//! [`RFE`] recursively removes the least-important features, ranking features
4//! by their importance at each elimination step. The importance is determined by
5//! an external importance vector that the user supplies via a callback.
6//!
7//! [`RFECV`] extends RFE by using cross-validation to find the optimal number
8//! of features to retain.
9//!
10//! Because `ferrolearn-preprocess` cannot depend on estimator crates (to avoid
11//! circular dependencies), these implementations accept feature importance
12//! vectors directly rather than wrapping fitted estimators.
13
14use ferrolearn_core::error::FerroError;
15use ferrolearn_core::traits::Transform;
16use ndarray::{Array1, Array2};
17use num_traits::Float;
18
19/// Build a new `Array2<F>` containing only the columns listed in `indices`.
20fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
21    let nrows = x.nrows();
22    let ncols = indices.len();
23    if ncols == 0 {
24        return Array2::zeros((nrows, 0));
25    }
26    let mut out = Array2::zeros((nrows, ncols));
27    for (new_j, &old_j) in indices.iter().enumerate() {
28        for i in 0..nrows {
29            out[[i, new_j]] = x[[i, old_j]];
30        }
31    }
32    out
33}
34
35// ===========================================================================
36// RFE
37// ===========================================================================
38
39/// Recursive Feature Elimination.
40///
41/// Starting from all features, repeatedly removes the `step` least-important
42/// features until `n_features_to_select` features remain. The ranking is
43/// determined by the importance vector supplied at construction.
44///
45/// # Examples
46///
47/// ```
48/// use ferrolearn_preprocess::rfe::RFE;
49/// use ferrolearn_core::traits::Transform;
50/// use ndarray::array;
51///
52/// // Feature importances: feature 0 is most important, feature 2 least
53/// let importances = array![0.6, 0.3, 0.1];
54/// let rfe = RFE::<f64>::new(&importances, 1, 1).unwrap();
55/// assert_eq!(rfe.support(), &[true, false, false]);
56/// assert_eq!(rfe.ranking(), &[1, 2, 3]);
57/// let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
58/// let out = rfe.transform(&x).unwrap();
59/// assert_eq!(out.ncols(), 1);
60/// ```
61#[derive(Debug, Clone)]
62pub struct RFE<F> {
63    /// Feature ranking: `ranking[j]` is the rank of feature `j` (1 = selected).
64    ranking: Vec<usize>,
65    /// Boolean mask: `support[j]` is `true` if feature `j` is selected.
66    support: Vec<bool>,
67    /// Indices of the selected features (sorted).
68    selected_indices: Vec<usize>,
69    /// Original number of features.
70    n_features_in: usize,
71    _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float + Send + Sync + 'static> RFE<F> {
75    /// Create a new `RFE` from pre-computed feature importances.
76    ///
77    /// # Parameters
78    ///
79    /// - `importances` — per-feature importance scores (higher = more important).
80    /// - `n_features_to_select` — number of features to keep.
81    /// - `step` — number of features to remove per iteration.
82    ///
83    /// # Errors
84    ///
85    /// - [`FerroError::InvalidParameter`] if `importances` is empty, `step` is
86    ///   zero, or `n_features_to_select` exceeds the number of features.
87    pub fn new(
88        importances: &Array1<F>,
89        n_features_to_select: usize,
90        step: usize,
91    ) -> Result<Self, FerroError> {
92        let n_features = importances.len();
93        if n_features == 0 {
94            return Err(FerroError::InvalidParameter {
95                name: "importances".into(),
96                reason: "importance vector must not be empty".into(),
97            });
98        }
99        if step == 0 {
100            return Err(FerroError::InvalidParameter {
101                name: "step".into(),
102                reason: "step must be at least 1".into(),
103            });
104        }
105        if n_features_to_select == 0 || n_features_to_select > n_features {
106            return Err(FerroError::InvalidParameter {
107                name: "n_features_to_select".into(),
108                reason: format!(
109                    "n_features_to_select ({n_features_to_select}) must be in [1, {n_features}]"
110                ),
111            });
112        }
113
114        // Simulate the elimination process.
115        // Track which round each feature is eliminated in; features removed in
116        // the same step share the same rank. Selected features get rank 1,
117        // features removed in the last elimination round get rank 2, etc.
118        let mut ranking = vec![0usize; n_features];
119        let mut remaining: Vec<usize> = (0..n_features).collect();
120        let mut elimination_rounds: Vec<Vec<usize>> = Vec::new();
121
122        // Working copy of importances
123        let imp: Vec<F> = importances.iter().copied().collect();
124
125        while remaining.len() > n_features_to_select {
126            // Sort remaining by importance (ascending)
127            remaining.sort_by(|&a, &b| {
128                imp[a]
129                    .partial_cmp(&imp[b])
130                    .unwrap_or(std::cmp::Ordering::Equal)
131            });
132
133            // Remove the `step` least important features
134            let n_to_remove = step.min(remaining.len() - n_features_to_select);
135            let removed: Vec<usize> = remaining[..n_to_remove].to_vec();
136            elimination_rounds.push(removed);
137            remaining = remaining[n_to_remove..].to_vec();
138        }
139
140        // Assign ranks: selected features get rank 1, features removed in the
141        // last round get rank 2, second-to-last round gets rank 3, etc.
142        for &idx in &remaining {
143            ranking[idx] = 1;
144        }
145        for (round_idx, round) in elimination_rounds.iter().rev().enumerate() {
146            let rank = round_idx + 2;
147            for &idx in round {
148                ranking[idx] = rank;
149            }
150        }
151
152        let support: Vec<bool> = ranking.iter().map(|&r| r == 1).collect();
153        let mut selected_indices: Vec<usize> = remaining;
154        selected_indices.sort_unstable();
155
156        Ok(Self {
157            ranking,
158            support,
159            selected_indices,
160            n_features_in: n_features,
161            _marker: std::marker::PhantomData,
162        })
163    }
164
165    /// Return the feature ranking (1 = best, higher = eliminated earlier).
166    #[must_use]
167    pub fn ranking(&self) -> &[usize] {
168        &self.ranking
169    }
170
171    /// Return the boolean support mask.
172    #[must_use]
173    pub fn support(&self) -> &[bool] {
174        &self.support
175    }
176
177    /// Return the indices of the selected features.
178    #[must_use]
179    pub fn selected_indices(&self) -> &[usize] {
180        &self.selected_indices
181    }
182
183    /// Return the number of selected features.
184    #[must_use]
185    pub fn n_features_selected(&self) -> usize {
186        self.selected_indices.len()
187    }
188}
189
190impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for RFE<F> {
191    type Output = Array2<F>;
192    type Error = FerroError;
193
194    /// Return a matrix containing only the selected features.
195    ///
196    /// # Errors
197    ///
198    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
199    /// from the number of features used at construction.
200    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
201        if x.ncols() != self.n_features_in {
202            return Err(FerroError::ShapeMismatch {
203                expected: vec![x.nrows(), self.n_features_in],
204                actual: vec![x.nrows(), x.ncols()],
205                context: "RFE::transform".into(),
206            });
207        }
208        Ok(select_columns(x, &self.selected_indices))
209    }
210}
211
212// ===========================================================================
213// RFECV
214// ===========================================================================
215
216/// Recursive Feature Elimination with Cross-Validation.
217///
218/// Like [`RFE`], but uses cross-validation scores to determine the optimal
219/// number of features. The user supplies a vector of per-feature-count CV
220/// scores (e.g., from running RFE with different `n_features_to_select`
221/// values), and RFECV picks the number that maximises the score.
222///
223/// # Examples
224///
225/// ```
226/// use ferrolearn_preprocess::rfe::RFECV;
227/// use ferrolearn_core::traits::Transform;
228/// use ndarray::array;
229///
230/// let importances = array![0.5, 0.3, 0.2];
231/// // CV scores for selecting 1, 2, 3 features:
232/// let cv_scores = vec![0.85, 0.95, 0.90];
233/// let rfecv = RFECV::<f64>::new(&importances, &cv_scores, 1).unwrap();
234/// // Best is 2 features (score 0.95)
235/// assert_eq!(rfecv.n_features_selected(), 2);
236/// let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
237/// let out = rfecv.transform(&x).unwrap();
238/// assert_eq!(out.ncols(), 2);
239/// ```
240#[derive(Debug, Clone)]
241pub struct RFECV<F> {
242    /// The underlying RFE with the optimal number of features.
243    rfe: RFE<F>,
244    /// CV scores for each number of features (1..=n_features).
245    cv_scores: Vec<f64>,
246    /// The optimal number of features (1-indexed).
247    optimal_n_features: usize,
248}
249
250impl<F: Float + Send + Sync + 'static> RFECV<F> {
251    /// Create a new `RFECV` from pre-computed importances and CV scores.
252    ///
253    /// # Parameters
254    ///
255    /// - `importances` — per-feature importance scores.
256    /// - `cv_scores` — CV score for each possible number of features
257    ///   (index 0 = 1 feature, index 1 = 2 features, ...).
258    /// - `step` — features removed per iteration.
259    ///
260    /// # Errors
261    ///
262    /// - [`FerroError::InvalidParameter`] if inputs are empty or mismatched.
263    pub fn new(
264        importances: &Array1<F>,
265        cv_scores: &[f64],
266        step: usize,
267    ) -> Result<Self, FerroError> {
268        let n_features = importances.len();
269        if n_features == 0 {
270            return Err(FerroError::InvalidParameter {
271                name: "importances".into(),
272                reason: "importance vector must not be empty".into(),
273            });
274        }
275        if cv_scores.len() != n_features {
276            return Err(FerroError::InvalidParameter {
277                name: "cv_scores".into(),
278                reason: format!(
279                    "cv_scores length ({}) must equal number of features ({})",
280                    cv_scores.len(),
281                    n_features
282                ),
283            });
284        }
285
286        // Find the optimal number of features (1-indexed)
287        let mut best_idx = 0;
288        let mut best_score = f64::NEG_INFINITY;
289        for (i, &score) in cv_scores.iter().enumerate() {
290            if score > best_score {
291                best_score = score;
292                best_idx = i;
293            }
294        }
295        let optimal_n_features = best_idx + 1;
296
297        let rfe = RFE::new(importances, optimal_n_features, step)?;
298
299        Ok(Self {
300            rfe,
301            cv_scores: cv_scores.to_vec(),
302            optimal_n_features,
303        })
304    }
305
306    /// Return the CV scores.
307    #[must_use]
308    pub fn cv_scores(&self) -> &[f64] {
309        &self.cv_scores
310    }
311
312    /// Return the optimal number of features.
313    #[must_use]
314    pub fn optimal_n_features(&self) -> usize {
315        self.optimal_n_features
316    }
317
318    /// Return the number of selected features.
319    #[must_use]
320    pub fn n_features_selected(&self) -> usize {
321        self.rfe.n_features_selected()
322    }
323
324    /// Return the feature ranking.
325    #[must_use]
326    pub fn ranking(&self) -> &[usize] {
327        self.rfe.ranking()
328    }
329
330    /// Return the boolean support mask.
331    #[must_use]
332    pub fn support(&self) -> &[bool] {
333        self.rfe.support()
334    }
335
336    /// Return the indices of the selected features.
337    #[must_use]
338    pub fn selected_indices(&self) -> &[usize] {
339        self.rfe.selected_indices()
340    }
341}
342
343impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for RFECV<F> {
344    type Output = Array2<F>;
345    type Error = FerroError;
346
347    /// Return a matrix containing only the optimally selected features.
348    ///
349    /// # Errors
350    ///
351    /// Returns [`FerroError::ShapeMismatch`] if column count does not match.
352    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
353        self.rfe.transform(x)
354    }
355}
356
357// ---------------------------------------------------------------------------
358// Tests
359// ---------------------------------------------------------------------------
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use approx::assert_abs_diff_eq;
365    use ndarray::array;
366
367    // ========================================================================
368    // RFE tests
369    // ========================================================================
370
371    #[test]
372    fn test_rfe_basic_ranking() {
373        // Importances: [0.6, 0.3, 0.1]
374        // Select 1 feature, step 1
375        // Round 1: remove feature 2 (lowest 0.1) → remaining [0, 1]
376        // Round 2: remove feature 1 (lowest 0.3) → remaining [0]
377        let imp = array![0.6, 0.3, 0.1];
378        let rfe = RFE::<f64>::new(&imp, 1, 1).unwrap();
379        assert_eq!(rfe.ranking(), &[1, 2, 3]);
380        assert_eq!(rfe.support(), &[true, false, false]);
381        assert_eq!(rfe.selected_indices(), &[0]);
382    }
383
384    #[test]
385    fn test_rfe_select_two() {
386        let imp = array![0.5, 0.3, 0.2];
387        let rfe = RFE::<f64>::new(&imp, 2, 1).unwrap();
388        assert_eq!(rfe.n_features_selected(), 2);
389        // Feature 2 (0.2) should be eliminated first
390        assert_eq!(rfe.ranking()[2], 2); // eliminated in round 1
391        assert_eq!(rfe.ranking()[0], 1);
392        assert_eq!(rfe.ranking()[1], 1);
393    }
394
395    #[test]
396    fn test_rfe_step_two() {
397        let imp = array![0.5, 0.3, 0.2, 0.1];
398        // Select 2, step 2: remove 2 features at once
399        let rfe = RFE::<f64>::new(&imp, 2, 2).unwrap();
400        assert_eq!(rfe.n_features_selected(), 2);
401        assert!(rfe.support()[0]);
402        assert!(rfe.support()[1]);
403        assert!(!rfe.support()[2]);
404        assert!(!rfe.support()[3]);
405    }
406
407    #[test]
408    fn test_rfe_transform() {
409        let imp = array![0.6, 0.3, 0.1];
410        let rfe = RFE::<f64>::new(&imp, 1, 1).unwrap();
411        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
412        let out = rfe.transform(&x).unwrap();
413        assert_eq!(out.ncols(), 1);
414        // Feature 0 is selected
415        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-15);
416        assert_abs_diff_eq!(out[[1, 0]], 4.0, epsilon = 1e-15);
417    }
418
419    #[test]
420    fn test_rfe_all_features_selected() {
421        let imp = array![0.5, 0.3, 0.2];
422        let rfe = RFE::<f64>::new(&imp, 3, 1).unwrap();
423        assert_eq!(rfe.n_features_selected(), 3);
424        assert!(rfe.support().iter().all(|&s| s));
425    }
426
427    #[test]
428    fn test_rfe_empty_importances_error() {
429        let imp: Array1<f64> = Array1::zeros(0);
430        assert!(RFE::<f64>::new(&imp, 1, 1).is_err());
431    }
432
433    #[test]
434    fn test_rfe_zero_step_error() {
435        let imp = array![0.5, 0.3];
436        assert!(RFE::<f64>::new(&imp, 1, 0).is_err());
437    }
438
439    #[test]
440    fn test_rfe_n_features_too_large_error() {
441        let imp = array![0.5, 0.3];
442        assert!(RFE::<f64>::new(&imp, 5, 1).is_err());
443    }
444
445    #[test]
446    fn test_rfe_n_features_zero_error() {
447        let imp = array![0.5, 0.3];
448        assert!(RFE::<f64>::new(&imp, 0, 1).is_err());
449    }
450
451    #[test]
452    fn test_rfe_shape_mismatch_error() {
453        let imp = array![0.5, 0.3];
454        let rfe = RFE::<f64>::new(&imp, 1, 1).unwrap();
455        let x_bad = array![[1.0, 2.0, 3.0]];
456        assert!(rfe.transform(&x_bad).is_err());
457    }
458
459    // ========================================================================
460    // RFECV tests
461    // ========================================================================
462
463    #[test]
464    fn test_rfecv_selects_optimal() {
465        let imp = array![0.5, 0.3, 0.2];
466        // Best CV score at 2 features
467        let cv_scores = vec![0.85, 0.95, 0.90];
468        let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
469        assert_eq!(rfecv.optimal_n_features(), 2);
470        assert_eq!(rfecv.n_features_selected(), 2);
471    }
472
473    #[test]
474    fn test_rfecv_transform() {
475        let imp = array![0.5, 0.3, 0.2];
476        let cv_scores = vec![0.85, 0.95, 0.90];
477        let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
478        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
479        let out = rfecv.transform(&x).unwrap();
480        assert_eq!(out.ncols(), 2);
481    }
482
483    #[test]
484    fn test_rfecv_cv_scores_accessor() {
485        let imp = array![0.5, 0.3];
486        let cv_scores = vec![0.9, 0.8];
487        let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
488        assert_eq!(rfecv.cv_scores(), &[0.9, 0.8]);
489        // Best is 1 feature (score 0.9)
490        assert_eq!(rfecv.optimal_n_features(), 1);
491    }
492
493    #[test]
494    fn test_rfecv_mismatched_scores_error() {
495        let imp = array![0.5, 0.3, 0.2];
496        let cv_scores = vec![0.85, 0.95]; // wrong length
497        assert!(RFECV::<f64>::new(&imp, &cv_scores, 1).is_err());
498    }
499
500    #[test]
501    fn test_rfecv_empty_importances_error() {
502        let imp: Array1<f64> = Array1::zeros(0);
503        let cv_scores: Vec<f64> = vec![];
504        assert!(RFECV::<f64>::new(&imp, &cv_scores, 1).is_err());
505    }
506
507    #[test]
508    fn test_rfecv_ranking_and_support() {
509        let imp = array![0.5, 0.3, 0.2];
510        let cv_scores = vec![0.80, 0.95, 0.90];
511        let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
512        assert_eq!(rfecv.n_features_selected(), 2);
513        let support = rfecv.support();
514        assert_eq!(support.iter().filter(|&&s| s).count(), 2);
515    }
516}