Skip to main content

ferrolearn_preprocess/
sequential_feature_selector.rs

1//! Sequential feature selection via forward or backward search.
2//!
3//! [`SequentialFeatureSelector`] greedily adds (forward) or removes (backward)
4//! features one at a time, evaluating each candidate subset with a
5//! user-supplied scoring callback.
6//!
7//! # Algorithm
8//!
9//! **Forward**: Start with an empty feature set. At each step, try adding
10//! each remaining feature, evaluate the score, and keep the addition that
11//! yields the highest score. Repeat until `n_features_to_select` features
12//! have been selected.
13//!
14//! **Backward**: Start with all features. At each step, try removing each
15//! remaining feature, evaluate the score, and keep the removal that yields
16//! the highest score. Repeat until `n_features_to_select` features remain.
17
18use ferrolearn_core::error::FerroError;
19use ferrolearn_core::traits::Transform;
20use ndarray::{Array1, Array2};
21use num_traits::Float;
22
23// ---------------------------------------------------------------------------
24// Direction
25// ---------------------------------------------------------------------------
26
27/// Search direction for [`SequentialFeatureSelector`].
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum Direction {
30    /// Start with no features and greedily add one at a time.
31    Forward,
32    /// Start with all features and greedily remove one at a time.
33    Backward,
34}
35
36// ---------------------------------------------------------------------------
37// SequentialFeatureSelector (unfitted)
38// ---------------------------------------------------------------------------
39
40/// A greedy sequential feature selector.
41///
42/// # Examples
43///
44/// ```
45/// use ferrolearn_preprocess::sequential_feature_selector::{
46///     SequentialFeatureSelector, Direction,
47/// };
48/// use ndarray::{array, Array1, Array2};
49///
50/// let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
51/// let x = array![[1.0, 10.0, 0.1],
52///                 [2.0, 20.0, 0.2],
53///                 [3.0, 30.0, 0.3]];
54/// let y = array![1.0, 2.0, 3.0];
55///
56/// // Score function: sum of selected column means (higher is better)
57/// let score_fn = |x_sub: &Array2<f64>, _y: &Array1<f64>| -> Result<f64, _> {
58///     let mean_sum: f64 = x_sub.columns().into_iter()
59///         .map(|c| c.sum() / c.len() as f64)
60///         .sum();
61///     Ok(mean_sum)
62/// };
63///
64/// let fitted = sfs.fit(&x, &y, score_fn).unwrap();
65/// assert_eq!(fitted.selected_indices().len(), 1);
66/// // Feature 1 (column means 10,20,30 → mean 20) should be selected
67/// assert_eq!(fitted.selected_indices(), &[1]);
68/// ```
69#[must_use]
70#[derive(Debug, Clone)]
71pub struct SequentialFeatureSelector {
72    /// Number of features to select.
73    n_features_to_select: usize,
74    /// Search direction (forward or backward).
75    direction: Direction,
76}
77
78impl SequentialFeatureSelector {
79    /// Create a new `SequentialFeatureSelector`.
80    ///
81    /// # Parameters
82    ///
83    /// - `n_features_to_select` — how many features to keep.
84    /// - `direction` — [`Direction::Forward`] or [`Direction::Backward`].
85    pub fn new(n_features_to_select: usize, direction: Direction) -> Self {
86        Self {
87            n_features_to_select,
88            direction,
89        }
90    }
91
92    /// Return the number of features to select.
93    #[must_use]
94    pub fn n_features_to_select(&self) -> usize {
95        self.n_features_to_select
96    }
97
98    /// Return the search direction.
99    #[must_use]
100    pub fn direction(&self) -> Direction {
101        self.direction
102    }
103
104    /// Fit the selector by evaluating feature subsets with a scoring function.
105    ///
106    /// # Parameters
107    ///
108    /// - `x` — the feature matrix (`n_samples x n_features`).
109    /// - `y` — the target vector (`n_samples`).
110    /// - `score_fn` — a callback `(&Array2<F>, &Array1<F>) -> Result<F, FerroError>`
111    ///   that evaluates the quality of a feature subset.
112    ///
113    /// # Errors
114    ///
115    /// - [`FerroError::InvalidParameter`] if `n_features_to_select` is zero or
116    ///   exceeds the number of features.
117    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have mismatched lengths.
118    /// - Propagates errors from `score_fn`.
119    pub fn fit<F: Float + Send + Sync + 'static>(
120        &self,
121        x: &Array2<F>,
122        y: &Array1<F>,
123        score_fn: impl Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
124    ) -> Result<FittedSequentialFeatureSelector<F>, FerroError> {
125        let n_features = x.ncols();
126        let n_samples = x.nrows();
127
128        if self.n_features_to_select == 0 {
129            return Err(FerroError::InvalidParameter {
130                name: "n_features_to_select".into(),
131                reason: "must be at least 1".into(),
132            });
133        }
134        if self.n_features_to_select > n_features {
135            return Err(FerroError::InvalidParameter {
136                name: "n_features_to_select".into(),
137                reason: format!(
138                    "n_features_to_select ({}) exceeds number of features ({})",
139                    self.n_features_to_select, n_features
140                ),
141            });
142        }
143        if n_samples == 0 {
144            return Err(FerroError::InsufficientSamples {
145                required: 1,
146                actual: 0,
147                context: "SequentialFeatureSelector::fit".into(),
148            });
149        }
150        if y.len() != n_samples {
151            return Err(FerroError::ShapeMismatch {
152                expected: vec![n_samples],
153                actual: vec![y.len()],
154                context: "SequentialFeatureSelector::fit — y must match x rows".into(),
155            });
156        }
157
158        let selected_indices = match self.direction {
159            Direction::Forward => self.forward_search(x, y, n_features, &score_fn)?,
160            Direction::Backward => self.backward_search(x, y, n_features, &score_fn)?,
161        };
162
163        Ok(FittedSequentialFeatureSelector {
164            n_features_in: n_features,
165            selected_indices,
166            _marker: std::marker::PhantomData,
167        })
168    }
169
170    /// Forward greedy search.
171    #[allow(clippy::type_complexity)]
172    fn forward_search<F: Float + Send + Sync + 'static>(
173        &self,
174        x: &Array2<F>,
175        y: &Array1<F>,
176        n_features: usize,
177        score_fn: &dyn Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
178    ) -> Result<Vec<usize>, FerroError> {
179        let mut selected: Vec<usize> = Vec::with_capacity(self.n_features_to_select);
180        let mut remaining: Vec<usize> = (0..n_features).collect();
181
182        for _ in 0..self.n_features_to_select {
183            let mut best_score = F::neg_infinity();
184            let mut best_feature = remaining[0];
185
186            for &candidate in &remaining {
187                let mut trial: Vec<usize> = selected.clone();
188                trial.push(candidate);
189                trial.sort_unstable();
190                let x_sub = select_columns(x, &trial);
191                let score = score_fn(&x_sub, y)?;
192                if score > best_score {
193                    best_score = score;
194                    best_feature = candidate;
195                }
196            }
197
198            selected.push(best_feature);
199            remaining.retain(|&f| f != best_feature);
200        }
201
202        selected.sort_unstable();
203        Ok(selected)
204    }
205
206    /// Backward greedy search.
207    #[allow(clippy::type_complexity)]
208    fn backward_search<F: Float + Send + Sync + 'static>(
209        &self,
210        x: &Array2<F>,
211        y: &Array1<F>,
212        n_features: usize,
213        score_fn: &dyn Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
214    ) -> Result<Vec<usize>, FerroError> {
215        let mut remaining: Vec<usize> = (0..n_features).collect();
216
217        while remaining.len() > self.n_features_to_select {
218            let mut best_score = F::neg_infinity();
219            let mut worst_feature = remaining[0];
220
221            for &candidate in &remaining {
222                // Try removing this feature
223                let trial: Vec<usize> = remaining
224                    .iter()
225                    .copied()
226                    .filter(|&f| f != candidate)
227                    .collect();
228                let x_sub = select_columns(x, &trial);
229                let score = score_fn(&x_sub, y)?;
230                if score > best_score {
231                    best_score = score;
232                    worst_feature = candidate;
233                }
234            }
235
236            remaining.retain(|&f| f != worst_feature);
237        }
238
239        remaining.sort_unstable();
240        Ok(remaining)
241    }
242}
243
244// ---------------------------------------------------------------------------
245// FittedSequentialFeatureSelector
246// ---------------------------------------------------------------------------
247
248/// A fitted sequential feature selector holding the selected feature indices.
249///
250/// Created by calling [`SequentialFeatureSelector::fit`].
251#[derive(Debug, Clone)]
252pub struct FittedSequentialFeatureSelector<F> {
253    /// Number of features seen during fitting.
254    n_features_in: usize,
255    /// Indices of the selected columns (sorted).
256    selected_indices: Vec<usize>,
257    _marker: std::marker::PhantomData<F>,
258}
259
260impl<F: Float + Send + Sync + 'static> FittedSequentialFeatureSelector<F> {
261    /// Return the indices of the selected features.
262    #[must_use]
263    pub fn selected_indices(&self) -> &[usize] {
264        &self.selected_indices
265    }
266
267    /// Return the number of selected features.
268    #[must_use]
269    pub fn n_features_selected(&self) -> usize {
270        self.selected_indices.len()
271    }
272}
273
274impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSequentialFeatureSelector<F> {
275    type Output = Array2<F>;
276    type Error = FerroError;
277
278    /// Return a matrix containing only the selected columns.
279    ///
280    /// # Errors
281    ///
282    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
283    /// from the number of features seen during fitting.
284    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
285        if x.ncols() != self.n_features_in {
286            return Err(FerroError::ShapeMismatch {
287                expected: vec![x.nrows(), self.n_features_in],
288                actual: vec![x.nrows(), x.ncols()],
289                context: "FittedSequentialFeatureSelector::transform".into(),
290            });
291        }
292        Ok(select_columns(x, &self.selected_indices))
293    }
294}
295
296/// Build a new `Array2<F>` containing only the columns listed in `indices`.
297fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
298    let nrows = x.nrows();
299    let ncols = indices.len();
300    if ncols == 0 {
301        return Array2::zeros((nrows, 0));
302    }
303    let mut out = Array2::zeros((nrows, ncols));
304    for (new_j, &old_j) in indices.iter().enumerate() {
305        for i in 0..nrows {
306            out[[i, new_j]] = x[[i, old_j]];
307        }
308    }
309    out
310}
311
312// ---------------------------------------------------------------------------
313// Tests
314// ---------------------------------------------------------------------------
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use approx::assert_abs_diff_eq;
320    use ndarray::array;
321
322    /// Score function: sum of column means (higher is better).
323    fn mean_sum_score(x: &Array2<f64>, _y: &Array1<f64>) -> Result<f64, FerroError> {
324        let score: f64 = x
325            .columns()
326            .into_iter()
327            .map(|c| c.sum() / c.len() as f64)
328            .sum();
329        Ok(score)
330    }
331
332    #[test]
333    fn test_forward_selects_best() {
334        let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
335        let x = array![[1.0, 10.0, 0.1], [2.0, 20.0, 0.2], [3.0, 30.0, 0.3]];
336        let y = array![1.0, 2.0, 3.0];
337        let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
338        assert_eq!(fitted.selected_indices(), &[1]); // col 1 has highest mean
339    }
340
341    #[test]
342    fn test_forward_select_two() {
343        let sfs = SequentialFeatureSelector::new(2, Direction::Forward);
344        let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
345        let y = array![1.0, 2.0];
346        let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
347        assert_eq!(fitted.n_features_selected(), 2);
348        // Top 2 by mean: col 2 (150.0) and col 1 (15.0)
349        assert!(fitted.selected_indices().contains(&1));
350        assert!(fitted.selected_indices().contains(&2));
351    }
352
353    #[test]
354    fn test_backward_selects_best() {
355        let sfs = SequentialFeatureSelector::new(1, Direction::Backward);
356        let x = array![[1.0, 10.0, 0.1], [2.0, 20.0, 0.2], [3.0, 30.0, 0.3]];
357        let y = array![1.0, 2.0, 3.0];
358        let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
359        // Backward: remove 2 features. With sum-of-means score, removing
360        // the smallest contributors yields col 1 remaining
361        assert_eq!(fitted.selected_indices(), &[1]);
362    }
363
364    #[test]
365    fn test_backward_select_two() {
366        let sfs = SequentialFeatureSelector::new(2, Direction::Backward);
367        let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
368        let y = array![1.0, 2.0];
369        let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
370        assert_eq!(fitted.n_features_selected(), 2);
371        // Remove col 0 (lowest mean), keep 1 and 2
372        assert_eq!(fitted.selected_indices(), &[1, 2]);
373    }
374
375    #[test]
376    fn test_select_all_features() {
377        let sfs = SequentialFeatureSelector::new(3, Direction::Forward);
378        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
379        let y = array![1.0, 2.0];
380        let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
381        assert_eq!(fitted.n_features_selected(), 3);
382    }
383
384    #[test]
385    fn test_transform() {
386        let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
387        let x = array![[1.0, 10.0], [2.0, 20.0]];
388        let y = array![1.0, 2.0];
389        let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
390        let out = fitted.transform(&x).unwrap();
391        assert_eq!(out.ncols(), 1);
392        assert_abs_diff_eq!(out[[0, 0]], 10.0, epsilon = 1e-15);
393        assert_abs_diff_eq!(out[[1, 0]], 20.0, epsilon = 1e-15);
394    }
395
396    #[test]
397    fn test_zero_features_error() {
398        let sfs = SequentialFeatureSelector::new(0, Direction::Forward);
399        let x = array![[1.0, 2.0]];
400        let y = array![1.0];
401        assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
402    }
403
404    #[test]
405    fn test_too_many_features_error() {
406        let sfs = SequentialFeatureSelector::new(5, Direction::Forward);
407        let x = array![[1.0, 2.0]];
408        let y = array![1.0];
409        assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
410    }
411
412    #[test]
413    fn test_zero_rows_error() {
414        let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
415        let x: Array2<f64> = Array2::zeros((0, 3));
416        let y: Array1<f64> = Array1::zeros(0);
417        assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
418    }
419
420    #[test]
421    fn test_y_length_mismatch() {
422        let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
423        let x = array![[1.0, 2.0], [3.0, 4.0]];
424        let y = array![1.0]; // wrong length
425        assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
426    }
427
428    #[test]
429    fn test_shape_mismatch_on_transform() {
430        let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
431        let x = array![[1.0, 2.0], [3.0, 4.0]];
432        let y = array![1.0, 2.0];
433        let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
434        let x_bad = array![[1.0, 2.0, 3.0]];
435        assert!(fitted.transform(&x_bad).is_err());
436    }
437
438    #[test]
439    fn test_score_fn_error_propagated() {
440        let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
441        let x = array![[1.0, 2.0]];
442        let y = array![1.0];
443        let bad_fn = |_x: &Array2<f64>, _y: &Array1<f64>| -> Result<f64, FerroError> {
444            Err(FerroError::NumericalInstability {
445                message: "test error".into(),
446            })
447        };
448        assert!(sfs.fit(&x, &y, bad_fn).is_err());
449    }
450
451    #[test]
452    fn test_indices_sorted() {
453        let sfs = SequentialFeatureSelector::new(2, Direction::Forward);
454        let x = array![[100.0, 1.0, 10.0], [200.0, 2.0, 20.0]];
455        let y = array![1.0, 2.0];
456        let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
457        let indices = fitted.selected_indices();
458        assert!(indices.windows(2).all(|w| w[0] < w[1]));
459    }
460
461    #[test]
462    fn test_accessors() {
463        let sfs = SequentialFeatureSelector::new(2, Direction::Backward);
464        assert_eq!(sfs.n_features_to_select(), 2);
465        assert_eq!(sfs.direction(), Direction::Backward);
466    }
467}