use anofox_ml_core::{Float, Result, RustMlError, Transform};
use ndarray::{Array1, Array2};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SelectFromModel {
pub threshold: Option<f64>,
pub max_features: Option<usize>,
}
impl SelectFromModel {
pub fn new() -> Self {
Self {
threshold: None,
max_features: None,
}
}
pub fn with_threshold(mut self, threshold: f64) -> Self {
self.threshold = Some(threshold);
self
}
pub fn with_max_features(mut self, max_features: usize) -> Self {
self.max_features = Some(max_features);
self
}
pub fn fit(&self, importances: &Array1<f64>) -> Result<FittedSelectFromModel> {
let n_features = importances.len();
if n_features == 0 {
return Err(RustMlError::EmptyInput(
"importances vector is empty".into(),
));
}
if self.threshold.is_none() && self.max_features.is_none() {
return Err(RustMlError::InvalidParameter(
"at least one of threshold or max_features must be set".into(),
));
}
if let Some(max_f) = self.max_features {
if max_f == 0 {
return Err(RustMlError::InvalidParameter(
"max_features must be at least 1".into(),
));
}
}
let mut candidates: Vec<(usize, f64)> = if let Some(thresh) = self.threshold {
importances
.iter()
.copied()
.enumerate()
.filter(|&(_, imp)| imp >= thresh)
.collect()
} else {
importances.iter().copied().enumerate().collect()
};
if let Some(max_f) = self.max_features {
if candidates.len() > max_f {
candidates.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.0.cmp(&b.0))
});
candidates.truncate(max_f);
}
}
if candidates.is_empty() {
return Err(RustMlError::InvalidParameter(
"no features meet the selection criteria".into(),
));
}
let mut selected_indices: Vec<usize> = candidates.iter().map(|&(idx, _)| idx).collect();
selected_indices.sort_unstable();
Ok(FittedSelectFromModel {
importances: importances.clone(),
selected_indices,
n_features_in: n_features,
})
}
}
impl Default for SelectFromModel {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct FittedSelectFromModel {
importances: Array1<f64>,
selected_indices: Vec<usize>,
n_features_in: usize,
}
impl FittedSelectFromModel {
pub fn importances(&self) -> &Array1<f64> {
&self.importances
}
pub fn selected_indices(&self) -> &[usize] {
&self.selected_indices
}
pub fn n_features_selected(&self) -> usize {
self.selected_indices.len()
}
}
impl<F: Float> Transform<F> for FittedSelectFromModel {
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
if x.ncols() != self.n_features_in {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} features, got {}",
self.n_features_in,
x.ncols()
)));
}
let n_rows = x.nrows();
let n_selected = self.selected_indices.len();
let mut result = Array2::<F>::zeros((n_rows, n_selected));
for (i, row) in x.rows().into_iter().enumerate() {
for (out_j, &src_j) in self.selected_indices.iter().enumerate() {
result[[i, out_j]] = row[src_j];
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_threshold_selects_important_features() {
let importances = array![0.05, 0.40, 0.10, 0.45];
let selector = SelectFromModel::new().with_threshold(0.20);
let fitted = selector.fit(&importances).unwrap();
assert_eq!(fitted.selected_indices(), &[1, 3]);
}
#[test]
fn test_max_features_selects_top_n() {
let importances = array![0.1, 0.5, 0.3, 0.8, 0.2];
let selector = SelectFromModel::new().with_max_features(2);
let fitted = selector.fit(&importances).unwrap();
assert_eq!(fitted.selected_indices(), &[1, 3]);
}
#[test]
fn test_threshold_and_max_features_combined() {
let importances = array![0.05, 0.40, 0.30, 0.45, 0.35];
let selector = SelectFromModel::new()
.with_threshold(0.20)
.with_max_features(2);
let fitted = selector.fit(&importances).unwrap();
assert_eq!(fitted.selected_indices(), &[1, 3]);
}
#[test]
fn test_transform_selects_correct_columns() {
let importances = array![0.1, 0.9, 0.5];
let selector = SelectFromModel::new().with_max_features(2);
let fitted = selector.fit(&importances).unwrap();
assert_eq!(fitted.selected_indices(), &[1, 2]);
let x = array![[10.0, 20.0, 30.0], [40.0, 50.0, 60.0],];
let result = fitted.transform(&x).unwrap();
assert_eq!(result.dim(), (2, 2));
assert_eq!(result[[0, 0]], 20.0);
assert_eq!(result[[0, 1]], 30.0);
assert_eq!(result[[1, 0]], 50.0);
assert_eq!(result[[1, 1]], 60.0);
}
#[test]
fn test_error_no_criteria_set() {
let importances = array![0.1, 0.2, 0.3];
let selector = SelectFromModel::new(); let result = selector.fit(&importances);
assert!(result.is_err());
match result.unwrap_err() {
RustMlError::InvalidParameter(msg) => {
assert!(
msg.contains("threshold") || msg.contains("max_features"),
"unexpected message: {}",
msg
);
}
other => panic!("expected InvalidParameter, got {:?}", other),
}
}
#[test]
fn test_error_no_features_survive_threshold() {
let importances = array![0.01, 0.02, 0.03];
let selector = SelectFromModel::new().with_threshold(0.50);
let result = selector.fit(&importances);
assert!(result.is_err());
match result.unwrap_err() {
RustMlError::InvalidParameter(msg) => {
assert!(msg.contains("no features"), "unexpected message: {}", msg);
}
other => panic!("expected InvalidParameter, got {:?}", other),
}
}
#[test]
fn test_error_empty_importances() {
let importances = Array1::<f64>::zeros(0);
let selector = SelectFromModel::new().with_threshold(0.0);
let result = selector.fit(&importances);
assert!(result.is_err());
}
#[test]
fn test_shape_mismatch_on_transform() {
let importances = array![0.5, 0.5, 0.5];
let selector = SelectFromModel::new().with_threshold(0.0);
let fitted = selector.fit(&importances).unwrap();
let wrong = array![[1.0, 2.0]]; assert!(Transform::<f64>::transform(&fitted, &wrong).is_err());
}
#[test]
fn test_works_with_f32_transform() {
let importances = array![0.1, 0.9];
let selector = SelectFromModel::new().with_max_features(1);
let fitted = selector.fit(&importances).unwrap();
assert_eq!(fitted.selected_indices(), &[1]);
let x: Array2<f32> = array![[1.0_f32, 2.0], [3.0, 4.0]];
let result = Transform::<f32>::transform(&fitted, &x).unwrap();
assert_eq!(result.dim(), (2, 1));
assert_eq!(result[[0, 0]], 2.0_f32);
}
#[test]
fn test_max_features_zero_is_error() {
let importances = array![0.1, 0.2];
let selector = SelectFromModel::new().with_max_features(0);
let result = selector.fit(&importances);
assert!(result.is_err());
}
#[test]
fn test_n_features_selected() {
let importances = array![0.1, 0.5, 0.3, 0.8];
let selector = SelectFromModel::new().with_threshold(0.25);
let fitted = selector.fit(&importances).unwrap();
assert_eq!(fitted.n_features_selected(), 3); assert_eq!(fitted.selected_indices(), &[1, 2, 3]);
}
}