Skip to main content

anofox_ml_preprocessing/
select_k_best.rs

1use anofox_ml_core::{Float, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2, Axis};
3use std::collections::HashMap;
4
5/// Pluggable scoring function used by [`SelectKBest`].
6///
7/// Each variant defines a different univariate statistical test for
8/// ranking features.
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10pub enum ScoringFunction {
11    /// ANOVA F-value for classification (one-way between-groups F-test).
12    ///
13    /// For each feature, groups samples by class label and computes
14    /// between-class variance / within-class variance. Higher F means the
15    /// feature is more discriminative. Requires target labels `y`.
16    FClassif,
17
18    /// Univariate linear-regression F-statistic for regression.
19    ///
20    /// For each feature j, computes the Pearson correlation r with the
21    /// target, then F = r^2 * (n-2) / (1 - r^2). Higher F means the
22    /// feature has a stronger linear relationship with the target.
23    /// Requires target values `y`.
24    FRegression,
25
26    /// Feature variance (unsupervised).
27    ///
28    /// Simply uses the variance of each feature as its score.
29    /// Target `y` is ignored when this variant is used.
30    Variance,
31}
32
33/// Parameters for `SelectKBest` feature selector (unfitted state).
34///
35/// Selects the top-k features according to a pluggable [`ScoringFunction`].
36/// This is more flexible than [`MutualInformationSelector`](crate::MutualInformationSelector),
37/// which is hard-coded to mutual information scoring.
38///
39/// # Example
40///
41/// ```
42/// use anofox_ml_preprocessing::SelectKBest;
43/// use anofox_ml_preprocessing::select_k_best::ScoringFunction;
44/// use anofox_ml_core::Transform;
45/// use ndarray::array;
46///
47/// // Feature 0 perfectly separates the two classes; feature 1 is noise.
48/// let x = array![
49///     [0.0, 0.5],
50///     [0.0, 0.8],
51///     [1.0, 0.3],
52///     [1.0, 0.7],
53/// ];
54/// let y = array![0.0, 0.0, 1.0, 1.0];
55///
56/// let selector = SelectKBest::new(1, ScoringFunction::FClassif);
57/// let fitted = selector.fit(&x, &y).unwrap();
58/// let x_selected = fitted.transform(&x).unwrap();
59///
60/// assert_eq!(x_selected.ncols(), 1);
61/// ```
62#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
63pub struct SelectKBest {
64    /// Number of top features to select.
65    pub k: usize,
66    /// Scoring function to rank features.
67    pub scoring_fn: ScoringFunction,
68}
69
70impl SelectKBest {
71    /// Create a new `SelectKBest` selector that keeps the top `k` features
72    /// ranked by the given scoring function.
73    pub fn new(k: usize, scoring_fn: ScoringFunction) -> Self {
74        Self { k, scoring_fn }
75    }
76
77    /// Fit the selector on the given data.
78    ///
79    /// For [`ScoringFunction::FClassif`] and [`ScoringFunction::FRegression`],
80    /// `y` is used as the target variable. For [`ScoringFunction::Variance`],
81    /// `y` is ignored.
82    pub fn fit<F: Float>(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedSelectKBest<F>> {
83        let (n_samples, n_features) = x.dim();
84
85        if n_samples == 0 || n_features == 0 {
86            return Err(RustMlError::EmptyInput("input array is empty".into()));
87        }
88
89        if self.k == 0 {
90            return Err(RustMlError::InvalidParameter("k must be at least 1".into()));
91        }
92
93        if self.k > n_features {
94            return Err(RustMlError::InvalidParameter(format!(
95                "k ({}) exceeds number of features ({})",
96                self.k, n_features
97            )));
98        }
99
100        // For supervised modes, validate y length.
101        if !matches!(self.scoring_fn, ScoringFunction::Variance) {
102            if y.len() != n_samples {
103                return Err(RustMlError::ShapeMismatch(format!(
104                    "X has {} samples but y has {} elements",
105                    n_samples,
106                    y.len()
107                )));
108            }
109        }
110
111        let scores = match &self.scoring_fn {
112            ScoringFunction::FClassif => compute_f_classif(x, y)?,
113            ScoringFunction::FRegression => compute_f_regression(x, y)?,
114            ScoringFunction::Variance => compute_variance(x),
115        };
116
117        // Select top-k features by score (descending).
118        let mut feature_scores: Vec<(usize, F)> = scores.iter().copied().enumerate().collect();
119        feature_scores.sort_by(|a, b| {
120            b.1.partial_cmp(&a.1)
121                .unwrap_or(std::cmp::Ordering::Equal)
122                .then(a.0.cmp(&b.0))
123        });
124
125        let mut selected_indices: Vec<usize> = feature_scores
126            .iter()
127            .take(self.k)
128            .map(|&(idx, _)| idx)
129            .collect();
130        // Sort indices for stable column ordering in transform.
131        selected_indices.sort_unstable();
132
133        Ok(FittedSelectKBest {
134            scores,
135            selected_indices,
136            n_features_in: n_features,
137        })
138    }
139}
140
141/// Fitted `SelectKBest` -- holds per-feature scores and the indices of the
142/// selected top-k features.
143#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
144#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
145pub struct FittedSelectKBest<F: Float> {
146    /// Per-feature score from the chosen scoring function.
147    scores: Array1<F>,
148    /// Indices of the top-k features (sorted ascending for stable column ordering).
149    selected_indices: Vec<usize>,
150    /// Total number of input features (before selection).
151    n_features_in: usize,
152}
153
154impl<F: Float> FittedSelectKBest<F> {
155    /// Per-feature scores computed during fitting.
156    pub fn scores(&self) -> &Array1<F> {
157        &self.scores
158    }
159
160    /// Indices of the selected features, sorted in ascending order.
161    pub fn selected_indices(&self) -> &[usize] {
162        &self.selected_indices
163    }
164}
165
166impl<F: Float> Transform<F> for FittedSelectKBest<F> {
167    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
168        if x.ncols() != self.n_features_in {
169            return Err(RustMlError::ShapeMismatch(format!(
170                "expected {} features, got {}",
171                self.n_features_in,
172                x.ncols()
173            )));
174        }
175
176        let n_rows = x.nrows();
177        let n_selected = self.selected_indices.len();
178        let mut result = Array2::<F>::zeros((n_rows, n_selected));
179
180        for (i, row) in x.rows().into_iter().enumerate() {
181            for (out_j, &src_j) in self.selected_indices.iter().enumerate() {
182                result[[i, out_j]] = row[src_j];
183            }
184        }
185
186        Ok(result)
187    }
188}
189
190// ---------------------------------------------------------------------------
191// Scoring function implementations
192// ---------------------------------------------------------------------------
193
194/// Compute ANOVA F-value for each feature (classification).
195///
196/// For each feature j:
197/// - Group samples by class label.
198/// - Compute between-class mean square (MSB) and within-class mean square (MSW).
199/// - F = MSB / MSW.
200fn compute_f_classif<F: Float>(x: &Array2<F>, y: &Array1<F>) -> Result<Array1<F>> {
201    let (n_samples, n_features) = x.dim();
202    let n_f = F::from_usize(n_samples).unwrap();
203
204    // Map labels to class indices.
205    let mut label_map: HashMap<u64, usize> = HashMap::new();
206    let mut class_indices: Vec<usize> = Vec::with_capacity(n_samples);
207    for &val in y.iter() {
208        let bits = val.to_f64().unwrap().to_bits();
209        let next_id = label_map.len();
210        let id = *label_map.entry(bits).or_insert(next_id);
211        class_indices.push(id);
212    }
213    let n_classes = label_map.len();
214
215    if n_classes < 2 {
216        return Err(RustMlError::InvalidParameter(
217            "FClassif requires at least 2 classes".into(),
218        ));
219    }
220
221    if n_samples <= n_classes {
222        return Err(RustMlError::InvalidParameter(
223            "not enough samples for FClassif (need more samples than classes)".into(),
224        ));
225    }
226
227    // Count samples per class.
228    let mut class_counts = vec![0usize; n_classes];
229    for &c in &class_indices {
230        class_counts[c] += 1;
231    }
232
233    let mut scores = Array1::<F>::zeros(n_features);
234
235    for j in 0..n_features {
236        let col = x.column(j);
237
238        // Grand mean.
239        let grand_mean = col.sum() / n_f;
240
241        // Per-class sum and sum of squares.
242        let mut class_sums = vec![F::zero(); n_classes];
243        for (i, &val) in col.iter().enumerate() {
244            class_sums[class_indices[i]] += val;
245        }
246
247        // Between-class sum of squares (SSB).
248        let mut ssb = F::zero();
249        for c in 0..n_classes {
250            let nc = F::from_usize(class_counts[c]).unwrap();
251            let class_mean = class_sums[c] / nc;
252            let diff = class_mean - grand_mean;
253            ssb += nc * diff * diff;
254        }
255
256        // Within-class sum of squares (SSW).
257        let mut ssw = F::zero();
258        for (i, &val) in col.iter().enumerate() {
259            let c = class_indices[i];
260            let nc = F::from_usize(class_counts[c]).unwrap();
261            let class_mean = class_sums[c] / nc;
262            let diff = val - class_mean;
263            ssw += diff * diff;
264        }
265
266        // Degrees of freedom.
267        let df_between = F::from_usize(n_classes - 1).unwrap();
268        let df_within = F::from_usize(n_samples - n_classes).unwrap();
269
270        let eps = F::from_f64(1e-15).unwrap();
271        if ssw < eps {
272            // All within-class variance is zero: feature is perfectly
273            // discriminative (or constant). Use a large F value.
274            scores[j] = if ssb > eps {
275                F::from_f64(1e12).unwrap()
276            } else {
277                F::zero()
278            };
279        } else {
280            let msb = ssb / df_between;
281            let msw = ssw / df_within;
282            scores[j] = msb / msw;
283        }
284    }
285
286    Ok(scores)
287}
288
289/// Compute univariate linear-regression F-statistic for each feature.
290///
291/// For each feature j:
292///   r = correlation(x[:,j], y)
293///   F = r^2 * (n - 2) / (1 - r^2)
294fn compute_f_regression<F: Float>(x: &Array2<F>, y: &Array1<F>) -> Result<Array1<F>> {
295    let (n_samples, n_features) = x.dim();
296
297    if n_samples < 3 {
298        return Err(RustMlError::InvalidParameter(
299            "FRegression requires at least 3 samples".into(),
300        ));
301    }
302
303    let n_f = F::from_usize(n_samples).unwrap();
304    let eps = F::from_f64(1e-15).unwrap();
305
306    // Compute y statistics once.
307    let y_mean = y.sum() / n_f;
308    let mut y_var = F::zero();
309    for &val in y.iter() {
310        let diff = val - y_mean;
311        y_var += diff * diff;
312    }
313
314    let mut scores = Array1::<F>::zeros(n_features);
315
316    for j in 0..n_features {
317        let col = x.column(j);
318        let x_mean = col.sum() / n_f;
319
320        let mut cov_xy = F::zero();
321        let mut x_var = F::zero();
322        for (&xv, &yv) in col.iter().zip(y.iter()) {
323            let dx = xv - x_mean;
324            let dy = yv - y_mean;
325            cov_xy += dx * dy;
326            x_var += dx * dx;
327        }
328
329        if x_var < eps || y_var < eps {
330            scores[j] = F::zero();
331            continue;
332        }
333
334        let r = cov_xy / (x_var.sqrt() * y_var.sqrt());
335        let r2 = r * r;
336
337        let one = F::one();
338        let denom = one - r2;
339        if denom < eps {
340            // Perfect correlation.
341            scores[j] = F::from_f64(1e12).unwrap();
342        } else {
343            let n_minus_2 = F::from_usize(n_samples - 2).unwrap();
344            scores[j] = r2 * n_minus_2 / denom;
345        }
346    }
347
348    Ok(scores)
349}
350
351/// Compute per-feature variance (unsupervised scoring).
352fn compute_variance<F: Float>(x: &Array2<F>) -> Array1<F> {
353    let n = F::from_usize(x.nrows()).unwrap();
354    let mean = x.sum_axis(Axis(0)) / n;
355    let n_features = x.ncols();
356
357    let mut variances = Array1::<F>::zeros(n_features);
358    for row in x.rows() {
359        for (j, (&val, &m)) in row.iter().zip(mean.iter()).enumerate() {
360            let diff = val - m;
361            variances[j] += diff * diff;
362        }
363    }
364    variances.mapv_inplace(|v| v / n);
365    variances
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use ndarray::array;
372
373    #[test]
374    fn test_f_classif_selects_discriminative_feature() {
375        // Feature 0 perfectly separates classes; feature 1 is random noise.
376        let x = array![
377            [0.0, 0.5],
378            [0.0, 0.8],
379            [0.0, 0.2],
380            [0.0, 0.9],
381            [1.0, 0.3],
382            [1.0, 0.7],
383            [1.0, 0.1],
384            [1.0, 0.6],
385        ];
386        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
387
388        let selector = SelectKBest::new(1, ScoringFunction::FClassif);
389        let fitted = selector.fit(&x, &y).unwrap();
390
391        assert_eq!(fitted.selected_indices(), &[0]);
392        assert!(
393            fitted.scores()[0] > fitted.scores()[1],
394            "discriminative feature score ({}) should exceed noise ({})",
395            fitted.scores()[0],
396            fitted.scores()[1]
397        );
398    }
399
400    #[test]
401    fn test_f_regression_selects_correlated_feature() {
402        // Feature 0: linearly correlated with y.
403        // Feature 1: constant (zero correlation).
404        let x = array![[1.0, 5.0], [2.0, 5.0], [3.0, 5.0], [4.0, 5.0], [5.0, 5.0],];
405        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
406
407        let selector = SelectKBest::new(1, ScoringFunction::FRegression);
408        let fitted = selector.fit(&x, &y).unwrap();
409
410        assert_eq!(fitted.selected_indices(), &[0]);
411        // Feature 0 has perfect correlation -> very large F.
412        assert!(fitted.scores()[0] > 100.0_f64);
413        // Feature 1 is constant -> score should be 0.
414        assert!(fitted.scores()[1].abs() < 1e-10_f64);
415    }
416
417    #[test]
418    fn test_variance_scoring_selects_high_variance_feature() {
419        // Feature 0: low variance; feature 1: high variance; feature 2: zero variance.
420        let x = array![
421            [1.0, 10.0, 5.0],
422            [1.1, 20.0, 5.0],
423            [0.9, 30.0, 5.0],
424            [1.0, 40.0, 5.0],
425        ];
426        let y = array![0.0, 0.0, 0.0, 0.0]; // ignored for Variance
427
428        let selector = SelectKBest::new(1, ScoringFunction::Variance);
429        let fitted = selector.fit(&x, &y).unwrap();
430
431        assert_eq!(fitted.selected_indices(), &[1]);
432    }
433
434    #[test]
435    fn test_transform_outputs_correct_columns() {
436        let x = array![[10.0, 20.0, 30.0], [40.0, 50.0, 60.0], [70.0, 80.0, 90.0],];
437        let y = array![1.0, 2.0, 3.0];
438
439        let selector = SelectKBest::new(2, ScoringFunction::FRegression);
440        let fitted = selector.fit(&x, &y).unwrap();
441        let result = fitted.transform(&x).unwrap();
442
443        assert_eq!(result.nrows(), 3);
444        assert_eq!(result.ncols(), 2);
445
446        // Verify selected columns are present in the output.
447        for &idx in fitted.selected_indices() {
448            let original_col: Vec<f64> = x.column(idx).to_vec();
449            let out_pos = fitted
450                .selected_indices()
451                .iter()
452                .position(|&i| i == idx)
453                .unwrap();
454            let result_col: Vec<f64> = result.column(out_pos).to_vec();
455            assert_eq!(original_col, result_col);
456        }
457    }
458
459    #[test]
460    fn test_error_k_zero() {
461        let x = array![[1.0, 2.0], [3.0, 4.0]];
462        let y = array![0.0, 1.0];
463
464        let selector = SelectKBest::new(0, ScoringFunction::FClassif);
465        let result = selector.fit(&x, &y);
466        assert!(result.is_err());
467    }
468
469    #[test]
470    fn test_error_k_exceeds_features() {
471        let x = array![[1.0, 2.0], [3.0, 4.0]];
472        let y = array![0.0, 1.0];
473
474        let selector = SelectKBest::new(5, ScoringFunction::FClassif);
475        let result = selector.fit(&x, &y);
476        assert!(result.is_err());
477        match result.unwrap_err() {
478            RustMlError::InvalidParameter(msg) => {
479                assert!(msg.contains("exceeds"), "unexpected message: {}", msg);
480            }
481            other => panic!("expected InvalidParameter, got {:?}", other),
482        }
483    }
484
485    #[test]
486    fn test_error_shape_mismatch_x_y() {
487        let x = array![[1.0, 2.0], [3.0, 4.0]];
488        let y = array![0.0, 1.0, 2.0]; // 3 labels for 2 samples
489
490        let selector = SelectKBest::new(1, ScoringFunction::FClassif);
491        let result = selector.fit(&x, &y);
492        assert!(result.is_err());
493        match result.unwrap_err() {
494            RustMlError::ShapeMismatch(msg) => {
495                assert!(msg.contains("samples"), "unexpected message: {}", msg);
496            }
497            other => panic!("expected ShapeMismatch, got {:?}", other),
498        }
499    }
500
501    #[test]
502    fn test_error_on_empty_input() {
503        let x = Array2::<f64>::zeros((0, 3));
504        let y = Array1::<f64>::zeros(0);
505
506        let selector = SelectKBest::new(1, ScoringFunction::FRegression);
507        let result = selector.fit(&x, &y);
508        assert!(result.is_err());
509    }
510
511    #[test]
512    fn test_shape_mismatch_on_transform() {
513        let x = array![
514            [1.0, 2.0, 3.0],
515            [4.0, 5.0, 6.0],
516            [7.0, 8.0, 9.0],
517            [10.0, 11.0, 12.0],
518        ];
519        let y = array![0.0, 0.0, 1.0, 1.0];
520
521        let selector = SelectKBest::new(1, ScoringFunction::FClassif);
522        let fitted = selector.fit(&x, &y).unwrap();
523
524        let wrong = array![[1.0, 2.0]]; // 2 cols instead of 3
525        assert!(fitted.transform(&wrong).is_err());
526    }
527
528    #[test]
529    fn test_selects_all_when_k_equals_n_features() {
530        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
531        let y = array![1.0, 2.0, 3.0];
532
533        let selector = SelectKBest::new(2, ScoringFunction::FRegression);
534        let fitted = selector.fit(&x, &y).unwrap();
535
536        assert_eq!(fitted.selected_indices().len(), 2);
537        assert_eq!(fitted.selected_indices(), &[0, 1]);
538    }
539
540    #[test]
541    fn test_works_with_f32() {
542        let x: Array2<f32> = array![[0.0_f32, 0.5], [0.0, 0.8], [1.0, 0.3], [1.0, 0.7],];
543        let y: Array1<f32> = array![0.0_f32, 0.0, 1.0, 1.0];
544
545        let selector = SelectKBest::new(1, ScoringFunction::FClassif);
546        let fitted = selector.fit(&x, &y).unwrap();
547
548        assert_eq!(fitted.selected_indices().len(), 1);
549        let result = fitted.transform(&x).unwrap();
550        assert_eq!(result.ncols(), 1);
551    }
552}