1use ferrolearn_core::error::FerroError;
17use ferrolearn_core::traits::{Fit, Transform};
18use ndarray::{Array1, Array2};
19use num_traits::Float;
20
21fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
27 let nrows = x.nrows();
28 let ncols = indices.len();
29 if ncols == 0 {
30 return Array2::zeros((nrows, 0));
31 }
32 let mut out = Array2::zeros((nrows, ncols));
33 for (new_j, &old_j) in indices.iter().enumerate() {
34 for i in 0..nrows {
35 out[[i, new_j]] = x[[i, old_j]];
36 }
37 }
38 out
39}
40
41fn validate_inputs(n_features: usize, alpha: f64) -> Result<(), FerroError> {
43 if n_features == 0 {
44 return Err(FerroError::InvalidParameter {
45 name: "p_values".into(),
46 reason: "p-value vector must not be empty".into(),
47 });
48 }
49 if alpha <= 0.0 || alpha > 1.0 {
50 return Err(FerroError::InvalidParameter {
51 name: "alpha".into(),
52 reason: format!("alpha must be in (0, 1], got {alpha}"),
53 });
54 }
55 Ok(())
56}
57
58#[must_use]
82#[derive(Debug, Clone)]
83pub struct SelectFpr<F> {
84 alpha: f64,
86 _marker: std::marker::PhantomData<F>,
87}
88
89impl<F: Float + Send + Sync + 'static> SelectFpr<F> {
90 pub fn new(alpha: f64) -> Self {
92 Self {
93 alpha,
94 _marker: std::marker::PhantomData,
95 }
96 }
97
98 #[must_use]
100 pub fn alpha(&self) -> f64 {
101 self.alpha
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct FittedSelectFpr<F> {
108 n_features_in: usize,
110 p_values: Array1<F>,
112 selected_indices: Vec<usize>,
114}
115
116impl<F: Float + Send + Sync + 'static> FittedSelectFpr<F> {
117 #[must_use]
119 pub fn p_values(&self) -> &Array1<F> {
120 &self.p_values
121 }
122
123 #[must_use]
125 pub fn selected_indices(&self) -> &[usize] {
126 &self.selected_indices
127 }
128
129 #[must_use]
131 pub fn n_features_selected(&self) -> usize {
132 self.selected_indices.len()
133 }
134}
135
136impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFpr<F> {
137 type Fitted = FittedSelectFpr<F>;
138 type Error = FerroError;
139
140 fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFpr<F>, FerroError> {
147 let n = x.len();
148 validate_inputs(n, self.alpha)?;
149
150 let alpha_f = F::from(self.alpha).unwrap_or_else(F::zero);
151 let selected_indices: Vec<usize> = x
152 .iter()
153 .enumerate()
154 .filter(|&(_, &p)| p < alpha_f)
155 .map(|(j, _)| j)
156 .collect();
157
158 Ok(FittedSelectFpr {
159 n_features_in: n,
160 p_values: x.clone(),
161 selected_indices,
162 })
163 }
164}
165
166impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFpr<F> {
167 type Output = Array2<F>;
168 type Error = FerroError;
169
170 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
176 if x.ncols() != self.n_features_in {
177 return Err(FerroError::ShapeMismatch {
178 expected: vec![x.nrows(), self.n_features_in],
179 actual: vec![x.nrows(), x.ncols()],
180 context: "FittedSelectFpr::transform".into(),
181 });
182 }
183 Ok(select_columns(x, &self.selected_indices))
184 }
185}
186
187#[must_use]
211#[derive(Debug, Clone)]
212pub struct SelectFdr<F> {
213 alpha: f64,
215 _marker: std::marker::PhantomData<F>,
216}
217
218impl<F: Float + Send + Sync + 'static> SelectFdr<F> {
219 pub fn new(alpha: f64) -> Self {
221 Self {
222 alpha,
223 _marker: std::marker::PhantomData,
224 }
225 }
226
227 #[must_use]
229 pub fn alpha(&self) -> f64 {
230 self.alpha
231 }
232}
233
234#[derive(Debug, Clone)]
236pub struct FittedSelectFdr<F> {
237 n_features_in: usize,
239 p_values: Array1<F>,
241 selected_indices: Vec<usize>,
243}
244
245impl<F: Float + Send + Sync + 'static> FittedSelectFdr<F> {
246 #[must_use]
248 pub fn p_values(&self) -> &Array1<F> {
249 &self.p_values
250 }
251
252 #[must_use]
254 pub fn selected_indices(&self) -> &[usize] {
255 &self.selected_indices
256 }
257
258 #[must_use]
260 pub fn n_features_selected(&self) -> usize {
261 self.selected_indices.len()
262 }
263}
264
265impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFdr<F> {
266 type Fitted = FittedSelectFdr<F>;
267 type Error = FerroError;
268
269 fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFdr<F>, FerroError> {
276 let n = x.len();
277 validate_inputs(n, self.alpha)?;
278
279 let alpha_f = F::from(self.alpha).unwrap_or_else(F::zero);
280 let n_f = F::from(n).unwrap_or_else(F::one);
281
282 let mut ranked: Vec<(usize, F)> = x.iter().copied().enumerate().collect();
284 ranked.sort_by(|a, b| {
285 a.1.partial_cmp(&b.1)
286 .unwrap_or(std::cmp::Ordering::Equal)
287 });
288
289 let mut max_qualifying_rank: Option<usize> = None;
291 for (rank, &(_, p_val)) in ranked.iter().enumerate() {
292 let bh_threshold = alpha_f * F::from(rank + 1).unwrap_or_else(F::one) / n_f;
293 if p_val <= bh_threshold {
294 max_qualifying_rank = Some(rank);
295 }
296 }
297
298 let mut selected_indices: Vec<usize> = match max_qualifying_rank {
300 Some(max_rank) => ranked[..=max_rank]
301 .iter()
302 .map(|&(idx, _)| idx)
303 .collect(),
304 None => Vec::new(),
305 };
306 selected_indices.sort_unstable();
307
308 Ok(FittedSelectFdr {
309 n_features_in: n,
310 p_values: x.clone(),
311 selected_indices,
312 })
313 }
314}
315
316impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFdr<F> {
317 type Output = Array2<F>;
318 type Error = FerroError;
319
320 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
326 if x.ncols() != self.n_features_in {
327 return Err(FerroError::ShapeMismatch {
328 expected: vec![x.nrows(), self.n_features_in],
329 actual: vec![x.nrows(), x.ncols()],
330 context: "FittedSelectFdr::transform".into(),
331 });
332 }
333 Ok(select_columns(x, &self.selected_indices))
334 }
335}
336
337#[must_use]
361#[derive(Debug, Clone)]
362pub struct SelectFwe<F> {
363 alpha: f64,
365 _marker: std::marker::PhantomData<F>,
366}
367
368impl<F: Float + Send + Sync + 'static> SelectFwe<F> {
369 pub fn new(alpha: f64) -> Self {
371 Self {
372 alpha,
373 _marker: std::marker::PhantomData,
374 }
375 }
376
377 #[must_use]
379 pub fn alpha(&self) -> f64 {
380 self.alpha
381 }
382}
383
384#[derive(Debug, Clone)]
386pub struct FittedSelectFwe<F> {
387 n_features_in: usize,
389 p_values: Array1<F>,
391 selected_indices: Vec<usize>,
393}
394
395impl<F: Float + Send + Sync + 'static> FittedSelectFwe<F> {
396 #[must_use]
398 pub fn p_values(&self) -> &Array1<F> {
399 &self.p_values
400 }
401
402 #[must_use]
404 pub fn selected_indices(&self) -> &[usize] {
405 &self.selected_indices
406 }
407
408 #[must_use]
410 pub fn n_features_selected(&self) -> usize {
411 self.selected_indices.len()
412 }
413}
414
415impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFwe<F> {
416 type Fitted = FittedSelectFwe<F>;
417 type Error = FerroError;
418
419 fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFwe<F>, FerroError> {
426 let n = x.len();
427 validate_inputs(n, self.alpha)?;
428
429 let adjusted_alpha = self.alpha / n as f64;
430 let adjusted_alpha_f = F::from(adjusted_alpha).unwrap_or_else(F::zero);
431
432 let selected_indices: Vec<usize> = x
433 .iter()
434 .enumerate()
435 .filter(|&(_, &p)| p < adjusted_alpha_f)
436 .map(|(j, _)| j)
437 .collect();
438
439 Ok(FittedSelectFwe {
440 n_features_in: n,
441 p_values: x.clone(),
442 selected_indices,
443 })
444 }
445}
446
447impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFwe<F> {
448 type Output = Array2<F>;
449 type Error = FerroError;
450
451 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
457 if x.ncols() != self.n_features_in {
458 return Err(FerroError::ShapeMismatch {
459 expected: vec![x.nrows(), self.n_features_in],
460 actual: vec![x.nrows(), x.ncols()],
461 context: "FittedSelectFwe::transform".into(),
462 });
463 }
464 Ok(select_columns(x, &self.selected_indices))
465 }
466}
467
468#[cfg(test)]
473mod tests {
474 use super::*;
475 use ndarray::array;
476
477 #[test]
482 fn test_fpr_selects_below_alpha() {
483 let sel = SelectFpr::<f64>::new(0.05);
484 let p = array![0.01, 0.5, 0.03, 0.9];
485 let fitted = sel.fit(&p, &()).unwrap();
486 assert_eq!(fitted.selected_indices(), &[0, 2]);
487 }
488
489 #[test]
490 fn test_fpr_none_below_alpha() {
491 let sel = SelectFpr::<f64>::new(0.001);
492 let p = array![0.01, 0.5, 0.03];
493 let fitted = sel.fit(&p, &()).unwrap();
494 assert_eq!(fitted.n_features_selected(), 0);
495 }
496
497 #[test]
498 fn test_fpr_all_below_alpha() {
499 let sel = SelectFpr::<f64>::new(0.99);
500 let p = array![0.01, 0.5, 0.03];
501 let fitted = sel.fit(&p, &()).unwrap();
502 assert_eq!(fitted.n_features_selected(), 3);
503 }
504
505 #[test]
506 fn test_fpr_transform() {
507 let sel = SelectFpr::<f64>::new(0.05);
508 let p = array![0.01, 0.5, 0.03];
509 let fitted = sel.fit(&p, &()).unwrap();
510 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
511 let out = fitted.transform(&x).unwrap();
512 assert_eq!(out.ncols(), 2); assert_eq!(out[[0, 0]], 1.0);
514 assert_eq!(out[[0, 1]], 3.0);
515 }
516
517 #[test]
518 fn test_fpr_empty_error() {
519 let sel = SelectFpr::<f64>::new(0.05);
520 let p: Array1<f64> = Array1::zeros(0);
521 assert!(sel.fit(&p, &()).is_err());
522 }
523
524 #[test]
525 fn test_fpr_invalid_alpha() {
526 let sel = SelectFpr::<f64>::new(0.0);
527 let p = array![0.01];
528 assert!(sel.fit(&p, &()).is_err());
529
530 let sel2 = SelectFpr::<f64>::new(1.5);
531 assert!(sel2.fit(&p, &()).is_err());
532 }
533
534 #[test]
535 fn test_fpr_shape_mismatch() {
536 let sel = SelectFpr::<f64>::new(0.05);
537 let p = array![0.01, 0.5];
538 let fitted = sel.fit(&p, &()).unwrap();
539 let x_bad = array![[1.0, 2.0, 3.0]];
540 assert!(fitted.transform(&x_bad).is_err());
541 }
542
543 #[test]
544 fn test_fpr_accessor() {
545 let sel = SelectFpr::<f64>::new(0.05);
546 assert_eq!(sel.alpha(), 0.05);
547 }
548
549 #[test]
550 fn test_fpr_p_values_accessor() {
551 let sel = SelectFpr::<f64>::new(0.05);
552 let p = array![0.01, 0.5];
553 let fitted = sel.fit(&p, &()).unwrap();
554 assert_eq!(fitted.p_values().len(), 2);
555 }
556
557 #[test]
562 fn test_fdr_basic() {
563 let sel = SelectFdr::<f64>::new(0.05);
564 let p = array![0.01, 0.5, 0.03, 0.9];
569 let fitted = sel.fit(&p, &()).unwrap();
570 assert!(fitted.selected_indices().contains(&0));
571 }
572
573 #[test]
574 fn test_fdr_multiple_pass() {
575 let sel = SelectFdr::<f64>::new(0.10);
576 let p = array![0.02, 0.5, 0.005, 0.04];
582 let fitted = sel.fit(&p, &()).unwrap();
583 assert_eq!(fitted.n_features_selected(), 3);
584 assert!(fitted.selected_indices().contains(&0)); assert!(fitted.selected_indices().contains(&2)); assert!(fitted.selected_indices().contains(&3)); }
588
589 #[test]
590 fn test_fdr_none_selected() {
591 let sel = SelectFdr::<f64>::new(0.001);
592 let p = array![0.01, 0.5, 0.03];
593 let fitted = sel.fit(&p, &()).unwrap();
594 assert_eq!(fitted.n_features_selected(), 0);
595 }
596
597 #[test]
598 fn test_fdr_transform() {
599 let sel = SelectFdr::<f64>::new(0.10);
600 let p = array![0.001, 0.5, 0.9];
601 let fitted = sel.fit(&p, &()).unwrap();
602 let x = array![[1.0, 2.0, 3.0]];
603 let out = fitted.transform(&x).unwrap();
604 assert!(out.ncols() >= 1);
606 }
607
608 #[test]
609 fn test_fdr_empty_error() {
610 let sel = SelectFdr::<f64>::new(0.05);
611 let p: Array1<f64> = Array1::zeros(0);
612 assert!(sel.fit(&p, &()).is_err());
613 }
614
615 #[test]
616 fn test_fdr_invalid_alpha() {
617 let sel = SelectFdr::<f64>::new(0.0);
618 let p = array![0.01];
619 assert!(sel.fit(&p, &()).is_err());
620 }
621
622 #[test]
623 fn test_fdr_shape_mismatch() {
624 let sel = SelectFdr::<f64>::new(0.05);
625 let p = array![0.01, 0.5];
626 let fitted = sel.fit(&p, &()).unwrap();
627 let x_bad = array![[1.0, 2.0, 3.0]];
628 assert!(fitted.transform(&x_bad).is_err());
629 }
630
631 #[test]
632 fn test_fdr_accessor() {
633 let sel = SelectFdr::<f64>::new(0.05);
634 assert_eq!(sel.alpha(), 0.05);
635 }
636
637 #[test]
642 fn test_fwe_basic() {
643 let sel = SelectFwe::<f64>::new(0.05);
644 let p = array![0.001, 0.5, 0.03, 0.9];
646 let fitted = sel.fit(&p, &()).unwrap();
647 assert_eq!(fitted.selected_indices(), &[0]);
648 }
649
650 #[test]
651 fn test_fwe_two_features() {
652 let sel = SelectFwe::<f64>::new(0.10);
653 let p = array![0.01, 0.02, 0.5];
655 let fitted = sel.fit(&p, &()).unwrap();
656 assert_eq!(fitted.selected_indices(), &[0, 1]);
657 }
658
659 #[test]
660 fn test_fwe_none_selected() {
661 let sel = SelectFwe::<f64>::new(0.01);
662 let p = array![0.005, 0.5, 0.03];
664 let fitted = sel.fit(&p, &()).unwrap();
665 assert_eq!(fitted.n_features_selected(), 0);
666 }
667
668 #[test]
669 fn test_fwe_transform() {
670 let sel = SelectFwe::<f64>::new(0.05);
671 let p = array![0.001, 0.5, 0.9];
672 let fitted = sel.fit(&p, &()).unwrap();
673 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
674 let out = fitted.transform(&x).unwrap();
675 assert_eq!(out.ncols(), 1);
676 assert_eq!(out[[0, 0]], 1.0);
677 }
678
679 #[test]
680 fn test_fwe_empty_error() {
681 let sel = SelectFwe::<f64>::new(0.05);
682 let p: Array1<f64> = Array1::zeros(0);
683 assert!(sel.fit(&p, &()).is_err());
684 }
685
686 #[test]
687 fn test_fwe_invalid_alpha() {
688 let sel = SelectFwe::<f64>::new(0.0);
689 let p = array![0.01];
690 assert!(sel.fit(&p, &()).is_err());
691 }
692
693 #[test]
694 fn test_fwe_shape_mismatch() {
695 let sel = SelectFwe::<f64>::new(0.05);
696 let p = array![0.01, 0.5];
697 let fitted = sel.fit(&p, &()).unwrap();
698 let x_bad = array![[1.0, 2.0, 3.0]];
699 assert!(fitted.transform(&x_bad).is_err());
700 }
701
702 #[test]
703 fn test_fwe_accessor() {
704 let sel = SelectFwe::<f64>::new(0.05);
705 assert_eq!(sel.alpha(), 0.05);
706 }
707
708 #[test]
709 fn test_fwe_single_feature() {
710 let sel = SelectFwe::<f64>::new(0.05);
711 let p = array![0.01];
713 let fitted = sel.fit(&p, &()).unwrap();
714 assert_eq!(fitted.selected_indices(), &[0]);
715 }
716
717 #[test]
718 fn test_fwe_f32() {
719 let sel = SelectFwe::<f32>::new(0.05);
720 let p: Array1<f32> = array![0.001f32, 0.5];
721 let fitted = sel.fit(&p, &()).unwrap();
722 assert_eq!(fitted.selected_indices(), &[0]);
724 }
725}