1use ferrolearn_core::error::FerroError;
15use ferrolearn_core::traits::Transform;
16use ndarray::{Array1, Array2};
17use num_traits::Float;
18
19fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
21 let nrows = x.nrows();
22 let ncols = indices.len();
23 if ncols == 0 {
24 return Array2::zeros((nrows, 0));
25 }
26 let mut out = Array2::zeros((nrows, ncols));
27 for (new_j, &old_j) in indices.iter().enumerate() {
28 for i in 0..nrows {
29 out[[i, new_j]] = x[[i, old_j]];
30 }
31 }
32 out
33}
34
35#[derive(Debug, Clone)]
62pub struct RFE<F> {
63 ranking: Vec<usize>,
65 support: Vec<bool>,
67 selected_indices: Vec<usize>,
69 n_features_in: usize,
71 _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float + Send + Sync + 'static> RFE<F> {
75 pub fn new(
88 importances: &Array1<F>,
89 n_features_to_select: usize,
90 step: usize,
91 ) -> Result<Self, FerroError> {
92 let n_features = importances.len();
93 if n_features == 0 {
94 return Err(FerroError::InvalidParameter {
95 name: "importances".into(),
96 reason: "importance vector must not be empty".into(),
97 });
98 }
99 if step == 0 {
100 return Err(FerroError::InvalidParameter {
101 name: "step".into(),
102 reason: "step must be at least 1".into(),
103 });
104 }
105 if n_features_to_select == 0 || n_features_to_select > n_features {
106 return Err(FerroError::InvalidParameter {
107 name: "n_features_to_select".into(),
108 reason: format!(
109 "n_features_to_select ({n_features_to_select}) must be in [1, {n_features}]"
110 ),
111 });
112 }
113
114 let mut ranking = vec![0usize; n_features];
119 let mut remaining: Vec<usize> = (0..n_features).collect();
120 let mut elimination_rounds: Vec<Vec<usize>> = Vec::new();
121
122 let imp: Vec<F> = importances.iter().copied().collect();
124
125 while remaining.len() > n_features_to_select {
126 remaining.sort_by(|&a, &b| {
128 imp[a]
129 .partial_cmp(&imp[b])
130 .unwrap_or(std::cmp::Ordering::Equal)
131 });
132
133 let n_to_remove = step.min(remaining.len() - n_features_to_select);
135 let removed: Vec<usize> = remaining[..n_to_remove].to_vec();
136 elimination_rounds.push(removed);
137 remaining = remaining[n_to_remove..].to_vec();
138 }
139
140 for &idx in &remaining {
143 ranking[idx] = 1;
144 }
145 for (round_idx, round) in elimination_rounds.iter().rev().enumerate() {
146 let rank = round_idx + 2;
147 for &idx in round {
148 ranking[idx] = rank;
149 }
150 }
151
152 let support: Vec<bool> = ranking.iter().map(|&r| r == 1).collect();
153 let mut selected_indices: Vec<usize> = remaining;
154 selected_indices.sort_unstable();
155
156 Ok(Self {
157 ranking,
158 support,
159 selected_indices,
160 n_features_in: n_features,
161 _marker: std::marker::PhantomData,
162 })
163 }
164
165 #[must_use]
167 pub fn ranking(&self) -> &[usize] {
168 &self.ranking
169 }
170
171 #[must_use]
173 pub fn support(&self) -> &[bool] {
174 &self.support
175 }
176
177 #[must_use]
179 pub fn selected_indices(&self) -> &[usize] {
180 &self.selected_indices
181 }
182
183 #[must_use]
185 pub fn n_features_selected(&self) -> usize {
186 self.selected_indices.len()
187 }
188}
189
190impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for RFE<F> {
191 type Output = Array2<F>;
192 type Error = FerroError;
193
194 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
201 if x.ncols() != self.n_features_in {
202 return Err(FerroError::ShapeMismatch {
203 expected: vec![x.nrows(), self.n_features_in],
204 actual: vec![x.nrows(), x.ncols()],
205 context: "RFE::transform".into(),
206 });
207 }
208 Ok(select_columns(x, &self.selected_indices))
209 }
210}
211
212#[derive(Debug, Clone)]
241pub struct RFECV<F> {
242 rfe: RFE<F>,
244 cv_scores: Vec<f64>,
246 optimal_n_features: usize,
248}
249
250impl<F: Float + Send + Sync + 'static> RFECV<F> {
251 pub fn new(
264 importances: &Array1<F>,
265 cv_scores: &[f64],
266 step: usize,
267 ) -> Result<Self, FerroError> {
268 let n_features = importances.len();
269 if n_features == 0 {
270 return Err(FerroError::InvalidParameter {
271 name: "importances".into(),
272 reason: "importance vector must not be empty".into(),
273 });
274 }
275 if cv_scores.len() != n_features {
276 return Err(FerroError::InvalidParameter {
277 name: "cv_scores".into(),
278 reason: format!(
279 "cv_scores length ({}) must equal number of features ({})",
280 cv_scores.len(),
281 n_features
282 ),
283 });
284 }
285
286 let mut best_idx = 0;
288 let mut best_score = f64::NEG_INFINITY;
289 for (i, &score) in cv_scores.iter().enumerate() {
290 if score > best_score {
291 best_score = score;
292 best_idx = i;
293 }
294 }
295 let optimal_n_features = best_idx + 1;
296
297 let rfe = RFE::new(importances, optimal_n_features, step)?;
298
299 Ok(Self {
300 rfe,
301 cv_scores: cv_scores.to_vec(),
302 optimal_n_features,
303 })
304 }
305
306 #[must_use]
308 pub fn cv_scores(&self) -> &[f64] {
309 &self.cv_scores
310 }
311
312 #[must_use]
314 pub fn optimal_n_features(&self) -> usize {
315 self.optimal_n_features
316 }
317
318 #[must_use]
320 pub fn n_features_selected(&self) -> usize {
321 self.rfe.n_features_selected()
322 }
323
324 #[must_use]
326 pub fn ranking(&self) -> &[usize] {
327 self.rfe.ranking()
328 }
329
330 #[must_use]
332 pub fn support(&self) -> &[bool] {
333 self.rfe.support()
334 }
335
336 #[must_use]
338 pub fn selected_indices(&self) -> &[usize] {
339 self.rfe.selected_indices()
340 }
341}
342
343impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for RFECV<F> {
344 type Output = Array2<F>;
345 type Error = FerroError;
346
347 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
353 self.rfe.transform(x)
354 }
355}
356
357#[cfg(test)]
362mod tests {
363 use super::*;
364 use approx::assert_abs_diff_eq;
365 use ndarray::array;
366
367 #[test]
372 fn test_rfe_basic_ranking() {
373 let imp = array![0.6, 0.3, 0.1];
378 let rfe = RFE::<f64>::new(&imp, 1, 1).unwrap();
379 assert_eq!(rfe.ranking(), &[1, 2, 3]);
380 assert_eq!(rfe.support(), &[true, false, false]);
381 assert_eq!(rfe.selected_indices(), &[0]);
382 }
383
384 #[test]
385 fn test_rfe_select_two() {
386 let imp = array![0.5, 0.3, 0.2];
387 let rfe = RFE::<f64>::new(&imp, 2, 1).unwrap();
388 assert_eq!(rfe.n_features_selected(), 2);
389 assert_eq!(rfe.ranking()[2], 2); assert_eq!(rfe.ranking()[0], 1);
392 assert_eq!(rfe.ranking()[1], 1);
393 }
394
395 #[test]
396 fn test_rfe_step_two() {
397 let imp = array![0.5, 0.3, 0.2, 0.1];
398 let rfe = RFE::<f64>::new(&imp, 2, 2).unwrap();
400 assert_eq!(rfe.n_features_selected(), 2);
401 assert!(rfe.support()[0]);
402 assert!(rfe.support()[1]);
403 assert!(!rfe.support()[2]);
404 assert!(!rfe.support()[3]);
405 }
406
407 #[test]
408 fn test_rfe_transform() {
409 let imp = array![0.6, 0.3, 0.1];
410 let rfe = RFE::<f64>::new(&imp, 1, 1).unwrap();
411 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
412 let out = rfe.transform(&x).unwrap();
413 assert_eq!(out.ncols(), 1);
414 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-15);
416 assert_abs_diff_eq!(out[[1, 0]], 4.0, epsilon = 1e-15);
417 }
418
419 #[test]
420 fn test_rfe_all_features_selected() {
421 let imp = array![0.5, 0.3, 0.2];
422 let rfe = RFE::<f64>::new(&imp, 3, 1).unwrap();
423 assert_eq!(rfe.n_features_selected(), 3);
424 assert!(rfe.support().iter().all(|&s| s));
425 }
426
427 #[test]
428 fn test_rfe_empty_importances_error() {
429 let imp: Array1<f64> = Array1::zeros(0);
430 assert!(RFE::<f64>::new(&imp, 1, 1).is_err());
431 }
432
433 #[test]
434 fn test_rfe_zero_step_error() {
435 let imp = array![0.5, 0.3];
436 assert!(RFE::<f64>::new(&imp, 1, 0).is_err());
437 }
438
439 #[test]
440 fn test_rfe_n_features_too_large_error() {
441 let imp = array![0.5, 0.3];
442 assert!(RFE::<f64>::new(&imp, 5, 1).is_err());
443 }
444
445 #[test]
446 fn test_rfe_n_features_zero_error() {
447 let imp = array![0.5, 0.3];
448 assert!(RFE::<f64>::new(&imp, 0, 1).is_err());
449 }
450
451 #[test]
452 fn test_rfe_shape_mismatch_error() {
453 let imp = array![0.5, 0.3];
454 let rfe = RFE::<f64>::new(&imp, 1, 1).unwrap();
455 let x_bad = array![[1.0, 2.0, 3.0]];
456 assert!(rfe.transform(&x_bad).is_err());
457 }
458
459 #[test]
464 fn test_rfecv_selects_optimal() {
465 let imp = array![0.5, 0.3, 0.2];
466 let cv_scores = vec![0.85, 0.95, 0.90];
468 let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
469 assert_eq!(rfecv.optimal_n_features(), 2);
470 assert_eq!(rfecv.n_features_selected(), 2);
471 }
472
473 #[test]
474 fn test_rfecv_transform() {
475 let imp = array![0.5, 0.3, 0.2];
476 let cv_scores = vec![0.85, 0.95, 0.90];
477 let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
478 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
479 let out = rfecv.transform(&x).unwrap();
480 assert_eq!(out.ncols(), 2);
481 }
482
483 #[test]
484 fn test_rfecv_cv_scores_accessor() {
485 let imp = array![0.5, 0.3];
486 let cv_scores = vec![0.9, 0.8];
487 let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
488 assert_eq!(rfecv.cv_scores(), &[0.9, 0.8]);
489 assert_eq!(rfecv.optimal_n_features(), 1);
491 }
492
493 #[test]
494 fn test_rfecv_mismatched_scores_error() {
495 let imp = array![0.5, 0.3, 0.2];
496 let cv_scores = vec![0.85, 0.95]; assert!(RFECV::<f64>::new(&imp, &cv_scores, 1).is_err());
498 }
499
500 #[test]
501 fn test_rfecv_empty_importances_error() {
502 let imp: Array1<f64> = Array1::zeros(0);
503 let cv_scores: Vec<f64> = vec![];
504 assert!(RFECV::<f64>::new(&imp, &cv_scores, 1).is_err());
505 }
506
507 #[test]
508 fn test_rfecv_ranking_and_support() {
509 let imp = array![0.5, 0.3, 0.2];
510 let cv_scores = vec![0.80, 0.95, 0.90];
511 let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
512 assert_eq!(rfecv.n_features_selected(), 2);
513 let support = rfecv.support();
514 assert_eq!(support.iter().filter(|&&s| s).count(), 2);
515 }
516}