Skip to main content

anofox_ml_preprocessing/
rfe.rs

1//! Recursive Feature Elimination (RFE).
2//!
3//! Mirrors `sklearn.feature_selection.RFE` with a callback-based API: the
4//! caller provides a function that, given `(X, y)`, returns per-feature
5//! importances (e.g. `|coef_|` for linear models or `feature_importances_`
6//! for trees). RFE repeatedly drops the `step` least-important features
7//! until `n_features_to_select` remain.
8
9use anofox_ml_core::{Result, RustMlError};
10use ndarray::{Array1, Array2};
11
12pub type ImportanceFn = dyn Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>> + Send + Sync;
13
14pub struct Rfe {
15    pub n_features_to_select: usize,
16    pub step: usize,
17    importance: Box<ImportanceFn>,
18}
19
20impl Rfe {
21    pub fn new<F>(n_features_to_select: usize, importance_fn: F) -> Self
22    where
23        F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>> + Send + Sync + 'static,
24    {
25        Self {
26            n_features_to_select,
27            step: 1,
28            importance: Box::new(importance_fn),
29        }
30    }
31
32    pub fn with_step(mut self, step: usize) -> Self {
33        self.step = step;
34        self
35    }
36}
37
38pub struct FittedRfe {
39    /// Boolean mask of selected features, length = original n_features.
40    pub support: Vec<bool>,
41    /// Ranking of features (1 = selected; higher = dropped earlier).
42    pub ranking: Vec<usize>,
43}
44
45impl FittedRfe {
46    pub fn transform(&self, x: &Array2<f64>) -> Array2<f64> {
47        let cols: Vec<usize> = self
48            .support
49            .iter()
50            .enumerate()
51            .filter(|(_, &b)| b)
52            .map(|(i, _)| i)
53            .collect();
54        let mut out = Array2::<f64>::zeros((x.nrows(), cols.len()));
55        for (k, &c) in cols.iter().enumerate() {
56            for i in 0..x.nrows() {
57                out[[i, k]] = x[[i, c]];
58            }
59        }
60        out
61    }
62}
63
64fn select_cols(x: &Array2<f64>, cols: &[usize]) -> Array2<f64> {
65    let mut out = Array2::<f64>::zeros((x.nrows(), cols.len()));
66    for (k, &c) in cols.iter().enumerate() {
67        for i in 0..x.nrows() {
68            out[[i, k]] = x[[i, c]];
69        }
70    }
71    out
72}
73
74impl Rfe {
75    pub fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<FittedRfe> {
76        if x.nrows() != y.len() {
77            return Err(RustMlError::ShapeMismatch(format!(
78                "X has {} rows but y has {}",
79                x.nrows(),
80                y.len()
81            )));
82        }
83        let d = x.ncols();
84        if self.n_features_to_select == 0 || self.n_features_to_select > d {
85            return Err(RustMlError::InvalidParameter(format!(
86                "n_features_to_select must be in 1..={}",
87                d
88            )));
89        }
90        let mut active: Vec<usize> = (0..d).collect();
91        let mut ranking = vec![0usize; d];
92
93        while active.len() > self.n_features_to_select {
94            let sub = select_cols(x, &active);
95            let imp = (self.importance)(&sub, y)?;
96            if imp.len() != active.len() {
97                return Err(RustMlError::InvalidParameter(
98                    "importance function returned wrong length".into(),
99                ));
100            }
101            // Sort active features by ascending importance; drop step
102            // least-important without going below the target.
103            let n_drop = self.step.min(active.len() - self.n_features_to_select);
104            let mut order: Vec<usize> = (0..active.len()).collect();
105            order.sort_by(|&a, &b| imp[a].abs().partial_cmp(&imp[b].abs()).unwrap());
106            let to_drop: Vec<usize> = order[..n_drop].iter().map(|&i| active[i]).collect();
107            for &j in &to_drop {
108                ranking[j] = active.len(); // ranking number set at time of drop
109            }
110            active.retain(|i| !to_drop.contains(i));
111        }
112        for &j in &active {
113            ranking[j] = 1;
114        }
115        let mut support = vec![false; d];
116        for &j in &active {
117            support[j] = true;
118        }
119        Ok(FittedRfe { support, ranking })
120    }
121}
122
123// ---------------------------------------------------------------------------
124// RFECV — CV-aware wrapper around Rfe that auto-selects n_features_to_select.
125// ---------------------------------------------------------------------------
126
127/// Recursive Feature Elimination with Cross-Validated selection of the
128/// optimal number of features.
129///
130/// Mirrors `sklearn.feature_selection.RFECV`. For each candidate
131/// `n_features_to_select` in `min..=n_features`, runs `Rfe` on each k-fold
132/// split, scores on the held-out fold, averages — picks the size with the
133/// highest mean CV score, then runs RFE on the full data to that size.
134pub struct Rfecv {
135    pub min_features_to_select: usize,
136    pub step: usize,
137    pub cv_folds: usize,
138    importance: Box<ImportanceFn>,
139    score: Box<ScoringFn>,
140}
141
142impl Rfecv {
143    pub fn new<I, S>(min_features_to_select: usize, importance_fn: I, score: S) -> Self
144    where
145        I: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>> + Send + Sync + 'static,
146        S: Fn(&Array2<f64>, &Array1<f64>) -> Result<f64> + Send + Sync + 'static,
147    {
148        Self {
149            min_features_to_select,
150            step: 1,
151            cv_folds: 5,
152            importance: Box::new(importance_fn),
153            score: Box::new(score),
154        }
155    }
156    pub fn with_cv_folds(mut self, k: usize) -> Self {
157        self.cv_folds = k;
158        self
159    }
160    pub fn with_step(mut self, step: usize) -> Self {
161        self.step = step;
162        self
163    }
164
165    pub fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<FittedRfecv> {
166        let n = x.nrows();
167        let d = x.ncols();
168        let k = self.cv_folds.min(n);
169        let folds = kfold(n, k);
170
171        // For each candidate size, do CV.
172        let mut mean_scores = Vec::with_capacity(d - self.min_features_to_select + 1);
173        let mut sizes = Vec::new();
174        for size in self.min_features_to_select..=d {
175            let mut scores = Vec::with_capacity(k);
176            for (train_idx, test_idx) in &folds {
177                let x_train = select_rows(x, train_idx);
178                let y_train = select_elements(y, train_idx);
179                let x_test = select_rows(x, test_idx);
180                let y_test = select_elements(y, test_idx);
181
182                // Inline RFE on this fold using our owned importance closure.
183                let support = run_rfe(
184                    &x_train,
185                    &y_train,
186                    size,
187                    self.step,
188                    self.importance.as_ref(),
189                )?;
190                let x_test_sel = select_cols_mask(&x_test, &support);
191                let s = (self.score)(&x_test_sel, &y_test)?;
192                scores.push(s);
193            }
194            let mean = scores.iter().sum::<f64>() / scores.len() as f64;
195            mean_scores.push(mean);
196            sizes.push(size);
197        }
198
199        let (best_i, _) = mean_scores
200            .iter()
201            .enumerate()
202            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
203            .unwrap();
204        let best_size = sizes[best_i];
205
206        let support = run_rfe(x, y, best_size, self.step, self.importance.as_ref())?;
207        let mut ranking = vec![0usize; d];
208        for (i, &b) in support.iter().enumerate() {
209            ranking[i] = if b { 1 } else { 2 };
210        }
211        Ok(FittedRfecv {
212            n_features_selected: best_size,
213            cv_scores: mean_scores,
214            sizes,
215            inner: FittedRfe { support, ranking },
216        })
217    }
218}
219
220/// Core RFE elimination loop, factored out so Rfecv can call it without
221/// constructing a fresh `Rfe` (which would conflict with closure lifetimes).
222fn run_rfe(
223    x: &Array2<f64>,
224    y: &Array1<f64>,
225    n_features_to_select: usize,
226    step: usize,
227    importance: &ImportanceFn,
228) -> Result<Vec<bool>> {
229    let d = x.ncols();
230    let mut active: Vec<usize> = (0..d).collect();
231    while active.len() > n_features_to_select {
232        let sub = select_cols(x, &active);
233        let imp = importance(&sub, y)?;
234        let n_drop = step.min(active.len() - n_features_to_select);
235        let mut order: Vec<usize> = (0..active.len()).collect();
236        order.sort_by(|&a, &b| imp[a].abs().partial_cmp(&imp[b].abs()).unwrap());
237        let to_drop: Vec<usize> = order[..n_drop].iter().map(|&i| active[i]).collect();
238        active.retain(|i| !to_drop.contains(i));
239    }
240    let mut support = vec![false; d];
241    for &j in &active {
242        support[j] = true;
243    }
244    Ok(support)
245}
246
247fn select_cols_mask(x: &Array2<f64>, mask: &[bool]) -> Array2<f64> {
248    let cols: Vec<usize> = mask
249        .iter()
250        .enumerate()
251        .filter(|(_, &b)| b)
252        .map(|(i, _)| i)
253        .collect();
254    let mut out = Array2::<f64>::zeros((x.nrows(), cols.len()));
255    for (k, &c) in cols.iter().enumerate() {
256        for i in 0..x.nrows() {
257            out[[i, k]] = x[[i, c]];
258        }
259    }
260    out
261}
262
263pub struct FittedRfecv {
264    pub n_features_selected: usize,
265    pub cv_scores: Vec<f64>,
266    pub sizes: Vec<usize>,
267    pub inner: FittedRfe,
268}
269
270impl FittedRfecv {
271    pub fn transform(&self, x: &Array2<f64>) -> Array2<f64> {
272        self.inner.transform(x)
273    }
274}
275
276fn kfold(n: usize, k: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
277    let fold_size = n / k;
278    let rem = n % k;
279    let mut folds = Vec::with_capacity(k);
280    let mut start = 0;
281    for f in 0..k {
282        let end = start + fold_size + if f < rem { 1 } else { 0 };
283        let test: Vec<usize> = (start..end).collect();
284        let train: Vec<usize> = (0..start).chain(end..n).collect();
285        folds.push((train, test));
286        start = end;
287    }
288    folds
289}
290
291fn select_rows(x: &Array2<f64>, idx: &[usize]) -> Array2<f64> {
292    let mut out = Array2::<f64>::zeros((idx.len(), x.ncols()));
293    for (k, &i) in idx.iter().enumerate() {
294        for j in 0..x.ncols() {
295            out[[k, j]] = x[[i, j]];
296        }
297    }
298    out
299}
300
301fn select_elements(y: &Array1<f64>, idx: &[usize]) -> Array1<f64> {
302    Array1::from_vec(idx.iter().map(|&i| y[i]).collect())
303}
304
305// ---------------------------------------------------------------------------
306// SequentialFeatureSelector (forward direction only, with CV scoring)
307// ---------------------------------------------------------------------------
308
309pub type ScoringFn = dyn Fn(&Array2<f64>, &Array1<f64>) -> Result<f64> + Send + Sync;
310
311pub struct SequentialFeatureSelector {
312    pub n_features_to_select: usize,
313    score: Box<ScoringFn>,
314}
315
316impl SequentialFeatureSelector {
317    pub fn new<F>(n_features_to_select: usize, score: F) -> Self
318    where
319        F: Fn(&Array2<f64>, &Array1<f64>) -> Result<f64> + Send + Sync + 'static,
320    {
321        Self {
322            n_features_to_select,
323            score: Box::new(score),
324        }
325    }
326}
327
328pub struct FittedSequentialFeatureSelector {
329    pub support: Vec<bool>,
330}
331
332impl FittedSequentialFeatureSelector {
333    pub fn transform(&self, x: &Array2<f64>) -> Array2<f64> {
334        let cols: Vec<usize> = self
335            .support
336            .iter()
337            .enumerate()
338            .filter(|(_, &b)| b)
339            .map(|(i, _)| i)
340            .collect();
341        let mut out = Array2::<f64>::zeros((x.nrows(), cols.len()));
342        for (k, &c) in cols.iter().enumerate() {
343            for i in 0..x.nrows() {
344                out[[i, k]] = x[[i, c]];
345            }
346        }
347        out
348    }
349}
350
351impl SequentialFeatureSelector {
352    pub fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<FittedSequentialFeatureSelector> {
353        let d = x.ncols();
354        if self.n_features_to_select == 0 || self.n_features_to_select > d {
355            return Err(RustMlError::InvalidParameter("invalid k".into()));
356        }
357        let mut selected: Vec<usize> = Vec::with_capacity(self.n_features_to_select);
358        let mut remaining: Vec<usize> = (0..d).collect();
359
360        while selected.len() < self.n_features_to_select {
361            let mut best_score = f64::NEG_INFINITY;
362            let mut best_j = remaining[0];
363            for (i, &j) in remaining.iter().enumerate() {
364                let mut cols = selected.clone();
365                cols.push(j);
366                let sub = select_cols(x, &cols);
367                let s = (self.score)(&sub, y)?;
368                if s > best_score {
369                    best_score = s;
370                    best_j = j;
371                    let _ = i; // silence
372                }
373            }
374            selected.push(best_j);
375            remaining.retain(|&j| j != best_j);
376        }
377
378        let mut support = vec![false; d];
379        for &j in &selected {
380            support[j] = true;
381        }
382        Ok(FittedSequentialFeatureSelector { support })
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use ndarray::array;
390
391    fn dummy_importance(x: &Array2<f64>, y: &Array1<f64>) -> Result<Array1<f64>> {
392        // Use absolute correlation of each column with y.
393        let n = x.nrows() as f64;
394        let mut out = Array1::<f64>::zeros(x.ncols());
395        let y_mean = y.sum() / n;
396        for j in 0..x.ncols() {
397            let m = x.column(j).sum() / n;
398            let mut num = 0.0;
399            let mut sx = 0.0;
400            let mut sy = 0.0;
401            for i in 0..x.nrows() {
402                let dx = x[[i, j]] - m;
403                let dy = y[i] - y_mean;
404                num += dx * dy;
405                sx += dx * dx;
406                sy += dy * dy;
407            }
408            let den = (sx * sy).sqrt().max(1e-12);
409            out[j] = (num / den).abs();
410        }
411        Ok(out)
412    }
413
414    #[test]
415    fn test_rfe_keeps_correlated_features() {
416        // y = 3x0 + 0*x1 + 2*x2 + 0*x3
417        let n = 40;
418        let mut xv = Vec::new();
419        let mut yv = Vec::new();
420        for i in 0..n {
421            let x0 = (i as f64) - 20.0;
422            let x1 = ((i * 11 % 13) as f64) - 6.0;
423            let x2 = ((i * 7 % 17) as f64) - 8.0;
424            let x3 = ((i * 5 % 11) as f64) - 5.0;
425            xv.extend([x0, x1, x2, x3]);
426            yv.push(3.0 * x0 + 2.0 * x2);
427        }
428        let x = Array2::from_shape_vec((n, 4), xv).unwrap();
429        let y = Array1::from_vec(yv);
430
431        let rfe = Rfe::new(2, dummy_importance);
432        let fitted = rfe.fit(&x, &y).unwrap();
433        assert!(fitted.support[0]);
434        assert!(fitted.support[2]);
435        assert!(!fitted.support[1]);
436        assert!(!fitted.support[3]);
437        let _ = array![1.0_f64];
438    }
439
440    #[test]
441    fn test_rfecv_finds_2_informative_features() {
442        // Same dataset as RFE test, but RFECV must auto-discover that
443        // n_features_to_select=2 is best.
444        let n = 60;
445        let mut xv = Vec::new();
446        let mut yv = Vec::new();
447        for i in 0..n {
448            let x0 = (i as f64) - 30.0;
449            let x1 = ((i * 11 % 13) as f64) - 6.0;
450            let x2 = ((i * 7 % 17) as f64) - 8.0;
451            let x3 = ((i * 5 % 11) as f64) - 5.0;
452            xv.extend([x0, x1, x2, x3]);
453            yv.push(3.0 * x0 + 2.0 * x2);
454        }
455        let x = Array2::from_shape_vec((n, 4), xv).unwrap();
456        let y = Array1::from_vec(yv);
457
458        // Scoring: 1 - (rss / tss) on test data using a simple least-squares
459        // refit on selected features.
460        let score_fn = |xs: &Array2<f64>, ys: &Array1<f64>| -> Result<f64> {
461            // Center y, then close-form OLS.
462            let n = xs.nrows() as f64;
463            let y_mean = ys.sum() / n.max(1.0);
464            let yc = ys.mapv(|v| v - y_mean);
465            // OLS via normal equations.
466            let m = xs.ncols();
467            let mut xtx = Array2::<f64>::zeros((m, m));
468            let mut xty = Array1::<f64>::zeros(m);
469            for i in 0..m {
470                for j in 0..m {
471                    let mut s = 0.0;
472                    for k in 0..xs.nrows() {
473                        s += xs[[k, i]] * xs[[k, j]];
474                    }
475                    xtx[[i, j]] = s;
476                }
477                xtx[[i, i]] += 1e-9; // tiny ridge for stability
478                let mut s = 0.0;
479                for k in 0..xs.nrows() {
480                    s += xs[[k, i]] * yc[k];
481                }
482                xty[i] = s;
483            }
484            // Solve via Gauss elim.
485            let mut a = xtx.clone();
486            let mut b = xty.clone();
487            for col in 0..m {
488                let pv = a[[col, col]];
489                if pv.abs() < 1e-14 {
490                    continue;
491                }
492                for r in (col + 1)..m {
493                    let f = a[[r, col]] / pv;
494                    for c in col..m {
495                        a[[r, c]] -= f * a[[col, c]];
496                    }
497                    b[r] -= f * b[col];
498                }
499            }
500            let mut beta = Array1::<f64>::zeros(m);
501            for r in (0..m).rev() {
502                let mut s = b[r];
503                for c in (r + 1)..m {
504                    s -= a[[r, c]] * beta[c];
505                }
506                let pv = a[[r, r]];
507                if pv.abs() > 1e-14 {
508                    beta[r] = s / pv;
509                }
510            }
511            let mut pred = Array1::<f64>::zeros(xs.nrows());
512            for i in 0..xs.nrows() {
513                let mut p = y_mean;
514                for j in 0..m {
515                    p += xs[[i, j]] * beta[j];
516                }
517                pred[i] = p;
518            }
519            let rss: f64 = pred
520                .iter()
521                .zip(ys.iter())
522                .map(|(p, t)| (p - t).powi(2))
523                .sum();
524            let tss: f64 = ys.iter().map(|t| (t - y_mean).powi(2)).sum();
525            Ok(1.0 - rss / tss.max(1e-12))
526        };
527
528        let rfecv = Rfecv::new(1, dummy_importance, score_fn).with_cv_folds(3);
529        let fitted = rfecv.fit(&x, &y).unwrap();
530        // Without noise penalty, more features often "wins" on test R²; just
531        // assert ≥2 features were kept and both informative features are in.
532        assert!(fitted.n_features_selected >= 2);
533        assert!(fitted.inner.support[0]);
534        assert!(fitted.inner.support[2]);
535        assert_eq!(fitted.cv_scores.len(), 4); // size 1..=4
536    }
537}