Skip to main content

ferrolearn_preprocess/
stat_selectors.rs

1//! Statistical-test-based feature selectors.
2//!
3//! Three selectors that choose features based on p-values obtained from a
4//! statistical test (e.g., ANOVA F-test, chi-squared test):
5//!
6//! - [`SelectFpr`] — **False Positive Rate**: selects every feature whose
7//!   p-value is below `alpha`.
8//! - [`SelectFdr`] — **False Discovery Rate**: applies the Benjamini-Hochberg
9//!   procedure to control the expected proportion of false positives.
10//! - [`SelectFwe`] — **Family-Wise Error**: applies the Bonferroni correction
11//!   (`alpha / n_features`) to control the probability of any false positive.
12//!
13//! All three take a pre-computed vector of p-values (one per feature) at fit
14//! time, allowing integration with any upstream scoring function.
15
16use ferrolearn_core::error::FerroError;
17use ferrolearn_core::traits::{Fit, Transform};
18use ndarray::{Array1, Array2};
19use num_traits::Float;
20
21// ---------------------------------------------------------------------------
22// Shared helper
23// ---------------------------------------------------------------------------
24
25/// Build a new `Array2<F>` containing only the columns listed in `indices`.
26fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
27    let nrows = x.nrows();
28    let ncols = indices.len();
29    if ncols == 0 {
30        return Array2::zeros((nrows, 0));
31    }
32    let mut out = Array2::zeros((nrows, ncols));
33    for (new_j, &old_j) in indices.iter().enumerate() {
34        for i in 0..nrows {
35            out[[i, new_j]] = x[[i, old_j]];
36        }
37    }
38    out
39}
40
41/// Validate common inputs for all three selectors.
42fn validate_inputs(n_features: usize, alpha: f64) -> Result<(), FerroError> {
43    if n_features == 0 {
44        return Err(FerroError::InvalidParameter {
45            name: "p_values".into(),
46            reason: "p-value vector must not be empty".into(),
47        });
48    }
49    if alpha <= 0.0 || alpha > 1.0 {
50        return Err(FerroError::InvalidParameter {
51            name: "alpha".into(),
52            reason: format!("alpha must be in (0, 1], got {alpha}"),
53        });
54    }
55    Ok(())
56}
57
58// ===========================================================================
59// SelectFpr — False Positive Rate
60// ===========================================================================
61
62/// Select features with p-values below `alpha`.
63///
64/// A feature is selected if its p-value is strictly less than `alpha`.
65/// This controls the per-feature false positive rate but does not adjust
66/// for multiple comparisons.
67///
68/// # Examples
69///
70/// ```
71/// use ferrolearn_preprocess::stat_selectors::SelectFpr;
72/// use ferrolearn_core::traits::{Fit, Transform};
73/// use ndarray::array;
74///
75/// let sel = SelectFpr::<f64>::new(0.05);
76/// let p_values = array![0.01, 0.5, 0.03, 0.9];
77/// let fitted = sel.fit(&p_values, &()).unwrap();
78/// // Features 0 (p=0.01) and 2 (p=0.03) are below alpha=0.05
79/// assert_eq!(fitted.selected_indices(), &[0, 2]);
80/// ```
81#[must_use]
82#[derive(Debug, Clone)]
83pub struct SelectFpr<F> {
84    /// Significance threshold.
85    alpha: f64,
86    _marker: std::marker::PhantomData<F>,
87}
88
89impl<F: Float + Send + Sync + 'static> SelectFpr<F> {
90    /// Create a new `SelectFpr` with the given significance level.
91    pub fn new(alpha: f64) -> Self {
92        Self {
93            alpha,
94            _marker: std::marker::PhantomData,
95        }
96    }
97
98    /// Return the significance level.
99    #[must_use]
100    pub fn alpha(&self) -> f64 {
101        self.alpha
102    }
103}
104
105/// A fitted `SelectFpr` holding the selected indices.
106#[derive(Debug, Clone)]
107pub struct FittedSelectFpr<F> {
108    /// Number of features seen during fitting.
109    n_features_in: usize,
110    /// P-values supplied during fitting.
111    p_values: Array1<F>,
112    /// Indices of selected columns (sorted).
113    selected_indices: Vec<usize>,
114}
115
116impl<F: Float + Send + Sync + 'static> FittedSelectFpr<F> {
117    /// Return the p-values.
118    #[must_use]
119    pub fn p_values(&self) -> &Array1<F> {
120        &self.p_values
121    }
122
123    /// Return the indices of the selected columns.
124    #[must_use]
125    pub fn selected_indices(&self) -> &[usize] {
126        &self.selected_indices
127    }
128
129    /// Return the number of selected features.
130    #[must_use]
131    pub fn n_features_selected(&self) -> usize {
132        self.selected_indices.len()
133    }
134}
135
136impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFpr<F> {
137    type Fitted = FittedSelectFpr<F>;
138    type Error = FerroError;
139
140    /// Fit by selecting features whose p-value is below `alpha`.
141    ///
142    /// # Errors
143    ///
144    /// - [`FerroError::InvalidParameter`] if p-values are empty or alpha is
145    ///   not in `(0, 1]`.
146    fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFpr<F>, FerroError> {
147        let n = x.len();
148        validate_inputs(n, self.alpha)?;
149
150        let alpha_f = F::from(self.alpha).unwrap_or_else(F::zero);
151        let selected_indices: Vec<usize> = x
152            .iter()
153            .enumerate()
154            .filter(|&(_, &p)| p < alpha_f)
155            .map(|(j, _)| j)
156            .collect();
157
158        Ok(FittedSelectFpr {
159            n_features_in: n,
160            p_values: x.clone(),
161            selected_indices,
162        })
163    }
164}
165
166impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFpr<F> {
167    type Output = Array2<F>;
168    type Error = FerroError;
169
170    /// Return a matrix containing only the selected columns.
171    ///
172    /// # Errors
173    ///
174    /// Returns [`FerroError::ShapeMismatch`] if column count does not match.
175    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
176        if x.ncols() != self.n_features_in {
177            return Err(FerroError::ShapeMismatch {
178                expected: vec![x.nrows(), self.n_features_in],
179                actual: vec![x.nrows(), x.ncols()],
180                context: "FittedSelectFpr::transform".into(),
181            });
182        }
183        Ok(select_columns(x, &self.selected_indices))
184    }
185}
186
187// ===========================================================================
188// SelectFdr — False Discovery Rate (Benjamini-Hochberg)
189// ===========================================================================
190
191/// Select features controlling the false discovery rate via the
192/// Benjamini-Hochberg procedure.
193///
194/// Features are sorted by p-value. Feature *i* (0-indexed, sorted ascending)
195/// is selected if `p_value[i] <= alpha * (i+1) / n_features`. All features
196/// with rank at or below the highest qualifying rank are selected.
197///
198/// # Examples
199///
200/// ```
201/// use ferrolearn_preprocess::stat_selectors::SelectFdr;
202/// use ferrolearn_core::traits::{Fit, Transform};
203/// use ndarray::array;
204///
205/// let sel = SelectFdr::<f64>::new(0.05);
206/// let p_values = array![0.01, 0.5, 0.03, 0.9];
207/// let fitted = sel.fit(&p_values, &()).unwrap();
208/// assert!(fitted.selected_indices().contains(&0));
209/// ```
210#[must_use]
211#[derive(Debug, Clone)]
212pub struct SelectFdr<F> {
213    /// Target false discovery rate.
214    alpha: f64,
215    _marker: std::marker::PhantomData<F>,
216}
217
218impl<F: Float + Send + Sync + 'static> SelectFdr<F> {
219    /// Create a new `SelectFdr` with the given FDR level.
220    pub fn new(alpha: f64) -> Self {
221        Self {
222            alpha,
223            _marker: std::marker::PhantomData,
224        }
225    }
226
227    /// Return the FDR level.
228    #[must_use]
229    pub fn alpha(&self) -> f64 {
230        self.alpha
231    }
232}
233
234/// A fitted `SelectFdr` holding the selected indices.
235#[derive(Debug, Clone)]
236pub struct FittedSelectFdr<F> {
237    /// Number of features seen during fitting.
238    n_features_in: usize,
239    /// P-values supplied during fitting.
240    p_values: Array1<F>,
241    /// Indices of selected columns (sorted in original order).
242    selected_indices: Vec<usize>,
243}
244
245impl<F: Float + Send + Sync + 'static> FittedSelectFdr<F> {
246    /// Return the p-values.
247    #[must_use]
248    pub fn p_values(&self) -> &Array1<F> {
249        &self.p_values
250    }
251
252    /// Return the indices of the selected columns.
253    #[must_use]
254    pub fn selected_indices(&self) -> &[usize] {
255        &self.selected_indices
256    }
257
258    /// Return the number of selected features.
259    #[must_use]
260    pub fn n_features_selected(&self) -> usize {
261        self.selected_indices.len()
262    }
263}
264
265impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFdr<F> {
266    type Fitted = FittedSelectFdr<F>;
267    type Error = FerroError;
268
269    /// Fit using the Benjamini-Hochberg procedure.
270    ///
271    /// # Errors
272    ///
273    /// - [`FerroError::InvalidParameter`] if p-values are empty or alpha is
274    ///   not in `(0, 1]`.
275    fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFdr<F>, FerroError> {
276        let n = x.len();
277        validate_inputs(n, self.alpha)?;
278
279        let alpha_f = F::from(self.alpha).unwrap_or_else(F::zero);
280        let n_f = F::from(n).unwrap_or_else(F::one);
281
282        // Sort features by p-value (ascending), keeping original indices
283        let mut ranked: Vec<(usize, F)> = x.iter().copied().enumerate().collect();
284        ranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
285
286        // Find the largest rank k where p_(k) <= alpha * (k+1) / n
287        let mut max_qualifying_rank: Option<usize> = None;
288        for (rank, &(_, p_val)) in ranked.iter().enumerate() {
289            let bh_threshold = alpha_f * F::from(rank + 1).unwrap_or_else(F::one) / n_f;
290            if p_val <= bh_threshold {
291                max_qualifying_rank = Some(rank);
292            }
293        }
294
295        // Select all features at or below the max qualifying rank
296        let mut selected_indices: Vec<usize> = match max_qualifying_rank {
297            Some(max_rank) => ranked[..=max_rank].iter().map(|&(idx, _)| idx).collect(),
298            None => Vec::new(),
299        };
300        selected_indices.sort_unstable();
301
302        Ok(FittedSelectFdr {
303            n_features_in: n,
304            p_values: x.clone(),
305            selected_indices,
306        })
307    }
308}
309
310impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFdr<F> {
311    type Output = Array2<F>;
312    type Error = FerroError;
313
314    /// Return a matrix containing only the selected columns.
315    ///
316    /// # Errors
317    ///
318    /// Returns [`FerroError::ShapeMismatch`] if column count does not match.
319    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
320        if x.ncols() != self.n_features_in {
321            return Err(FerroError::ShapeMismatch {
322                expected: vec![x.nrows(), self.n_features_in],
323                actual: vec![x.nrows(), x.ncols()],
324                context: "FittedSelectFdr::transform".into(),
325            });
326        }
327        Ok(select_columns(x, &self.selected_indices))
328    }
329}
330
331// ===========================================================================
332// SelectFwe — Family-Wise Error (Bonferroni)
333// ===========================================================================
334
335/// Select features controlling the family-wise error rate via the
336/// Bonferroni correction.
337///
338/// A feature is selected if its p-value is strictly less than
339/// `alpha / n_features`.
340///
341/// # Examples
342///
343/// ```
344/// use ferrolearn_preprocess::stat_selectors::SelectFwe;
345/// use ferrolearn_core::traits::{Fit, Transform};
346/// use ndarray::array;
347///
348/// let sel = SelectFwe::<f64>::new(0.05);
349/// let p_values = array![0.001, 0.5, 0.03, 0.9];
350/// let fitted = sel.fit(&p_values, &()).unwrap();
351/// // Bonferroni threshold = 0.05/4 = 0.0125; only feature 0 qualifies
352/// assert_eq!(fitted.selected_indices(), &[0]);
353/// ```
354#[must_use]
355#[derive(Debug, Clone)]
356pub struct SelectFwe<F> {
357    /// Significance level before Bonferroni correction.
358    alpha: f64,
359    _marker: std::marker::PhantomData<F>,
360}
361
362impl<F: Float + Send + Sync + 'static> SelectFwe<F> {
363    /// Create a new `SelectFwe` with the given significance level.
364    pub fn new(alpha: f64) -> Self {
365        Self {
366            alpha,
367            _marker: std::marker::PhantomData,
368        }
369    }
370
371    /// Return the significance level.
372    #[must_use]
373    pub fn alpha(&self) -> f64 {
374        self.alpha
375    }
376}
377
378/// A fitted `SelectFwe` holding the selected indices.
379#[derive(Debug, Clone)]
380pub struct FittedSelectFwe<F> {
381    /// Number of features seen during fitting.
382    n_features_in: usize,
383    /// P-values supplied during fitting.
384    p_values: Array1<F>,
385    /// Indices of selected columns (sorted).
386    selected_indices: Vec<usize>,
387}
388
389impl<F: Float + Send + Sync + 'static> FittedSelectFwe<F> {
390    /// Return the p-values.
391    #[must_use]
392    pub fn p_values(&self) -> &Array1<F> {
393        &self.p_values
394    }
395
396    /// Return the indices of the selected columns.
397    #[must_use]
398    pub fn selected_indices(&self) -> &[usize] {
399        &self.selected_indices
400    }
401
402    /// Return the number of selected features.
403    #[must_use]
404    pub fn n_features_selected(&self) -> usize {
405        self.selected_indices.len()
406    }
407}
408
409impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFwe<F> {
410    type Fitted = FittedSelectFwe<F>;
411    type Error = FerroError;
412
413    /// Fit using the Bonferroni correction: `p < alpha / n_features`.
414    ///
415    /// # Errors
416    ///
417    /// - [`FerroError::InvalidParameter`] if p-values are empty or alpha is
418    ///   not in `(0, 1]`.
419    fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFwe<F>, FerroError> {
420        let n = x.len();
421        validate_inputs(n, self.alpha)?;
422
423        let adjusted_alpha = self.alpha / n as f64;
424        let adjusted_alpha_f = F::from(adjusted_alpha).unwrap_or_else(F::zero);
425
426        let selected_indices: Vec<usize> = x
427            .iter()
428            .enumerate()
429            .filter(|&(_, &p)| p < adjusted_alpha_f)
430            .map(|(j, _)| j)
431            .collect();
432
433        Ok(FittedSelectFwe {
434            n_features_in: n,
435            p_values: x.clone(),
436            selected_indices,
437        })
438    }
439}
440
441impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFwe<F> {
442    type Output = Array2<F>;
443    type Error = FerroError;
444
445    /// Return a matrix containing only the selected columns.
446    ///
447    /// # Errors
448    ///
449    /// Returns [`FerroError::ShapeMismatch`] if column count does not match.
450    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
451        if x.ncols() != self.n_features_in {
452            return Err(FerroError::ShapeMismatch {
453                expected: vec![x.nrows(), self.n_features_in],
454                actual: vec![x.nrows(), x.ncols()],
455                context: "FittedSelectFwe::transform".into(),
456            });
457        }
458        Ok(select_columns(x, &self.selected_indices))
459    }
460}
461
462// ---------------------------------------------------------------------------
463// Tests
464// ---------------------------------------------------------------------------
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use ndarray::array;
470
471    // ========================================================================
472    // SelectFpr tests
473    // ========================================================================
474
475    #[test]
476    fn test_fpr_selects_below_alpha() {
477        let sel = SelectFpr::<f64>::new(0.05);
478        let p = array![0.01, 0.5, 0.03, 0.9];
479        let fitted = sel.fit(&p, &()).unwrap();
480        assert_eq!(fitted.selected_indices(), &[0, 2]);
481    }
482
483    #[test]
484    fn test_fpr_none_below_alpha() {
485        let sel = SelectFpr::<f64>::new(0.001);
486        let p = array![0.01, 0.5, 0.03];
487        let fitted = sel.fit(&p, &()).unwrap();
488        assert_eq!(fitted.n_features_selected(), 0);
489    }
490
491    #[test]
492    fn test_fpr_all_below_alpha() {
493        let sel = SelectFpr::<f64>::new(0.99);
494        let p = array![0.01, 0.5, 0.03];
495        let fitted = sel.fit(&p, &()).unwrap();
496        assert_eq!(fitted.n_features_selected(), 3);
497    }
498
499    #[test]
500    fn test_fpr_transform() {
501        let sel = SelectFpr::<f64>::new(0.05);
502        let p = array![0.01, 0.5, 0.03];
503        let fitted = sel.fit(&p, &()).unwrap();
504        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
505        let out = fitted.transform(&x).unwrap();
506        assert_eq!(out.ncols(), 2); // features 0 and 2
507        assert_eq!(out[[0, 0]], 1.0);
508        assert_eq!(out[[0, 1]], 3.0);
509    }
510
511    #[test]
512    fn test_fpr_empty_error() {
513        let sel = SelectFpr::<f64>::new(0.05);
514        let p: Array1<f64> = Array1::zeros(0);
515        assert!(sel.fit(&p, &()).is_err());
516    }
517
518    #[test]
519    fn test_fpr_invalid_alpha() {
520        let sel = SelectFpr::<f64>::new(0.0);
521        let p = array![0.01];
522        assert!(sel.fit(&p, &()).is_err());
523
524        let sel2 = SelectFpr::<f64>::new(1.5);
525        assert!(sel2.fit(&p, &()).is_err());
526    }
527
528    #[test]
529    fn test_fpr_shape_mismatch() {
530        let sel = SelectFpr::<f64>::new(0.05);
531        let p = array![0.01, 0.5];
532        let fitted = sel.fit(&p, &()).unwrap();
533        let x_bad = array![[1.0, 2.0, 3.0]];
534        assert!(fitted.transform(&x_bad).is_err());
535    }
536
537    #[test]
538    fn test_fpr_accessor() {
539        let sel = SelectFpr::<f64>::new(0.05);
540        assert_eq!(sel.alpha(), 0.05);
541    }
542
543    #[test]
544    fn test_fpr_p_values_accessor() {
545        let sel = SelectFpr::<f64>::new(0.05);
546        let p = array![0.01, 0.5];
547        let fitted = sel.fit(&p, &()).unwrap();
548        assert_eq!(fitted.p_values().len(), 2);
549    }
550
551    // ========================================================================
552    // SelectFdr tests (Benjamini-Hochberg)
553    // ========================================================================
554
555    #[test]
556    fn test_fdr_basic() {
557        let sel = SelectFdr::<f64>::new(0.05);
558        // Sorted p-values: 0.01 (feat 0), 0.03 (feat 2), 0.5 (feat 1), 0.9 (feat 3)
559        // BH thresholds: 0.05*1/4=0.0125, 0.05*2/4=0.025, 0.05*3/4=0.0375, 0.05*4/4=0.05
560        // 0.01 <= 0.0125 ✓ (rank 0)
561        // 0.03 <= 0.025  ✗ → but check all: max qualifying rank = 0
562        let p = array![0.01, 0.5, 0.03, 0.9];
563        let fitted = sel.fit(&p, &()).unwrap();
564        assert!(fitted.selected_indices().contains(&0));
565    }
566
567    #[test]
568    fn test_fdr_multiple_pass() {
569        let sel = SelectFdr::<f64>::new(0.10);
570        // Sorted: 0.005 (rank 0), 0.02 (rank 1), 0.04 (rank 2), 0.5 (rank 3)
571        // BH: 0.1*1/4=0.025, 0.1*2/4=0.05, 0.1*3/4=0.075, 0.1*4/4=0.1
572        // 0.005 <= 0.025 ✓
573        // 0.02  <= 0.05  ✓
574        // 0.04  <= 0.075 ✓ → max rank = 2 → select rank 0,1,2
575        let p = array![0.02, 0.5, 0.005, 0.04];
576        let fitted = sel.fit(&p, &()).unwrap();
577        assert_eq!(fitted.n_features_selected(), 3);
578        assert!(fitted.selected_indices().contains(&0)); // 0.02
579        assert!(fitted.selected_indices().contains(&2)); // 0.005
580        assert!(fitted.selected_indices().contains(&3)); // 0.04
581    }
582
583    #[test]
584    fn test_fdr_none_selected() {
585        let sel = SelectFdr::<f64>::new(0.001);
586        let p = array![0.01, 0.5, 0.03];
587        let fitted = sel.fit(&p, &()).unwrap();
588        assert_eq!(fitted.n_features_selected(), 0);
589    }
590
591    #[test]
592    fn test_fdr_transform() {
593        let sel = SelectFdr::<f64>::new(0.10);
594        let p = array![0.001, 0.5, 0.9];
595        let fitted = sel.fit(&p, &()).unwrap();
596        let x = array![[1.0, 2.0, 3.0]];
597        let out = fitted.transform(&x).unwrap();
598        // Feature 0 (p=0.001) selected: BH threshold = 0.1*1/3 ≈ 0.033
599        assert!(out.ncols() >= 1);
600    }
601
602    #[test]
603    fn test_fdr_empty_error() {
604        let sel = SelectFdr::<f64>::new(0.05);
605        let p: Array1<f64> = Array1::zeros(0);
606        assert!(sel.fit(&p, &()).is_err());
607    }
608
609    #[test]
610    fn test_fdr_invalid_alpha() {
611        let sel = SelectFdr::<f64>::new(0.0);
612        let p = array![0.01];
613        assert!(sel.fit(&p, &()).is_err());
614    }
615
616    #[test]
617    fn test_fdr_shape_mismatch() {
618        let sel = SelectFdr::<f64>::new(0.05);
619        let p = array![0.01, 0.5];
620        let fitted = sel.fit(&p, &()).unwrap();
621        let x_bad = array![[1.0, 2.0, 3.0]];
622        assert!(fitted.transform(&x_bad).is_err());
623    }
624
625    #[test]
626    fn test_fdr_accessor() {
627        let sel = SelectFdr::<f64>::new(0.05);
628        assert_eq!(sel.alpha(), 0.05);
629    }
630
631    // ========================================================================
632    // SelectFwe tests (Bonferroni)
633    // ========================================================================
634
635    #[test]
636    fn test_fwe_basic() {
637        let sel = SelectFwe::<f64>::new(0.05);
638        // Bonferroni threshold = 0.05/4 = 0.0125
639        let p = array![0.001, 0.5, 0.03, 0.9];
640        let fitted = sel.fit(&p, &()).unwrap();
641        assert_eq!(fitted.selected_indices(), &[0]);
642    }
643
644    #[test]
645    fn test_fwe_two_features() {
646        let sel = SelectFwe::<f64>::new(0.10);
647        // Bonferroni: 0.1/3 ≈ 0.0333
648        let p = array![0.01, 0.02, 0.5];
649        let fitted = sel.fit(&p, &()).unwrap();
650        assert_eq!(fitted.selected_indices(), &[0, 1]);
651    }
652
653    #[test]
654    fn test_fwe_none_selected() {
655        let sel = SelectFwe::<f64>::new(0.01);
656        // Bonferroni: 0.01/3 ≈ 0.00333
657        let p = array![0.005, 0.5, 0.03];
658        let fitted = sel.fit(&p, &()).unwrap();
659        assert_eq!(fitted.n_features_selected(), 0);
660    }
661
662    #[test]
663    fn test_fwe_transform() {
664        let sel = SelectFwe::<f64>::new(0.05);
665        let p = array![0.001, 0.5, 0.9];
666        let fitted = sel.fit(&p, &()).unwrap();
667        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
668        let out = fitted.transform(&x).unwrap();
669        assert_eq!(out.ncols(), 1);
670        assert_eq!(out[[0, 0]], 1.0);
671    }
672
673    #[test]
674    fn test_fwe_empty_error() {
675        let sel = SelectFwe::<f64>::new(0.05);
676        let p: Array1<f64> = Array1::zeros(0);
677        assert!(sel.fit(&p, &()).is_err());
678    }
679
680    #[test]
681    fn test_fwe_invalid_alpha() {
682        let sel = SelectFwe::<f64>::new(0.0);
683        let p = array![0.01];
684        assert!(sel.fit(&p, &()).is_err());
685    }
686
687    #[test]
688    fn test_fwe_shape_mismatch() {
689        let sel = SelectFwe::<f64>::new(0.05);
690        let p = array![0.01, 0.5];
691        let fitted = sel.fit(&p, &()).unwrap();
692        let x_bad = array![[1.0, 2.0, 3.0]];
693        assert!(fitted.transform(&x_bad).is_err());
694    }
695
696    #[test]
697    fn test_fwe_accessor() {
698        let sel = SelectFwe::<f64>::new(0.05);
699        assert_eq!(sel.alpha(), 0.05);
700    }
701
702    #[test]
703    fn test_fwe_single_feature() {
704        let sel = SelectFwe::<f64>::new(0.05);
705        // Bonferroni: 0.05/1 = 0.05; p=0.01 < 0.05 ✓
706        let p = array![0.01];
707        let fitted = sel.fit(&p, &()).unwrap();
708        assert_eq!(fitted.selected_indices(), &[0]);
709    }
710
711    #[test]
712    fn test_fwe_f32() {
713        let sel = SelectFwe::<f32>::new(0.05);
714        let p: Array1<f32> = array![0.001f32, 0.5];
715        let fitted = sel.fit(&p, &()).unwrap();
716        // Bonferroni: 0.05/2 = 0.025; p=0.001 < 0.025 ✓
717        assert_eq!(fitted.selected_indices(), &[0]);
718    }
719}