1use ferrolearn_core::error::FerroError;
8use ferrolearn_core::traits::{Fit, Transform};
9use ndarray::{Array1, Array2};
10use num_traits::Float;
11
12use crate::feature_selection::ScoreFunc;
13
14fn anova_f_scores<F: Float>(x: &Array2<F>, y: &Array1<usize>) -> Vec<F> {
20 let n_samples = x.nrows();
21 let n_features = x.ncols();
22
23 let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
24 std::collections::HashMap::new();
25 for (i, &label) in y.iter().enumerate() {
26 class_indices.entry(label).or_default().push(i);
27 }
28 let n_classes = class_indices.len();
29
30 let mut scores = Vec::with_capacity(n_features);
31
32 for j in 0..n_features {
33 let col = x.column(j);
34 let grand_mean =
35 col.iter().copied().fold(F::zero(), |acc, v| acc + v) / F::from(n_samples).unwrap();
36
37 let mut ss_between = F::zero();
38 let mut ss_within = F::zero();
39
40 for rows in class_indices.values() {
41 let n_k = F::from(rows.len()).unwrap();
42 let class_mean = rows
43 .iter()
44 .map(|&i| col[i])
45 .fold(F::zero(), |acc, v| acc + v)
46 / n_k;
47 let diff = class_mean - grand_mean;
48 ss_between = ss_between + n_k * diff * diff;
49 for &i in rows {
50 let d = col[i] - class_mean;
51 ss_within = ss_within + d * d;
52 }
53 }
54
55 let df_between = F::from(n_classes.saturating_sub(1)).unwrap();
56 let df_within = F::from(n_samples.saturating_sub(n_classes)).unwrap();
57
58 let f = if df_between == F::zero() || df_within == F::zero() {
59 F::zero()
60 } else {
61 let ms_between = ss_between / df_between;
62 let ms_within = ss_within / df_within;
63 if ms_within == F::zero() {
64 F::infinity()
65 } else {
66 ms_between / ms_within
67 }
68 };
69
70 scores.push(f);
71 }
72
73 scores
74}
75
76fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
78 let nrows = x.nrows();
79 let ncols = indices.len();
80 if ncols == 0 {
81 return Array2::zeros((nrows, 0));
82 }
83 let mut out = Array2::zeros((nrows, ncols));
84 for (new_j, &old_j) in indices.iter().enumerate() {
85 for i in 0..nrows {
86 out[[i, new_j]] = x[[i, old_j]];
87 }
88 }
89 out
90}
91
92#[must_use]
120#[derive(Debug, Clone)]
121pub struct SelectPercentile<F> {
122 percentile: usize,
124 score_func: ScoreFunc,
126 _marker: std::marker::PhantomData<F>,
127}
128
129impl<F: Float + Send + Sync + 'static> SelectPercentile<F> {
130 pub fn new(percentile: usize, score_func: ScoreFunc) -> Self {
137 Self {
138 percentile,
139 score_func,
140 _marker: std::marker::PhantomData,
141 }
142 }
143
144 #[must_use]
146 pub fn percentile(&self) -> usize {
147 self.percentile
148 }
149
150 #[must_use]
152 pub fn score_func(&self) -> ScoreFunc {
153 self.score_func
154 }
155}
156
157impl<F: Float + Send + Sync + 'static> Default for SelectPercentile<F> {
158 fn default() -> Self {
159 Self::new(10, ScoreFunc::FClassif)
160 }
161}
162
163#[derive(Debug, Clone)]
171pub struct FittedSelectPercentile<F> {
172 n_features_in: usize,
174 scores: Array1<F>,
176 selected_indices: Vec<usize>,
178}
179
180impl<F: Float + Send + Sync + 'static> FittedSelectPercentile<F> {
181 #[must_use]
183 pub fn scores(&self) -> &Array1<F> {
184 &self.scores
185 }
186
187 #[must_use]
189 pub fn selected_indices(&self) -> &[usize] {
190 &self.selected_indices
191 }
192}
193
194impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for SelectPercentile<F> {
199 type Fitted = FittedSelectPercentile<F>;
200 type Error = FerroError;
201
202 fn fit(
210 &self,
211 x: &Array2<F>,
212 y: &Array1<usize>,
213 ) -> Result<FittedSelectPercentile<F>, FerroError> {
214 let n_samples = x.nrows();
215 if n_samples == 0 {
216 return Err(FerroError::InsufficientSamples {
217 required: 1,
218 actual: 0,
219 context: "SelectPercentile::fit".into(),
220 });
221 }
222 if y.len() != n_samples {
223 return Err(FerroError::ShapeMismatch {
224 expected: vec![n_samples],
225 actual: vec![y.len()],
226 context: "SelectPercentile::fit — y must have same length as x rows".into(),
227 });
228 }
229 if self.percentile > 100 {
230 return Err(FerroError::InvalidParameter {
231 name: "percentile".into(),
232 reason: format!("percentile must be in [0, 100], got {}", self.percentile),
233 });
234 }
235
236 let n_features = x.ncols();
237 let raw_scores = match self.score_func {
238 ScoreFunc::FClassif => anova_f_scores(x, y),
239 };
240 let scores = Array1::from_vec(raw_scores.clone());
241
242 let k = (n_features * self.percentile).div_ceil(100);
244 let k = k.min(n_features);
245
246 let mut ranked: Vec<usize> = (0..n_features).collect();
248 ranked.sort_by(|&a, &b| {
249 raw_scores[b]
250 .partial_cmp(&raw_scores[a])
251 .unwrap_or(std::cmp::Ordering::Equal)
252 .then(a.cmp(&b))
253 });
254
255 let mut selected_indices: Vec<usize> = ranked[..k].to_vec();
256 selected_indices.sort_unstable();
257
258 Ok(FittedSelectPercentile {
259 n_features_in: n_features,
260 scores,
261 selected_indices,
262 })
263 }
264}
265
266impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectPercentile<F> {
267 type Output = Array2<F>;
268 type Error = FerroError;
269
270 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
277 if x.ncols() != self.n_features_in {
278 return Err(FerroError::ShapeMismatch {
279 expected: vec![x.nrows(), self.n_features_in],
280 actual: vec![x.nrows(), x.ncols()],
281 context: "FittedSelectPercentile::transform".into(),
282 });
283 }
284 Ok(select_columns(x, &self.selected_indices))
285 }
286}
287
288#[cfg(test)]
293mod tests {
294 use super::*;
295 use ndarray::array;
296
297 #[test]
298 fn test_select_percentile_50_percent() {
299 let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
300 let x = array![
302 [1.0, 5.0, 0.1, 0.01],
303 [1.0, 6.0, 0.2, 0.02],
304 [10.0, 5.0, 0.1, 0.01],
305 [10.0, 6.0, 0.2, 0.02]
306 ];
307 let y: Array1<usize> = array![0, 0, 1, 1];
308 let fitted = sel.fit(&x, &y).unwrap();
309 let out = fitted.transform(&x).unwrap();
310 assert_eq!(out.ncols(), 2);
312 }
313
314 #[test]
315 fn test_select_percentile_100_percent_keeps_all() {
316 let sel = SelectPercentile::<f64>::new(100, ScoreFunc::FClassif);
317 let x = array![[1.0, 2.0], [3.0, 4.0]];
318 let y: Array1<usize> = array![0, 1];
319 let fitted = sel.fit(&x, &y).unwrap();
320 let out = fitted.transform(&x).unwrap();
321 assert_eq!(out.ncols(), 2);
322 }
323
324 #[test]
325 fn test_select_percentile_selects_highest_scoring() {
326 let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
327 let x = array![[0.0, 5.0], [0.0, 5.5], [10.0, 5.0], [10.0, 5.5]];
329 let y: Array1<usize> = array![0, 0, 1, 1];
330 let fitted = sel.fit(&x, &y).unwrap();
331 assert!(fitted.selected_indices().contains(&0));
333 }
334
335 #[test]
336 fn test_select_percentile_scores_stored() {
337 let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
338 let x = array![[1.0, 2.0], [3.0, 4.0]];
339 let y: Array1<usize> = array![0, 1];
340 let fitted = sel.fit(&x, &y).unwrap();
341 assert_eq!(fitted.scores().len(), 2);
342 }
343
344 #[test]
345 fn test_select_percentile_zero_rows_error() {
346 let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
347 let x: Array2<f64> = Array2::zeros((0, 3));
348 let y: Array1<usize> = Array1::zeros(0);
349 assert!(sel.fit(&x, &y).is_err());
350 }
351
352 #[test]
353 fn test_select_percentile_over_100_error() {
354 let sel = SelectPercentile::<f64>::new(150, ScoreFunc::FClassif);
355 let x = array![[1.0, 2.0], [3.0, 4.0]];
356 let y: Array1<usize> = array![0, 1];
357 assert!(sel.fit(&x, &y).is_err());
358 }
359
360 #[test]
361 fn test_select_percentile_y_length_mismatch_error() {
362 let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
363 let x = array![[1.0, 2.0], [3.0, 4.0]];
364 let y: Array1<usize> = array![0]; assert!(sel.fit(&x, &y).is_err());
366 }
367
368 #[test]
369 fn test_select_percentile_shape_mismatch_on_transform() {
370 let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
371 let x = array![[1.0, 2.0], [3.0, 4.0]];
372 let y: Array1<usize> = array![0, 1];
373 let fitted = sel.fit(&x, &y).unwrap();
374 let x_bad = array![[1.0, 2.0, 3.0]];
375 assert!(fitted.transform(&x_bad).is_err());
376 }
377
378 #[test]
379 fn test_select_percentile_default() {
380 let sel = SelectPercentile::<f64>::default();
381 assert_eq!(sel.percentile(), 10);
382 }
383
384 #[test]
385 fn test_select_percentile_indices_sorted() {
386 let sel = SelectPercentile::<f64>::new(50, ScoreFunc::FClassif);
387 let x = array![
388 [1.0, 100.0, 0.5, 0.01],
389 [2.0, 200.0, 0.6, 0.02],
390 [10.0, 100.0, 0.5, 0.01],
391 [20.0, 200.0, 0.6, 0.02]
392 ];
393 let y: Array1<usize> = array![0, 0, 1, 1];
394 let fitted = sel.fit(&x, &y).unwrap();
395 let indices = fitted.selected_indices();
396 assert!(indices.windows(2).all(|w| w[0] < w[1]));
398 }
399}