1use ferrolearn_core::error::FerroError;
15use ferrolearn_core::traits::Transform;
16use ndarray::{Array1, Array2};
17use num_traits::Float;
18
19fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
21 let nrows = x.nrows();
22 let ncols = indices.len();
23 if ncols == 0 {
24 return Array2::zeros((nrows, 0));
25 }
26 let mut out = Array2::zeros((nrows, ncols));
27 for (new_j, &old_j) in indices.iter().enumerate() {
28 for i in 0..nrows {
29 out[[i, new_j]] = x[[i, old_j]];
30 }
31 }
32 out
33}
34
35#[derive(Debug, Clone)]
62pub struct RFE<F> {
63 ranking: Vec<usize>,
65 support: Vec<bool>,
67 selected_indices: Vec<usize>,
69 n_features_in: usize,
71 _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float + Send + Sync + 'static> RFE<F> {
75 pub fn new(
88 importances: &Array1<F>,
89 n_features_to_select: usize,
90 step: usize,
91 ) -> Result<Self, FerroError> {
92 let n_features = importances.len();
93 if n_features == 0 {
94 return Err(FerroError::InvalidParameter {
95 name: "importances".into(),
96 reason: "importance vector must not be empty".into(),
97 });
98 }
99 if step == 0 {
100 return Err(FerroError::InvalidParameter {
101 name: "step".into(),
102 reason: "step must be at least 1".into(),
103 });
104 }
105 if n_features_to_select == 0 || n_features_to_select > n_features {
106 return Err(FerroError::InvalidParameter {
107 name: "n_features_to_select".into(),
108 reason: format!(
109 "n_features_to_select ({}) must be in [1, {}]",
110 n_features_to_select, n_features
111 ),
112 });
113 }
114
115 let mut ranking = vec![0usize; n_features];
120 let mut remaining: Vec<usize> = (0..n_features).collect();
121 let mut elimination_rounds: Vec<Vec<usize>> = Vec::new();
122
123 let imp: Vec<F> = importances.iter().copied().collect();
125
126 while remaining.len() > n_features_to_select {
127 remaining.sort_by(|&a, &b| {
129 imp[a]
130 .partial_cmp(&imp[b])
131 .unwrap_or(std::cmp::Ordering::Equal)
132 });
133
134 let n_to_remove = step.min(remaining.len() - n_features_to_select);
136 let removed: Vec<usize> = remaining[..n_to_remove].to_vec();
137 elimination_rounds.push(removed);
138 remaining = remaining[n_to_remove..].to_vec();
139 }
140
141 for &idx in &remaining {
144 ranking[idx] = 1;
145 }
146 for (round_idx, round) in elimination_rounds.iter().rev().enumerate() {
147 let rank = round_idx + 2;
148 for &idx in round {
149 ranking[idx] = rank;
150 }
151 }
152
153 let support: Vec<bool> = ranking.iter().map(|&r| r == 1).collect();
154 let mut selected_indices: Vec<usize> = remaining;
155 selected_indices.sort_unstable();
156
157 Ok(Self {
158 ranking,
159 support,
160 selected_indices,
161 n_features_in: n_features,
162 _marker: std::marker::PhantomData,
163 })
164 }
165
166 #[must_use]
168 pub fn ranking(&self) -> &[usize] {
169 &self.ranking
170 }
171
172 #[must_use]
174 pub fn support(&self) -> &[bool] {
175 &self.support
176 }
177
178 #[must_use]
180 pub fn selected_indices(&self) -> &[usize] {
181 &self.selected_indices
182 }
183
184 #[must_use]
186 pub fn n_features_selected(&self) -> usize {
187 self.selected_indices.len()
188 }
189}
190
191impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for RFE<F> {
192 type Output = Array2<F>;
193 type Error = FerroError;
194
195 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
202 if x.ncols() != self.n_features_in {
203 return Err(FerroError::ShapeMismatch {
204 expected: vec![x.nrows(), self.n_features_in],
205 actual: vec![x.nrows(), x.ncols()],
206 context: "RFE::transform".into(),
207 });
208 }
209 Ok(select_columns(x, &self.selected_indices))
210 }
211}
212
213#[derive(Debug, Clone)]
242pub struct RFECV<F> {
243 rfe: RFE<F>,
245 cv_scores: Vec<f64>,
247 optimal_n_features: usize,
249}
250
251impl<F: Float + Send + Sync + 'static> RFECV<F> {
252 pub fn new(
265 importances: &Array1<F>,
266 cv_scores: &[f64],
267 step: usize,
268 ) -> Result<Self, FerroError> {
269 let n_features = importances.len();
270 if n_features == 0 {
271 return Err(FerroError::InvalidParameter {
272 name: "importances".into(),
273 reason: "importance vector must not be empty".into(),
274 });
275 }
276 if cv_scores.len() != n_features {
277 return Err(FerroError::InvalidParameter {
278 name: "cv_scores".into(),
279 reason: format!(
280 "cv_scores length ({}) must equal number of features ({})",
281 cv_scores.len(),
282 n_features
283 ),
284 });
285 }
286
287 let mut best_idx = 0;
289 let mut best_score = f64::NEG_INFINITY;
290 for (i, &score) in cv_scores.iter().enumerate() {
291 if score > best_score {
292 best_score = score;
293 best_idx = i;
294 }
295 }
296 let optimal_n_features = best_idx + 1;
297
298 let rfe = RFE::new(importances, optimal_n_features, step)?;
299
300 Ok(Self {
301 rfe,
302 cv_scores: cv_scores.to_vec(),
303 optimal_n_features,
304 })
305 }
306
307 #[must_use]
309 pub fn cv_scores(&self) -> &[f64] {
310 &self.cv_scores
311 }
312
313 #[must_use]
315 pub fn optimal_n_features(&self) -> usize {
316 self.optimal_n_features
317 }
318
319 #[must_use]
321 pub fn n_features_selected(&self) -> usize {
322 self.rfe.n_features_selected()
323 }
324
325 #[must_use]
327 pub fn ranking(&self) -> &[usize] {
328 self.rfe.ranking()
329 }
330
331 #[must_use]
333 pub fn support(&self) -> &[bool] {
334 self.rfe.support()
335 }
336
337 #[must_use]
339 pub fn selected_indices(&self) -> &[usize] {
340 self.rfe.selected_indices()
341 }
342}
343
344impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for RFECV<F> {
345 type Output = Array2<F>;
346 type Error = FerroError;
347
348 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
354 self.rfe.transform(x)
355 }
356}
357
358#[cfg(test)]
363mod tests {
364 use super::*;
365 use approx::assert_abs_diff_eq;
366 use ndarray::array;
367
368 #[test]
373 fn test_rfe_basic_ranking() {
374 let imp = array![0.6, 0.3, 0.1];
379 let rfe = RFE::<f64>::new(&imp, 1, 1).unwrap();
380 assert_eq!(rfe.ranking(), &[1, 2, 3]);
381 assert_eq!(rfe.support(), &[true, false, false]);
382 assert_eq!(rfe.selected_indices(), &[0]);
383 }
384
385 #[test]
386 fn test_rfe_select_two() {
387 let imp = array![0.5, 0.3, 0.2];
388 let rfe = RFE::<f64>::new(&imp, 2, 1).unwrap();
389 assert_eq!(rfe.n_features_selected(), 2);
390 assert_eq!(rfe.ranking()[2], 2); assert_eq!(rfe.ranking()[0], 1);
393 assert_eq!(rfe.ranking()[1], 1);
394 }
395
396 #[test]
397 fn test_rfe_step_two() {
398 let imp = array![0.5, 0.3, 0.2, 0.1];
399 let rfe = RFE::<f64>::new(&imp, 2, 2).unwrap();
401 assert_eq!(rfe.n_features_selected(), 2);
402 assert!(rfe.support()[0]);
403 assert!(rfe.support()[1]);
404 assert!(!rfe.support()[2]);
405 assert!(!rfe.support()[3]);
406 }
407
408 #[test]
409 fn test_rfe_transform() {
410 let imp = array![0.6, 0.3, 0.1];
411 let rfe = RFE::<f64>::new(&imp, 1, 1).unwrap();
412 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
413 let out = rfe.transform(&x).unwrap();
414 assert_eq!(out.ncols(), 1);
415 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-15);
417 assert_abs_diff_eq!(out[[1, 0]], 4.0, epsilon = 1e-15);
418 }
419
420 #[test]
421 fn test_rfe_all_features_selected() {
422 let imp = array![0.5, 0.3, 0.2];
423 let rfe = RFE::<f64>::new(&imp, 3, 1).unwrap();
424 assert_eq!(rfe.n_features_selected(), 3);
425 assert!(rfe.support().iter().all(|&s| s));
426 }
427
428 #[test]
429 fn test_rfe_empty_importances_error() {
430 let imp: Array1<f64> = Array1::zeros(0);
431 assert!(RFE::<f64>::new(&imp, 1, 1).is_err());
432 }
433
434 #[test]
435 fn test_rfe_zero_step_error() {
436 let imp = array![0.5, 0.3];
437 assert!(RFE::<f64>::new(&imp, 1, 0).is_err());
438 }
439
440 #[test]
441 fn test_rfe_n_features_too_large_error() {
442 let imp = array![0.5, 0.3];
443 assert!(RFE::<f64>::new(&imp, 5, 1).is_err());
444 }
445
446 #[test]
447 fn test_rfe_n_features_zero_error() {
448 let imp = array![0.5, 0.3];
449 assert!(RFE::<f64>::new(&imp, 0, 1).is_err());
450 }
451
452 #[test]
453 fn test_rfe_shape_mismatch_error() {
454 let imp = array![0.5, 0.3];
455 let rfe = RFE::<f64>::new(&imp, 1, 1).unwrap();
456 let x_bad = array![[1.0, 2.0, 3.0]];
457 assert!(rfe.transform(&x_bad).is_err());
458 }
459
460 #[test]
465 fn test_rfecv_selects_optimal() {
466 let imp = array![0.5, 0.3, 0.2];
467 let cv_scores = vec![0.85, 0.95, 0.90];
469 let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
470 assert_eq!(rfecv.optimal_n_features(), 2);
471 assert_eq!(rfecv.n_features_selected(), 2);
472 }
473
474 #[test]
475 fn test_rfecv_transform() {
476 let imp = array![0.5, 0.3, 0.2];
477 let cv_scores = vec![0.85, 0.95, 0.90];
478 let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
479 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
480 let out = rfecv.transform(&x).unwrap();
481 assert_eq!(out.ncols(), 2);
482 }
483
484 #[test]
485 fn test_rfecv_cv_scores_accessor() {
486 let imp = array![0.5, 0.3];
487 let cv_scores = vec![0.9, 0.8];
488 let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
489 assert_eq!(rfecv.cv_scores(), &[0.9, 0.8]);
490 assert_eq!(rfecv.optimal_n_features(), 1);
492 }
493
494 #[test]
495 fn test_rfecv_mismatched_scores_error() {
496 let imp = array![0.5, 0.3, 0.2];
497 let cv_scores = vec![0.85, 0.95]; assert!(RFECV::<f64>::new(&imp, &cv_scores, 1).is_err());
499 }
500
501 #[test]
502 fn test_rfecv_empty_importances_error() {
503 let imp: Array1<f64> = Array1::zeros(0);
504 let cv_scores: Vec<f64> = vec![];
505 assert!(RFECV::<f64>::new(&imp, &cv_scores, 1).is_err());
506 }
507
508 #[test]
509 fn test_rfecv_ranking_and_support() {
510 let imp = array![0.5, 0.3, 0.2];
511 let cv_scores = vec![0.80, 0.95, 0.90];
512 let rfecv = RFECV::<f64>::new(&imp, &cv_scores, 1).unwrap();
513 assert_eq!(rfecv.n_features_selected(), 2);
514 let support = rfecv.support();
515 assert_eq!(support.iter().filter(|&&s| s).count(), 2);
516 }
517}