1use ferrolearn_core::error::FerroError;
19use ferrolearn_core::traits::Transform;
20use ndarray::{Array1, Array2};
21use num_traits::Float;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum Direction {
30 Forward,
32 Backward,
34}
35
36#[must_use]
70#[derive(Debug, Clone)]
71pub struct SequentialFeatureSelector {
72 n_features_to_select: usize,
74 direction: Direction,
76}
77
78impl SequentialFeatureSelector {
79 pub fn new(n_features_to_select: usize, direction: Direction) -> Self {
86 Self {
87 n_features_to_select,
88 direction,
89 }
90 }
91
92 #[must_use]
94 pub fn n_features_to_select(&self) -> usize {
95 self.n_features_to_select
96 }
97
98 #[must_use]
100 pub fn direction(&self) -> Direction {
101 self.direction
102 }
103
104 pub fn fit<F: Float + Send + Sync + 'static>(
120 &self,
121 x: &Array2<F>,
122 y: &Array1<F>,
123 score_fn: impl Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
124 ) -> Result<FittedSequentialFeatureSelector<F>, FerroError> {
125 let n_features = x.ncols();
126 let n_samples = x.nrows();
127
128 if self.n_features_to_select == 0 {
129 return Err(FerroError::InvalidParameter {
130 name: "n_features_to_select".into(),
131 reason: "must be at least 1".into(),
132 });
133 }
134 if self.n_features_to_select > n_features {
135 return Err(FerroError::InvalidParameter {
136 name: "n_features_to_select".into(),
137 reason: format!(
138 "n_features_to_select ({}) exceeds number of features ({})",
139 self.n_features_to_select, n_features
140 ),
141 });
142 }
143 if n_samples == 0 {
144 return Err(FerroError::InsufficientSamples {
145 required: 1,
146 actual: 0,
147 context: "SequentialFeatureSelector::fit".into(),
148 });
149 }
150 if y.len() != n_samples {
151 return Err(FerroError::ShapeMismatch {
152 expected: vec![n_samples],
153 actual: vec![y.len()],
154 context: "SequentialFeatureSelector::fit — y must match x rows".into(),
155 });
156 }
157
158 let selected_indices = match self.direction {
159 Direction::Forward => self.forward_search(x, y, n_features, &score_fn)?,
160 Direction::Backward => self.backward_search(x, y, n_features, &score_fn)?,
161 };
162
163 Ok(FittedSequentialFeatureSelector {
164 n_features_in: n_features,
165 selected_indices,
166 _marker: std::marker::PhantomData,
167 })
168 }
169
170 #[allow(clippy::type_complexity)]
172 fn forward_search<F: Float + Send + Sync + 'static>(
173 &self,
174 x: &Array2<F>,
175 y: &Array1<F>,
176 n_features: usize,
177 score_fn: &dyn Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
178 ) -> Result<Vec<usize>, FerroError> {
179 let mut selected: Vec<usize> = Vec::with_capacity(self.n_features_to_select);
180 let mut remaining: Vec<usize> = (0..n_features).collect();
181
182 for _ in 0..self.n_features_to_select {
183 let mut best_score = F::neg_infinity();
184 let mut best_feature = remaining[0];
185
186 for &candidate in &remaining {
187 let mut trial: Vec<usize> = selected.clone();
188 trial.push(candidate);
189 trial.sort_unstable();
190 let x_sub = select_columns(x, &trial);
191 let score = score_fn(&x_sub, y)?;
192 if score > best_score {
193 best_score = score;
194 best_feature = candidate;
195 }
196 }
197
198 selected.push(best_feature);
199 remaining.retain(|&f| f != best_feature);
200 }
201
202 selected.sort_unstable();
203 Ok(selected)
204 }
205
206 #[allow(clippy::type_complexity)]
208 fn backward_search<F: Float + Send + Sync + 'static>(
209 &self,
210 x: &Array2<F>,
211 y: &Array1<F>,
212 n_features: usize,
213 score_fn: &dyn Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
214 ) -> Result<Vec<usize>, FerroError> {
215 let mut remaining: Vec<usize> = (0..n_features).collect();
216
217 while remaining.len() > self.n_features_to_select {
218 let mut best_score = F::neg_infinity();
219 let mut worst_feature = remaining[0];
220
221 for &candidate in &remaining {
222 let trial: Vec<usize> = remaining
224 .iter()
225 .copied()
226 .filter(|&f| f != candidate)
227 .collect();
228 let x_sub = select_columns(x, &trial);
229 let score = score_fn(&x_sub, y)?;
230 if score > best_score {
231 best_score = score;
232 worst_feature = candidate;
233 }
234 }
235
236 remaining.retain(|&f| f != worst_feature);
237 }
238
239 remaining.sort_unstable();
240 Ok(remaining)
241 }
242}
243
244#[derive(Debug, Clone)]
252pub struct FittedSequentialFeatureSelector<F> {
253 n_features_in: usize,
255 selected_indices: Vec<usize>,
257 _marker: std::marker::PhantomData<F>,
258}
259
260impl<F: Float + Send + Sync + 'static> FittedSequentialFeatureSelector<F> {
261 #[must_use]
263 pub fn selected_indices(&self) -> &[usize] {
264 &self.selected_indices
265 }
266
267 #[must_use]
269 pub fn n_features_selected(&self) -> usize {
270 self.selected_indices.len()
271 }
272}
273
274impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSequentialFeatureSelector<F> {
275 type Output = Array2<F>;
276 type Error = FerroError;
277
278 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
285 if x.ncols() != self.n_features_in {
286 return Err(FerroError::ShapeMismatch {
287 expected: vec![x.nrows(), self.n_features_in],
288 actual: vec![x.nrows(), x.ncols()],
289 context: "FittedSequentialFeatureSelector::transform".into(),
290 });
291 }
292 Ok(select_columns(x, &self.selected_indices))
293 }
294}
295
296fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
298 let nrows = x.nrows();
299 let ncols = indices.len();
300 if ncols == 0 {
301 return Array2::zeros((nrows, 0));
302 }
303 let mut out = Array2::zeros((nrows, ncols));
304 for (new_j, &old_j) in indices.iter().enumerate() {
305 for i in 0..nrows {
306 out[[i, new_j]] = x[[i, old_j]];
307 }
308 }
309 out
310}
311
312#[cfg(test)]
317mod tests {
318 use super::*;
319 use approx::assert_abs_diff_eq;
320 use ndarray::array;
321
322 fn mean_sum_score(x: &Array2<f64>, _y: &Array1<f64>) -> Result<f64, FerroError> {
324 let score: f64 = x
325 .columns()
326 .into_iter()
327 .map(|c| c.sum() / c.len() as f64)
328 .sum();
329 Ok(score)
330 }
331
332 #[test]
333 fn test_forward_selects_best() {
334 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
335 let x = array![[1.0, 10.0, 0.1], [2.0, 20.0, 0.2], [3.0, 30.0, 0.3]];
336 let y = array![1.0, 2.0, 3.0];
337 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
338 assert_eq!(fitted.selected_indices(), &[1]); }
340
341 #[test]
342 fn test_forward_select_two() {
343 let sfs = SequentialFeatureSelector::new(2, Direction::Forward);
344 let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
345 let y = array![1.0, 2.0];
346 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
347 assert_eq!(fitted.n_features_selected(), 2);
348 assert!(fitted.selected_indices().contains(&1));
350 assert!(fitted.selected_indices().contains(&2));
351 }
352
353 #[test]
354 fn test_backward_selects_best() {
355 let sfs = SequentialFeatureSelector::new(1, Direction::Backward);
356 let x = array![[1.0, 10.0, 0.1], [2.0, 20.0, 0.2], [3.0, 30.0, 0.3]];
357 let y = array![1.0, 2.0, 3.0];
358 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
359 assert_eq!(fitted.selected_indices(), &[1]);
362 }
363
364 #[test]
365 fn test_backward_select_two() {
366 let sfs = SequentialFeatureSelector::new(2, Direction::Backward);
367 let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
368 let y = array![1.0, 2.0];
369 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
370 assert_eq!(fitted.n_features_selected(), 2);
371 assert_eq!(fitted.selected_indices(), &[1, 2]);
373 }
374
375 #[test]
376 fn test_select_all_features() {
377 let sfs = SequentialFeatureSelector::new(3, Direction::Forward);
378 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
379 let y = array![1.0, 2.0];
380 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
381 assert_eq!(fitted.n_features_selected(), 3);
382 }
383
384 #[test]
385 fn test_transform() {
386 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
387 let x = array![[1.0, 10.0], [2.0, 20.0]];
388 let y = array![1.0, 2.0];
389 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
390 let out = fitted.transform(&x).unwrap();
391 assert_eq!(out.ncols(), 1);
392 assert_abs_diff_eq!(out[[0, 0]], 10.0, epsilon = 1e-15);
393 assert_abs_diff_eq!(out[[1, 0]], 20.0, epsilon = 1e-15);
394 }
395
396 #[test]
397 fn test_zero_features_error() {
398 let sfs = SequentialFeatureSelector::new(0, Direction::Forward);
399 let x = array![[1.0, 2.0]];
400 let y = array![1.0];
401 assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
402 }
403
404 #[test]
405 fn test_too_many_features_error() {
406 let sfs = SequentialFeatureSelector::new(5, Direction::Forward);
407 let x = array![[1.0, 2.0]];
408 let y = array![1.0];
409 assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
410 }
411
412 #[test]
413 fn test_zero_rows_error() {
414 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
415 let x: Array2<f64> = Array2::zeros((0, 3));
416 let y: Array1<f64> = Array1::zeros(0);
417 assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
418 }
419
420 #[test]
421 fn test_y_length_mismatch() {
422 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
423 let x = array![[1.0, 2.0], [3.0, 4.0]];
424 let y = array![1.0]; assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
426 }
427
428 #[test]
429 fn test_shape_mismatch_on_transform() {
430 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
431 let x = array![[1.0, 2.0], [3.0, 4.0]];
432 let y = array![1.0, 2.0];
433 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
434 let x_bad = array![[1.0, 2.0, 3.0]];
435 assert!(fitted.transform(&x_bad).is_err());
436 }
437
438 #[test]
439 fn test_score_fn_error_propagated() {
440 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
441 let x = array![[1.0, 2.0]];
442 let y = array![1.0];
443 let bad_fn = |_x: &Array2<f64>, _y: &Array1<f64>| -> Result<f64, FerroError> {
444 Err(FerroError::NumericalInstability {
445 message: "test error".into(),
446 })
447 };
448 assert!(sfs.fit(&x, &y, bad_fn).is_err());
449 }
450
451 #[test]
452 fn test_indices_sorted() {
453 let sfs = SequentialFeatureSelector::new(2, Direction::Forward);
454 let x = array![[100.0, 1.0, 10.0], [200.0, 2.0, 20.0]];
455 let y = array![1.0, 2.0];
456 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
457 let indices = fitted.selected_indices();
458 assert!(indices.windows(2).all(|w| w[0] < w[1]));
459 }
460
461 #[test]
462 fn test_accessors() {
463 let sfs = SequentialFeatureSelector::new(2, Direction::Backward);
464 assert_eq!(sfs.n_features_to_select(), 2);
465 assert_eq!(sfs.direction(), Direction::Backward);
466 }
467}