1use ferrolearn_core::error::FerroError;
16use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
17use ferrolearn_core::traits::{Fit, Transform};
18use ndarray::{Array1, Array2};
19use num_traits::Float;
20
21fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
29 let nrows = x.nrows();
30 let ncols = indices.len();
31 if ncols == 0 {
32 return Array2::zeros((nrows, 0));
34 }
35 let mut out = Array2::zeros((nrows, ncols));
36 for (new_j, &old_j) in indices.iter().enumerate() {
37 for i in 0..nrows {
38 out[[i, new_j]] = x[[i, old_j]];
39 }
40 }
41 out
42}
43
44#[derive(Debug, Clone)]
73pub struct VarianceThreshold<F> {
74 threshold: F,
76}
77
78impl<F: Float + Send + Sync + 'static> VarianceThreshold<F> {
79 pub fn new(threshold: F) -> Self {
87 Self { threshold }
88 }
89
90 #[must_use]
92 pub fn threshold(&self) -> F {
93 self.threshold
94 }
95}
96
97impl<F: Float + Send + Sync + 'static> Default for VarianceThreshold<F> {
98 fn default() -> Self {
99 Self::new(F::zero())
100 }
101}
102
103#[derive(Debug, Clone)]
112pub struct FittedVarianceThreshold<F> {
113 selected_indices: Vec<usize>,
115 variances: Array1<F>,
117}
118
119impl<F: Float + Send + Sync + 'static> FittedVarianceThreshold<F> {
120 #[must_use]
122 pub fn selected_indices(&self) -> &[usize] {
123 &self.selected_indices
124 }
125
126 #[must_use]
128 pub fn variances(&self) -> &Array1<F> {
129 &self.variances
130 }
131}
132
133impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for VarianceThreshold<F> {
138 type Fitted = FittedVarianceThreshold<F>;
139 type Error = FerroError;
140
141 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedVarianceThreshold<F>, FerroError> {
148 if self.threshold < F::zero() {
149 return Err(FerroError::InvalidParameter {
150 name: "threshold".into(),
151 reason: "variance threshold must be non-negative".into(),
152 });
153 }
154 let n_samples = x.nrows();
155 if n_samples == 0 {
156 return Err(FerroError::InsufficientSamples {
157 required: 1,
158 actual: 0,
159 context: "VarianceThreshold::fit".into(),
160 });
161 }
162
163 let n = F::from(n_samples).unwrap_or_else(F::one);
164 let n_features = x.ncols();
165 let mut variances = Array1::zeros(n_features);
166 let mut selected_indices = Vec::new();
167
168 for j in 0..n_features {
169 let col = x.column(j);
170 let mut mean = F::zero();
177 let mut m2 = F::zero();
178 let mut count = F::zero();
179 for &v in col.iter() {
180 count = count + F::one();
181 let delta = v - mean;
182 mean = mean + delta / count;
183 let delta2 = v - mean;
184 m2 = m2 + delta * delta2;
185 }
186 let var = m2 / n;
187 variances[j] = var;
188 if var > self.threshold {
189 selected_indices.push(j);
190 }
191 }
192
193 Ok(FittedVarianceThreshold {
194 selected_indices,
195 variances,
196 })
197 }
198}
199
200impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedVarianceThreshold<F> {
201 type Output = Array2<F>;
202 type Error = FerroError;
203
204 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
211 let n_original = self.variances.len();
212 if x.ncols() != n_original {
213 return Err(FerroError::ShapeMismatch {
214 expected: vec![x.nrows(), n_original],
215 actual: vec![x.nrows(), x.ncols()],
216 context: "FittedVarianceThreshold::transform".into(),
217 });
218 }
219 Ok(select_columns(x, &self.selected_indices))
220 }
221}
222
223impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for VarianceThreshold<F> {
228 fn fit_pipeline(
234 &self,
235 x: &Array2<F>,
236 _y: &Array1<F>,
237 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
238 let fitted = self.fit(x, &())?;
239 Ok(Box::new(fitted))
240 }
241}
242
243impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedVarianceThreshold<F> {
244 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
250 self.transform(x)
251 }
252}
253
254#[derive(Debug, Clone, Copy, PartialEq, Eq)]
262pub enum ScoreFunc {
263 FClassif,
267}
268
269#[derive(Debug, Clone)]
289pub struct SelectKBest<F> {
290 k: usize,
292 score_func: ScoreFunc,
294 _marker: std::marker::PhantomData<F>,
295}
296
297impl<F: Float + Send + Sync + 'static> SelectKBest<F> {
298 #[must_use]
306 pub fn new(k: usize, score_func: ScoreFunc) -> Self {
307 Self {
308 k,
309 score_func,
310 _marker: std::marker::PhantomData,
311 }
312 }
313
314 #[must_use]
316 pub fn k(&self) -> usize {
317 self.k
318 }
319
320 #[must_use]
322 pub fn score_func(&self) -> ScoreFunc {
323 self.score_func
324 }
325}
326
327#[derive(Debug, Clone)]
335pub struct FittedSelectKBest<F> {
336 n_features_in: usize,
338 scores: Array1<F>,
340 selected_indices: Vec<usize>,
342}
343
344impl<F: Float + Send + Sync + 'static> FittedSelectKBest<F> {
345 #[must_use]
347 pub fn scores(&self) -> &Array1<F> {
348 &self.scores
349 }
350
351 #[must_use]
353 pub fn selected_indices(&self) -> &[usize] {
354 &self.selected_indices
355 }
356}
357
358fn anova_f_scores<F: Float>(x: &Array2<F>, y: &Array1<usize>) -> Vec<F> {
376 let n_samples = x.nrows();
377 let n_features = x.ncols();
378
379 let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
381 std::collections::HashMap::new();
382 for (i, &label) in y.iter().enumerate() {
383 class_indices.entry(label).or_default().push(i);
384 }
385 let n_classes = class_indices.len();
386
387 let mut scores = Vec::with_capacity(n_features);
388
389 for j in 0..n_features {
390 let col = x.column(j);
391
392 let grand_mean =
394 col.iter().copied().fold(F::zero(), |acc, v| acc + v) / F::from(n_samples).unwrap();
395
396 let mut ss_between = F::zero();
398 let mut ss_within = F::zero();
400
401 for rows in class_indices.values() {
402 let n_k = F::from(rows.len()).unwrap();
403 let class_mean = rows
404 .iter()
405 .map(|&i| col[i])
406 .fold(F::zero(), |acc, v| acc + v)
407 / n_k;
408 let diff = class_mean - grand_mean;
409 ss_between = ss_between + n_k * diff * diff;
410 for &i in rows {
411 let d = col[i] - class_mean;
412 ss_within = ss_within + d * d;
413 }
414 }
415
416 let df_between = F::from(n_classes.saturating_sub(1)).unwrap();
417 let df_within = F::from(n_samples.saturating_sub(n_classes)).unwrap();
418
419 let f = if df_between == F::zero() || df_within == F::zero() {
420 F::zero()
421 } else {
422 let ms_between = ss_between / df_between;
423 let ms_within = ss_within / df_within;
424 if ms_within == F::zero() {
425 F::infinity()
426 } else {
427 ms_between / ms_within
428 }
429 };
430
431 scores.push(f);
432 }
433
434 scores
435}
436
437impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for SelectKBest<F> {
442 type Fitted = FittedSelectKBest<F>;
443 type Error = FerroError;
444
445 fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedSelectKBest<F>, FerroError> {
454 let n_samples = x.nrows();
455 if n_samples == 0 {
456 return Err(FerroError::InsufficientSamples {
457 required: 1,
458 actual: 0,
459 context: "SelectKBest::fit".into(),
460 });
461 }
462 if y.len() != n_samples {
463 return Err(FerroError::ShapeMismatch {
464 expected: vec![n_samples],
465 actual: vec![y.len()],
466 context: "SelectKBest::fit — y must have the same length as x has rows".into(),
467 });
468 }
469 let n_features = x.ncols();
470 if self.k > n_features {
471 return Err(FerroError::InvalidParameter {
472 name: "k".into(),
473 reason: format!(
474 "k ({}) cannot exceed the number of features ({})",
475 self.k, n_features
476 ),
477 });
478 }
479
480 let raw_scores = match self.score_func {
481 ScoreFunc::FClassif => anova_f_scores(x, y),
482 };
483
484 let scores = Array1::from_vec(raw_scores.clone());
485
486 let mut ranked: Vec<usize> = (0..n_features).collect();
489 ranked.sort_by(|&a, &b| {
490 raw_scores[b]
491 .partial_cmp(&raw_scores[a])
492 .unwrap_or(std::cmp::Ordering::Equal)
493 .then(a.cmp(&b))
495 });
496
497 let mut selected_indices: Vec<usize> = ranked[..self.k].to_vec();
498 selected_indices.sort_unstable();
500
501 Ok(FittedSelectKBest {
502 n_features_in: n_features,
503 scores,
504 selected_indices,
505 })
506 }
507}
508
509impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectKBest<F> {
510 type Output = Array2<F>;
511 type Error = FerroError;
512
513 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
520 if x.ncols() != self.n_features_in {
521 return Err(FerroError::ShapeMismatch {
522 expected: vec![x.nrows(), self.n_features_in],
523 actual: vec![x.nrows(), x.ncols()],
524 context: "FittedSelectKBest::transform".into(),
525 });
526 }
527 Ok(select_columns(x, &self.selected_indices))
528 }
529}
530
531impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for SelectKBest<F> {
540 fn fit_pipeline(
548 &self,
549 x: &Array2<F>,
550 y: &Array1<F>,
551 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
552 let y_usize: Array1<usize> = y.mapv(|v| v.round().to_usize().unwrap_or(0));
553 let fitted = self.fit(x, &y_usize)?;
554 Ok(Box::new(fitted))
555 }
556}
557
558impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedSelectKBest<F> {
559 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
565 self.transform(x)
566 }
567}
568
569#[derive(Debug, Clone)]
596pub struct SelectFromModel<F> {
597 importances: Array1<F>,
599 threshold: F,
601 selected_indices: Vec<usize>,
603}
604
605impl<F: Float + Send + Sync + 'static> SelectFromModel<F> {
606 pub fn new_from_importances(
618 importances: &Array1<F>,
619 threshold: Option<F>,
620 ) -> Result<Self, FerroError> {
621 let n = importances.len();
622 if n == 0 {
623 return Err(FerroError::InvalidParameter {
624 name: "importances".into(),
625 reason: "importance vector must not be empty".into(),
626 });
627 }
628
629 let thr = threshold.unwrap_or_else(|| {
630 importances
631 .iter()
632 .copied()
633 .fold(F::zero(), |acc, v| acc + v)
634 / F::from(n).unwrap_or_else(F::one)
635 });
636
637 let selected_indices: Vec<usize> = importances
638 .iter()
639 .enumerate()
640 .filter(|&(_, &imp)| imp >= thr)
641 .map(|(j, _)| j)
642 .collect();
643
644 Ok(Self {
645 importances: importances.clone(),
646 threshold: thr,
647 selected_indices,
648 })
649 }
650
651 #[must_use]
653 pub fn threshold(&self) -> F {
654 self.threshold
655 }
656
657 #[must_use]
659 pub fn importances(&self) -> &Array1<F> {
660 &self.importances
661 }
662
663 #[must_use]
665 pub fn selected_indices(&self) -> &[usize] {
666 &self.selected_indices
667 }
668}
669
670impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SelectFromModel<F> {
675 type Output = Array2<F>;
676 type Error = FerroError;
677
678 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
686 let n_features = self.importances.len();
687 if x.ncols() != n_features {
688 return Err(FerroError::ShapeMismatch {
689 expected: vec![x.nrows(), n_features],
690 actual: vec![x.nrows(), x.ncols()],
691 context: "SelectFromModel::transform".into(),
692 });
693 }
694 Ok(select_columns(x, &self.selected_indices))
695 }
696}
697
698impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for SelectFromModel<F> {
706 fn fit_pipeline(
712 &self,
713 _x: &Array2<F>,
714 _y: &Array1<F>,
715 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
716 Ok(Box::new(self.clone()))
717 }
718}
719
720impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for SelectFromModel<F> {
721 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
727 self.transform(x)
728 }
729}
730
731#[cfg(test)]
736mod tests {
737 use super::*;
738 use approx::assert_abs_diff_eq;
739 use ndarray::array;
740
741 #[test]
746 fn test_variance_threshold_removes_constant_column() {
747 let sel = VarianceThreshold::<f64>::new(0.0);
748 let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
750 let fitted = sel.fit(&x, &()).unwrap();
751 assert_eq!(fitted.selected_indices(), &[0usize]);
752 let out = fitted.transform(&x).unwrap();
753 assert_eq!(out.ncols(), 1);
754 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-15);
756 assert_abs_diff_eq!(out[[1, 0]], 2.0, epsilon = 1e-15);
757 }
758
759 #[test]
760 fn test_variance_threshold_keeps_all_when_above() {
761 let sel = VarianceThreshold::<f64>::new(0.0);
762 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
763 let fitted = sel.fit(&x, &()).unwrap();
764 assert_eq!(fitted.selected_indices().len(), 2);
765 let out = fitted.transform(&x).unwrap();
766 assert_eq!(out.ncols(), 2);
767 }
768
769 #[test]
770 fn test_variance_threshold_custom_threshold() {
771 let sel = VarianceThreshold::<f64>::new(1.5);
772 let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
775 let fitted = sel.fit(&x, &()).unwrap();
776 assert_eq!(fitted.selected_indices(), &[1usize]);
777 let out = fitted.transform(&x).unwrap();
778 assert_eq!(out.ncols(), 1);
779 }
780
781 #[test]
782 fn test_variance_threshold_stores_variances() {
783 let sel = VarianceThreshold::<f64>::default();
784 let x = array![[0.0], [0.0], [0.0]]; let fitted = sel.fit(&x, &()).unwrap();
786 assert_abs_diff_eq!(fitted.variances()[0], 0.0, epsilon = 1e-15);
787 }
788
789 #[test]
790 fn test_variance_threshold_zero_rows_error() {
791 let sel = VarianceThreshold::<f64>::new(0.0);
792 let x: Array2<f64> = Array2::zeros((0, 2));
793 assert!(sel.fit(&x, &()).is_err());
794 }
795
796 #[test]
797 fn test_variance_threshold_negative_threshold_error() {
798 let sel = VarianceThreshold::<f64>::new(-0.1);
799 let x = array![[1.0], [2.0]];
800 assert!(sel.fit(&x, &()).is_err());
801 }
802
803 #[test]
804 fn test_variance_threshold_shape_mismatch_on_transform() {
805 let sel = VarianceThreshold::<f64>::new(0.0);
806 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
807 let fitted = sel.fit(&x_train, &()).unwrap();
808 let x_bad = array![[1.0, 2.0, 3.0]];
809 assert!(fitted.transform(&x_bad).is_err());
810 }
811
812 #[test]
813 fn test_variance_threshold_all_constant_columns() {
814 let sel = VarianceThreshold::<f64>::new(0.0);
815 let x = array![[5.0, 3.0], [5.0, 3.0], [5.0, 3.0]];
816 let fitted = sel.fit(&x, &()).unwrap();
817 assert_eq!(fitted.selected_indices().len(), 0);
819 let out = fitted.transform(&x).unwrap();
820 assert_eq!(out.ncols(), 0);
821 assert_eq!(out.nrows(), 3);
822 }
823
824 #[test]
825 fn test_variance_threshold_pipeline_integration() {
826 use ferrolearn_core::pipeline::PipelineTransformer;
827 let sel = VarianceThreshold::<f64>::new(0.0);
828 let x = array![[1.0, 7.0], [2.0, 7.0], [3.0, 7.0]];
829 let y = ndarray::array![0.0, 1.0, 0.0];
830 let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
831 let out = fitted_box.transform_pipeline(&x).unwrap();
832 assert_eq!(out.ncols(), 1);
833 }
834
835 #[test]
836 fn test_variance_threshold_f32() {
837 let sel = VarianceThreshold::<f32>::new(0.0f32);
838 let x: Array2<f32> = array![[1.0f32, 5.0], [2.0, 5.0], [3.0, 5.0]];
839 let fitted = sel.fit(&x, &()).unwrap();
840 assert_eq!(fitted.selected_indices(), &[0usize]);
841 }
842
843 #[test]
848 fn test_select_k_best_selects_highest_scoring_feature() {
849 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
851 let x = array![[1.0, 5.0], [1.0, 6.0], [10.0, 5.0], [10.0, 6.0]];
852 let y: Array1<usize> = array![0, 0, 1, 1];
853 let fitted = sel.fit(&x, &y).unwrap();
854 assert_eq!(fitted.selected_indices(), &[0usize]);
856 let out = fitted.transform(&x).unwrap();
857 assert_eq!(out.ncols(), 1);
858 }
859
860 #[test]
861 fn test_select_k_best_k_equals_n_features_keeps_all() {
862 let sel = SelectKBest::<f64>::new(2, ScoreFunc::FClassif);
863 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
864 let y: Array1<usize> = array![0, 1, 0];
865 let fitted = sel.fit(&x, &y).unwrap();
866 assert_eq!(fitted.selected_indices().len(), 2);
867 let out = fitted.transform(&x).unwrap();
868 assert_eq!(out.ncols(), 2);
869 }
870
871 #[test]
872 fn test_select_k_best_scores_stored() {
873 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
874 let x = array![[1.0, 2.0], [1.0, 4.0]];
875 let y: Array1<usize> = array![0, 1];
876 let fitted = sel.fit(&x, &y).unwrap();
877 assert_eq!(fitted.scores().len(), 2);
878 }
879
880 #[test]
881 fn test_select_k_best_zero_rows_error() {
882 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
883 let x: Array2<f64> = Array2::zeros((0, 3));
884 let y: Array1<usize> = Array1::zeros(0);
885 assert!(sel.fit(&x, &y).is_err());
886 }
887
888 #[test]
889 fn test_select_k_best_k_exceeds_n_features_error() {
890 let sel = SelectKBest::<f64>::new(5, ScoreFunc::FClassif);
891 let x = array![[1.0, 2.0], [3.0, 4.0]];
892 let y: Array1<usize> = array![0, 1];
893 assert!(sel.fit(&x, &y).is_err());
894 }
895
896 #[test]
897 fn test_select_k_best_y_length_mismatch_error() {
898 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
899 let x = array![[1.0, 2.0], [3.0, 4.0]];
900 let y: Array1<usize> = array![0]; assert!(sel.fit(&x, &y).is_err());
902 }
903
904 #[test]
905 fn test_select_k_best_shape_mismatch_on_transform() {
906 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
907 let x = array![[1.0, 2.0], [3.0, 4.0]];
908 let y: Array1<usize> = array![0, 1];
909 let fitted = sel.fit(&x, &y).unwrap();
910 let x_bad = array![[1.0, 2.0, 3.0]];
911 assert!(fitted.transform(&x_bad).is_err());
912 }
913
914 #[test]
915 fn test_select_k_best_selected_indices_in_column_order() {
916 let sel = SelectKBest::<f64>::new(2, ScoreFunc::FClassif);
918 let x = array![[1.0, 100.0], [2.0, 200.0]];
919 let y: Array1<usize> = array![0, 1];
920 let fitted = sel.fit(&x, &y).unwrap();
921 let indices = fitted.selected_indices();
922 assert!(indices.windows(2).all(|w| w[0] < w[1]));
923 }
924
925 #[test]
926 fn test_select_k_best_pipeline_integration() {
927 use ferrolearn_core::pipeline::PipelineTransformer;
928 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
929 let x = array![[1.0, 5.0], [1.0, 6.0], [10.0, 5.0], [10.0, 6.0]];
930 let y = ndarray::array![0.0, 0.0, 1.0, 1.0];
931 let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
932 let out = fitted_box.transform_pipeline(&x).unwrap();
933 assert_eq!(out.ncols(), 1);
934 }
935
936 #[test]
937 fn test_select_k_best_f_score_zero_within_class_variance() {
938 let sel = SelectKBest::<f64>::new(1, ScoreFunc::FClassif);
940 let x = array![[0.0], [0.0], [10.0], [10.0]];
941 let y: Array1<usize> = array![0, 0, 1, 1];
942 let fitted = sel.fit(&x, &y).unwrap();
943 assert!(fitted.scores()[0].is_infinite());
944 }
945
946 #[test]
951 fn test_select_from_model_mean_threshold() {
952 let importances = array![0.1, 0.5, 0.4];
955 let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
956 assert_eq!(sel.selected_indices(), &[1usize, 2]);
957 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
958 let out = sel.transform(&x).unwrap();
959 assert_eq!(out.ncols(), 2);
960 assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
961 assert_abs_diff_eq!(out[[0, 1]], 3.0, epsilon = 1e-15);
962 }
963
964 #[test]
965 fn test_select_from_model_explicit_threshold() {
966 let importances = array![0.1, 0.5, 0.4];
967 let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.45)).unwrap();
969 assert_eq!(sel.selected_indices(), &[1usize]);
970 let x = array![[1.0, 2.0, 3.0]];
971 let out = sel.transform(&x).unwrap();
972 assert_eq!(out.ncols(), 1);
973 assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
974 }
975
976 #[test]
977 fn test_select_from_model_all_selected_when_threshold_zero() {
978 let importances = array![0.1, 0.2, 0.3];
979 let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.0)).unwrap();
980 assert_eq!(sel.selected_indices().len(), 3);
981 }
982
983 #[test]
984 fn test_select_from_model_none_selected_when_threshold_high() {
985 let importances = array![0.1, 0.2, 0.3];
986 let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(1.0)).unwrap();
987 assert_eq!(sel.selected_indices().len(), 0);
988 let x = array![[1.0, 2.0, 3.0]];
989 let out = sel.transform(&x).unwrap();
990 assert_eq!(out.ncols(), 0);
991 }
992
993 #[test]
994 fn test_select_from_model_empty_importances_error() {
995 let importances: Array1<f64> = Array1::zeros(0);
996 assert!(SelectFromModel::<f64>::new_from_importances(&importances, None).is_err());
997 }
998
999 #[test]
1000 fn test_select_from_model_shape_mismatch_on_transform() {
1001 let importances = array![0.3, 0.7];
1002 let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1003 let x_bad = array![[1.0, 2.0, 3.0]]; assert!(sel.transform(&x_bad).is_err());
1005 }
1006
1007 #[test]
1008 fn test_select_from_model_threshold_accessor() {
1009 let importances = array![0.3, 0.7];
1010 let sel = SelectFromModel::<f64>::new_from_importances(&importances, Some(0.5)).unwrap();
1011 assert_abs_diff_eq!(sel.threshold(), 0.5, epsilon = 1e-15);
1012 }
1013
1014 #[test]
1015 fn test_select_from_model_pipeline_integration() {
1016 use ferrolearn_core::pipeline::PipelineTransformer;
1017 let importances = array![0.1, 0.9];
1018 let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1019 let x = array![[1.0, 2.0], [3.0, 4.0]];
1020 let y = ndarray::array![0.0, 1.0];
1021 let fitted_box = sel.fit_pipeline(&x, &y).unwrap();
1022 let out = fitted_box.transform_pipeline(&x).unwrap();
1023 assert_eq!(out.ncols(), 1);
1025 assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
1026 }
1027
1028 #[test]
1029 fn test_select_from_model_importances_accessor() {
1030 let importances = array![0.2, 0.8];
1031 let sel = SelectFromModel::<f64>::new_from_importances(&importances, None).unwrap();
1032 assert_abs_diff_eq!(sel.importances()[0], 0.2, epsilon = 1e-15);
1033 assert_abs_diff_eq!(sel.importances()[1], 0.8, epsilon = 1e-15);
1034 }
1035}