1use ferrolearn_core::error::FerroError;
19use ferrolearn_core::traits::Transform;
20use ndarray::{Array1, Array2};
21use num_traits::Float;
22
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Direction {
31 Forward,
33 Backward,
35}
36
37#[must_use]
71#[derive(Debug, Clone)]
72pub struct SequentialFeatureSelector {
73 n_features_to_select: usize,
75 direction: Direction,
77}
78
79impl SequentialFeatureSelector {
80 pub fn new(n_features_to_select: usize, direction: Direction) -> Self {
87 Self {
88 n_features_to_select,
89 direction,
90 }
91 }
92
93 #[must_use]
95 pub fn n_features_to_select(&self) -> usize {
96 self.n_features_to_select
97 }
98
99 #[must_use]
101 pub fn direction(&self) -> Direction {
102 self.direction
103 }
104
105 pub fn fit<F: Float + Send + Sync + 'static>(
121 &self,
122 x: &Array2<F>,
123 y: &Array1<F>,
124 score_fn: impl Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
125 ) -> Result<FittedSequentialFeatureSelector<F>, FerroError> {
126 let n_features = x.ncols();
127 let n_samples = x.nrows();
128
129 if self.n_features_to_select == 0 {
130 return Err(FerroError::InvalidParameter {
131 name: "n_features_to_select".into(),
132 reason: "must be at least 1".into(),
133 });
134 }
135 if self.n_features_to_select > n_features {
136 return Err(FerroError::InvalidParameter {
137 name: "n_features_to_select".into(),
138 reason: format!(
139 "n_features_to_select ({}) exceeds number of features ({})",
140 self.n_features_to_select, n_features
141 ),
142 });
143 }
144 if n_samples == 0 {
145 return Err(FerroError::InsufficientSamples {
146 required: 1,
147 actual: 0,
148 context: "SequentialFeatureSelector::fit".into(),
149 });
150 }
151 if y.len() != n_samples {
152 return Err(FerroError::ShapeMismatch {
153 expected: vec![n_samples],
154 actual: vec![y.len()],
155 context: "SequentialFeatureSelector::fit — y must match x rows".into(),
156 });
157 }
158
159 let selected_indices = match self.direction {
160 Direction::Forward => {
161 self.forward_search(x, y, n_features, &score_fn)?
162 }
163 Direction::Backward => {
164 self.backward_search(x, y, n_features, &score_fn)?
165 }
166 };
167
168 Ok(FittedSequentialFeatureSelector {
169 n_features_in: n_features,
170 selected_indices,
171 _marker: std::marker::PhantomData,
172 })
173 }
174
175 #[allow(clippy::type_complexity)]
177 fn forward_search<F: Float + Send + Sync + 'static>(
178 &self,
179 x: &Array2<F>,
180 y: &Array1<F>,
181 n_features: usize,
182 score_fn: &dyn Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
183 ) -> Result<Vec<usize>, FerroError> {
184 let mut selected: Vec<usize> = Vec::with_capacity(self.n_features_to_select);
185 let mut remaining: Vec<usize> = (0..n_features).collect();
186
187 for _ in 0..self.n_features_to_select {
188 let mut best_score = F::neg_infinity();
189 let mut best_feature = remaining[0];
190
191 for &candidate in &remaining {
192 let mut trial: Vec<usize> = selected.clone();
193 trial.push(candidate);
194 trial.sort_unstable();
195 let x_sub = select_columns(x, &trial);
196 let score = score_fn(&x_sub, y)?;
197 if score > best_score {
198 best_score = score;
199 best_feature = candidate;
200 }
201 }
202
203 selected.push(best_feature);
204 remaining.retain(|&f| f != best_feature);
205 }
206
207 selected.sort_unstable();
208 Ok(selected)
209 }
210
211 #[allow(clippy::type_complexity)]
213 fn backward_search<F: Float + Send + Sync + 'static>(
214 &self,
215 x: &Array2<F>,
216 y: &Array1<F>,
217 n_features: usize,
218 score_fn: &dyn Fn(&Array2<F>, &Array1<F>) -> Result<F, FerroError>,
219 ) -> Result<Vec<usize>, FerroError> {
220 let mut remaining: Vec<usize> = (0..n_features).collect();
221
222 while remaining.len() > self.n_features_to_select {
223 let mut best_score = F::neg_infinity();
224 let mut worst_feature = remaining[0];
225
226 for &candidate in &remaining {
227 let trial: Vec<usize> = remaining
229 .iter()
230 .copied()
231 .filter(|&f| f != candidate)
232 .collect();
233 let x_sub = select_columns(x, &trial);
234 let score = score_fn(&x_sub, y)?;
235 if score > best_score {
236 best_score = score;
237 worst_feature = candidate;
238 }
239 }
240
241 remaining.retain(|&f| f != worst_feature);
242 }
243
244 remaining.sort_unstable();
245 Ok(remaining)
246 }
247}
248
249#[derive(Debug, Clone)]
257pub struct FittedSequentialFeatureSelector<F> {
258 n_features_in: usize,
260 selected_indices: Vec<usize>,
262 _marker: std::marker::PhantomData<F>,
263}
264
265impl<F: Float + Send + Sync + 'static> FittedSequentialFeatureSelector<F> {
266 #[must_use]
268 pub fn selected_indices(&self) -> &[usize] {
269 &self.selected_indices
270 }
271
272 #[must_use]
274 pub fn n_features_selected(&self) -> usize {
275 self.selected_indices.len()
276 }
277}
278
279impl<F: Float + Send + Sync + 'static> Transform<Array2<F>>
280 for FittedSequentialFeatureSelector<F>
281{
282 type Output = Array2<F>;
283 type Error = FerroError;
284
285 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
292 if x.ncols() != self.n_features_in {
293 return Err(FerroError::ShapeMismatch {
294 expected: vec![x.nrows(), self.n_features_in],
295 actual: vec![x.nrows(), x.ncols()],
296 context: "FittedSequentialFeatureSelector::transform".into(),
297 });
298 }
299 Ok(select_columns(x, &self.selected_indices))
300 }
301}
302
303fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
305 let nrows = x.nrows();
306 let ncols = indices.len();
307 if ncols == 0 {
308 return Array2::zeros((nrows, 0));
309 }
310 let mut out = Array2::zeros((nrows, ncols));
311 for (new_j, &old_j) in indices.iter().enumerate() {
312 for i in 0..nrows {
313 out[[i, new_j]] = x[[i, old_j]];
314 }
315 }
316 out
317}
318
319#[cfg(test)]
324mod tests {
325 use super::*;
326 use approx::assert_abs_diff_eq;
327 use ndarray::array;
328
329 fn mean_sum_score(x: &Array2<f64>, _y: &Array1<f64>) -> Result<f64, FerroError> {
331 let score: f64 = x
332 .columns()
333 .into_iter()
334 .map(|c| c.sum() / c.len() as f64)
335 .sum();
336 Ok(score)
337 }
338
339 #[test]
340 fn test_forward_selects_best() {
341 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
342 let x = array![[1.0, 10.0, 0.1], [2.0, 20.0, 0.2], [3.0, 30.0, 0.3]];
343 let y = array![1.0, 2.0, 3.0];
344 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
345 assert_eq!(fitted.selected_indices(), &[1]); }
347
348 #[test]
349 fn test_forward_select_two() {
350 let sfs = SequentialFeatureSelector::new(2, Direction::Forward);
351 let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
352 let y = array![1.0, 2.0];
353 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
354 assert_eq!(fitted.n_features_selected(), 2);
355 assert!(fitted.selected_indices().contains(&1));
357 assert!(fitted.selected_indices().contains(&2));
358 }
359
360 #[test]
361 fn test_backward_selects_best() {
362 let sfs = SequentialFeatureSelector::new(1, Direction::Backward);
363 let x = array![[1.0, 10.0, 0.1], [2.0, 20.0, 0.2], [3.0, 30.0, 0.3]];
364 let y = array![1.0, 2.0, 3.0];
365 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
366 assert_eq!(fitted.selected_indices(), &[1]);
369 }
370
371 #[test]
372 fn test_backward_select_two() {
373 let sfs = SequentialFeatureSelector::new(2, Direction::Backward);
374 let x = array![[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]];
375 let y = array![1.0, 2.0];
376 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
377 assert_eq!(fitted.n_features_selected(), 2);
378 assert_eq!(fitted.selected_indices(), &[1, 2]);
380 }
381
382 #[test]
383 fn test_select_all_features() {
384 let sfs = SequentialFeatureSelector::new(3, Direction::Forward);
385 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
386 let y = array![1.0, 2.0];
387 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
388 assert_eq!(fitted.n_features_selected(), 3);
389 }
390
391 #[test]
392 fn test_transform() {
393 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
394 let x = array![[1.0, 10.0], [2.0, 20.0]];
395 let y = array![1.0, 2.0];
396 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
397 let out = fitted.transform(&x).unwrap();
398 assert_eq!(out.ncols(), 1);
399 assert_abs_diff_eq!(out[[0, 0]], 10.0, epsilon = 1e-15);
400 assert_abs_diff_eq!(out[[1, 0]], 20.0, epsilon = 1e-15);
401 }
402
403 #[test]
404 fn test_zero_features_error() {
405 let sfs = SequentialFeatureSelector::new(0, Direction::Forward);
406 let x = array![[1.0, 2.0]];
407 let y = array![1.0];
408 assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
409 }
410
411 #[test]
412 fn test_too_many_features_error() {
413 let sfs = SequentialFeatureSelector::new(5, Direction::Forward);
414 let x = array![[1.0, 2.0]];
415 let y = array![1.0];
416 assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
417 }
418
419 #[test]
420 fn test_zero_rows_error() {
421 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
422 let x: Array2<f64> = Array2::zeros((0, 3));
423 let y: Array1<f64> = Array1::zeros(0);
424 assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
425 }
426
427 #[test]
428 fn test_y_length_mismatch() {
429 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
430 let x = array![[1.0, 2.0], [3.0, 4.0]];
431 let y = array![1.0]; assert!(sfs.fit(&x, &y, mean_sum_score).is_err());
433 }
434
435 #[test]
436 fn test_shape_mismatch_on_transform() {
437 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
438 let x = array![[1.0, 2.0], [3.0, 4.0]];
439 let y = array![1.0, 2.0];
440 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
441 let x_bad = array![[1.0, 2.0, 3.0]];
442 assert!(fitted.transform(&x_bad).is_err());
443 }
444
445 #[test]
446 fn test_score_fn_error_propagated() {
447 let sfs = SequentialFeatureSelector::new(1, Direction::Forward);
448 let x = array![[1.0, 2.0]];
449 let y = array![1.0];
450 let bad_fn = |_x: &Array2<f64>, _y: &Array1<f64>| -> Result<f64, FerroError> {
451 Err(FerroError::NumericalInstability {
452 message: "test error".into(),
453 })
454 };
455 assert!(sfs.fit(&x, &y, bad_fn).is_err());
456 }
457
458 #[test]
459 fn test_indices_sorted() {
460 let sfs = SequentialFeatureSelector::new(2, Direction::Forward);
461 let x = array![[100.0, 1.0, 10.0], [200.0, 2.0, 20.0]];
462 let y = array![1.0, 2.0];
463 let fitted = sfs.fit(&x, &y, mean_sum_score).unwrap();
464 let indices = fitted.selected_indices();
465 assert!(indices.windows(2).all(|w| w[0] < w[1]));
466 }
467
468 #[test]
469 fn test_accessors() {
470 let sfs = SequentialFeatureSelector::new(2, Direction::Backward);
471 assert_eq!(sfs.n_features_to_select(), 2);
472 assert_eq!(sfs.direction(), Direction::Backward);
473 }
474}