use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::{Fit, Transform};
use ndarray::{Array1, Array2};
use num_traits::Float;
fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
let nrows = x.nrows();
let ncols = indices.len();
if ncols == 0 {
return Array2::zeros((nrows, 0));
}
let mut out = Array2::zeros((nrows, ncols));
for (new_j, &old_j) in indices.iter().enumerate() {
for i in 0..nrows {
out[[i, new_j]] = x[[i, old_j]];
}
}
out
}
fn validate_inputs(n_features: usize, alpha: f64) -> Result<(), FerroError> {
if n_features == 0 {
return Err(FerroError::InvalidParameter {
name: "p_values".into(),
reason: "p-value vector must not be empty".into(),
});
}
if alpha <= 0.0 || alpha > 1.0 {
return Err(FerroError::InvalidParameter {
name: "alpha".into(),
reason: format!("alpha must be in (0, 1], got {alpha}"),
});
}
Ok(())
}
#[must_use]
#[derive(Debug, Clone)]
pub struct SelectFpr<F> {
alpha: f64,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float + Send + Sync + 'static> SelectFpr<F> {
pub fn new(alpha: f64) -> Self {
Self {
alpha,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn alpha(&self) -> f64 {
self.alpha
}
}
#[derive(Debug, Clone)]
pub struct FittedSelectFpr<F> {
n_features_in: usize,
p_values: Array1<F>,
selected_indices: Vec<usize>,
}
impl<F: Float + Send + Sync + 'static> FittedSelectFpr<F> {
#[must_use]
pub fn p_values(&self) -> &Array1<F> {
&self.p_values
}
#[must_use]
pub fn selected_indices(&self) -> &[usize] {
&self.selected_indices
}
#[must_use]
pub fn n_features_selected(&self) -> usize {
self.selected_indices.len()
}
}
impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFpr<F> {
type Fitted = FittedSelectFpr<F>;
type Error = FerroError;
fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFpr<F>, FerroError> {
let n = x.len();
validate_inputs(n, self.alpha)?;
let alpha_f = F::from(self.alpha).unwrap_or(F::zero());
let selected_indices: Vec<usize> = x
.iter()
.enumerate()
.filter(|&(_, &p)| p < alpha_f)
.map(|(j, _)| j)
.collect();
Ok(FittedSelectFpr {
n_features_in: n,
p_values: x.clone(),
selected_indices,
})
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFpr<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
if x.ncols() != self.n_features_in {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows(), self.n_features_in],
actual: vec![x.nrows(), x.ncols()],
context: "FittedSelectFpr::transform".into(),
});
}
Ok(select_columns(x, &self.selected_indices))
}
}
#[must_use]
#[derive(Debug, Clone)]
pub struct SelectFdr<F> {
alpha: f64,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float + Send + Sync + 'static> SelectFdr<F> {
pub fn new(alpha: f64) -> Self {
Self {
alpha,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn alpha(&self) -> f64 {
self.alpha
}
}
#[derive(Debug, Clone)]
pub struct FittedSelectFdr<F> {
n_features_in: usize,
p_values: Array1<F>,
selected_indices: Vec<usize>,
}
impl<F: Float + Send + Sync + 'static> FittedSelectFdr<F> {
#[must_use]
pub fn p_values(&self) -> &Array1<F> {
&self.p_values
}
#[must_use]
pub fn selected_indices(&self) -> &[usize] {
&self.selected_indices
}
#[must_use]
pub fn n_features_selected(&self) -> usize {
self.selected_indices.len()
}
}
impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFdr<F> {
type Fitted = FittedSelectFdr<F>;
type Error = FerroError;
fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFdr<F>, FerroError> {
let n = x.len();
validate_inputs(n, self.alpha)?;
let alpha_f = F::from(self.alpha).unwrap_or(F::zero());
let n_f = F::from(n).unwrap_or(F::one());
let mut ranked: Vec<(usize, F)> = x.iter().copied().enumerate().collect();
ranked.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut max_qualifying_rank: Option<usize> = None;
for (rank, &(_, p_val)) in ranked.iter().enumerate() {
let bh_threshold = alpha_f * F::from(rank + 1).unwrap_or(F::one()) / n_f;
if p_val <= bh_threshold {
max_qualifying_rank = Some(rank);
}
}
let mut selected_indices: Vec<usize> = match max_qualifying_rank {
Some(max_rank) => ranked[..=max_rank]
.iter()
.map(|&(idx, _)| idx)
.collect(),
None => Vec::new(),
};
selected_indices.sort_unstable();
Ok(FittedSelectFdr {
n_features_in: n,
p_values: x.clone(),
selected_indices,
})
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFdr<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
if x.ncols() != self.n_features_in {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows(), self.n_features_in],
actual: vec![x.nrows(), x.ncols()],
context: "FittedSelectFdr::transform".into(),
});
}
Ok(select_columns(x, &self.selected_indices))
}
}
#[must_use]
#[derive(Debug, Clone)]
pub struct SelectFwe<F> {
alpha: f64,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float + Send + Sync + 'static> SelectFwe<F> {
pub fn new(alpha: f64) -> Self {
Self {
alpha,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn alpha(&self) -> f64 {
self.alpha
}
}
#[derive(Debug, Clone)]
pub struct FittedSelectFwe<F> {
n_features_in: usize,
p_values: Array1<F>,
selected_indices: Vec<usize>,
}
impl<F: Float + Send + Sync + 'static> FittedSelectFwe<F> {
#[must_use]
pub fn p_values(&self) -> &Array1<F> {
&self.p_values
}
#[must_use]
pub fn selected_indices(&self) -> &[usize] {
&self.selected_indices
}
#[must_use]
pub fn n_features_selected(&self) -> usize {
self.selected_indices.len()
}
}
impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFwe<F> {
type Fitted = FittedSelectFwe<F>;
type Error = FerroError;
fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFwe<F>, FerroError> {
let n = x.len();
validate_inputs(n, self.alpha)?;
let adjusted_alpha = self.alpha / n as f64;
let adjusted_alpha_f = F::from(adjusted_alpha).unwrap_or(F::zero());
let selected_indices: Vec<usize> = x
.iter()
.enumerate()
.filter(|&(_, &p)| p < adjusted_alpha_f)
.map(|(j, _)| j)
.collect();
Ok(FittedSelectFwe {
n_features_in: n,
p_values: x.clone(),
selected_indices,
})
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFwe<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
if x.ncols() != self.n_features_in {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows(), self.n_features_in],
actual: vec![x.nrows(), x.ncols()],
context: "FittedSelectFwe::transform".into(),
});
}
Ok(select_columns(x, &self.selected_indices))
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_fpr_selects_below_alpha() {
let sel = SelectFpr::<f64>::new(0.05);
let p = array![0.01, 0.5, 0.03, 0.9];
let fitted = sel.fit(&p, &()).unwrap();
assert_eq!(fitted.selected_indices(), &[0, 2]);
}
#[test]
fn test_fpr_none_below_alpha() {
let sel = SelectFpr::<f64>::new(0.001);
let p = array![0.01, 0.5, 0.03];
let fitted = sel.fit(&p, &()).unwrap();
assert_eq!(fitted.n_features_selected(), 0);
}
#[test]
fn test_fpr_all_below_alpha() {
let sel = SelectFpr::<f64>::new(0.99);
let p = array![0.01, 0.5, 0.03];
let fitted = sel.fit(&p, &()).unwrap();
assert_eq!(fitted.n_features_selected(), 3);
}
#[test]
fn test_fpr_transform() {
let sel = SelectFpr::<f64>::new(0.05);
let p = array![0.01, 0.5, 0.03];
let fitted = sel.fit(&p, &()).unwrap();
let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let out = fitted.transform(&x).unwrap();
assert_eq!(out.ncols(), 2); assert_eq!(out[[0, 0]], 1.0);
assert_eq!(out[[0, 1]], 3.0);
}
#[test]
fn test_fpr_empty_error() {
let sel = SelectFpr::<f64>::new(0.05);
let p: Array1<f64> = Array1::zeros(0);
assert!(sel.fit(&p, &()).is_err());
}
#[test]
fn test_fpr_invalid_alpha() {
let sel = SelectFpr::<f64>::new(0.0);
let p = array![0.01];
assert!(sel.fit(&p, &()).is_err());
let sel2 = SelectFpr::<f64>::new(1.5);
assert!(sel2.fit(&p, &()).is_err());
}
#[test]
fn test_fpr_shape_mismatch() {
let sel = SelectFpr::<f64>::new(0.05);
let p = array![0.01, 0.5];
let fitted = sel.fit(&p, &()).unwrap();
let x_bad = array![[1.0, 2.0, 3.0]];
assert!(fitted.transform(&x_bad).is_err());
}
#[test]
fn test_fpr_accessor() {
let sel = SelectFpr::<f64>::new(0.05);
assert_eq!(sel.alpha(), 0.05);
}
#[test]
fn test_fpr_p_values_accessor() {
let sel = SelectFpr::<f64>::new(0.05);
let p = array![0.01, 0.5];
let fitted = sel.fit(&p, &()).unwrap();
assert_eq!(fitted.p_values().len(), 2);
}
#[test]
fn test_fdr_basic() {
let sel = SelectFdr::<f64>::new(0.05);
let p = array![0.01, 0.5, 0.03, 0.9];
let fitted = sel.fit(&p, &()).unwrap();
assert!(fitted.selected_indices().contains(&0));
}
#[test]
fn test_fdr_multiple_pass() {
let sel = SelectFdr::<f64>::new(0.10);
let p = array![0.02, 0.5, 0.005, 0.04];
let fitted = sel.fit(&p, &()).unwrap();
assert_eq!(fitted.n_features_selected(), 3);
assert!(fitted.selected_indices().contains(&0)); assert!(fitted.selected_indices().contains(&2)); assert!(fitted.selected_indices().contains(&3)); }
#[test]
fn test_fdr_none_selected() {
let sel = SelectFdr::<f64>::new(0.001);
let p = array![0.01, 0.5, 0.03];
let fitted = sel.fit(&p, &()).unwrap();
assert_eq!(fitted.n_features_selected(), 0);
}
#[test]
fn test_fdr_transform() {
let sel = SelectFdr::<f64>::new(0.10);
let p = array![0.001, 0.5, 0.9];
let fitted = sel.fit(&p, &()).unwrap();
let x = array![[1.0, 2.0, 3.0]];
let out = fitted.transform(&x).unwrap();
assert!(out.ncols() >= 1);
}
#[test]
fn test_fdr_empty_error() {
let sel = SelectFdr::<f64>::new(0.05);
let p: Array1<f64> = Array1::zeros(0);
assert!(sel.fit(&p, &()).is_err());
}
#[test]
fn test_fdr_invalid_alpha() {
let sel = SelectFdr::<f64>::new(0.0);
let p = array![0.01];
assert!(sel.fit(&p, &()).is_err());
}
#[test]
fn test_fdr_shape_mismatch() {
let sel = SelectFdr::<f64>::new(0.05);
let p = array![0.01, 0.5];
let fitted = sel.fit(&p, &()).unwrap();
let x_bad = array![[1.0, 2.0, 3.0]];
assert!(fitted.transform(&x_bad).is_err());
}
#[test]
fn test_fdr_accessor() {
let sel = SelectFdr::<f64>::new(0.05);
assert_eq!(sel.alpha(), 0.05);
}
#[test]
fn test_fwe_basic() {
let sel = SelectFwe::<f64>::new(0.05);
let p = array![0.001, 0.5, 0.03, 0.9];
let fitted = sel.fit(&p, &()).unwrap();
assert_eq!(fitted.selected_indices(), &[0]);
}
#[test]
fn test_fwe_two_features() {
let sel = SelectFwe::<f64>::new(0.10);
let p = array![0.01, 0.02, 0.5];
let fitted = sel.fit(&p, &()).unwrap();
assert_eq!(fitted.selected_indices(), &[0, 1]);
}
#[test]
fn test_fwe_none_selected() {
let sel = SelectFwe::<f64>::new(0.01);
let p = array![0.005, 0.5, 0.03];
let fitted = sel.fit(&p, &()).unwrap();
assert_eq!(fitted.n_features_selected(), 0);
}
#[test]
fn test_fwe_transform() {
let sel = SelectFwe::<f64>::new(0.05);
let p = array![0.001, 0.5, 0.9];
let fitted = sel.fit(&p, &()).unwrap();
let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let out = fitted.transform(&x).unwrap();
assert_eq!(out.ncols(), 1);
assert_eq!(out[[0, 0]], 1.0);
}
#[test]
fn test_fwe_empty_error() {
let sel = SelectFwe::<f64>::new(0.05);
let p: Array1<f64> = Array1::zeros(0);
assert!(sel.fit(&p, &()).is_err());
}
#[test]
fn test_fwe_invalid_alpha() {
let sel = SelectFwe::<f64>::new(0.0);
let p = array![0.01];
assert!(sel.fit(&p, &()).is_err());
}
#[test]
fn test_fwe_shape_mismatch() {
let sel = SelectFwe::<f64>::new(0.05);
let p = array![0.01, 0.5];
let fitted = sel.fit(&p, &()).unwrap();
let x_bad = array![[1.0, 2.0, 3.0]];
assert!(fitted.transform(&x_bad).is_err());
}
#[test]
fn test_fwe_accessor() {
let sel = SelectFwe::<f64>::new(0.05);
assert_eq!(sel.alpha(), 0.05);
}
#[test]
fn test_fwe_single_feature() {
let sel = SelectFwe::<f64>::new(0.05);
let p = array![0.01];
let fitted = sel.fit(&p, &()).unwrap();
assert_eq!(fitted.selected_indices(), &[0]);
}
#[test]
fn test_fwe_f32() {
let sel = SelectFwe::<f32>::new(0.05);
let p: Array1<f32> = array![0.001f32, 0.5];
let fitted = sel.fit(&p, &()).unwrap();
assert_eq!(fitted.selected_indices(), &[0]);
}
}