1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2, Axis};
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
35pub struct VarianceThreshold {
36 pub threshold: f64,
39}
40
41impl VarianceThreshold {
42 pub fn new(threshold: f64) -> Self {
46 Self { threshold }
47 }
48
49 pub fn with_threshold(mut self, threshold: f64) -> Self {
51 self.threshold = threshold;
52 self
53 }
54}
55
56impl Default for VarianceThreshold {
57 fn default() -> Self {
58 Self::new(0.0)
59 }
60}
61
62#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
65#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
66pub struct FittedVarianceThreshold<F: Float> {
67 variances: Array1<F>,
69 selected_indices: Vec<usize>,
71 n_features_in: usize,
73}
74
75impl<F: Float> FittedVarianceThreshold<F> {
76 pub fn variances(&self) -> &Array1<F> {
78 &self.variances
79 }
80
81 pub fn selected_indices(&self) -> &[usize] {
83 &self.selected_indices
84 }
85
86 pub fn n_features_selected(&self) -> usize {
88 self.selected_indices.len()
89 }
90}
91
92impl<F: Float> FitUnsupervised<F> for VarianceThreshold {
93 type Fitted = FittedVarianceThreshold<F>;
94
95 fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
96 let (n_samples, n_features) = x.dim();
97
98 if n_samples == 0 || n_features == 0 {
99 return Err(RustMlError::EmptyInput("input array is empty".into()));
100 }
101
102 if self.threshold < 0.0 {
103 return Err(RustMlError::InvalidParameter(
104 "threshold must be non-negative".into(),
105 ));
106 }
107
108 let n = F::from_usize(n_samples).unwrap();
109
110 let mean = x.sum_axis(Axis(0)) / n;
112
113 let mut variances = Array1::<F>::zeros(n_features);
115 for row in x.rows() {
116 for (j, (&val, &m)) in row.iter().zip(mean.iter()).enumerate() {
117 let diff = val - m;
118 variances[j] += diff * diff;
119 }
120 }
121 variances.mapv_inplace(|v| v / n);
122
123 let threshold_f = F::from_f64(self.threshold).unwrap();
125 let selected_indices: Vec<usize> = (0..n_features)
126 .filter(|&j| variances[j] > threshold_f)
127 .collect();
128
129 if selected_indices.is_empty() {
130 return Err(RustMlError::InvalidParameter(
131 "no features meet the variance threshold; all features have variance <= threshold"
132 .into(),
133 ));
134 }
135
136 Ok(FittedVarianceThreshold {
137 variances,
138 selected_indices,
139 n_features_in: n_features,
140 })
141 }
142}
143
144impl<F: Float> Transform<F> for FittedVarianceThreshold<F> {
145 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
146 if x.ncols() != self.n_features_in {
147 return Err(RustMlError::ShapeMismatch(format!(
148 "expected {} features, got {}",
149 self.n_features_in,
150 x.ncols()
151 )));
152 }
153
154 let n_rows = x.nrows();
155 let n_selected = self.selected_indices.len();
156 let mut result = Array2::<F>::zeros((n_rows, n_selected));
157
158 for (i, row) in x.rows().into_iter().enumerate() {
159 for (out_j, &src_j) in self.selected_indices.iter().enumerate() {
160 result[[i, out_j]] = row[src_j];
161 }
162 }
163
164 Ok(result)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use approx::assert_abs_diff_eq;
172 use ndarray::array;
173
174 #[test]
175 fn test_removes_constant_features() {
176 let x = array![
178 [5.0, 1.0, 3.0],
179 [5.0, 2.0, 3.0],
180 [5.0, 3.0, 3.0],
181 [5.0, 4.0, 3.0],
182 ];
183
184 let selector = VarianceThreshold::default();
185 let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
186
187 assert_eq!(fitted.selected_indices(), &[1]);
188 assert_eq!(fitted.n_features_selected(), 1);
189
190 assert_abs_diff_eq!(fitted.variances()[0], 0.0, epsilon = 1e-10);
192 assert_abs_diff_eq!(fitted.variances()[2], 0.0, epsilon = 1e-10);
193 assert!(fitted.variances()[1] > 0.0);
194 }
195
196 #[test]
197 fn test_higher_threshold_removes_low_variance() {
198 let x = array![
202 [1.0, 10.0, 0.0],
203 [2.0, 20.0, 0.0],
204 [3.0, 30.0, 0.0],
205 [4.0, 40.0, 1.0],
206 ];
207
208 let selector = VarianceThreshold::new(1.0);
210 let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
211
212 assert_eq!(fitted.selected_indices(), &[0, 1]);
213
214 let selector = VarianceThreshold::new(2.0);
216 let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
217
218 assert_eq!(fitted.selected_indices(), &[1]);
219 }
220
221 #[test]
222 fn test_transform_outputs_correct_shape() {
223 let x = array![
224 [0.0, 1.0, 2.0, 3.0],
225 [0.0, 4.0, 5.0, 6.0],
226 [0.0, 7.0, 8.0, 9.0],
227 ];
228
229 let selector = VarianceThreshold::new(0.0);
230 let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
231 let result = fitted.transform(&x).unwrap();
232
233 assert_eq!(result.dim(), (3, 3));
235
236 assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
238 assert_abs_diff_eq!(result[[0, 1]], 2.0, epsilon = 1e-10);
239 assert_abs_diff_eq!(result[[0, 2]], 3.0, epsilon = 1e-10);
240 assert_abs_diff_eq!(result[[2, 0]], 7.0, epsilon = 1e-10);
241 }
242
243 #[test]
244 fn test_keeps_all_features_when_all_vary() {
245 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
246
247 let selector = VarianceThreshold::new(0.0);
248 let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
249
250 assert_eq!(fitted.selected_indices(), &[0, 1]);
251 let result = fitted.transform(&x).unwrap();
252 assert_eq!(result.dim(), (3, 2));
253 }
254
255 #[test]
256 fn test_error_when_no_features_survive() {
257 let x = array![[1.0, 2.0], [1.0, 2.0], [1.0, 2.0]];
259
260 let selector = VarianceThreshold::new(0.0);
261 let result = FitUnsupervised::<f64>::fit(&selector, &x);
262
263 assert!(result.is_err());
264 match result.unwrap_err() {
265 RustMlError::InvalidParameter(msg) => {
266 assert!(msg.contains("no features"), "unexpected message: {}", msg);
267 }
268 other => panic!("expected InvalidParameter, got {:?}", other),
269 }
270 }
271
272 #[test]
273 fn test_error_on_empty_input() {
274 let x = Array2::<f64>::zeros((0, 3));
275
276 let selector = VarianceThreshold::new(0.0);
277 let result = FitUnsupervised::<f64>::fit(&selector, &x);
278
279 assert!(result.is_err());
280 }
281
282 #[test]
283 fn test_shape_mismatch_on_transform() {
284 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
285
286 let selector = VarianceThreshold::new(0.0);
287 let fitted = FitUnsupervised::<f64>::fit(&selector, &x).unwrap();
288
289 let wrong = array![[1.0, 2.0]]; assert!(fitted.transform(&wrong).is_err());
291 }
292
293 #[test]
294 fn test_works_with_f32() {
295 let x: Array2<f32> = array![[0.0_f32, 1.0], [0.0, 2.0], [0.0, 3.0]];
296
297 let selector = VarianceThreshold::new(0.0);
298 let fitted = FitUnsupervised::<f32>::fit(&selector, &x).unwrap();
299
300 assert_eq!(fitted.selected_indices(), &[1]);
301 let result = fitted.transform(&x).unwrap();
302 assert_eq!(result.dim(), (3, 1));
303 }
304}