Skip to main content

anofox_ml_preprocessing/
variance_threshold.rs

1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2, Axis};
3
4/// Parameters for `VarianceThreshold` feature selector (unfitted state).
5///
6/// Removes all features whose variance does not meet a minimum threshold.
7/// By default (`threshold = 0.0`), it removes features that have zero variance,
8/// i.e., features that are constant across all samples.
9///
10/// This is a simple baseline approach to feature selection: a feature with
11/// higher variance is more likely to be informative (though this is not
12/// guaranteed).
13///
14/// # Example
15///
16/// ```
17/// use anofox_ml_preprocessing::VarianceThreshold;
18/// use anofox_ml_core::{FitUnsupervised, Transform};
19/// use ndarray::array;
20///
21/// let x = array![
22///     [0.0, 2.0, 0.0],
23///     [0.0, 4.0, 0.0],
24///     [0.0, 6.0, 0.0],
25/// ];
26///
27/// // Remove zero-variance features (columns 0 and 2 are constant)
28/// let selector = VarianceThreshold::new(0.0);
29/// let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
30/// let x_selected = fitted.transform(&x).unwrap();
31///
32/// assert_eq!(x_selected.ncols(), 1); // only the varying column survives
33/// ```
34#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
35pub struct VarianceThreshold {
36    /// Minimum variance required for a feature to be kept.
37    /// Features with variance <= threshold are removed.
38    pub threshold: f64,
39}
40
41impl VarianceThreshold {
42    /// Create a new `VarianceThreshold` with the given threshold.
43    ///
44    /// A threshold of `0.0` removes only constant (zero-variance) features.
45    pub fn new(threshold: f64) -> Self {
46        Self { threshold }
47    }
48
49    /// Set the variance threshold.
50    pub fn with_threshold(mut self, threshold: f64) -> Self {
51        self.threshold = threshold;
52        self
53    }
54}
55
56impl Default for VarianceThreshold {
57    fn default() -> Self {
58        Self::new(0.0)
59    }
60}
61
62/// Fitted `VarianceThreshold` — holds learned per-feature variances and the
63/// indices of features that exceeded the threshold.
64#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
65#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
66pub struct FittedVarianceThreshold<F: Float> {
67    /// Per-feature variance computed during fitting.
68    variances: Array1<F>,
69    /// Indices of features whose variance exceeded the threshold.
70    selected_indices: Vec<usize>,
71    /// Total number of input features (before selection).
72    n_features_in: usize,
73}
74
75impl<F: Float> FittedVarianceThreshold<F> {
76    /// Per-feature variances computed during fitting.
77    pub fn variances(&self) -> &Array1<F> {
78        &self.variances
79    }
80
81    /// Indices of selected features (those with variance > threshold).
82    pub fn selected_indices(&self) -> &[usize] {
83        &self.selected_indices
84    }
85
86    /// Number of features that survived selection.
87    pub fn n_features_selected(&self) -> usize {
88        self.selected_indices.len()
89    }
90}
91
92impl<F: Float> FitUnsupervised<F> for VarianceThreshold {
93    type Fitted = FittedVarianceThreshold<F>;
94
95    fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
96        let (n_samples, n_features) = x.dim();
97
98        if n_samples == 0 || n_features == 0 {
99            return Err(RustMlError::EmptyInput("input array is empty".into()));
100        }
101
102        if self.threshold < 0.0 {
103            return Err(RustMlError::InvalidParameter(
104                "threshold must be non-negative".into(),
105            ));
106        }
107
108        let n = F::from_usize(n_samples).unwrap();
109
110        // Compute per-feature mean.
111        let mean = x.sum_axis(Axis(0)) / n;
112
113        // Compute per-feature variance: Var(X) = E[(X - mean)^2].
114        let mut variances = Array1::<F>::zeros(n_features);
115        for row in x.rows() {
116            for (j, (&val, &m)) in row.iter().zip(mean.iter()).enumerate() {
117                let diff = val - m;
118                variances[j] += diff * diff;
119            }
120        }
121        variances.mapv_inplace(|v| v / n);
122
123        // Select features whose variance exceeds the threshold.
124        let threshold_f = F::from_f64(self.threshold).unwrap();
125        let selected_indices: Vec<usize> = (0..n_features)
126            .filter(|&j| variances[j] > threshold_f)
127            .collect();
128
129        if selected_indices.is_empty() {
130            return Err(RustMlError::InvalidParameter(
131                "no features meet the variance threshold; all features have variance <= threshold"
132                    .into(),
133            ));
134        }
135
136        Ok(FittedVarianceThreshold {
137            variances,
138            selected_indices,
139            n_features_in: n_features,
140        })
141    }
142}
143
144impl<F: Float> Transform<F> for FittedVarianceThreshold<F> {
145    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
146        if x.ncols() != self.n_features_in {
147            return Err(RustMlError::ShapeMismatch(format!(
148                "expected {} features, got {}",
149                self.n_features_in,
150                x.ncols()
151            )));
152        }
153
154        let n_rows = x.nrows();
155        let n_selected = self.selected_indices.len();
156        let mut result = Array2::<F>::zeros((n_rows, n_selected));
157
158        for (i, row) in x.rows().into_iter().enumerate() {
159            for (out_j, &src_j) in self.selected_indices.iter().enumerate() {
160                result[[i, out_j]] = row[src_j];
161            }
162        }
163
164        Ok(result)
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use approx::assert_abs_diff_eq;
172    use ndarray::array;
173
174    #[test]
175    fn test_removes_constant_features() {
176        // Column 0 and 2 are constant, column 1 varies.
177        let x = array![
178            [5.0, 1.0, 3.0],
179            [5.0, 2.0, 3.0],
180            [5.0, 3.0, 3.0],
181            [5.0, 4.0, 3.0],
182        ];
183
184        let selector = VarianceThreshold::default();
185        let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
186
187        assert_eq!(fitted.selected_indices(), &[1]);
188        assert_eq!(fitted.n_features_selected(), 1);
189
190        // Constant columns should have variance 0.
191        assert_abs_diff_eq!(fitted.variances()[0], 0.0, epsilon = 1e-10);
192        assert_abs_diff_eq!(fitted.variances()[2], 0.0, epsilon = 1e-10);
193        assert!(fitted.variances()[1] > 0.0);
194    }
195
196    #[test]
197    fn test_higher_threshold_removes_low_variance() {
198        // Col 0: values 1,2,3,4 -> var = 1.25
199        // Col 1: values 10,20,30,40 -> var = 125.0
200        // Col 2: values 0,0,0,1 -> var = 0.1875
201        let x = array![
202            [1.0, 10.0, 0.0],
203            [2.0, 20.0, 0.0],
204            [3.0, 30.0, 0.0],
205            [4.0, 40.0, 1.0],
206        ];
207
208        // Threshold = 1.0 should remove col 2 (var=0.1875) but keep col 0 (var=1.25)
209        let selector = VarianceThreshold::new(1.0);
210        let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
211
212        assert_eq!(fitted.selected_indices(), &[0, 1]);
213
214        // Threshold = 2.0 should keep only col 1 (var=125.0)
215        let selector = VarianceThreshold::new(2.0);
216        let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
217
218        assert_eq!(fitted.selected_indices(), &[1]);
219    }
220
221    #[test]
222    fn test_transform_outputs_correct_shape() {
223        let x = array![
224            [0.0, 1.0, 2.0, 3.0],
225            [0.0, 4.0, 5.0, 6.0],
226            [0.0, 7.0, 8.0, 9.0],
227        ];
228
229        let selector = VarianceThreshold::new(0.0);
230        let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
231        let result = fitted.transform(&x).unwrap();
232
233        // Column 0 is constant -> removed; columns 1,2,3 survive.
234        assert_eq!(result.dim(), (3, 3));
235
236        // Verify the selected columns contain the right data.
237        assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
238        assert_abs_diff_eq!(result[[0, 1]], 2.0, epsilon = 1e-10);
239        assert_abs_diff_eq!(result[[0, 2]], 3.0, epsilon = 1e-10);
240        assert_abs_diff_eq!(result[[2, 0]], 7.0, epsilon = 1e-10);
241    }
242
243    #[test]
244    fn test_keeps_all_features_when_all_vary() {
245        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
246
247        let selector = VarianceThreshold::new(0.0);
248        let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
249
250        assert_eq!(fitted.selected_indices(), &[0, 1]);
251        let result = fitted.transform(&x).unwrap();
252        assert_eq!(result.dim(), (3, 2));
253    }
254
255    #[test]
256    fn test_error_when_no_features_survive() {
257        // All features are constant.
258        let x = array![[1.0, 2.0], [1.0, 2.0], [1.0, 2.0]];
259
260        let selector = VarianceThreshold::new(0.0);
261        let result = FitUnsupervised::<f64>::fit(&selector, &x);
262
263        assert!(result.is_err());
264        match result.unwrap_err() {
265            RustMlError::InvalidParameter(msg) => {
266                assert!(msg.contains("no features"), "unexpected message: {}", msg);
267            }
268            other => panic!("expected InvalidParameter, got {:?}", other),
269        }
270    }
271
272    #[test]
273    fn test_error_on_empty_input() {
274        let x = Array2::<f64>::zeros((0, 3));
275
276        let selector = VarianceThreshold::new(0.0);
277        let result = FitUnsupervised::<f64>::fit(&selector, &x);
278
279        assert!(result.is_err());
280    }
281
282    #[test]
283    fn test_shape_mismatch_on_transform() {
284        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
285
286        let selector = VarianceThreshold::new(0.0);
287        let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
288
289        let wrong = array![[1.0, 2.0]]; // 2 cols instead of 3
290        assert!(fitted.transform(&wrong).is_err());
291    }
292
293    #[test]
294    fn test_works_with_f32() {
295        let x: Array2<f32> = array![[0.0_f32, 1.0], [0.0, 2.0], [0.0, 3.0]];
296
297        let selector = VarianceThreshold::new(0.0);
298        let fitted = FitUnsupervised::<f32>::fit(&selector, &x).unwrap();
299
300        assert_eq!(fitted.selected_indices(), &[1]);
301        let result = fitted.transform(&x).unwrap();
302        assert_eq!(result.dim(), (3, 1));
303    }
304}