use ferrolearn_core::error::FerroError;
use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
use ferrolearn_core::traits::{Fit, Transform};
use ndarray::{Array1, Array2};
use num_traits::Float;
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum ThresholdStrategy {
#[default]
Mean,
Median,
Value(f64),
Percentile(f64),
}
#[must_use]
#[derive(Debug, Clone)]
pub struct SelectFromModelExt<F> {
threshold: ThresholdStrategy,
max_features: Option<usize>,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float + Send + Sync + 'static> SelectFromModelExt<F> {
pub fn new(threshold: ThresholdStrategy, max_features: Option<usize>) -> Self {
Self {
threshold,
max_features,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn threshold_strategy(&self) -> ThresholdStrategy {
self.threshold
}
#[must_use]
pub fn max_features(&self) -> Option<usize> {
self.max_features
}
}
impl<F: Float + Send + Sync + 'static> Default for SelectFromModelExt<F> {
fn default() -> Self {
Self::new(ThresholdStrategy::Mean, None)
}
}
#[derive(Debug, Clone)]
pub struct FittedSelectFromModelExt<F> {
n_features_in: usize,
threshold_value: F,
importances: Array1<F>,
selected_indices: Vec<usize>,
}
impl<F: Float + Send + Sync + 'static> FittedSelectFromModelExt<F> {
#[must_use]
pub fn threshold_value(&self) -> F {
self.threshold_value
}
#[must_use]
pub fn importances(&self) -> &Array1<F> {
&self.importances
}
#[must_use]
pub fn selected_indices(&self) -> &[usize] {
&self.selected_indices
}
#[must_use]
pub fn n_features_selected(&self) -> usize {
self.selected_indices.len()
}
}
fn compute_median<F: Float>(values: &[F]) -> F {
let mut sorted: Vec<F> = values.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = sorted.len();
if n % 2 == 0 {
let two = F::one() + F::one();
(sorted[n / 2 - 1] + sorted[n / 2]) / two
} else {
sorted[n / 2]
}
}
fn compute_percentile_threshold<F: Float>(values: &[F], pct: f64) -> F {
let mut sorted: Vec<F> = values.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = sorted.len();
let rank = ((100.0 - pct) / 100.0) * (n.saturating_sub(1)) as f64;
let lower = rank.floor() as usize;
let upper = rank.ceil() as usize;
let lower = lower.min(n.saturating_sub(1));
let upper = upper.min(n.saturating_sub(1));
if lower == upper {
sorted[lower]
} else {
let frac = F::from(rank - rank.floor()).unwrap_or(F::zero());
sorted[lower] * (F::one() - frac) + sorted[upper] * frac
}
}
fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
let nrows = x.nrows();
let ncols = indices.len();
if ncols == 0 {
return Array2::zeros((nrows, 0));
}
let mut out = Array2::zeros((nrows, ncols));
for (new_j, &old_j) in indices.iter().enumerate() {
for i in 0..nrows {
out[[i, new_j]] = x[[i, old_j]];
}
}
out
}
impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFromModelExt<F> {
type Fitted = FittedSelectFromModelExt<F>;
type Error = FerroError;
fn fit(
&self,
x: &Array1<F>,
_y: &(),
) -> Result<FittedSelectFromModelExt<F>, FerroError> {
let n = x.len();
if n == 0 {
return Err(FerroError::InvalidParameter {
name: "importances".into(),
reason: "importance vector must not be empty".into(),
});
}
let values: Vec<F> = x.iter().copied().collect();
let threshold_value = match self.threshold {
ThresholdStrategy::Mean => {
values.iter().copied().fold(F::zero(), |acc, v| acc + v)
/ F::from(n).unwrap_or(F::one())
}
ThresholdStrategy::Median => compute_median(&values),
ThresholdStrategy::Value(v) => F::from(v).unwrap_or(F::zero()),
ThresholdStrategy::Percentile(pct) => {
if pct <= 0.0 || pct > 100.0 {
return Err(FerroError::InvalidParameter {
name: "percentile".into(),
reason: format!(
"percentile must be in (0, 100], got {}",
pct
),
});
}
compute_percentile_threshold(&values, pct)
}
};
let mut selected_indices: Vec<usize> = values
.iter()
.enumerate()
.filter(|&(_, &imp)| imp >= threshold_value)
.map(|(j, _)| j)
.collect();
if let Some(max_f) = self.max_features {
if selected_indices.len() > max_f {
selected_indices.sort_by(|&a, &b| {
values[b]
.partial_cmp(&values[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
selected_indices.truncate(max_f);
selected_indices.sort_unstable();
}
}
Ok(FittedSelectFromModelExt {
n_features_in: n,
threshold_value,
importances: x.clone(),
selected_indices,
})
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFromModelExt<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
if x.ncols() != self.n_features_in {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows(), self.n_features_in],
actual: vec![x.nrows(), x.ncols()],
context: "FittedSelectFromModelExt::transform".into(),
});
}
Ok(select_columns(x, &self.selected_indices))
}
}
impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for FittedSelectFromModelExt<F> {
fn fit_pipeline(
&self,
_x: &Array2<F>,
_y: &Array1<F>,
) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
Ok(Box::new(self.clone()))
}
}
impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F>
for FittedSelectFromModelExt<F>
{
fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
self.transform(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn test_mean_threshold() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
let importances = array![0.1, 0.5, 0.4];
let fitted = sel.fit(&importances, &()).unwrap();
assert_eq!(fitted.selected_indices(), &[1, 2]);
let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let out = fitted.transform(&x).unwrap();
assert_eq!(out.ncols(), 2);
assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
assert_abs_diff_eq!(out[[0, 1]], 3.0, epsilon = 1e-15);
}
#[test]
fn test_median_threshold() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Median, None);
let importances = array![0.1, 0.5, 0.3];
let fitted = sel.fit(&importances, &()).unwrap();
assert_eq!(fitted.selected_indices(), &[1, 2]);
}
#[test]
fn test_median_threshold_even() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Median, None);
let importances = array![0.1, 0.5, 0.2, 0.6];
let fitted = sel.fit(&importances, &()).unwrap();
assert_eq!(fitted.selected_indices(), &[1, 3]);
}
#[test]
fn test_explicit_value_threshold() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(0.45), None);
let importances = array![0.1, 0.5, 0.4];
let fitted = sel.fit(&importances, &()).unwrap();
assert_eq!(fitted.selected_indices(), &[1]);
}
#[test]
fn test_percentile_threshold_top_50() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Percentile(50.0), None);
let importances = array![0.5, 0.1, 0.7, 0.3];
let fitted = sel.fit(&importances, &()).unwrap();
assert!(fitted.selected_indices().contains(&0));
assert!(fitted.selected_indices().contains(&2));
assert_eq!(fitted.n_features_selected(), 2);
}
#[test]
fn test_percentile_100_keeps_all() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Percentile(100.0), None);
let importances = array![0.1, 0.5, 0.3];
let fitted = sel.fit(&importances, &()).unwrap();
assert_eq!(fitted.n_features_selected(), 3);
}
#[test]
fn test_percentile_invalid() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Percentile(0.0), None);
let importances = array![0.1, 0.5, 0.3];
assert!(sel.fit(&importances, &()).is_err());
let sel2 = SelectFromModelExt::<f64>::new(ThresholdStrategy::Percentile(101.0), None);
assert!(sel2.fit(&importances, &()).is_err());
}
#[test]
fn test_max_features_cap() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(0.0), Some(2));
let importances = array![0.3, 0.5, 0.1, 0.7];
let fitted = sel.fit(&importances, &()).unwrap();
assert_eq!(fitted.n_features_selected(), 2);
assert_eq!(fitted.selected_indices(), &[1, 3]);
}
#[test]
fn test_max_features_not_needed() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(0.4), Some(5));
let importances = array![0.1, 0.5, 0.4];
let fitted = sel.fit(&importances, &()).unwrap();
assert_eq!(fitted.n_features_selected(), 2);
}
#[test]
fn test_empty_importances_error() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
let importances: Array1<f64> = Array1::zeros(0);
assert!(sel.fit(&importances, &()).is_err());
}
#[test]
fn test_shape_mismatch_on_transform() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
let importances = array![0.5, 0.5];
let fitted = sel.fit(&importances, &()).unwrap();
let x_bad = array![[1.0, 2.0, 3.0]]; assert!(fitted.transform(&x_bad).is_err());
}
#[test]
fn test_threshold_value_accessor() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(0.42), None);
let importances = array![0.1, 0.5];
let fitted = sel.fit(&importances, &()).unwrap();
assert_abs_diff_eq!(fitted.threshold_value(), 0.42, epsilon = 1e-15);
}
#[test]
fn test_default() {
let sel = SelectFromModelExt::<f64>::default();
assert_eq!(sel.threshold_strategy(), ThresholdStrategy::Mean);
assert_eq!(sel.max_features(), None);
}
#[test]
fn test_pipeline_integration() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
let importances = array![0.1, 0.9];
let fitted = sel.fit(&importances, &()).unwrap();
let x = array![[1.0, 2.0], [3.0, 4.0]];
let y = array![0.0, 1.0];
let fitted_box = fitted.fit_pipeline(&x, &y).unwrap();
let out = fitted_box.transform_pipeline(&x).unwrap();
assert_eq!(out.ncols(), 1);
}
#[test]
fn test_f32() {
let sel = SelectFromModelExt::<f32>::new(ThresholdStrategy::Mean, None);
let importances: Array1<f32> = array![0.1f32, 0.5, 0.4];
let fitted = sel.fit(&importances, &()).unwrap();
assert_eq!(fitted.n_features_selected(), 2);
}
#[test]
fn test_none_selected_high_threshold() {
let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(10.0), None);
let importances = array![0.1, 0.5, 0.4];
let fitted = sel.fit(&importances, &()).unwrap();
assert_eq!(fitted.n_features_selected(), 0);
let x = array![[1.0, 2.0, 3.0]];
let out = fitted.transform(&x).unwrap();
assert_eq!(out.ncols(), 0);
assert_eq!(out.nrows(), 1);
}
}