1use ferrolearn_core::error::FerroError;
43use ferrolearn_core::traits::Transform;
44use ndarray::{Array1, Array2};
45use num_traits::Float;
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum Direction {
54 Forward,
56 Backward,
58}
59
60#[must_use]
94#[derive(Debug, Clone)]
95pub struct SequentialFeatureSelector {
96 n_features_to_select: usize,
98 direction: Direction,
100}
101
102impl SequentialFeatureSelector {
103 pub fn new(n_features_to_select: usize, direction: Direction) -> Self {
110 Self {
111 n_features_to_select,
112 direction,
113 }
114 }
115
116 #[must_use]
118 pub fn n_features_to_select(&self) -> usize {
119 self.n_features_to_select
120 }
121
122 #[must_use]
124 pub fn direction(&self) -> Direction {
125 self.direction
126 }
127
128 pub fn fit<F: Float + Send + Sync + 'static>(
144 &self,
145 x: &Array2<F>,
146 y: &Array1<F>,
147 score_fn: impl Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
148 ) -> Result<FittedSequentialFeatureSelector<F>, FerroError> {
149 let n_features = x.ncols();
150 let n_samples = x.nrows();
151
152 if n_features < 2 {
158 return Err(FerroError::InvalidParameter {
159 name: "x".into(),
160 reason: format!(
161 "Found array with {n_features} feature(s) while a minimum of 2 is required by SequentialFeatureSelector"
162 ),
163 });
164 }
165
166 if self.n_features_to_select == 0 {
167 return Err(FerroError::InvalidParameter {
168 name: "n_features_to_select".into(),
169 reason: "must be at least 1".into(),
170 });
171 }
172 if self.n_features_to_select >= n_features {
178 return Err(FerroError::InvalidParameter {
179 name: "n_features_to_select".into(),
180 reason: format!(
181 "n_features_to_select ({}) must be < number of features ({})",
182 self.n_features_to_select, n_features
183 ),
184 });
185 }
186 if n_samples == 0 {
187 return Err(FerroError::InsufficientSamples {
188 required: 1,
189 actual: 0,
190 context: "SequentialFeatureSelector::fit".into(),
191 });
192 }
193 if y.len() != n_samples {
194 return Err(FerroError::ShapeMismatch {
195 expected: vec![n_samples],
196 actual: vec![y.len()],
197 context: "SequentialFeatureSelector::fit — y must match x rows".into(),
198 });
199 }
200
201 let selected_indices = match self.direction {
202 Direction::Forward => self.forward_search(x, y, n_features, &score_fn)?,
203 Direction::Backward => self.backward_search(x, y, n_features, &score_fn)?,
204 };
205
206 Ok(FittedSequentialFeatureSelector {
207 n_features_in: n_features,
208 selected_indices,
209 _marker: std::marker::PhantomData,
210 })
211 }
212
213 #[allow(clippy::type_complexity)]
215 fn forward_search<F: Float + Send + Sync + 'static>(
216 &self,
217 x: &Array2<F>,
218 y: &Array1<F>,
219 n_features: usize,
220 score_fn: &dyn Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
221 ) -> Result<Vec<usize>, FerroError> {
222 let mut selected: Vec<usize> = Vec::with_capacity(self.n_features_to_select);
223 let mut remaining: Vec<usize> = (0..n_features).collect();
224
225 for _ in 0..self.n_features_to_select {
226 let mut best_score = F::neg_infinity();
227 let mut best_feature = remaining[0];
228
229 for &candidate in &remaining {
230 let mut trial: Vec<usize> = selected.clone();
231 trial.push(candidate);
232 trial.sort_unstable();
233 let x_sub = select_columns(x, &trial);
234 let score = score_fn(&x_sub, y)?;
235 if score > best_score {
236 best_score = score;
237 best_feature = candidate;
238 }
239 }
240
241 selected.push(best_feature);
242 remaining.retain(|&f| f != best_feature);
243 }
244
245 selected.sort_unstable();
246 Ok(selected)
247 }
248
249 #[allow(clippy::type_complexity)]
251 fn backward_search<F: Float + Send + Sync + 'static>(
252 &self,
253 x: &Array2<F>,
254 y: &Array1<F>,
255 n_features: usize,
256 score_fn: &dyn Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
257 ) -> Result<Vec<usize>, FerroError> {
258 let mut remaining: Vec<usize> = (0..n_features).collect();
259
260 while remaining.len() > self.n_features_to_select {
261 let mut best_score = F::neg_infinity();
262 let mut worst_feature = remaining[0];
263
264 for &candidate in &remaining {
265 let trial: Vec<usize> = remaining
267 .iter()
268 .copied()
269 .filter(|&f| f != candidate)
270 .collect();
271 let x_sub = select_columns(x, &trial);
272 let score = score_fn(&x_sub, y)?;
273 if score > best_score {
274 best_score = score;
275 worst_feature = candidate;
276 }
277 }
278
279 remaining.retain(|&f| f != worst_feature);
280 }
281
282 remaining.sort_unstable();
283 Ok(remaining)
284 }
285}
286
287#[derive(Debug, Clone)]
295pub struct FittedSequentialFeatureSelector<F> {
296 n_features_in: usize,
298 selected_indices: Vec<usize>,
300 _marker: std::marker::PhantomData<F>,
301}
302
303impl<F: Float + Send + Sync + 'static> FittedSequentialFeatureSelector<F> {
304 #[must_use]
306 pub fn selected_indices(&self) -> &[usize] {
307 &self.selected_indices
308 }
309
310 #[must_use]
312 pub fn n_features_selected(&self) -> usize {
313 self.selected_indices.len()
314 }
315}
316
317impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSequentialFeatureSelector<F> {
318 type Output = Array2<F>;
319 type Error = FerroError;
320
321 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
328 if x.ncols() != self.n_features_in {
329 return Err(FerroError::ShapeMismatch {
330 expected: vec![x.nrows(), self.n_features_in],
331 actual: vec![x.nrows(), x.ncols()],
332 context: "FittedSequentialFeatureSelector::transform".into(),
333 });
334 }
335 Ok(select_columns(x, &self.selected_indices))
336 }
337}
338
339fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
341 let nrows = x.nrows();
342 let ncols = indices.len();
343 if ncols == 0 {
344 return Array2::zeros((nrows, 0));
345 }
346 let mut out = Array2::zeros((nrows, ncols));
347 for (new_j, &old_j) in indices.iter().enumerate() {
348 for i in 0..nrows {
349 out[[i, new_j]] = x[[i, old_j]];
350 }
351 }
352 out
353}
354
355#[cfg(test)]
360mod tests {
361 use super::*;
362 use approx::assert_abs_diff_eq;
363 use ndarray::array;
364
365 fn mean_sum_score(x: &Array2<f64>, _y: &Array1<f64>) -> Result<f64, FerroError> {
367 let score: f64 = x
368 .columns()
369 .into_iter()
370 .map(|c| c.sum() / c.len() as f64)
371 .sum();
372 Ok(score)
373 }
374
375 #[test]
376 fn test_forward_selects_best() {
377 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
378 let x = array![[1.0, 10.0, 0.1], [2.0, 20.0, 0.2], [3.0, 30.0, 0.3]];
379 let y = array![1.0, 2.0, 3.0];
380 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
381 assert_eq!(fitted.selected_indices(), &[1]); }
383
384 #[test]
385 fn test_forward_select_two() {
386 let sfs = SequentialFeatureSelector::new(2, Direction::Forward);
387 let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
388 let y = array![1.0, 2.0];
389 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
390 assert_eq!(fitted.n_features_selected(), 2);
391 assert!(fitted.selected_indices().contains(&1));
393 assert!(fitted.selected_indices().contains(&2));
394 }
395
396 #[test]
397 fn test_backward_selects_best() {
398 let sfs = SequentialFeatureSelector::new(1, Direction::Backward);
399 let x = array![[1.0, 10.0, 0.1], [2.0, 20.0, 0.2], [3.0, 30.0, 0.3]];
400 let y = array![1.0, 2.0, 3.0];
401 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
402 assert_eq!(fitted.selected_indices(), &[1]);
405 }
406
407 #[test]
408 fn test_backward_select_two() {
409 let sfs = SequentialFeatureSelector::new(2, Direction::Backward);
410 let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
411 let y = array![1.0, 2.0];
412 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
413 assert_eq!(fitted.n_features_selected(), 2);
414 assert_eq!(fitted.selected_indices(), &[1, 2]);
416 }
417
418 #[test]
425 fn test_select_all_features_rejected() {
426 let sfs = SequentialFeatureSelector::new(3, Direction::Forward);
427 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
428 let y = array![1.0, 2.0];
429 assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
430 }
431
432 #[test]
433 fn test_transform() {
434 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
435 let x = array![[1.0, 10.0], [2.0, 20.0]];
436 let y = array![1.0, 2.0];
437 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
438 let out = fitted.transform(&x).unwrap();
439 assert_eq!(out.ncols(), 1);
440 assert_abs_diff_eq!(out[[0, 0]], 10.0, epsilon = 1e-15);
441 assert_abs_diff_eq!(out[[1, 0]], 20.0, epsilon = 1e-15);
442 }
443
444 #[test]
445 fn test_zero_features_error() {
446 let sfs = SequentialFeatureSelector::new(0, Direction::Forward);
447 let x = array![[1.0, 2.0]];
448 let y = array![1.0];
449 assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
450 }
451
452 #[test]
453 fn test_too_many_features_error() {
454 let sfs = SequentialFeatureSelector::new(5, Direction::Forward);
455 let x = array![[1.0, 2.0]];
456 let y = array![1.0];
457 assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
458 }
459
460 #[test]
461 fn test_zero_rows_error() {
462 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
463 let x: Array2<f64> = Array2::zeros((0, 3));
464 let y: Array1<f64> = Array1::zeros(0);
465 assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
466 }
467
468 #[test]
469 fn test_y_length_mismatch() {
470 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
471 let x = array![[1.0, 2.0], [3.0, 4.0]];
472 let y = array![1.0]; assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
474 }
475
476 #[test]
477 fn test_shape_mismatch_on_transform() {
478 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
479 let x = array![[1.0, 2.0], [3.0, 4.0]];
480 let y = array![1.0, 2.0];
481 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
482 let x_bad = array![[1.0, 2.0, 3.0]];
483 assert!(fitted.transform(&x_bad).is_err());
484 }
485
486 #[test]
487 fn test_score_fn_error_propagated() {
488 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
489 let x = array![[1.0, 2.0]];
490 let y = array![1.0];
491 let bad_fn = |_x: &Array2<f64>, _y: &Array1<f64>| -> Result<f64, FerroError> {
492 Err(FerroError::NumericalInstability {
493 message: "test error".into(),
494 })
495 };
496 assert!(sfs.fit(&x, &y, bad_fn).is_err());
497 }
498
499 #[test]
500 fn test_indices_sorted() {
501 let sfs = SequentialFeatureSelector::new(2, Direction::Forward);
502 let x = array![[100.0, 1.0, 10.0], [200.0, 2.0, 20.0]];
503 let y = array![1.0, 2.0];
504 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
505 let indices = fitted.selected_indices();
506 assert!(indices.windows(2).all(|w| w[0] < w[1]));
507 }
508
509 #[test]
510 fn test_accessors() {
511 let sfs = SequentialFeatureSelector::new(2, Direction::Backward);
512 assert_eq!(sfs.n_features_to_select(), 2);
513 assert_eq!(sfs.direction(), Direction::Backward);
514 }
515}