use crate::dataset::Dataset;
use crate::error::{Result, ScryLearnError};
use crate::preprocess::Transformer;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct VarianceThreshold {
threshold: f64,
variances_: Vec<f64>,
mask_: Vec<bool>,
fitted: bool,
}
impl VarianceThreshold {
pub fn new() -> Self {
Self {
threshold: 0.0,
variances_: Vec::new(),
mask_: Vec::new(),
fitted: false,
}
}
pub fn threshold(mut self, t: f64) -> Self {
self.threshold = t;
self
}
pub fn variances(&self) -> &[f64] {
&self.variances_
}
pub fn get_support(&self) -> &[bool] {
&self.mask_
}
}
impl Default for VarianceThreshold {
fn default() -> Self {
Self::new()
}
}
impl Transformer for VarianceThreshold {
fn fit(&mut self, data: &Dataset) -> Result<()> {
let n = data.n_samples();
if n == 0 {
return Err(ScryLearnError::EmptyDataset);
}
let nf = n as f64;
self.variances_ = Vec::with_capacity(data.n_features());
self.mask_ = Vec::with_capacity(data.n_features());
for col in &data.features {
let mean = col.iter().sum::<f64>() / nf;
let var = col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / nf;
self.variances_.push(var);
self.mask_.push(var > self.threshold);
}
self.fitted = true;
Ok(())
}
fn transform(&self, data: &mut Dataset) -> Result<()> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
filter_features(data, &self.mask_);
Ok(())
}
fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
Err(ScryLearnError::InvalidParameter(
"VarianceThreshold is not invertible — dropped columns cannot be restored".into(),
))
}
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum ScoreFn {
FClassif,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct SelectKBest {
k: usize,
score_fn: ScoreFn,
scores_: Vec<f64>,
mask_: Vec<bool>,
fitted: bool,
}
impl SelectKBest {
pub fn new(score_fn: ScoreFn) -> Self {
Self {
k: 10,
score_fn,
scores_: Vec::new(),
mask_: Vec::new(),
fitted: false,
}
}
pub fn k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn scores(&self) -> &[f64] {
&self.scores_
}
pub fn get_support(&self) -> &[bool] {
&self.mask_
}
}
impl Transformer for SelectKBest {
fn fit(&mut self, data: &Dataset) -> Result<()> {
let n = data.n_samples();
if n == 0 {
return Err(ScryLearnError::EmptyDataset);
}
self.scores_ = match self.score_fn {
ScoreFn::FClassif => f_classif(data),
};
let k = self.k.min(data.n_features());
let mut sorted_scores: Vec<f64> = self.scores_.clone();
sorted_scores.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let cutoff = if k > 0 && k <= sorted_scores.len() {
sorted_scores[k - 1]
} else {
f64::NEG_INFINITY
};
self.mask_ = vec![false; self.scores_.len()];
let mut kept = 0;
for (i, &score) in self.scores_.iter().enumerate() {
if score > cutoff && kept < k {
self.mask_[i] = true;
kept += 1;
}
}
for (i, &score) in self.scores_.iter().enumerate() {
if kept >= k {
break;
}
if !self.mask_[i] && (score - cutoff).abs() < 1e-12 {
self.mask_[i] = true;
kept += 1;
}
}
self.fitted = true;
Ok(())
}
fn transform(&self, data: &mut Dataset) -> Result<()> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
filter_features(data, &self.mask_);
Ok(())
}
fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
Err(ScryLearnError::InvalidParameter(
"SelectKBest is not invertible — dropped columns cannot be restored".into(),
))
}
}
pub fn f_classif(data: &Dataset) -> Vec<f64> {
let n = data.n_samples();
let n_features = data.n_features();
let mut class_set: Vec<i64> = data.target.iter().map(|&v| v as i64).collect();
class_set.sort_unstable();
class_set.dedup();
let n_classes = class_set.len();
if n_classes < 2 {
return vec![0.0; n_features];
}
let class_indices: Vec<Vec<usize>> = class_set
.iter()
.map(|&c| (0..n).filter(|&i| data.target[i] as i64 == c).collect())
.collect();
let mut f_values = Vec::with_capacity(n_features);
for j in 0..n_features {
let col = &data.features[j];
let grand_mean = col.iter().sum::<f64>() / n as f64;
let mut ss_between = 0.0;
let mut ss_within = 0.0;
for group in &class_indices {
let n_g = group.len() as f64;
if n_g == 0.0 {
continue;
}
let group_mean = group.iter().map(|&i| col[i]).sum::<f64>() / n_g;
ss_between += n_g * (group_mean - grand_mean).powi(2);
for &i in group {
ss_within += (col[i] - group_mean).powi(2);
}
}
let df_between = (n_classes - 1) as f64;
let df_within = (n - n_classes) as f64;
let f_val = if df_within > 0.0 && ss_within > 1e-15 {
(ss_between / df_between) / (ss_within / df_within)
} else if ss_between > 1e-15 {
f64::MAX
} else {
0.0
};
f_values.push(f_val);
}
f_values
}
fn filter_features(data: &mut Dataset, mask: &[bool]) {
let mut new_features = Vec::new();
let mut new_names = Vec::new();
for (j, &keep) in mask.iter().enumerate() {
if keep {
new_features.push(data.features[j].clone());
new_names.push(data.feature_names[j].clone());
}
}
data.features = new_features;
data.feature_names = new_names;
data.sync_matrix();
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pipeline::Pipeline;
use crate::preprocess::StandardScaler;
use crate::tree::DecisionTreeClassifier;
fn iris_like() -> Dataset {
let n_per_class = 30;
let n = n_per_class * 3;
let mut f0 = Vec::with_capacity(n);
let mut f1 = Vec::with_capacity(n);
let mut f2 = Vec::with_capacity(n);
let mut f3 = Vec::with_capacity(n);
let mut target = Vec::with_capacity(n);
let mut rng = crate::rng::FastRng::new(123);
for _ in 0..n_per_class {
f0.push(5.0 + rng.f64() * 0.5); f1.push(3.4 + rng.f64() * 0.4); f2.push(1.0 + rng.f64() * 0.5); f3.push(0.1 + rng.f64() * 0.2); target.push(0.0);
}
for _ in 0..n_per_class {
f0.push(5.5 + rng.f64() * 0.8); f1.push(2.5 + rng.f64() * 0.5); f2.push(4.0 + rng.f64() * 0.5); f3.push(1.2 + rng.f64() * 0.3); target.push(1.0);
}
for _ in 0..n_per_class {
f0.push(6.0 + rng.f64() * 1.0); f1.push(2.8 + rng.f64() * 0.5); f2.push(5.5 + rng.f64() * 0.5); f3.push(2.0 + rng.f64() * 0.3); target.push(2.0);
}
Dataset::new(
vec![f0, f1, f2, f3],
target,
vec![
"sepal_len".into(),
"sepal_wid".into(),
"petal_len".into(),
"petal_wid".into(),
],
"species",
)
}
#[test]
fn test_variance_threshold_removes_constant() {
let mut data = Dataset::new(
vec![
vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 5.0, 5.0, 5.0], vec![0.0, 1.0, 0.0, 1.0], ],
vec![0.0, 1.0, 0.0, 1.0],
vec!["a".into(), "b".into(), "c".into()],
"t",
);
let mut vt = VarianceThreshold::new();
vt.fit_transform(&mut data).unwrap();
assert_eq!(data.n_features(), 2);
assert_eq!(data.feature_names, vec!["a", "c"]);
}
#[test]
fn test_variance_threshold_custom() {
let mut data = Dataset::new(
vec![
vec![1.0, 1.0, 1.0, 1.1], vec![0.0, 10.0, 0.0, 10.0], ],
vec![0.0; 4],
vec!["low_var".into(), "high_var".into()],
"t",
);
let mut vt = VarianceThreshold::new().threshold(0.01);
vt.fit_transform(&mut data).unwrap();
assert_eq!(data.n_features(), 1);
assert_eq!(data.feature_names, vec!["high_var"]);
}
#[test]
fn test_select_k_best_petal_features_rank_highest() {
let data = iris_like();
let mut sel = SelectKBest::new(ScoreFn::FClassif).k(2);
sel.fit(&data).unwrap();
let scores = sel.scores();
assert!(
scores[2] > scores[0],
"petal_len ({:.1}) should rank higher than sepal_len ({:.1})",
scores[2],
scores[0]
);
assert!(
scores[3] > scores[1],
"petal_wid ({:.1}) should rank higher than sepal_wid ({:.1})",
scores[3],
scores[1]
);
let mut data_copy = data.clone();
sel.transform(&mut data_copy).unwrap();
assert_eq!(data_copy.n_features(), 2);
let support = sel.get_support();
assert!(!support[0], "sepal_len should be dropped");
assert!(!support[1], "sepal_wid should be dropped");
assert!(support[2], "petal_len should be kept");
assert!(support[3], "petal_wid should be kept");
}
#[test]
fn test_select_k_best_not_fitted() {
let sel = SelectKBest::new(ScoreFn::FClassif);
let mut data = Dataset::new(vec![vec![1.0]], vec![0.0], vec!["x".into()], "t");
assert!(sel.transform(&mut data).is_err());
}
#[test]
fn test_f_classif_basic() {
let data = Dataset::new(
vec![
vec![1.0, 1.0, 1.0, 10.0, 10.0, 10.0], vec![3.0, 7.0, 2.0, 5.0, 8.0, 1.0], ],
vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
vec!["good".into(), "noise".into()],
"class",
);
let scores = f_classif(&data);
assert!(
scores[0] > scores[1],
"good feature ({:.1}) should have higher F-value than noise ({:.1})",
scores[0],
scores[1]
);
}
#[test]
fn test_pipeline_vt_scaler_dt() {
let features = vec![
vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0], vec![5.0, 5.0, 5.0, 5.0, 5.0, 5.0], vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0], ];
let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let data = Dataset::new(
features,
target,
vec!["a".into(), "b".into(), "c".into()],
"class",
);
let mut pipeline = Pipeline::new()
.add_transformer(VarianceThreshold::new())
.add_transformer(StandardScaler::new())
.set_model(DecisionTreeClassifier::new());
pipeline.fit(&data).unwrap();
let preds = pipeline.predict(&data).unwrap();
assert_eq!(preds.len(), 6);
}
}