1use ferrolearn_core::error::FerroError;
17use ferrolearn_core::traits::{Fit, Transform};
18use ndarray::{Array1, Array2};
19use num_traits::Float;
20
21fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
27 let nrows = x.nrows();
28 let ncols = indices.len();
29 if ncols == 0 {
30 return Array2::zeros((nrows, 0));
31 }
32 let mut out = Array2::zeros((nrows, ncols));
33 for (new_j, &old_j) in indices.iter().enumerate() {
34 for i in 0..nrows {
35 out[[i, new_j]] = x[[i, old_j]];
36 }
37 }
38 out
39}
40
41fn validate_inputs(n_features: usize, alpha: f64) -> Result<(), FerroError> {
43 if n_features == 0 {
44 return Err(FerroError::InvalidParameter {
45 name: "p_values".into(),
46 reason: "p-value vector must not be empty".into(),
47 });
48 }
49 if alpha <= 0.0 || alpha > 1.0 {
50 return Err(FerroError::InvalidParameter {
51 name: "alpha".into(),
52 reason: format!("alpha must be in (0, 1], got {alpha}"),
53 });
54 }
55 Ok(())
56}
57
58#[must_use]
82#[derive(Debug, Clone)]
83pub struct SelectFpr<F> {
84 alpha: f64,
86 _marker: std::marker::PhantomData<F>,
87}
88
89impl<F: Float + Send + Sync + 'static> SelectFpr<F> {
90 pub fn new(alpha: f64) -> Self {
92 Self {
93 alpha,
94 _marker: std::marker::PhantomData,
95 }
96 }
97
98 #[must_use]
100 pub fn alpha(&self) -> f64 {
101 self.alpha
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct FittedSelectFpr<F> {
108 n_features_in: usize,
110 p_values: Array1<F>,
112 selected_indices: Vec<usize>,
114}
115
116impl<F: Float + Send + Sync + 'static> FittedSelectFpr<F> {
117 #[must_use]
119 pub fn p_values(&self) -> &Array1<F> {
120 &self.p_values
121 }
122
123 #[must_use]
125 pub fn selected_indices(&self) -> &[usize] {
126 &self.selected_indices
127 }
128
129 #[must_use]
131 pub fn n_features_selected(&self) -> usize {
132 self.selected_indices.len()
133 }
134}
135
136impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFpr<F> {
137 type Fitted = FittedSelectFpr<F>;
138 type Error = FerroError;
139
140 fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFpr<F>, FerroError> {
147 let n = x.len();
148 validate_inputs(n, self.alpha)?;
149
150 let alpha_f = F::from(self.alpha).unwrap_or_else(F::zero);
151 let selected_indices: Vec<usize> = x
152 .iter()
153 .enumerate()
154 .filter(|&(_, &p)| p < alpha_f)
155 .map(|(j, _)| j)
156 .collect();
157
158 Ok(FittedSelectFpr {
159 n_features_in: n,
160 p_values: x.clone(),
161 selected_indices,
162 })
163 }
164}
165
166impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFpr<F> {
167 type Output = Array2<F>;
168 type Error = FerroError;
169
170 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
176 if x.ncols() != self.n_features_in {
177 return Err(FerroError::ShapeMismatch {
178 expected: vec![x.nrows(), self.n_features_in],
179 actual: vec![x.nrows(), x.ncols()],
180 context: "FittedSelectFpr::transform".into(),
181 });
182 }
183 Ok(select_columns(x, &self.selected_indices))
184 }
185}
186
187#[must_use]
211#[derive(Debug, Clone)]
212pub struct SelectFdr<F> {
213 alpha: f64,
215 _marker: std::marker::PhantomData<F>,
216}
217
218impl<F: Float + Send + Sync + 'static> SelectFdr<F> {
219 pub fn new(alpha: f64) -> Self {
221 Self {
222 alpha,
223 _marker: std::marker::PhantomData,
224 }
225 }
226
227 #[must_use]
229 pub fn alpha(&self) -> f64 {
230 self.alpha
231 }
232}
233
234#[derive(Debug, Clone)]
236pub struct FittedSelectFdr<F> {
237 n_features_in: usize,
239 p_values: Array1<F>,
241 selected_indices: Vec<usize>,
243}
244
245impl<F: Float + Send + Sync + 'static> FittedSelectFdr<F> {
246 #[must_use]
248 pub fn p_values(&self) -> &Array1<F> {
249 &self.p_values
250 }
251
252 #[must_use]
254 pub fn selected_indices(&self) -> &[usize] {
255 &self.selected_indices
256 }
257
258 #[must_use]
260 pub fn n_features_selected(&self) -> usize {
261 self.selected_indices.len()
262 }
263}
264
265impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFdr<F> {
266 type Fitted = FittedSelectFdr<F>;
267 type Error = FerroError;
268
269 fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFdr<F>, FerroError> {
276 let n = x.len();
277 validate_inputs(n, self.alpha)?;
278
279 let alpha_f = F::from(self.alpha).unwrap_or_else(F::zero);
280 let n_f = F::from(n).unwrap_or_else(F::one);
281
282 let mut ranked: Vec<(usize, F)> = x.iter().copied().enumerate().collect();
284 ranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
285
286 let mut max_qualifying_rank: Option<usize> = None;
288 for (rank, &(_, p_val)) in ranked.iter().enumerate() {
289 let bh_threshold = alpha_f * F::from(rank + 1).unwrap_or_else(F::one) / n_f;
290 if p_val <= bh_threshold {
291 max_qualifying_rank = Some(rank);
292 }
293 }
294
295 let mut selected_indices: Vec<usize> = match max_qualifying_rank {
297 Some(max_rank) => ranked[..=max_rank].iter().map(|&(idx, _)| idx).collect(),
298 None => Vec::new(),
299 };
300 selected_indices.sort_unstable();
301
302 Ok(FittedSelectFdr {
303 n_features_in: n,
304 p_values: x.clone(),
305 selected_indices,
306 })
307 }
308}
309
310impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFdr<F> {
311 type Output = Array2<F>;
312 type Error = FerroError;
313
314 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
320 if x.ncols() != self.n_features_in {
321 return Err(FerroError::ShapeMismatch {
322 expected: vec![x.nrows(), self.n_features_in],
323 actual: vec![x.nrows(), x.ncols()],
324 context: "FittedSelectFdr::transform".into(),
325 });
326 }
327 Ok(select_columns(x, &self.selected_indices))
328 }
329}
330
331#[must_use]
355#[derive(Debug, Clone)]
356pub struct SelectFwe<F> {
357 alpha: f64,
359 _marker: std::marker::PhantomData<F>,
360}
361
362impl<F: Float + Send + Sync + 'static> SelectFwe<F> {
363 pub fn new(alpha: f64) -> Self {
365 Self {
366 alpha,
367 _marker: std::marker::PhantomData,
368 }
369 }
370
371 #[must_use]
373 pub fn alpha(&self) -> f64 {
374 self.alpha
375 }
376}
377
378#[derive(Debug, Clone)]
380pub struct FittedSelectFwe<F> {
381 n_features_in: usize,
383 p_values: Array1<F>,
385 selected_indices: Vec<usize>,
387}
388
389impl<F: Float + Send + Sync + 'static> FittedSelectFwe<F> {
390 #[must_use]
392 pub fn p_values(&self) -> &Array1<F> {
393 &self.p_values
394 }
395
396 #[must_use]
398 pub fn selected_indices(&self) -> &[usize] {
399 &self.selected_indices
400 }
401
402 #[must_use]
404 pub fn n_features_selected(&self) -> usize {
405 self.selected_indices.len()
406 }
407}
408
409impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFwe<F> {
410 type Fitted = FittedSelectFwe<F>;
411 type Error = FerroError;
412
413 fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFwe<F>, FerroError> {
420 let n = x.len();
421 validate_inputs(n, self.alpha)?;
422
423 let adjusted_alpha = self.alpha / n as f64;
424 let adjusted_alpha_f = F::from(adjusted_alpha).unwrap_or_else(F::zero);
425
426 let selected_indices: Vec<usize> = x
427 .iter()
428 .enumerate()
429 .filter(|&(_, &p)| p < adjusted_alpha_f)
430 .map(|(j, _)| j)
431 .collect();
432
433 Ok(FittedSelectFwe {
434 n_features_in: n,
435 p_values: x.clone(),
436 selected_indices,
437 })
438 }
439}
440
441impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFwe<F> {
442 type Output = Array2<F>;
443 type Error = FerroError;
444
445 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
451 if x.ncols() != self.n_features_in {
452 return Err(FerroError::ShapeMismatch {
453 expected: vec![x.nrows(), self.n_features_in],
454 actual: vec![x.nrows(), x.ncols()],
455 context: "FittedSelectFwe::transform".into(),
456 });
457 }
458 Ok(select_columns(x, &self.selected_indices))
459 }
460}
461
462#[cfg(test)]
467mod tests {
468 use super::*;
469 use ndarray::array;
470
471 #[test]
476 fn test_fpr_selects_below_alpha() {
477 let sel = SelectFpr::<f64>::new(0.05);
478 let p = array![0.01, 0.5, 0.03, 0.9];
479 let fitted = sel.fit(&p, &()).unwrap();
480 assert_eq!(fitted.selected_indices(), &[0, 2]);
481 }
482
483 #[test]
484 fn test_fpr_none_below_alpha() {
485 let sel = SelectFpr::<f64>::new(0.001);
486 let p = array![0.01, 0.5, 0.03];
487 let fitted = sel.fit(&p, &()).unwrap();
488 assert_eq!(fitted.n_features_selected(), 0);
489 }
490
491 #[test]
492 fn test_fpr_all_below_alpha() {
493 let sel = SelectFpr::<f64>::new(0.99);
494 let p = array![0.01, 0.5, 0.03];
495 let fitted = sel.fit(&p, &()).unwrap();
496 assert_eq!(fitted.n_features_selected(), 3);
497 }
498
499 #[test]
500 fn test_fpr_transform() {
501 let sel = SelectFpr::<f64>::new(0.05);
502 let p = array![0.01, 0.5, 0.03];
503 let fitted = sel.fit(&p, &()).unwrap();
504 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
505 let out = fitted.transform(&x).unwrap();
506 assert_eq!(out.ncols(), 2); assert_eq!(out[[0, 0]], 1.0);
508 assert_eq!(out[[0, 1]], 3.0);
509 }
510
511 #[test]
512 fn test_fpr_empty_error() {
513 let sel = SelectFpr::<f64>::new(0.05);
514 let p: Array1<f64> = Array1::zeros(0);
515 assert!(sel.fit(&p, &()).is_err());
516 }
517
518 #[test]
519 fn test_fpr_invalid_alpha() {
520 let sel = SelectFpr::<f64>::new(0.0);
521 let p = array![0.01];
522 assert!(sel.fit(&p, &()).is_err());
523
524 let sel2 = SelectFpr::<f64>::new(1.5);
525 assert!(sel2.fit(&p, &()).is_err());
526 }
527
528 #[test]
529 fn test_fpr_shape_mismatch() {
530 let sel = SelectFpr::<f64>::new(0.05);
531 let p = array![0.01, 0.5];
532 let fitted = sel.fit(&p, &()).unwrap();
533 let x_bad = array![[1.0, 2.0, 3.0]];
534 assert!(fitted.transform(&x_bad).is_err());
535 }
536
537 #[test]
538 fn test_fpr_accessor() {
539 let sel = SelectFpr::<f64>::new(0.05);
540 assert_eq!(sel.alpha(), 0.05);
541 }
542
543 #[test]
544 fn test_fpr_p_values_accessor() {
545 let sel = SelectFpr::<f64>::new(0.05);
546 let p = array![0.01, 0.5];
547 let fitted = sel.fit(&p, &()).unwrap();
548 assert_eq!(fitted.p_values().len(), 2);
549 }
550
551 #[test]
556 fn test_fdr_basic() {
557 let sel = SelectFdr::<f64>::new(0.05);
558 let p = array![0.01, 0.5, 0.03, 0.9];
563 let fitted = sel.fit(&p, &()).unwrap();
564 assert!(fitted.selected_indices().contains(&0));
565 }
566
567 #[test]
568 fn test_fdr_multiple_pass() {
569 let sel = SelectFdr::<f64>::new(0.10);
570 let p = array![0.02, 0.5, 0.005, 0.04];
576 let fitted = sel.fit(&p, &()).unwrap();
577 assert_eq!(fitted.n_features_selected(), 3);
578 assert!(fitted.selected_indices().contains(&0)); assert!(fitted.selected_indices().contains(&2)); assert!(fitted.selected_indices().contains(&3)); }
582
583 #[test]
584 fn test_fdr_none_selected() {
585 let sel = SelectFdr::<f64>::new(0.001);
586 let p = array![0.01, 0.5, 0.03];
587 let fitted = sel.fit(&p, &()).unwrap();
588 assert_eq!(fitted.n_features_selected(), 0);
589 }
590
591 #[test]
592 fn test_fdr_transform() {
593 let sel = SelectFdr::<f64>::new(0.10);
594 let p = array![0.001, 0.5, 0.9];
595 let fitted = sel.fit(&p, &()).unwrap();
596 let x = array![[1.0, 2.0, 3.0]];
597 let out = fitted.transform(&x).unwrap();
598 assert!(out.ncols() >= 1);
600 }
601
602 #[test]
603 fn test_fdr_empty_error() {
604 let sel = SelectFdr::<f64>::new(0.05);
605 let p: Array1<f64> = Array1::zeros(0);
606 assert!(sel.fit(&p, &()).is_err());
607 }
608
609 #[test]
610 fn test_fdr_invalid_alpha() {
611 let sel = SelectFdr::<f64>::new(0.0);
612 let p = array![0.01];
613 assert!(sel.fit(&p, &()).is_err());
614 }
615
616 #[test]
617 fn test_fdr_shape_mismatch() {
618 let sel = SelectFdr::<f64>::new(0.05);
619 let p = array![0.01, 0.5];
620 let fitted = sel.fit(&p, &()).unwrap();
621 let x_bad = array![[1.0, 2.0, 3.0]];
622 assert!(fitted.transform(&x_bad).is_err());
623 }
624
625 #[test]
626 fn test_fdr_accessor() {
627 let sel = SelectFdr::<f64>::new(0.05);
628 assert_eq!(sel.alpha(), 0.05);
629 }
630
631 #[test]
636 fn test_fwe_basic() {
637 let sel = SelectFwe::<f64>::new(0.05);
638 let p = array![0.001, 0.5, 0.03, 0.9];
640 let fitted = sel.fit(&p, &()).unwrap();
641 assert_eq!(fitted.selected_indices(), &[0]);
642 }
643
644 #[test]
645 fn test_fwe_two_features() {
646 let sel = SelectFwe::<f64>::new(0.10);
647 let p = array![0.01, 0.02, 0.5];
649 let fitted = sel.fit(&p, &()).unwrap();
650 assert_eq!(fitted.selected_indices(), &[0, 1]);
651 }
652
653 #[test]
654 fn test_fwe_none_selected() {
655 let sel = SelectFwe::<f64>::new(0.01);
656 let p = array![0.005, 0.5, 0.03];
658 let fitted = sel.fit(&p, &()).unwrap();
659 assert_eq!(fitted.n_features_selected(), 0);
660 }
661
662 #[test]
663 fn test_fwe_transform() {
664 let sel = SelectFwe::<f64>::new(0.05);
665 let p = array![0.001, 0.5, 0.9];
666 let fitted = sel.fit(&p, &()).unwrap();
667 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
668 let out = fitted.transform(&x).unwrap();
669 assert_eq!(out.ncols(), 1);
670 assert_eq!(out[[0, 0]], 1.0);
671 }
672
673 #[test]
674 fn test_fwe_empty_error() {
675 let sel = SelectFwe::<f64>::new(0.05);
676 let p: Array1<f64> = Array1::zeros(0);
677 assert!(sel.fit(&p, &()).is_err());
678 }
679
680 #[test]
681 fn test_fwe_invalid_alpha() {
682 let sel = SelectFwe::<f64>::new(0.0);
683 let p = array![0.01];
684 assert!(sel.fit(&p, &()).is_err());
685 }
686
687 #[test]
688 fn test_fwe_shape_mismatch() {
689 let sel = SelectFwe::<f64>::new(0.05);
690 let p = array![0.01, 0.5];
691 let fitted = sel.fit(&p, &()).unwrap();
692 let x_bad = array![[1.0, 2.0, 3.0]];
693 assert!(fitted.transform(&x_bad).is_err());
694 }
695
696 #[test]
697 fn test_fwe_accessor() {
698 let sel = SelectFwe::<f64>::new(0.05);
699 assert_eq!(sel.alpha(), 0.05);
700 }
701
702 #[test]
703 fn test_fwe_single_feature() {
704 let sel = SelectFwe::<f64>::new(0.05);
705 let p = array![0.01];
707 let fitted = sel.fit(&p, &()).unwrap();
708 assert_eq!(fitted.selected_indices(), &[0]);
709 }
710
711 #[test]
712 fn test_fwe_f32() {
713 let sel = SelectFwe::<f32>::new(0.05);
714 let p: Array1<f32> = array![0.001f32, 0.5];
715 let fitted = sel.fit(&p, &()).unwrap();
716 assert_eq!(fitted.selected_indices(), &[0]);
718 }
719}