1use ferrolearn_core::error::FerroError;
41use ferrolearn_core::traits::{Fit, Transform};
42use ndarray::{Array1, Array2};
43use num_traits::Float;
44
45fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
51 let nrows = x.nrows();
52 let ncols = indices.len();
53 if ncols == 0 {
54 return Array2::zeros((nrows, 0));
55 }
56 let mut out = Array2::zeros((nrows, ncols));
57 for (new_j, &old_j) in indices.iter().enumerate() {
58 for i in 0..nrows {
59 out[[i, new_j]] = x[[i, old_j]];
60 }
61 }
62 out
63}
64
65fn validate_inputs(n_features: usize, alpha: f64) -> Result<(), FerroError> {
67 if n_features == 0 {
68 return Err(FerroError::InvalidParameter {
69 name: "p_values".into(),
70 reason: "p-value vector must not be empty".into(),
71 });
72 }
73 if !(0.0..=1.0).contains(&alpha) {
74 return Err(FerroError::InvalidParameter {
75 name: "alpha".into(),
76 reason: format!("alpha must be in [0, 1], got {alpha}"),
77 });
78 }
79 Ok(())
80}
81
82#[must_use]
106#[derive(Debug, Clone)]
107pub struct SelectFpr<F> {
108 alpha: f64,
110 _marker: std::marker::PhantomData<F>,
111}
112
113impl<F: Float + Send + Sync + 'static> SelectFpr<F> {
114 pub fn new(alpha: f64) -> Self {
116 Self {
117 alpha,
118 _marker: std::marker::PhantomData,
119 }
120 }
121
122 #[must_use]
124 pub fn alpha(&self) -> f64 {
125 self.alpha
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct FittedSelectFpr<F> {
132 n_features_in: usize,
134 p_values: Array1<F>,
136 selected_indices: Vec<usize>,
138}
139
140impl<F: Float + Send + Sync + 'static> FittedSelectFpr<F> {
141 #[must_use]
143 pub fn p_values(&self) -> &Array1<F> {
144 &self.p_values
145 }
146
147 #[must_use]
149 pub fn selected_indices(&self) -> &[usize] {
150 &self.selected_indices
151 }
152
153 #[must_use]
155 pub fn n_features_selected(&self) -> usize {
156 self.selected_indices.len()
157 }
158}
159
160impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFpr<F> {
161 type Fitted = FittedSelectFpr<F>;
162 type Error = FerroError;
163
164 fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFpr<F>, FerroError> {
171 let n = x.len();
172 validate_inputs(n, self.alpha)?;
173
174 let alpha_f = F::from(self.alpha).unwrap_or_else(F::zero);
175 let selected_indices: Vec<usize> = x
176 .iter()
177 .enumerate()
178 .filter(|&(_, &p)| p < alpha_f)
179 .map(|(j, _)| j)
180 .collect();
181
182 Ok(FittedSelectFpr {
183 n_features_in: n,
184 p_values: x.clone(),
185 selected_indices,
186 })
187 }
188}
189
190impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFpr<F> {
191 type Output = Array2<F>;
192 type Error = FerroError;
193
194 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
200 if x.ncols() != self.n_features_in {
201 return Err(FerroError::ShapeMismatch {
202 expected: vec![x.nrows(), self.n_features_in],
203 actual: vec![x.nrows(), x.ncols()],
204 context: "FittedSelectFpr::transform".into(),
205 });
206 }
207 Ok(select_columns(x, &self.selected_indices))
208 }
209}
210
211#[must_use]
235#[derive(Debug, Clone)]
236pub struct SelectFdr<F> {
237 alpha: f64,
239 _marker: std::marker::PhantomData<F>,
240}
241
242impl<F: Float + Send + Sync + 'static> SelectFdr<F> {
243 pub fn new(alpha: f64) -> Self {
245 Self {
246 alpha,
247 _marker: std::marker::PhantomData,
248 }
249 }
250
251 #[must_use]
253 pub fn alpha(&self) -> f64 {
254 self.alpha
255 }
256}
257
258#[derive(Debug, Clone)]
260pub struct FittedSelectFdr<F> {
261 n_features_in: usize,
263 p_values: Array1<F>,
265 selected_indices: Vec<usize>,
267}
268
269impl<F: Float + Send + Sync + 'static> FittedSelectFdr<F> {
270 #[must_use]
272 pub fn p_values(&self) -> &Array1<F> {
273 &self.p_values
274 }
275
276 #[must_use]
278 pub fn selected_indices(&self) -> &[usize] {
279 &self.selected_indices
280 }
281
282 #[must_use]
284 pub fn n_features_selected(&self) -> usize {
285 self.selected_indices.len()
286 }
287}
288
289impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFdr<F> {
290 type Fitted = FittedSelectFdr<F>;
291 type Error = FerroError;
292
293 fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFdr<F>, FerroError> {
300 let n = x.len();
301 validate_inputs(n, self.alpha)?;
302
303 let alpha_f = F::from(self.alpha).unwrap_or_else(F::zero);
304 let n_f = F::from(n).unwrap_or_else(F::one);
305
306 let mut ranked: Vec<(usize, F)> = x.iter().copied().enumerate().collect();
308 ranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
309
310 let mut max_qualifying_rank: Option<usize> = None;
312 for (rank, &(_, p_val)) in ranked.iter().enumerate() {
313 let bh_threshold = alpha_f * F::from(rank + 1).unwrap_or_else(F::one) / n_f;
314 if p_val <= bh_threshold {
315 max_qualifying_rank = Some(rank);
316 }
317 }
318
319 let mut selected_indices: Vec<usize> = match max_qualifying_rank {
321 Some(max_rank) => ranked[..=max_rank].iter().map(|&(idx, _)| idx).collect(),
322 None => Vec::new(),
323 };
324 selected_indices.sort_unstable();
325
326 Ok(FittedSelectFdr {
327 n_features_in: n,
328 p_values: x.clone(),
329 selected_indices,
330 })
331 }
332}
333
334impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFdr<F> {
335 type Output = Array2<F>;
336 type Error = FerroError;
337
338 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
344 if x.ncols() != self.n_features_in {
345 return Err(FerroError::ShapeMismatch {
346 expected: vec![x.nrows(), self.n_features_in],
347 actual: vec![x.nrows(), x.ncols()],
348 context: "FittedSelectFdr::transform".into(),
349 });
350 }
351 Ok(select_columns(x, &self.selected_indices))
352 }
353}
354
355#[must_use]
379#[derive(Debug, Clone)]
380pub struct SelectFwe<F> {
381 alpha: f64,
383 _marker: std::marker::PhantomData<F>,
384}
385
386impl<F: Float + Send + Sync + 'static> SelectFwe<F> {
387 pub fn new(alpha: f64) -> Self {
389 Self {
390 alpha,
391 _marker: std::marker::PhantomData,
392 }
393 }
394
395 #[must_use]
397 pub fn alpha(&self) -> f64 {
398 self.alpha
399 }
400}
401
402#[derive(Debug, Clone)]
404pub struct FittedSelectFwe<F> {
405 n_features_in: usize,
407 p_values: Array1<F>,
409 selected_indices: Vec<usize>,
411}
412
413impl<F: Float + Send + Sync + 'static> FittedSelectFwe<F> {
414 #[must_use]
416 pub fn p_values(&self) -> &Array1<F> {
417 &self.p_values
418 }
419
420 #[must_use]
422 pub fn selected_indices(&self) -> &[usize] {
423 &self.selected_indices
424 }
425
426 #[must_use]
428 pub fn n_features_selected(&self) -> usize {
429 self.selected_indices.len()
430 }
431}
432
433impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFwe<F> {
434 type Fitted = FittedSelectFwe<F>;
435 type Error = FerroError;
436
437 fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFwe<F>, FerroError> {
444 let n = x.len();
445 validate_inputs(n, self.alpha)?;
446
447 let adjusted_alpha = self.alpha / n as f64;
448 let adjusted_alpha_f = F::from(adjusted_alpha).unwrap_or_else(F::zero);
449
450 let selected_indices: Vec<usize> = x
451 .iter()
452 .enumerate()
453 .filter(|&(_, &p)| p < adjusted_alpha_f)
454 .map(|(j, _)| j)
455 .collect();
456
457 Ok(FittedSelectFwe {
458 n_features_in: n,
459 p_values: x.clone(),
460 selected_indices,
461 })
462 }
463}
464
465impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFwe<F> {
466 type Output = Array2<F>;
467 type Error = FerroError;
468
469 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
475 if x.ncols() != self.n_features_in {
476 return Err(FerroError::ShapeMismatch {
477 expected: vec![x.nrows(), self.n_features_in],
478 actual: vec![x.nrows(), x.ncols()],
479 context: "FittedSelectFwe::transform".into(),
480 });
481 }
482 Ok(select_columns(x, &self.selected_indices))
483 }
484}
485
486#[cfg(test)]
491mod tests {
492 use super::*;
493 use ndarray::array;
494
495 #[test]
500 fn test_fpr_selects_below_alpha() {
501 let sel = SelectFpr::<f64>::new(0.05);
502 let p = array![0.01, 0.5, 0.03, 0.9];
503 let fitted = sel.fit(&p, &()).unwrap();
504 assert_eq!(fitted.selected_indices(), &[0, 2]);
505 }
506
507 #[test]
508 fn test_fpr_none_below_alpha() {
509 let sel = SelectFpr::<f64>::new(0.001);
510 let p = array![0.01, 0.5, 0.03];
511 let fitted = sel.fit(&p, &()).unwrap();
512 assert_eq!(fitted.n_features_selected(), 0);
513 }
514
515 #[test]
516 fn test_fpr_all_below_alpha() {
517 let sel = SelectFpr::<f64>::new(0.99);
518 let p = array![0.01, 0.5, 0.03];
519 let fitted = sel.fit(&p, &()).unwrap();
520 assert_eq!(fitted.n_features_selected(), 3);
521 }
522
523 #[test]
524 fn test_fpr_transform() {
525 let sel = SelectFpr::<f64>::new(0.05);
526 let p = array![0.01, 0.5, 0.03];
527 let fitted = sel.fit(&p, &()).unwrap();
528 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
529 let out = fitted.transform(&x).unwrap();
530 assert_eq!(out.ncols(), 2); assert_eq!(out[[0, 0]], 1.0);
532 assert_eq!(out[[0, 1]], 3.0);
533 }
534
535 #[test]
536 fn test_fpr_empty_error() {
537 let sel = SelectFpr::<f64>::new(0.05);
538 let p: Array1<f64> = Array1::zeros(0);
539 assert!(sel.fit(&p, &()).is_err());
540 }
541
542 #[test]
543 fn test_fpr_invalid_alpha() {
544 let p = array![0.01];
548
549 let neg = SelectFpr::<f64>::new(-0.1);
550 assert!(neg.fit(&p, &()).is_err());
551
552 let sel2 = SelectFpr::<f64>::new(1.5);
553 assert!(sel2.fit(&p, &()).is_err());
554 }
555
556 #[test]
557 fn test_fpr_alpha_zero_valid() {
558 let sel = SelectFpr::<f64>::new(0.0);
562 let p = array![0.01, 0.5, 0.03];
563 let fitted = sel.fit(&p, &());
564 assert!(fitted.is_ok(), "alpha=0 is valid (closed=both)");
565 if let Ok(f) = fitted {
566 assert_eq!(f.n_features_selected(), 0);
567 }
568 }
569
570 #[test]
571 fn test_fpr_shape_mismatch() {
572 let sel = SelectFpr::<f64>::new(0.05);
573 let p = array![0.01, 0.5];
574 let fitted = sel.fit(&p, &()).unwrap();
575 let x_bad = array![[1.0, 2.0, 3.0]];
576 assert!(fitted.transform(&x_bad).is_err());
577 }
578
579 #[test]
580 fn test_fpr_accessor() {
581 let sel = SelectFpr::<f64>::new(0.05);
582 assert_eq!(sel.alpha(), 0.05);
583 }
584
585 #[test]
586 fn test_fpr_p_values_accessor() {
587 let sel = SelectFpr::<f64>::new(0.05);
588 let p = array![0.01, 0.5];
589 let fitted = sel.fit(&p, &()).unwrap();
590 assert_eq!(fitted.p_values().len(), 2);
591 }
592
593 #[test]
598 fn test_fdr_basic() {
599 let sel = SelectFdr::<f64>::new(0.05);
600 let p = array![0.01, 0.5, 0.03, 0.9];
605 let fitted = sel.fit(&p, &()).unwrap();
606 assert!(fitted.selected_indices().contains(&0));
607 }
608
609 #[test]
610 fn test_fdr_multiple_pass() {
611 let sel = SelectFdr::<f64>::new(0.10);
612 let p = array![0.02, 0.5, 0.005, 0.04];
618 let fitted = sel.fit(&p, &()).unwrap();
619 assert_eq!(fitted.n_features_selected(), 3);
620 assert!(fitted.selected_indices().contains(&0)); assert!(fitted.selected_indices().contains(&2)); assert!(fitted.selected_indices().contains(&3)); }
624
625 #[test]
626 fn test_fdr_none_selected() {
627 let sel = SelectFdr::<f64>::new(0.001);
628 let p = array![0.01, 0.5, 0.03];
629 let fitted = sel.fit(&p, &()).unwrap();
630 assert_eq!(fitted.n_features_selected(), 0);
631 }
632
633 #[test]
634 fn test_fdr_transform() {
635 let sel = SelectFdr::<f64>::new(0.10);
636 let p = array![0.001, 0.5, 0.9];
637 let fitted = sel.fit(&p, &()).unwrap();
638 let x = array![[1.0, 2.0, 3.0]];
639 let out = fitted.transform(&x).unwrap();
640 assert!(out.ncols() >= 1);
642 }
643
644 #[test]
645 fn test_fdr_empty_error() {
646 let sel = SelectFdr::<f64>::new(0.05);
647 let p: Array1<f64> = Array1::zeros(0);
648 assert!(sel.fit(&p, &()).is_err());
649 }
650
651 #[test]
652 fn test_fdr_invalid_alpha() {
653 let p = array![0.01];
656
657 let neg = SelectFdr::<f64>::new(-0.1);
658 assert!(neg.fit(&p, &()).is_err());
659
660 let big = SelectFdr::<f64>::new(1.5);
661 assert!(big.fit(&p, &()).is_err());
662 }
663
664 #[test]
665 fn test_fdr_alpha_zero_valid() {
666 let sel = SelectFdr::<f64>::new(0.0);
669 let p = array![0.01, 0.5, 0.03];
670 let fitted = sel.fit(&p, &());
671 assert!(fitted.is_ok(), "alpha=0 is valid (closed=both)");
672 if let Ok(f) = fitted {
673 assert_eq!(f.n_features_selected(), 0);
674 }
675 }
676
677 #[test]
678 fn test_fdr_shape_mismatch() {
679 let sel = SelectFdr::<f64>::new(0.05);
680 let p = array![0.01, 0.5];
681 let fitted = sel.fit(&p, &()).unwrap();
682 let x_bad = array![[1.0, 2.0, 3.0]];
683 assert!(fitted.transform(&x_bad).is_err());
684 }
685
686 #[test]
687 fn test_fdr_accessor() {
688 let sel = SelectFdr::<f64>::new(0.05);
689 assert_eq!(sel.alpha(), 0.05);
690 }
691
692 #[test]
697 fn test_fwe_basic() {
698 let sel = SelectFwe::<f64>::new(0.05);
699 let p = array![0.001, 0.5, 0.03, 0.9];
701 let fitted = sel.fit(&p, &()).unwrap();
702 assert_eq!(fitted.selected_indices(), &[0]);
703 }
704
705 #[test]
706 fn test_fwe_two_features() {
707 let sel = SelectFwe::<f64>::new(0.10);
708 let p = array![0.01, 0.02, 0.5];
710 let fitted = sel.fit(&p, &()).unwrap();
711 assert_eq!(fitted.selected_indices(), &[0, 1]);
712 }
713
714 #[test]
715 fn test_fwe_none_selected() {
716 let sel = SelectFwe::<f64>::new(0.01);
717 let p = array![0.005, 0.5, 0.03];
719 let fitted = sel.fit(&p, &()).unwrap();
720 assert_eq!(fitted.n_features_selected(), 0);
721 }
722
723 #[test]
724 fn test_fwe_transform() {
725 let sel = SelectFwe::<f64>::new(0.05);
726 let p = array![0.001, 0.5, 0.9];
727 let fitted = sel.fit(&p, &()).unwrap();
728 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
729 let out = fitted.transform(&x).unwrap();
730 assert_eq!(out.ncols(), 1);
731 assert_eq!(out[[0, 0]], 1.0);
732 }
733
734 #[test]
735 fn test_fwe_empty_error() {
736 let sel = SelectFwe::<f64>::new(0.05);
737 let p: Array1<f64> = Array1::zeros(0);
738 assert!(sel.fit(&p, &()).is_err());
739 }
740
741 #[test]
742 fn test_fwe_invalid_alpha() {
743 let p = array![0.01];
746
747 let neg = SelectFwe::<f64>::new(-0.1);
748 assert!(neg.fit(&p, &()).is_err());
749
750 let big = SelectFwe::<f64>::new(1.5);
751 assert!(big.fit(&p, &()).is_err());
752 }
753
754 #[test]
755 fn test_fwe_alpha_zero_valid() {
756 let sel = SelectFwe::<f64>::new(0.0);
759 let p = array![0.01, 0.5, 0.03];
760 let fitted = sel.fit(&p, &());
761 assert!(fitted.is_ok(), "alpha=0 is valid (closed=both)");
762 if let Ok(f) = fitted {
763 assert_eq!(f.n_features_selected(), 0);
764 }
765 }
766
767 #[test]
768 fn test_fwe_shape_mismatch() {
769 let sel = SelectFwe::<f64>::new(0.05);
770 let p = array![0.01, 0.5];
771 let fitted = sel.fit(&p, &()).unwrap();
772 let x_bad = array![[1.0, 2.0, 3.0]];
773 assert!(fitted.transform(&x_bad).is_err());
774 }
775
776 #[test]
777 fn test_fwe_accessor() {
778 let sel = SelectFwe::<f64>::new(0.05);
779 assert_eq!(sel.alpha(), 0.05);
780 }
781
782 #[test]
783 fn test_fwe_single_feature() {
784 let sel = SelectFwe::<f64>::new(0.05);
785 let p = array![0.01];
787 let fitted = sel.fit(&p, &()).unwrap();
788 assert_eq!(fitted.selected_indices(), &[0]);
789 }
790
791 #[test]
792 fn test_fwe_f32() {
793 let sel = SelectFwe::<f32>::new(0.05);
794 let p: Array1<f32> = array![0.001f32, 0.5];
795 let fitted = sel.fit(&p, &()).unwrap();
796 assert_eq!(fitted.selected_indices(), &[0]);
798 }
799}