1use ferrolearn_core::error::FerroError;
16use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
17use ferrolearn_core::traits::{Fit, Transform};
18use ndarray::{Array1, Array2};
19use num_traits::Float;
20
21fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
29 let nrows = x.nrows();
30 let ncols = indices.len();
31 if ncols == 0 {
32 return Array2::zeros((nrows, 0));
34 }
35 let mut out = Array2::zeros((nrows, ncols));
36 for (new_j, &old_j) in indices.iter().enumerate() {
37 for i in 0..nrows {
38 out[[i, new_j]] = x[[i, old_j]];
39 }
40 }
41 out
42}
43
44#[derive(Debug, Clone)]
73pub struct VarianceThreshold<F> {
74 threshold: F,
76}
77
78impl<F: Float + Send + Sync + 'static> VarianceThreshold<F> {
79 pub fn new(threshold: F) -> Self {
87 Self { threshold }
88 }
89
90 #[must_use]
92 pub fn threshold(&self) -> F {
93 self.threshold
94 }
95}
96
97impl<F: Float + Send + Sync + 'static> Default for VarianceThreshold<F> {
98 fn default() -> Self {
99 Self::new(F::zero())
100 }
101}
102
103#[derive(Debug, Clone)]
112pub struct FittedVarianceThreshold<F> {
113 selected_indices: Vec<usize>,
115 variances: Array1<F>,
117}
118
119impl<F: Float + Send + Sync + 'static> FittedVarianceThreshold<F> {
120 #[must_use]
122 pub fn selected_indices(&self) -> &[usize] {
123 &self.selected_indices
124 }
125
126 #[must_use]
128 pub fn variances(&self) -> &Array1<F> {
129 &self.variances
130 }
131}
132
133impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for VarianceThreshold<F> {
138 type Fitted = FittedVarianceThreshold<F>;
139 type Error = FerroError;
140
141 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedVarianceThreshold<F>, FerroError> {
148 if self.threshold < F::zero() {
149 return Err(FerroError::InvalidParameter {
150 name: "threshold".into(),
151 reason: "variance threshold must be non-negative".into(),
152 });
153 }
154 let n_samples = x.nrows();
155 if n_samples == 0 {
156 return Err(FerroError::InsufficientSamples {
157 required: 1,
158 actual: 0,
159 context: "VarianceThreshold::fit".into(),
160 });
161 }
162
163 let n = F::from(n_samples).unwrap_or(F::one());
164 let n_features = x.ncols();
165 let mut variances = Array1::zeros(n_features);
166 let mut selected_indices = Vec::new();
167
168 for j in 0..n_features {
169 let col = x.column(j);
170 let mean = col.iter().copied().fold(F::zero(), |acc, v| acc + v) / n;
171 let var = col
172 .iter()
173 .copied()
174 .map(|v| (v - mean) * (v - mean))
175 .fold(F::zero(), |acc, v| acc + v)
176 / n;
177 variances[j] = var;
178 if var > self.threshold {
179 selected_indices.push(j);
180 }
181 }
182
183 Ok(FittedVarianceThreshold {
184 selected_indices,
185 variances,
186 })
187 }
188}
189
190impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedVarianceThreshold<F> {
191 type Output = Array2<F>;
192 type Error = FerroError;
193
194 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
201 let n_original = self.variances.len();
202 if x.ncols() != n_original {
203 return Err(FerroError::ShapeMismatch {
204 expected: vec![x.nrows(), n_original],
205 actual: vec![x.nrows(), x.ncols()],
206 context: "FittedVarianceThreshold::transform".into(),
207 });
208 }
209 Ok(select_columns(x, &self.selected_indices))
210 }
211}
212
213impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for VarianceThreshold<F> {
218 fn fit_pipeline(
224 &self,
225 x: &Array2<F>,
226 _y: &Array1<F>,
227 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
228 let fitted = self.fit(x, &())?;
229 Ok(Box::new(fitted))
230 }
231}
232
233impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedVarianceThreshold<F> {
234 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
240 self.transform(x)
241 }
242}
243
244#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum ScoreFunc {
253 FClassif,
257}
258
259#[derive(Debug, Clone)]
279pub struct SelectKBest<F> {
280 k: usize,
282 score_func: ScoreFunc,
284 _marker: std::marker::PhantomData<F>,
285}
286
287impl<F: Float + Send + Sync + 'static> SelectKBest<F> {
288 #[must_use]
296 pub fn new(k: usize, score_func: ScoreFunc) -> Self {
297 Self {
298 k,
299 score_func,
300 _marker: std::marker::PhantomData,
301 }
302 }
303
304 #[must_use]
306 pub fn k(&self) -> usize {
307 self.k
308 }
309
310 #[must_use]
312 pub fn score_func(&self) -> ScoreFunc {
313 self.score_func
314 }
315}
316
317#[derive(Debug, Clone)]
325pub struct FittedSelectKBest<F> {
326 n_features_in: usize,
328 scores: Array1<F>,
330 selected_indices: Vec<usize>,
332}
333
334impl<F: Float + Send + Sync + 'static> FittedSelectKBest<F> {
335 #[must_use]
337 pub fn scores(&self) -> &Array1<F> {
338 &self.scores
339 }
340
341 #[must_use]
343 pub fn selected_indices(&self) -> &[usize] {
344 &self.selected_indices
345 }
346}
347
348fn anova_f_scores<F: Float>(x: &Array2<F>, y: &Array1<usize>) -> Vec<F> {
366 let n_samples = x.nrows();
367 let n_features = x.ncols();
368
369 let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
371 std::collections::HashMap::new();
372 for (i, &label) in y.iter().enumerate() {
373 class_indices.entry(label).or_default().push(i);
374 }
375 let n_classes = class_indices.len();
376
377 let mut scores = Vec::with_capacity(n_features);
378
379 for j in 0..n_features {
380 let col = x.column(j);
381
382 let grand_mean =
384 col.iter().copied().fold(F::zero(), |acc, v| acc + v) / F::from(n_samples).unwrap();
385
386 let mut ss_between = F::zero();
388 let mut ss_within = F::zero();
390
391 for rows in class_indices.values() {
392 let n_k = F::from(rows.len()).unwrap();
393 let class_mean = rows
394 .iter()
395 .map(|&i| col[i])
396 .fold(F::zero(), |acc, v| acc + v)
397 / n_k;
398 let diff = class_mean - grand_mean;
399 ss_between = ss_between + n_k * diff * diff;
400 for &i in rows {
401 let d = col[i] - class_mean;
402 ss_within = ss_within + d * d;
403 }
404 }
405
406 let df_between = F::from(n_classes.saturating_sub(1)).unwrap();
407 let df_within = F::from(n_samples.saturating_sub(n_classes)).unwrap();
408
409 let f = if df_between == F::zero() || df_within == F::zero() {
410 F::zero()
411 } else {
412 let ms_between = ss_between / df_between;
413 let ms_within = ss_within / df_within;
414 if ms_within == F::zero() {
415 F::infinity()
416 } else {
417 ms_between / ms_within
418 }
419 };
420
421 scores.push(f);
422 }
423
424 scores
425}
426
427impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for SelectKBest<F> {
432 type Fitted = FittedSelectKBest<F>;
433 type Error = FerroError;
434
435 fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedSelectKBest<F>, FerroError> {
444 let n_samples = x.nrows();
445 if n_samples == 0 {
446 return Err(FerroError::InsufficientSamples {
447 required: 1,
448 actual: 0,
449 context: "SelectKBest::fit".into(),
450 });
451 }
452 if y.len() != n_samples {
453 return Err(FerroError::ShapeMismatch {
454 expected: vec![n_samples],
455 actual: vec![y.len()],
456 context: "SelectKBest::fit — y must have the same length as x has rows".into(),
457 });
458 }
459 let n_features = x.ncols();
460 if self.k > n_features {
461 return Err(FerroError::InvalidParameter {
462 name: "k".into(),
463 reason: format!(
464 "k ({}) cannot exceed the number of features ({})",
465 self.k, n_features
466 ),
467 });
468 }
469
470 let raw_scores = match self.score_func {
471 ScoreFunc::FClassif => anova_f_scores(x, y),
472 };
473
474 let scores = Array1::from_vec(raw_scores.clone());
475
476 let mut ranked: Vec<usize> = (0..n_features).collect();
479 ranked.sort_by(|&a, &b| {
480 raw_scores[b]
481 .partial_cmp(&raw_scores[a])
482 .unwrap_or(std::cmp::Ordering::Equal)
483 .then(a.cmp(&b))
485 });
486
487 let mut selected_indices: Vec<usize> = ranked[..self.k].to_vec();
488 selected_indices.sort_unstable();
490
491 Ok(FittedSelectKBest {
492 n_features_in: n_features,
493 scores,
494 selected_indices,
495 })
496 }
497}
498
499impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectKBest<F> {
500 type Output = Array2<F>;
501 type Error = FerroError;
502
503 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
510 if x.ncols() != self.n_features_in {
511 return Err(FerroError::ShapeMismatch {
512 expected: vec![x.nrows(), self.n_features_in],
513 actual: vec![x.nrows(), x.ncols()],
514 context: "FittedSelectKBest::transform".into(),
515 });
516 }
517 Ok(select_columns(x, &self.selected_indices))
518 }
519}
520
521impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for SelectKBest<F> {
530 fn fit_pipeline(
538 &self,
539 x: &Array2<F>,
540 y: &Array1<F>,
541 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
542 let y_usize: Array1<usize> = y.mapv(|v| v.round().to_usize().unwrap_or(0));
543 let fitted = self.fit(x, &y_usize)?;
544 Ok(Box::new(fitted))
545 }
546}
547
548impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedSelectKBest<F> {
549 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
555 self.transform(x)
556 }
557}
558
559#[derive(Debug, Clone)]
586pub struct SelectFromModel<F> {
587 importances: Array1<F>,
589 threshold: F,
591 selected_indices: Vec<usize>,
593}
594
595impl<F: Float + Send + Sync + 'static> SelectFromModel<F> {
596 pub fn new_from_importances(
608 importances: &Array1<F>,
609 threshold: Option<F>,
610 ) -> Result<Self, FerroError> {
611 let n = importances.len();
612 if n == 0 {
613 return Err(FerroError::InvalidParameter {
614 name: "importances".into(),
615 reason: "importance vector must not be empty".into(),
616 });
617 }
618
619 let thr = threshold.unwrap_or_else(|| {
620 importances
621 .iter()
622 .copied()
623 .fold(F::zero(), |acc, v| acc + v)
624 / F::from(n).unwrap_or(F::one())
625 });
626
627 let selected_indices: Vec<usize> = importances
628 .iter()
629 .enumerate()
630 .filter(|&(_, &imp)| imp >= thr)
631 .map(|(j, _)| j)
632 .collect();
633
634 Ok(Self {
635 importances: importances.clone(),
636 threshold: thr,
637 selected_indices,
638 })
639 }
640
641 #[must_use]
643 pub fn threshold(&self) -> F {
644 self.threshold
645 }
646
647 #[must_use]
649 pub fn importances(&self) -> &Array1<F> {
650 &self.importances
651 }
652
653 #[must_use]
655 pub fn selected_indices(&self) -> &[usize] {
656 &self.selected_indices
657 }
658}
659
660impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SelectFromModel<F> {
665 type Output = Array2<F>;
666 type Error = FerroError;
667
668 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
676 let n_features = self.importances.len();
677 if x.ncols() != n_features {
678 return Err(FerroError::ShapeMismatch {
679 expected: vec![x.nrows(), n_features],
680 actual: vec![x.nrows(), x.ncols()],
681 context: "SelectFromModel::transform".into(),
682 });
683 }
684 Ok(select_columns(x, &self.selected_indices))
685 }
686}
687
688impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for SelectFromModel<F> {
696 fn fit_pipeline(
702 &self,
703 _x: &Array2<F>,
704 _y: &Array1<F>,
705 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
706 Ok(Box::new(self.clone()))
707 }
708}
709
710impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for SelectFromModel<F> {
711 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
717 self.transform(x)
718 }
719}
720
721#[cfg(test)]
726mod tests {
727 use super::*;
728 use approx::assert_abs_diff_eq;
729 use ndarray::array;
730
731 #[test]
736 fn test_variance_threshold_removes_constant_column() {
737 let sel = VarianceThreshold::<f64>::new(0.0);
738 let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
740 let fitted = sel.fit(&x, &()).unwrap();
741 assert_eq!(fitted.selected_indices(), &[0usize]);
742 let out = fitted.transform(&x).unwrap();
743 assert_eq!(out.ncols(), 1);
744 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-15);
746 assert_abs_diff_eq!(out[[1, 0]], 2.0, epsilon = 1e-15);
747 }
748
749 #[test]
750 fn test_variance_threshold_keeps_all_when_above() {
751 let sel = VarianceThreshold::<f64>::new(0.0);
752 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
753 let fitted = sel.fit(&x, &()).unwrap();
754 assert_eq!(fitted.selected_indices().len(), 2);
755 let out = fitted.transform(&x).unwrap();
756 assert_eq!(out.ncols(), 2);
757 }
758
759 #[test]
760 fn test_variance_threshold_custom_threshold() {
761 let sel = VarianceThreshold::<f64>::new(1.5);
762 let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
765 let fitted = sel.fit(&x, &()).unwrap();
766 assert_eq!(fitted.selected_indices(), &[1usize]);
767 let out = fitted.transform(&x).unwrap();
768 assert_eq!(out.ncols(), 1);
769 }
770
771 #[test]
772 fn test_variance_threshold_stores_variances() {
773 let sel = VarianceThreshold::<f64>::default();
774 let x = array![[0.0], [0.0], [0.0]]; let fitted = sel.fit(&x, &()).unwrap();
776 assert_abs_diff_eq!(fitted.variances()[0], 0.0, epsilon = 1e-15);
777 }
778
779 #[test]
780 fn test_variance_threshold_zero_rows_error() {
781 let sel = VarianceThreshold::<f64>::new(0.0);
782 let x: Array2<f64> = Array2::zeros((0, 2));
783 assert!(sel.fit(&x, &()).is_err());
784 }
785
786 #[test]
787 fn test_variance_threshold_negative_threshold_error() {
788 let sel = VarianceThreshold::<f64>::new(-0.1);
789 let x = array![[1.0], [2.0]];
790 assert!(sel.fit(&x, &()).is_err());
791 }
792
793 #[test]
794 fn test_variance_threshold_shape_mismatch_on_transform() {
795 let sel = VarianceThreshold::<f64>::new(0.0);
796 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
797 let fitted = sel.fit(&x_train, &()).unwrap();
798 let x_bad = array![[1.0, 2.0, 3.0]];
799 assert!(fitted.transform(&x_bad).is_err());
800 }
801
802 #[test]
803 fn test_variance_threshold_all_constant_columns() {
804 let sel = VarianceThreshold::<f64>::new(0.0);
805 let x = array![[5.0, 3.0], [5.0, 3.0], [5.0, 3.0]];
806 let fitted = sel.fit(&x, &()).unwrap();
807 assert_eq!(fitted.selected_indices().len(), 0);
809 let out = fitted.transform(&x).unwrap();
810 assert_eq!(out.ncols(), 0);
811 assert_eq!(out.nrows(), 3);
812 }
813
814 #[test]
815 fn test_variance_threshold_pipeline_integration() {
816 use ferrolearn_core::pipeline::PipelineTransformer;
817 let sel = VarianceThreshold::<f64>::new(0.0);
818 let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
819 let y = ndarray::array![0.0, 1.0, 0.0];
820 let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
821 let out = fitted_box.transform_pipeline(&x).unwrap();
822 assert_eq!(out.ncols(), 1);
823 }
824
825 #[test]
826 fn test_variance_threshold_f32() {
827 let sel = VarianceThreshold::<f32>::new(0.0f32);
828 let x: Array2<f32> = array![[1.0f32, 5.0], [2.0, 5.0], [3.0, 5.0]];
829 let fitted = sel.fit(&x, &()).unwrap();
830 assert_eq!(fitted.selected_indices(), &[0usize]);
831 }
832
833 #[test]
838 fn test_select_k_best_selects_highest_scoring_feature() {
839 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
841 let x = array![[1.0, 5.0], [1.0, 6.0], [10.0, 5.0], [10.0, 6.0]];
842 let y: Array1<usize> = array![0, 0, 1, 1];
843 let fitted = sel.fit(&x, &y).unwrap();
844 assert_eq!(fitted.selected_indices(), &[0usize]);
846 let out = fitted.transform(&x).unwrap();
847 assert_eq!(out.ncols(), 1);
848 }
849
850 #[test]
851 fn test_select_k_best_k_equals_n_features_keeps_all() {
852 let sel = SelectKBest::<f64>::new(2, ScoreFunc::FClassif);
853 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
854 let y: Array1<usize> = array![0, 1, 0];
855 let fitted = sel.fit(&x, &y).unwrap();
856 assert_eq!(fitted.selected_indices().len(), 2);
857 let out = fitted.transform(&x).unwrap();
858 assert_eq!(out.ncols(), 2);
859 }
860
861 #[test]
862 fn test_select_k_best_scores_stored() {
863 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
864 let x = array![[1.0, 2.0], [1.0, 4.0]];
865 let y: Array1<usize> = array![0, 1];
866 let fitted = sel.fit(&x, &y).unwrap();
867 assert_eq!(fitted.scores().len(), 2);
868 }
869
870 #[test]
871 fn test_select_k_best_zero_rows_error() {
872 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
873 let x: Array2<f64> = Array2::zeros((0, 3));
874 let y: Array1<usize> = Array1::zeros(0);
875 assert!(sel.fit(&x, &y).is_err());
876 }
877
878 #[test]
879 fn test_select_k_best_k_exceeds_n_features_error() {
880 let sel = SelectKBest::<f64>::new(5, ScoreFunc::FClassif);
881 let x = array![[1.0, 2.0], [3.0, 4.0]];
882 let y: Array1<usize> = array![0, 1];
883 assert!(sel.fit(&x, &y).is_err());
884 }
885
886 #[test]
887 fn test_select_k_best_y_length_mismatch_error() {
888 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
889 let x = array![[1.0, 2.0], [3.0, 4.0]];
890 let y: Array1<usize> = array![0]; assert!(sel.fit(&x, &y).is_err());
892 }
893
894 #[test]
895 fn test_select_k_best_shape_mismatch_on_transform() {
896 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
897 let x = array![[1.0, 2.0], [3.0, 4.0]];
898 let y: Array1<usize> = array![0, 1];
899 let fitted = sel.fit(&x, &y).unwrap();
900 let x_bad = array![[1.0, 2.0, 3.0]];
901 assert!(fitted.transform(&x_bad).is_err());
902 }
903
904 #[test]
905 fn test_select_k_best_selected_indices_in_column_order() {
906 let sel = SelectKBest::<f64>::new(2, ScoreFunc::FClassif);
908 let x = array![[1.0, 100.0], [2.0, 200.0]];
909 let y: Array1<usize> = array![0, 1];
910 let fitted = sel.fit(&x, &y).unwrap();
911 let indices = fitted.selected_indices();
912 assert!(indices.windows(2).all(|w| w[0] < w[1]));
913 }
914
915 #[test]
916 fn test_select_k_best_pipeline_integration() {
917 use ferrolearn_core::pipeline::PipelineTransformer;
918 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
919 let x = array![[1.0, 5.0], [1.0, 6.0], [10.0, 5.0], [10.0, 6.0]];
920 let y = ndarray::array![0.0, 0.0, 1.0, 1.0];
921 let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
922 let out = fitted_box.transform_pipeline(&x).unwrap();
923 assert_eq!(out.ncols(), 1);
924 }
925
926 #[test]
927 fn test_select_k_best_f_score_zero_within_class_variance() {
928 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
930 let x = array![[0.0], [0.0], [10.0], [10.0]];
931 let y: Array1<usize> = array![0, 0, 1, 1];
932 let fitted = sel.fit(&x, &y).unwrap();
933 assert!(fitted.scores()[0].is_infinite());
934 }
935
936 #[test]
941 fn test_select_from_model_mean_threshold() {
942 let importances = array![0.1, 0.5, 0.4];
945 let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
946 assert_eq!(sel.selected_indices(), &[1usize, 2]);
947 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
948 let out = sel.transform(&x).unwrap();
949 assert_eq!(out.ncols(), 2);
950 assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
951 assert_abs_diff_eq!(out[[0, 1]], 3.0, epsilon = 1e-15);
952 }
953
954 #[test]
955 fn test_select_from_model_explicit_threshold() {
956 let importances = array![0.1, 0.5, 0.4];
957 let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.45)).unwrap();
959 assert_eq!(sel.selected_indices(), &[1usize]);
960 let x = array![[1.0, 2.0, 3.0]];
961 let out = sel.transform(&x).unwrap();
962 assert_eq!(out.ncols(), 1);
963 assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
964 }
965
966 #[test]
967 fn test_select_from_model_all_selected_when_threshold_zero() {
968 let importances = array![0.1, 0.2, 0.3];
969 let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.0)).unwrap();
970 assert_eq!(sel.selected_indices().len(), 3);
971 }
972
973 #[test]
974 fn test_select_from_model_none_selected_when_threshold_high() {
975 let importances = array![0.1, 0.2, 0.3];
976 let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(1.0)).unwrap();
977 assert_eq!(sel.selected_indices().len(), 0);
978 let x = array![[1.0, 2.0, 3.0]];
979 let out = sel.transform(&x).unwrap();
980 assert_eq!(out.ncols(), 0);
981 }
982
983 #[test]
984 fn test_select_from_model_empty_importances_error() {
985 let importances: Array1<f64> = Array1::zeros(0);
986 assert!(SelectFromModel::<f64>::new_from_importances(&importances, None).is_err());
987 }
988
989 #[test]
990 fn test_select_from_model_shape_mismatch_on_transform() {
991 let importances = array![0.3, 0.7];
992 let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
993 let x_bad = array![[1.0, 2.0, 3.0]]; assert!(sel.transform(&x_bad).is_err());
995 }
996
997 #[test]
998 fn test_select_from_model_threshold_accessor() {
999 let importances = array![0.3, 0.7];
1000 let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.5)).unwrap();
1001 assert_abs_diff_eq!(sel.threshold(), 0.5, epsilon = 1e-15);
1002 }
1003
1004 #[test]
1005 fn test_select_from_model_pipeline_integration() {
1006 use ferrolearn_core::pipeline::PipelineTransformer;
1007 let importances = array![0.1, 0.9];
1008 let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1009 let x = array![[1.0, 2.0], [3.0, 4.0]];
1010 let y = ndarray::array![0.0, 1.0];
1011 let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
1012 let out = fitted_box.transform_pipeline(&x).unwrap();
1013 assert_eq!(out.ncols(), 1);
1015 assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
1016 }
1017
1018 #[test]
1019 fn test_select_from_model_importances_accessor() {
1020 let importances = array![0.2, 0.8];
1021 let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1022 assert_abs_diff_eq!(sel.importances()[0], 0.2, epsilon = 1e-15);
1023 assert_abs_diff_eq!(sel.importances()[1], 0.8, epsilon = 1e-15);
1024 }
1025}