use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::Transform;
use ndarray::{Array1, Array2};
use num_traits::Float;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Direction {
Forward,
Backward,
}
#[must_use]
#[derive(Debug, Clone)]
pub struct SequentialFeatureSelector {
n_features_to_select: usize,
direction: Direction,
}
impl SequentialFeatureSelector {
pub fn new(n_features_to_select: usize, direction: Direction) -> Self {
Self {
n_features_to_select,
direction,
}
}
#[must_use]
pub fn n_features_to_select(&self) -> usize {
self.n_features_to_select
}
#[must_use]
pub fn direction(&self) -> Direction {
self.direction
}
pub fn fit<F: Float + Send + Sync + 'static>(
&self,
x: &Array2<F>,
y: &Array1<F>,
score_fn: impl Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
) -> Result<FittedSequentialFeatureSelector<F>, FerroError> {
let n_features = x.ncols();
let n_samples = x.nrows();
if self.n_features_to_select == 0 {
return Err(FerroError::InvalidParameter {
name: "n_features_to_select".into(),
reason: "must be at least 1".into(),
});
}
if self.n_features_to_select > n_features {
return Err(FerroError::InvalidParameter {
name: "n_features_to_select".into(),
reason: format!(
"n_features_to_select ({}) exceeds number of features ({})",
self.n_features_to_select, n_features
),
});
}
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "SequentialFeatureSelector::fit".into(),
});
}
if y.len() != n_samples {
return Err(FerroError::ShapeMismatch {
expected: vec![n_samples],
actual: vec![y.len()],
context: "SequentialFeatureSelector::fit — y must match x rows".into(),
});
}
let selected_indices = match self.direction {
Direction::Forward => {
self.forward_search(x, y, n_features, &score_fn)?
}
Direction::Backward => {
self.backward_search(x, y, n_features, &score_fn)?
}
};
Ok(FittedSequentialFeatureSelector {
n_features_in: n_features,
selected_indices,
_marker: std::marker::PhantomData,
})
}
#[allow(clippy::type_complexity)]
fn forward_search<F: Float + Send + Sync + 'static>(
&self,
x: &Array2<F>,
y: &Array1<F>,
n_features: usize,
score_fn: &dyn Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
) -> Result<Vec<usize>, FerroError> {
let mut selected: Vec<usize> = Vec::with_capacity(self.n_features_to_select);
let mut remaining: Vec<usize> = (0..n_features).collect();
for _ in 0..self.n_features_to_select {
let mut best_score = F::neg_infinity();
let mut best_feature = remaining[0];
for &candidate in &remaining {
let mut trial: Vec<usize> = selected.clone();
trial.push(candidate);
trial.sort_unstable();
let x_sub = select_columns(x, &trial);
let score = score_fn(&x_sub, y)?;
if score > best_score {
best_score = score;
best_feature = candidate;
}
}
selected.push(best_feature);
remaining.retain(|&f| f != best_feature);
}
selected.sort_unstable();
Ok(selected)
}
#[allow(clippy::type_complexity)]
fn backward_search<F: Float + Send + Sync + 'static>(
&self,
x: &Array2<F>,
y: &Array1<F>,
n_features: usize,
score_fn: &dyn Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
) -> Result<Vec<usize>, FerroError> {
let mut remaining: Vec<usize> = (0..n_features).collect();
while remaining.len() > self.n_features_to_select {
let mut best_score = F::neg_infinity();
let mut worst_feature = remaining[0];
for &candidate in &remaining {
let trial: Vec<usize> = remaining
.iter()
.copied()
.filter(|&f| f != candidate)
.collect();
let x_sub = select_columns(x, &trial);
let score = score_fn(&x_sub, y)?;
if score > best_score {
best_score = score;
worst_feature = candidate;
}
}
remaining.retain(|&f| f != worst_feature);
}
remaining.sort_unstable();
Ok(remaining)
}
}
#[derive(Debug, Clone)]
pub struct FittedSequentialFeatureSelector<F> {
n_features_in: usize,
selected_indices: Vec<usize>,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float + Send + Sync + 'static> FittedSequentialFeatureSelector<F> {
#[must_use]
pub fn selected_indices(&self) -> &[usize] {
&self.selected_indices
}
#[must_use]
pub fn n_features_selected(&self) -> usize {
self.selected_indices.len()
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>>
for FittedSequentialFeatureSelector<F>
{
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
if x.ncols() != self.n_features_in {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows(), self.n_features_in],
actual: vec![x.nrows(), x.ncols()],
context: "FittedSequentialFeatureSelector::transform".into(),
});
}
Ok(select_columns(x, &self.selected_indices))
}
}
fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
let nrows = x.nrows();
let ncols = indices.len();
if ncols == 0 {
return Array2::zeros((nrows, 0));
}
let mut out = Array2::zeros((nrows, ncols));
for (new_j, &old_j) in indices.iter().enumerate() {
for i in 0..nrows {
out[[i, new_j]] = x[[i, old_j]];
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
fn mean_sum_score(x: &Array2<f64>, _y: &Array1<f64>) -> Result<f64, FerroError> {
let score: f64 = x
.columns()
.into_iter()
.map(|c| c.sum() / c.len() as f64)
.sum();
Ok(score)
}
#[test]
fn test_forward_selects_best() {
let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
let x = array![[1.0, 10.0, 0.1], [2.0, 20.0, 0.2], [3.0, 30.0, 0.3]];
let y = array![1.0, 2.0, 3.0];
let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
assert_eq!(fitted.selected_indices(), &[1]); }
#[test]
fn test_forward_select_two() {
let sfs = SequentialFeatureSelector::new(2, Direction::Forward);
let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
let y = array![1.0, 2.0];
let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
assert_eq!(fitted.n_features_selected(), 2);
assert!(fitted.selected_indices().contains(&1));
assert!(fitted.selected_indices().contains(&2));
}
#[test]
fn test_backward_selects_best() {
let sfs = SequentialFeatureSelector::new(1, Direction::Backward);
let x = array![[1.0, 10.0, 0.1], [2.0, 20.0, 0.2], [3.0, 30.0, 0.3]];
let y = array![1.0, 2.0, 3.0];
let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
assert_eq!(fitted.selected_indices(), &[1]);
}
#[test]
fn test_backward_select_two() {
let sfs = SequentialFeatureSelector::new(2, Direction::Backward);
let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
let y = array![1.0, 2.0];
let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
assert_eq!(fitted.n_features_selected(), 2);
assert_eq!(fitted.selected_indices(), &[1, 2]);
}
#[test]
fn test_select_all_features() {
let sfs = SequentialFeatureSelector::new(3, Direction::Forward);
let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let y = array![1.0, 2.0];
let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
assert_eq!(fitted.n_features_selected(), 3);
}
#[test]
fn test_transform() {
let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
let x = array![[1.0, 10.0], [2.0, 20.0]];
let y = array![1.0, 2.0];
let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
let out = fitted.transform(&x).unwrap();
assert_eq!(out.ncols(), 1);
assert_abs_diff_eq!(out[[0, 0]], 10.0, epsilon = 1e-15);
assert_abs_diff_eq!(out[[1, 0]], 20.0, epsilon = 1e-15);
}
#[test]
fn test_zero_features_error() {
let sfs = SequentialFeatureSelector::new(0, Direction::Forward);
let x = array![[1.0, 2.0]];
let y = array![1.0];
assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
}
#[test]
fn test_too_many_features_error() {
let sfs = SequentialFeatureSelector::new(5, Direction::Forward);
let x = array![[1.0, 2.0]];
let y = array![1.0];
assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
}
#[test]
fn test_zero_rows_error() {
let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
let x: Array2<f64> = Array2::zeros((0, 3));
let y: Array1<f64> = Array1::zeros(0);
assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
}
#[test]
fn test_y_length_mismatch() {
let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
let x = array![[1.0, 2.0], [3.0, 4.0]];
let y = array![1.0]; assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
}
#[test]
fn test_shape_mismatch_on_transform() {
let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
let x = array![[1.0, 2.0], [3.0, 4.0]];
let y = array![1.0, 2.0];
let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
let x_bad = array![[1.0, 2.0, 3.0]];
assert!(fitted.transform(&x_bad).is_err());
}
#[test]
fn test_score_fn_error_propagated() {
let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
let x = array![[1.0, 2.0]];
let y = array![1.0];
let bad_fn = |_x: &Array2<f64>, _y: &Array1<f64>| -> Result<f64, FerroError> {
Err(FerroError::NumericalInstability {
message: "test error".into(),
})
};
assert!(sfs.fit(&x, &y, bad_fn).is_err());
}
#[test]
fn test_indices_sorted() {
let sfs = SequentialFeatureSelector::new(2, Direction::Forward);
let x = array![[100.0, 1.0, 10.0], [200.0, 2.0, 20.0]];
let y = array![1.0, 2.0];
let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
let indices = fitted.selected_indices();
assert!(indices.windows(2).all(|w| w[0] < w[1]));
}
#[test]
fn test_accessors() {
let sfs = SequentialFeatureSelector::new(2, Direction::Backward);
assert_eq!(sfs.n_features_to_select(), 2);
assert_eq!(sfs.direction(), Direction::Backward);
}
}