Skip to main content

ferrolearn_preprocess/
select_percentile.rs

1//! Select features by percentile of highest scores.
2//!
3//! [`SelectPercentile`] retains features whose ANOVA F-score ranks in the top
4//! `percentile` percent. It reuses the scoring infrastructure from
5//! [`crate::feature_selection`].
6
7use ferrolearn_core::error::FerroError;
8use ferrolearn_core::traits::{Fit, Transform};
9use ndarray::{Array1, Array2};
10use num_traits::Float;
11
12use crate::feature_selection::ScoreFunc;
13
14// ---------------------------------------------------------------------------
15// Helper: ANOVA F-scores (duplicated from feature_selection to avoid pub(crate))
16// ---------------------------------------------------------------------------
17
18/// Compute per-feature ANOVA F-scores.
19fn anova_f_scores<F: Float>(x: &Array2<F>, y: &Array1<usize>) -> Vec<F> {
20    let n_samples = x.nrows();
21    let n_features = x.ncols();
22
23    let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
24        std::collections::HashMap::new();
25    for (i, &label) in y.iter().enumerate() {
26        class_indices.entry(label).or_default().push(i);
27    }
28    let n_classes = class_indices.len();
29
30    let mut scores = Vec::with_capacity(n_features);
31
32    for j in 0..n_features {
33        let col = x.column(j);
34        let grand_mean =
35            col.iter().copied().fold(F::zero(), |acc, v| acc + v) / F::from(n_samples).unwrap();
36
37        let mut ss_between = F::zero();
38        let mut ss_within = F::zero();
39
40        for rows in class_indices.values() {
41            let n_k = F::from(rows.len()).unwrap();
42            let class_mean = rows
43                .iter()
44                .map(|&i| col[i])
45                .fold(F::zero(), |acc, v| acc + v)
46                / n_k;
47            let diff = class_mean - grand_mean;
48            ss_between = ss_between + n_k * diff * diff;
49            for &i in rows {
50                let d = col[i] - class_mean;
51                ss_within = ss_within + d * d;
52            }
53        }
54
55        let df_between = F::from(n_classes.saturating_sub(1)).unwrap();
56        let df_within = F::from(n_samples.saturating_sub(n_classes)).unwrap();
57
58        let f = if df_between == F::zero() || df_within == F::zero() {
59            F::zero()
60        } else {
61            let ms_between = ss_between / df_between;
62            let ms_within = ss_within / df_within;
63            if ms_within == F::zero() {
64                F::infinity()
65            } else {
66                ms_between / ms_within
67            }
68        };
69
70        scores.push(f);
71    }
72
73    scores
74}
75
76/// Build a new `Array2<F>` containing only the columns listed in `indices`.
77fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
78    let nrows = x.nrows();
79    let ncols = indices.len();
80    if ncols == 0 {
81        return Array2::zeros((nrows, 0));
82    }
83    let mut out = Array2::zeros((nrows, ncols));
84    for (new_j, &old_j) in indices.iter().enumerate() {
85        for i in 0..nrows {
86            out[[i, new_j]] = x[[i, old_j]];
87        }
88    }
89    out
90}
91
92// ---------------------------------------------------------------------------
93// SelectPercentile
94// ---------------------------------------------------------------------------
95
96/// An unfitted percentile-based feature selector.
97///
98/// Retains the features whose ANOVA F-score ranks in the top `percentile`
99/// percent.
100///
101/// # Examples
102///
103/// ```
104/// use ferrolearn_preprocess::select_percentile::SelectPercentile;
105/// use ferrolearn_preprocess::feature_selection::ScoreFunc;
106/// use ferrolearn_core::traits::{Fit, Transform};
107/// use ndarray::{array, Array1};
108///
109/// let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
110/// let x = array![[1.0, 10.0, 0.1, 0.01],
111///                 [1.0, 20.0, 0.2, 0.02],
112///                 [2.0, 10.0, 0.1, 0.01],
113///                 [2.0, 20.0, 0.2, 0.02]];
114/// let y: Array1<usize> = array![0, 0, 1, 1];
115/// let fitted = sel.fit(&x, &y).unwrap();
116/// let out = fitted.transform(&x).unwrap();
117/// assert_eq!(out.ncols(), 2); // 50% of 4 features = 2
118/// ```
119#[must_use]
120#[derive(Debug, Clone)]
121pub struct SelectPercentile<F> {
122    /// Percentile of features to keep (0-100).
123    percentile: usize,
124    /// Scoring function.
125    score_func: ScoreFunc,
126    _marker: std::marker::PhantomData<F>,
127}
128
129impl<F: Float + Send + Sync + 'static> SelectPercentile<F> {
130    /// Create a new `SelectPercentile` selector.
131    ///
132    /// # Parameters
133    ///
134    /// - `percentile` — the percentile of top-scoring features to keep (0-100).
135    /// - `score_func` — the scoring function to use.
136    pub fn new(percentile: usize, score_func: ScoreFunc) -> Self {
137        Self {
138            percentile,
139            score_func,
140            _marker: std::marker::PhantomData,
141        }
142    }
143
144    /// Return the percentile.
145    #[must_use]
146    pub fn percentile(&self) -> usize {
147        self.percentile
148    }
149
150    /// Return the score function.
151    #[must_use]
152    pub fn score_func(&self) -> ScoreFunc {
153        self.score_func
154    }
155}
156
157impl<F: Float + Send + Sync + 'static> Default for SelectPercentile<F> {
158    fn default() -> Self {
159        Self::new(10, ScoreFunc::FClassif)
160    }
161}
162
163// ---------------------------------------------------------------------------
164// FittedSelectPercentile
165// ---------------------------------------------------------------------------
166
167/// A fitted percentile selector holding scores and selected indices.
168///
169/// Created by calling [`Fit::fit`] on a [`SelectPercentile`].
170#[derive(Debug, Clone)]
171pub struct FittedSelectPercentile<F> {
172    /// Number of features seen during fitting.
173    n_features_in: usize,
174    /// Per-feature scores.
175    scores: Array1<F>,
176    /// Indices of selected columns (in original column order).
177    selected_indices: Vec<usize>,
178}
179
180impl<F: Float + Send + Sync + 'static> FittedSelectPercentile<F> {
181    /// Return the per-feature scores.
182    #[must_use]
183    pub fn scores(&self) -> &Array1<F> {
184        &self.scores
185    }
186
187    /// Return the indices of selected columns.
188    #[must_use]
189    pub fn selected_indices(&self) -> &[usize] {
190        &self.selected_indices
191    }
192}
193
194// ---------------------------------------------------------------------------
195// Trait implementations
196// ---------------------------------------------------------------------------
197
198impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for SelectPercentile<F> {
199    type Fitted = FittedSelectPercentile<F>;
200    type Error = FerroError;
201
202    /// Fit by computing per-feature scores and selecting the top percentile.
203    ///
204    /// # Errors
205    ///
206    /// - [`FerroError::InsufficientSamples`] if the input has zero rows.
207    /// - [`FerroError::InvalidParameter`] if `percentile` > 100.
208    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different row counts.
209    fn fit(
210        &self,
211        x: &Array2<F>,
212        y: &Array1<usize>,
213    ) -> Result<FittedSelectPercentile<F>, FerroError> {
214        let n_samples = x.nrows();
215        if n_samples == 0 {
216            return Err(FerroError::InsufficientSamples {
217                required: 1,
218                actual: 0,
219                context: "SelectPercentile::fit".into(),
220            });
221        }
222        if y.len() != n_samples {
223            return Err(FerroError::ShapeMismatch {
224                expected: vec![n_samples],
225                actual: vec![y.len()],
226                context: "SelectPercentile::fit — y must have same length as x rows".into(),
227            });
228        }
229        if self.percentile > 100 {
230            return Err(FerroError::InvalidParameter {
231                name: "percentile".into(),
232                reason: format!("percentile must be in [0, 100], got {}", self.percentile),
233            });
234        }
235
236        let n_features = x.ncols();
237        let raw_scores = match self.score_func {
238            ScoreFunc::FClassif => anova_f_scores(x, y),
239        };
240        let scores = Array1::from_vec(raw_scores.clone());
241
242        // Compute how many features to keep
243        let k = (n_features * self.percentile).div_ceil(100);
244        let k = k.min(n_features);
245
246        // Rank features by score (descending)
247        let mut ranked: Vec<usize> = (0..n_features).collect();
248        ranked.sort_by(|&a, &b| {
249            raw_scores[b]
250                .partial_cmp(&raw_scores[a])
251                .unwrap_or(std::cmp::Ordering::Equal)
252                .then(a.cmp(&b))
253        });
254
255        let mut selected_indices: Vec<usize> = ranked[..k].to_vec();
256        selected_indices.sort_unstable();
257
258        Ok(FittedSelectPercentile {
259            n_features_in: n_features,
260            scores,
261            selected_indices,
262        })
263    }
264}
265
266impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectPercentile<F> {
267    type Output = Array2<F>;
268    type Error = FerroError;
269
270    /// Return a matrix containing only the selected columns.
271    ///
272    /// # Errors
273    ///
274    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
275    /// from the number of features seen during fitting.
276    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
277        if x.ncols() != self.n_features_in {
278            return Err(FerroError::ShapeMismatch {
279                expected: vec![x.nrows(), self.n_features_in],
280                actual: vec![x.nrows(), x.ncols()],
281                context: "FittedSelectPercentile::transform".into(),
282            });
283        }
284        Ok(select_columns(x, &self.selected_indices))
285    }
286}
287
288// ---------------------------------------------------------------------------
289// Tests
290// ---------------------------------------------------------------------------
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use ndarray::array;
296
297    #[test]
298    fn test_select_percentile_50_percent() {
299        let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
300        // Feature 0 separates classes; features 1-3 do not
301        let x = array![
302            [1.0, 5.0, 0.1, 0.01],
303            [1.0, 6.0, 0.2, 0.02],
304            [10.0, 5.0, 0.1, 0.01],
305            [10.0, 6.0, 0.2, 0.02]
306        ];
307        let y: Array1<usize> = array![0, 0, 1, 1];
308        let fitted = sel.fit(&x, &y).unwrap();
309        let out = fitted.transform(&x).unwrap();
310        // 50% of 4 = 2 features
311        assert_eq!(out.ncols(), 2);
312    }
313
314    #[test]
315    fn test_select_percentile_100_percent_keeps_all() {
316        let sel = SelectPercentile::<f64>::new(100, ScoreFunc::FClassif);
317        let x = array![[1.0, 2.0], [3.0, 4.0]];
318        let y: Array1<usize> = array![0, 1];
319        let fitted = sel.fit(&x, &y).unwrap();
320        let out = fitted.transform(&x).unwrap();
321        assert_eq!(out.ncols(), 2);
322    }
323
324    #[test]
325    fn test_select_percentile_selects_highest_scoring() {
326        let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
327        // Feature 0 perfectly separates classes, feature 1 does not
328        let x = array![[0.0, 5.0], [0.0, 5.5], [10.0, 5.0], [10.0, 5.5]];
329        let y: Array1<usize> = array![0, 0, 1, 1];
330        let fitted = sel.fit(&x, &y).unwrap();
331        // Feature 0 should be selected
332        assert!(fitted.selected_indices().contains(&0));
333    }
334
335    #[test]
336    fn test_select_percentile_scores_stored() {
337        let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
338        let x = array![[1.0, 2.0], [3.0, 4.0]];
339        let y: Array1<usize> = array![0, 1];
340        let fitted = sel.fit(&x, &y).unwrap();
341        assert_eq!(fitted.scores().len(), 2);
342    }
343
344    #[test]
345    fn test_select_percentile_zero_rows_error() {
346        let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
347        let x: Array2<f64> = Array2::zeros((0, 3));
348        let y: Array1<usize> = Array1::zeros(0);
349        assert!(sel.fit(&x, &y).is_err());
350    }
351
352    #[test]
353    fn test_select_percentile_over_100_error() {
354        let sel = SelectPercentile::<f64>::new(150, ScoreFunc::FClassif);
355        let x = array![[1.0, 2.0], [3.0, 4.0]];
356        let y: Array1<usize> = array![0, 1];
357        assert!(sel.fit(&x, &y).is_err());
358    }
359
360    #[test]
361    fn test_select_percentile_y_length_mismatch_error() {
362        let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
363        let x = array![[1.0, 2.0], [3.0, 4.0]];
364        let y: Array1<usize> = array![0]; // wrong length
365        assert!(sel.fit(&x, &y).is_err());
366    }
367
368    #[test]
369    fn test_select_percentile_shape_mismatch_on_transform() {
370        let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
371        let x = array![[1.0, 2.0], [3.0, 4.0]];
372        let y: Array1<usize> = array![0, 1];
373        let fitted = sel.fit(&x, &y).unwrap();
374        let x_bad = array![[1.0, 2.0, 3.0]];
375        assert!(fitted.transform(&x_bad).is_err());
376    }
377
378    #[test]
379    fn test_select_percentile_default() {
380        let sel = SelectPercentile::<f64>::default();
381        assert_eq!(sel.percentile(), 10);
382    }
383
384    #[test]
385    fn test_select_percentile_indices_sorted() {
386        let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
387        let x = array![
388            [1.0, 100.0, 0.5, 0.01],
389            [2.0, 200.0, 0.6, 0.02],
390            [10.0, 100.0, 0.5, 0.01],
391            [20.0, 200.0, 0.6, 0.02]
392        ];
393        let y: Array1<usize> = array![0, 0, 1, 1];
394        let fitted = sel.fit(&x, &y).unwrap();
395        let indices = fitted.selected_indices();
396        // Indices should be sorted
397        assert!(indices.windows(2).all(|w| w[0] < w[1]));
398    }
399}