Skip to main content

ferrolearn_preprocess/
rfe.rs

1//! Recursive Feature Elimination (RFE) and RFE with Cross-Validation (RFECV).
2//!
3//! [`RFE`] recursively removes the least-important features, ranking features
4//! by their importance at each elimination step. The importance is determined by
5//! an external importance vector that the user supplies via a callback.
6//!
7//! [`RFECV`] extends RFE by using cross-validation to find the optimal number
8//! of features to retain.
9//!
10//! Because `ferrolearn-preprocess` cannot depend on estimator crates (to avoid
11//! circular dependencies), these implementations accept feature importance
12//! vectors directly rather than wrapping fitted estimators.
13
14use ferrolearn_core::error::FerroError;
15use ferrolearn_core::traits::Transform;
16use ndarray::{Array1, Array2};
17use num_traits::Float;
18
19/// Build a new `Array2<F>` containing only the columns listed in `indices`.
20fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
21    let nrows = x.nrows();
22    let ncols = indices.len();
23    if ncols == 0 {
24        return Array2::zeros((nrows, 0));
25    }
26    let mut out = Array2::zeros((nrows, ncols));
27    for (new_j, &old_j) in indices.iter().enumerate() {
28        for i in 0..nrows {
29            out[[i, new_j]] = x[[i, old_j]];
30        }
31    }
32    out
33}
34
35// ===========================================================================
36// RFE
37// ===========================================================================
38
39/// Recursive Feature Elimination.
40///
41/// Starting from all features, repeatedly removes the `step` least-important
42/// features until `n_features_to_select` features remain. The ranking is
43/// determined by the importance vector supplied at construction.
44///
45/// # Examples
46///
47/// ```
48/// use ferrolearn_preprocess::rfe::RFE;
49/// use ferrolearn_core::traits::Transform;
50/// use ndarray::array;
51///
52/// // Feature importances: feature 0 is most important, feature 2 least
53/// let importances = array![0.6, 0.3, 0.1];
54/// let rfe = RFE::<f64>::new(&importances, 1, 1).unwrap();
55/// assert_eq!(rfe.support(), &[true, false, false]);
56/// assert_eq!(rfe.ranking(), &[1, 2, 3]);
57/// let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
58/// let out = rfe.transform(&x).unwrap();
59/// assert_eq!(out.ncols(), 1);
60/// ```
61#[derive(Debug, Clone)]
62pub struct RFE<F> {
63    /// Feature ranking: `ranking[j]` is the rank of feature `j` (1 = selected).
64    ranking: Vec<usize>,
65    /// Boolean mask: `support[j]` is `true` if feature `j` is selected.
66    support: Vec<bool>,
67    /// Indices of the selected features (sorted).
68    selected_indices: Vec<usize>,
69    /// Original number of features.
70    n_features_in: usize,
71    _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float + Send + Sync + 'static> RFE<F> {
75    /// Create a new `RFE` from pre-computed feature importances.
76    ///
77    /// # Parameters
78    ///
79    /// - `importances` — per-feature importance scores (higher = more important).
80    /// - `n_features_to_select` — number of features to keep.
81    /// - `step` — number of features to remove per iteration.
82    ///
83    /// # Errors
84    ///
85    /// - [`FerroError::InvalidParameter`] if `importances` is empty, `step` is
86    ///   zero, or `n_features_to_select` exceeds the number of features.
87    pub fn new(
88        importances: &Array1<F>,
89        n_features_to_select: usize,
90        step: usize,
91    ) -> Result<Self, FerroError> {
92        let n_features = importances.len();
93        if n_features == 0 {
94            return Err(FerroError::InvalidParameter {
95                name: "importances".into(),
96                reason: "importance vector must not be empty".into(),
97            });
98        }
99        if step == 0 {
100            return Err(FerroError::InvalidParameter {
101                name: "step".into(),
102                reason: "step must be at least 1".into(),
103            });
104        }
105        if n_features_to_select == 0 || n_features_to_select > n_features {
106            return Err(FerroError::InvalidParameter {
107                name: "n_features_to_select".into(),
108                reason: format!(
109                    "n_features_to_select ({}) must be in [1, {}]",
110                    n_features_to_select, n_features
111                ),
112            });
113        }
114
115        // Simulate the elimination process.
116        // Track which round each feature is eliminated in; features removed in
117        // the same step share the same rank. Selected features get rank 1,
118        // features removed in the last elimination round get rank 2, etc.
119        let mut ranking = vec![0usize; n_features];
120        let mut remaining: Vec<usize> = (0..n_features).collect();
121        let mut elimination_rounds: Vec<Vec<usize>> = Vec::new();
122
123        // Working copy of importances
124        let imp: Vec<F> = importances.iter().copied().collect();
125
126        while remaining.len() > n_features_to_select {
127            // Sort remaining by importance (ascending)
128            remaining.sort_by(|&a, &b| {
129                imp[a]
130                    .partial_cmp(&imp[b])
131                    .unwrap_or(std::cmp::Ordering::Equal)
132            });
133
134            // Remove the `step` least important features
135            let n_to_remove = step.min(remaining.len() - n_features_to_select);
136            let removed: Vec<usize> = remaining[..n_to_remove].to_vec();
137            elimination_rounds.push(removed);
138            remaining = remaining[n_to_remove..].to_vec();
139        }
140
141        // Assign ranks: selected features get rank 1, features removed in the
142        // last round get rank 2, second-to-last round gets rank 3, etc.
143        for &idx in &remaining {
144            ranking[idx] = 1;
145        }
146        for (round_idx, round) in elimination_rounds.iter().rev().enumerate() {
147            let rank = round_idx + 2;
148            for &idx in round {
149                ranking[idx] = rank;
150            }
151        }
152
153        let support: Vec<bool> = ranking.iter().map(|&r| r == 1).collect();
154        let mut selected_indices: Vec<usize> = remaining;
155        selected_indices.sort_unstable();
156
157        Ok(Self {
158            ranking,
159            support,
160            selected_indices,
161            n_features_in: n_features,
162            _marker: std::marker::PhantomData,
163        })
164    }
165
166    /// Return the feature ranking (1 = best, higher = eliminated earlier).
167    #[must_use]
168    pub fn ranking(&self) -> &[usize] {
169        &self.ranking
170    }
171
172    /// Return the boolean support mask.
173    #[must_use]
174    pub fn support(&self) -> &[bool] {
175        &self.support
176    }
177
178    /// Return the indices of the selected features.
179    #[must_use]
180    pub fn selected_indices(&self) -> &[usize] {
181        &self.selected_indices
182    }
183
184    /// Return the number of selected features.
185    #[must_use]
186    pub fn n_features_selected(&self) -> usize {
187        self.selected_indices.len()
188    }
189}
190
191impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for RFE<F> {
192    type Output = Array2<F>;
193    type Error = FerroError;
194
195    /// Return a matrix containing only the selected features.
196    ///
197    /// # Errors
198    ///
199    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
200    /// from the number of features used at construction.
201    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
202        if x.ncols() != self.n_features_in {
203            return Err(FerroError::ShapeMismatch {
204                expected: vec![x.nrows(), self.n_features_in],
205                actual: vec![x.nrows(), x.ncols()],
206                context: "RFE::transform".into(),
207            });
208        }
209        Ok(select_columns(x, &self.selected_indices))
210    }
211}
212
213// ===========================================================================
214// RFECV
215// ===========================================================================
216
217/// Recursive Feature Elimination with Cross-Validation.
218///
219/// Like [`RFE`], but uses cross-validation scores to determine the optimal
220/// number of features. The user supplies a vector of per-feature-count CV
221/// scores (e.g., from running RFE with different `n_features_to_select`
222/// values), and RFECV picks the number that maximises the score.
223///
224/// # Examples
225///
226/// ```
227/// use ferrolearn_preprocess::rfe::RFECV;
228/// use ferrolearn_core::traits::Transform;
229/// use ndarray::array;
230///
231/// let importances = array![0.5, 0.3, 0.2];
232/// // CV scores for selecting 1, 2, 3 features:
233/// let cv_scores = vec![0.85, 0.95, 0.90];
234/// let rfecv = RFECV::<f64>::new(&importances, &cv_scores, 1).unwrap();
235/// // Best is 2 features (score 0.95)
236/// assert_eq!(rfecv.n_features_selected(), 2);
237/// let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
238/// let out = rfecv.transform(&x).unwrap();
239/// assert_eq!(out.ncols(), 2);
240/// ```
241#[derive(Debug, Clone)]
242pub struct RFECV<F> {
243    /// The underlying RFE with the optimal number of features.
244    rfe: RFE<F>,
245    /// CV scores for each number of features (1..=n_features).
246    cv_scores: Vec<f64>,
247    /// The optimal number of features (1-indexed).
248    optimal_n_features: usize,
249}
250
251impl<F: Float + Send + Sync + 'static> RFECV<F> {
252    /// Create a new `RFECV` from pre-computed importances and CV scores.
253    ///
254    /// # Parameters
255    ///
256    /// - `importances` — per-feature importance scores.
257    /// - `cv_scores` — CV score for each possible number of features
258    ///   (index 0 = 1 feature, index 1 = 2 features, ...).
259    /// - `step` — features removed per iteration.
260    ///
261    /// # Errors
262    ///
263    /// - [`FerroError::InvalidParameter`] if inputs are empty or mismatched.
264    pub fn new(
265        importances: &Array1<F>,
266        cv_scores: &[f64],
267        step: usize,
268    ) -> Result<Self, FerroError> {
269        let n_features = importances.len();
270        if n_features == 0 {
271            return Err(FerroError::InvalidParameter {
272                name: "importances".into(),
273                reason: "importance vector must not be empty".into(),
274            });
275        }
276        if cv_scores.len() != n_features {
277            return Err(FerroError::InvalidParameter {
278                name: "cv_scores".into(),
279                reason: format!(
280                    "cv_scores length ({}) must equal number of features ({})",
281                    cv_scores.len(),
282                    n_features
283                ),
284            });
285        }
286
287        // Find the optimal number of features (1-indexed)
288        let mut best_idx = 0;
289        let mut best_score = f64::NEG_INFINITY;
290        for (i, &score) in cv_scores.iter().enumerate() {
291            if score > best_score {
292                best_score = score;
293                best_idx = i;
294            }
295        }
296        let optimal_n_features = best_idx + 1;
297
298        let rfe = RFE::new(importances, optimal_n_features, step)?;
299
300        Ok(Self {
301            rfe,
302            cv_scores: cv_scores.to_vec(),
303            optimal_n_features,
304        })
305    }
306
307    /// Return the CV scores.
308    #[must_use]
309    pub fn cv_scores(&self) -> &[f64] {
310        &self.cv_scores
311    }
312
313    /// Return the optimal number of features.
314    #[must_use]
315    pub fn optimal_n_features(&self) -> usize {
316        self.optimal_n_features
317    }
318
319    /// Return the number of selected features.
320    #[must_use]
321    pub fn n_features_selected(&self) -> usize {
322        self.rfe.n_features_selected()
323    }
324
325    /// Return the feature ranking.
326    #[must_use]
327    pub fn ranking(&self) -> &[usize] {
328        self.rfe.ranking()
329    }
330
331    /// Return the boolean support mask.
332    #[must_use]
333    pub fn support(&self) -> &[bool] {
334        self.rfe.support()
335    }
336
337    /// Return the indices of the selected features.
338    #[must_use]
339    pub fn selected_indices(&self) -> &[usize] {
340        self.rfe.selected_indices()
341    }
342}
343
344impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for RFECV<F> {
345    type Output = Array2<F>;
346    type Error = FerroError;
347
348    /// Return a matrix containing only the optimally selected features.
349    ///
350    /// # Errors
351    ///
352    /// Returns [`FerroError::ShapeMismatch`] if column count does not match.
353    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
354        self.rfe.transform(x)
355    }
356}
357
358// ---------------------------------------------------------------------------
359// Tests
360// ---------------------------------------------------------------------------
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use approx::assert_abs_diff_eq;
366    use ndarray::array;
367
368    // ========================================================================
369    // RFE tests
370    // ========================================================================
371
372    #[test]
373    fn test_rfe_basic_ranking() {
374        // Importances: [0.6, 0.3, 0.1]
375        // Select 1 feature, step 1
376        // Round 1: remove feature 2 (lowest 0.1) → remaining [0, 1]
377        // Round 2: remove feature 1 (lowest 0.3) → remaining [0]
378        let imp = array![0.6, 0.3, 0.1];
379        let rfe = RFE::<f64>::new(&imp, 1, 1).unwrap();
380        assert_eq!(rfe.ranking(), &[1, 2, 3]);
381        assert_eq!(rfe.support(), &[true, false, false]);
382        assert_eq!(rfe.selected_indices(), &[0]);
383    }
384
385    #[test]
386    fn test_rfe_select_two() {
387        let imp = array![0.5, 0.3, 0.2];
388        let rfe = RFE::<f64>::new(&imp, 2, 1).unwrap();
389        assert_eq!(rfe.n_features_selected(), 2);
390        // Feature 2 (0.2) should be eliminated first
391        assert_eq!(rfe.ranking()[2], 2); // eliminated in round 1
392        assert_eq!(rfe.ranking()[0], 1);
393        assert_eq!(rfe.ranking()[1], 1);
394    }
395
396    #[test]
397    fn test_rfe_step_two() {
398        let imp = array![0.5, 0.3, 0.2, 0.1];
399        // Select 2, step 2: remove 2 features at once
400        let rfe = RFE::<f64>::new(&imp, 2, 2).unwrap();
401        assert_eq!(rfe.n_features_selected(), 2);
402        assert!(rfe.support()[0]);
403        assert!(rfe.support()[1]);
404        assert!(!rfe.support()[2]);
405        assert!(!rfe.support()[3]);
406    }
407
408    #[test]
409    fn test_rfe_transform() {
410        let imp = array![0.6, 0.3, 0.1];
411        let rfe = RFE::<f64>::new(&imp, 1, 1).unwrap();
412        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
413        let out = rfe.transform(&x).unwrap();
414        assert_eq!(out.ncols(), 1);
415        // Feature 0 is selected
416        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-15);
417        assert_abs_diff_eq!(out[[1, 0]], 4.0, epsilon = 1e-15);
418    }
419
420    #[test]
421    fn test_rfe_all_features_selected() {
422        let imp = array![0.5, 0.3, 0.2];
423        let rfe = RFE::<f64>::new(&imp, 3, 1).unwrap();
424        assert_eq!(rfe.n_features_selected(), 3);
425        assert!(rfe.support().iter().all(|&s| s));
426    }
427
428    #[test]
429    fn test_rfe_empty_importances_error() {
430        let imp: Array1<f64> = Array1::zeros(0);
431        assert!(RFE::<f64>::new(&imp, 1, 1).is_err());
432    }
433
434    #[test]
435    fn test_rfe_zero_step_error() {
436        let imp = array![0.5, 0.3];
437        assert!(RFE::<f64>::new(&imp, 1, 0).is_err());
438    }
439
440    #[test]
441    fn test_rfe_n_features_too_large_error() {
442        let imp = array![0.5, 0.3];
443        assert!(RFE::<f64>::new(&imp, 5, 1).is_err());
444    }
445
446    #[test]
447    fn test_rfe_n_features_zero_error() {
448        let imp = array![0.5, 0.3];
449        assert!(RFE::<f64>::new(&imp, 0, 1).is_err());
450    }
451
452    #[test]
453    fn test_rfe_shape_mismatch_error() {
454        let imp = array![0.5, 0.3];
455        let rfe = RFE::<f64>::new(&imp, 1, 1).unwrap();
456        let x_bad = array![[1.0, 2.0, 3.0]];
457        assert!(rfe.transform(&x_bad).is_err());
458    }
459
460    // ========================================================================
461    // RFECV tests
462    // ========================================================================
463
464    #[test]
465    fn test_rfecv_selects_optimal() {
466        let imp = array![0.5, 0.3, 0.2];
467        // Best CV score at 2 features
468        let cv_scores = vec![0.85, 0.95, 0.90];
469        let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
470        assert_eq!(rfecv.optimal_n_features(), 2);
471        assert_eq!(rfecv.n_features_selected(), 2);
472    }
473
474    #[test]
475    fn test_rfecv_transform() {
476        let imp = array![0.5, 0.3, 0.2];
477        let cv_scores = vec![0.85, 0.95, 0.90];
478        let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
479        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
480        let out = rfecv.transform(&x).unwrap();
481        assert_eq!(out.ncols(), 2);
482    }
483
484    #[test]
485    fn test_rfecv_cv_scores_accessor() {
486        let imp = array![0.5, 0.3];
487        let cv_scores = vec![0.9, 0.8];
488        let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
489        assert_eq!(rfecv.cv_scores(), &[0.9, 0.8]);
490        // Best is 1 feature (score 0.9)
491        assert_eq!(rfecv.optimal_n_features(), 1);
492    }
493
494    #[test]
495    fn test_rfecv_mismatched_scores_error() {
496        let imp = array![0.5, 0.3, 0.2];
497        let cv_scores = vec![0.85, 0.95]; // wrong length
498        assert!(RFECV::<f64>::new(&imp, &cv_scores, 1).is_err());
499    }
500
501    #[test]
502    fn test_rfecv_empty_importances_error() {
503        let imp: Array1<f64> = Array1::zeros(0);
504        let cv_scores: Vec<f64> = vec![];
505        assert!(RFECV::<f64>::new(&imp, &cv_scores, 1).is_err());
506    }
507
508    #[test]
509    fn test_rfecv_ranking_and_support() {
510        let imp = array![0.5, 0.3, 0.2];
511        let cv_scores = vec![0.80, 0.95, 0.90];
512        let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
513        assert_eq!(rfecv.n_features_selected(), 2);
514        let support = rfecv.support();
515        assert_eq!(support.iter().filter(|&&s| s).count(), 2);
516    }
517}