use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
use scirs2_core::numeric::{Float, NumCast};
use crate::error::{Result, TransformError};
#[derive(Debug, Clone)]
pub struct VarianceThreshold {
threshold: f64,
variances_: Option<Array1<f64>>,
selected_features_: Option<Vec<usize>>,
n_features_in_: Option<usize>,
}
impl VarianceThreshold {
pub fn new(threshold: f64) -> Result<Self> {
if threshold < 0.0 {
return Err(TransformError::InvalidInput(
"Threshold must be non-negative".to_string(),
));
}
Ok(VarianceThreshold {
threshold,
variances_: None,
selected_features_: None,
n_features_in_: None,
})
}
pub fn with_defaults() -> Self {
VarianceThreshold {
threshold: 0.0,
variances_: None,
selected_features_: None,
n_features_in_: None,
}
}
pub fn threshold(&self) -> f64 {
self.threshold
}
pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
where
S: Data,
S::Elem: Float + NumCast,
{
let n_samples = x.shape()[0];
let n_features = x.shape()[1];
if n_samples == 0 || n_features == 0 {
return Err(TransformError::InvalidInput("Empty input data".to_string()));
}
if n_samples < 2 {
return Err(TransformError::InvalidInput(
"At least 2 samples required to compute variance".to_string(),
));
}
let mut variances = Array1::zeros(n_features);
let mut selected_features = Vec::new();
for j in 0..n_features {
let mut mean = 0.0_f64;
let mut m2 = 0.0_f64;
for i in 0..n_samples {
let val: f64 = NumCast::from(x[[i, j]]).unwrap_or(0.0);
let delta = val - mean;
mean += delta / (i as f64 + 1.0);
let delta2 = val - mean;
m2 += delta * delta2;
}
let variance = m2 / n_samples as f64;
variances[j] = variance;
if variance > self.threshold {
selected_features.push(j);
}
}
self.variances_ = Some(variances);
self.selected_features_ = Some(selected_features);
self.n_features_in_ = Some(n_features);
Ok(())
}
pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
where
S: Data,
S::Elem: Float + NumCast,
{
let n_samples = x.shape()[0];
let n_features = x.shape()[1];
let selected_features = self.selected_features_.as_ref().ok_or_else(|| {
TransformError::NotFitted("VarianceThreshold has not been fitted".to_string())
})?;
let n_features_in = self.n_features_in_.unwrap_or(0);
if n_features != n_features_in {
return Err(TransformError::InvalidInput(format!(
"x has {} features, but VarianceThreshold was fitted with {} features",
n_features, n_features_in
)));
}
let n_selected = selected_features.len();
let mut transformed = Array2::zeros((n_samples, n_selected));
for (new_idx, &old_idx) in selected_features.iter().enumerate() {
for i in 0..n_samples {
transformed[[i, new_idx]] = NumCast::from(x[[i, old_idx]]).unwrap_or(0.0);
}
}
Ok(transformed)
}
pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
where
S: Data,
S::Elem: Float + NumCast,
{
self.fit(x)?;
self.transform(x)
}
pub fn variances(&self) -> Option<&Array1<f64>> {
self.variances_.as_ref()
}
pub fn get_support(&self) -> Option<&Vec<usize>> {
self.selected_features_.as_ref()
}
pub fn get_support_mask(&self) -> Option<Array1<bool>> {
let n_features_in = self.n_features_in_?;
let selected = self.selected_features_.as_ref()?;
let mut mask = Array1::from_elem(n_features_in, false);
for &idx in selected {
mask[idx] = true;
}
Some(mask)
}
pub fn n_features_selected(&self) -> Option<usize> {
self.selected_features_.as_ref().map(|s| s.len())
}
pub fn inverse_transform(&self, _x: &Array2<f64>) -> Result<Array2<f64>> {
Err(TransformError::TransformationError(
"inverse_transform is not supported for feature selection".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array;
#[test]
fn test_variance_threshold_basic() {
let data = Array::from_shape_vec(
(3, 4),
vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
)
.expect("test data");
let mut selector = VarianceThreshold::with_defaults();
let transformed = selector.fit_transform(&data).expect("fit_transform");
assert_eq!(transformed.shape(), &[3, 2]);
let selected = selector.get_support().expect("get_support");
assert_eq!(selected, &[1, 3]);
assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(transformed[[2, 0]], 3.0, epsilon = 1e-10);
}
#[test]
fn test_variance_threshold_custom() {
let data = Array::from_shape_vec(
(4, 3),
vec![1.0, 1.0, 1.0, 2.0, 1.1, 2.0, 3.0, 1.0, 3.0, 4.0, 1.1, 4.0],
)
.expect("test data");
let mut selector = VarianceThreshold::new(0.1).expect("new");
let transformed = selector.fit_transform(&data).expect("fit_transform");
assert_eq!(transformed.shape(), &[4, 2]);
let selected = selector.get_support().expect("get_support");
assert_eq!(selected, &[0, 2]);
}
#[test]
fn test_variance_threshold_support_mask() {
let data = Array::from_shape_vec(
(3, 4),
vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
)
.expect("test data");
let mut selector = VarianceThreshold::with_defaults();
selector.fit(&data).expect("fit");
let mask = selector.get_support_mask().expect("mask");
assert_eq!(mask.len(), 4);
assert!(!mask[0]);
assert!(mask[1]);
assert!(!mask[2]);
assert!(mask[3]);
assert_eq!(selector.n_features_selected().expect("n_selected"), 2);
}
#[test]
fn test_variance_threshold_all_constant() {
let data = Array::from_shape_vec((3, 2), vec![5.0, 10.0, 5.0, 10.0, 5.0, 10.0])
.expect("test data");
let mut selector = VarianceThreshold::with_defaults();
let transformed = selector.fit_transform(&data).expect("fit_transform");
assert_eq!(transformed.shape(), &[3, 0]);
assert_eq!(selector.n_features_selected().expect("n_selected"), 0);
}
#[test]
fn test_variance_threshold_errors() {
assert!(VarianceThreshold::new(-0.1).is_err());
let small_data = Array::from_shape_vec((1, 2), vec![1.0, 2.0]).expect("test data");
let mut selector = VarianceThreshold::with_defaults();
assert!(selector.fit(&small_data).is_err());
let data =
Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("test data");
let selector_unfitted = VarianceThreshold::with_defaults();
assert!(selector_unfitted.transform(&data).is_err());
}
#[test]
fn test_variance_calculation_welford() {
let data = Array::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).expect("test data");
let mut selector = VarianceThreshold::with_defaults();
selector.fit(&data).expect("fit");
let variances = selector.variances().expect("variances");
let expected_variance = 2.0 / 3.0;
assert_abs_diff_eq!(variances[0], expected_variance, epsilon = 1e-10);
}
#[test]
fn test_feature_mismatch() {
let train =
Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
.expect("test data");
let test = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("test data");
let mut selector = VarianceThreshold::with_defaults();
selector.fit(&train).expect("fit");
assert!(selector.transform(&test).is_err());
}
}