1use anofox_ml_core::{Float, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2, Axis};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10pub enum ScoringFunction {
11 FClassif,
17
18 FRegression,
25
26 Variance,
31}
32
33#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
63pub struct SelectKBest {
64 pub k: usize,
66 pub scoring_fn: ScoringFunction,
68}
69
70impl SelectKBest {
71 pub fn new(k: usize, scoring_fn: ScoringFunction) -> Self {
74 Self { k, scoring_fn }
75 }
76
77 pub fn fit<F: Float>(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedSelectKBest<F>> {
83 let (n_samples, n_features) = x.dim();
84
85 if n_samples == 0 || n_features == 0 {
86 return Err(RustMlError::EmptyInput("input array is empty".into()));
87 }
88
89 if self.k == 0 {
90 return Err(RustMlError::InvalidParameter("k must be at least 1".into()));
91 }
92
93 if self.k > n_features {
94 return Err(RustMlError::InvalidParameter(format!(
95 "k ({}) exceeds number of features ({})",
96 self.k, n_features
97 )));
98 }
99
100 if !matches!(self.scoring_fn, ScoringFunction::Variance) {
102 if y.len() != n_samples {
103 return Err(RustMlError::ShapeMismatch(format!(
104 "X has {} samples but y has {} elements",
105 n_samples,
106 y.len()
107 )));
108 }
109 }
110
111 let scores = match &self.scoring_fn {
112 ScoringFunction::FClassif => compute_f_classif(x, y)?,
113 ScoringFunction::FRegression => compute_f_regression(x, y)?,
114 ScoringFunction::Variance => compute_variance(x),
115 };
116
117 let mut feature_scores: Vec<(usize, F)> = scores.iter().copied().enumerate().collect();
119 feature_scores.sort_by(|a, b| {
120 b.1.partial_cmp(&a.1)
121 .unwrap_or(std::cmp::Ordering::Equal)
122 .then(a.0.cmp(&b.0))
123 });
124
125 let mut selected_indices: Vec<usize> = feature_scores
126 .iter()
127 .take(self.k)
128 .map(|&(idx, _)| idx)
129 .collect();
130 selected_indices.sort_unstable();
132
133 Ok(FittedSelectKBest {
134 scores,
135 selected_indices,
136 n_features_in: n_features,
137 })
138 }
139}
140
141#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
144#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
145pub struct FittedSelectKBest<F: Float> {
146 scores: Array1<F>,
148 selected_indices: Vec<usize>,
150 n_features_in: usize,
152}
153
154impl<F: Float> FittedSelectKBest<F> {
155 pub fn scores(&self) -> &Array1<F> {
157 &self.scores
158 }
159
160 pub fn selected_indices(&self) -> &[usize] {
162 &self.selected_indices
163 }
164}
165
166impl<F: Float> Transform<F> for FittedSelectKBest<F> {
167 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
168 if x.ncols() != self.n_features_in {
169 return Err(RustMlError::ShapeMismatch(format!(
170 "expected {} features, got {}",
171 self.n_features_in,
172 x.ncols()
173 )));
174 }
175
176 let n_rows = x.nrows();
177 let n_selected = self.selected_indices.len();
178 let mut result = Array2::<F>::zeros((n_rows, n_selected));
179
180 for (i, row) in x.rows().into_iter().enumerate() {
181 for (out_j, &src_j) in self.selected_indices.iter().enumerate() {
182 result[[i, out_j]] = row[src_j];
183 }
184 }
185
186 Ok(result)
187 }
188}
189
190fn compute_f_classif<F: Float>(x: &Array2<F>, y: &Array1<F>) -> Result<Array1<F>> {
201 let (n_samples, n_features) = x.dim();
202 let n_f = F::from_usize(n_samples).unwrap();
203
204 let mut label_map: HashMap<u64, usize> = HashMap::new();
206 let mut class_indices: Vec<usize> = Vec::with_capacity(n_samples);
207 for &val in y.iter() {
208 let bits = val.to_f64().unwrap().to_bits();
209 let next_id = label_map.len();
210 let id = *label_map.entry(bits).or_insert(next_id);
211 class_indices.push(id);
212 }
213 let n_classes = label_map.len();
214
215 if n_classes < 2 {
216 return Err(RustMlError::InvalidParameter(
217 "FClassif requires at least 2 classes".into(),
218 ));
219 }
220
221 if n_samples <= n_classes {
222 return Err(RustMlError::InvalidParameter(
223 "not enough samples for FClassif (need more samples than classes)".into(),
224 ));
225 }
226
227 let mut class_counts = vec![0usize; n_classes];
229 for &c in &class_indices {
230 class_counts[c] += 1;
231 }
232
233 let mut scores = Array1::<F>::zeros(n_features);
234
235 for j in 0..n_features {
236 let col = x.column(j);
237
238 let grand_mean = col.sum() / n_f;
240
241 let mut class_sums = vec![F::zero(); n_classes];
243 for (i, &val) in col.iter().enumerate() {
244 class_sums[class_indices[i]] += val;
245 }
246
247 let mut ssb = F::zero();
249 for c in 0..n_classes {
250 let nc = F::from_usize(class_counts[c]).unwrap();
251 let class_mean = class_sums[c] / nc;
252 let diff = class_mean - grand_mean;
253 ssb += nc * diff * diff;
254 }
255
256 let mut ssw = F::zero();
258 for (i, &val) in col.iter().enumerate() {
259 let c = class_indices[i];
260 let nc = F::from_usize(class_counts[c]).unwrap();
261 let class_mean = class_sums[c] / nc;
262 let diff = val - class_mean;
263 ssw += diff * diff;
264 }
265
266 let df_between = F::from_usize(n_classes - 1).unwrap();
268 let df_within = F::from_usize(n_samples - n_classes).unwrap();
269
270 let eps = F::from_f64(1e-15).unwrap();
271 if ssw < eps {
272 scores[j] = if ssb > eps {
275 F::from_f64(1e12).unwrap()
276 } else {
277 F::zero()
278 };
279 } else {
280 let msb = ssb / df_between;
281 let msw = ssw / df_within;
282 scores[j] = msb / msw;
283 }
284 }
285
286 Ok(scores)
287}
288
289fn compute_f_regression<F: Float>(x: &Array2<F>, y: &Array1<F>) -> Result<Array1<F>> {
295 let (n_samples, n_features) = x.dim();
296
297 if n_samples < 3 {
298 return Err(RustMlError::InvalidParameter(
299 "FRegression requires at least 3 samples".into(),
300 ));
301 }
302
303 let n_f = F::from_usize(n_samples).unwrap();
304 let eps = F::from_f64(1e-15).unwrap();
305
306 let y_mean = y.sum() / n_f;
308 let mut y_var = F::zero();
309 for &val in y.iter() {
310 let diff = val - y_mean;
311 y_var += diff * diff;
312 }
313
314 let mut scores = Array1::<F>::zeros(n_features);
315
316 for j in 0..n_features {
317 let col = x.column(j);
318 let x_mean = col.sum() / n_f;
319
320 let mut cov_xy = F::zero();
321 let mut x_var = F::zero();
322 for (&xv, &yv) in col.iter().zip(y.iter()) {
323 let dx = xv - x_mean;
324 let dy = yv - y_mean;
325 cov_xy += dx * dy;
326 x_var += dx * dx;
327 }
328
329 if x_var < eps || y_var < eps {
330 scores[j] = F::zero();
331 continue;
332 }
333
334 let r = cov_xy / (x_var.sqrt() * y_var.sqrt());
335 let r2 = r * r;
336
337 let one = F::one();
338 let denom = one - r2;
339 if denom < eps {
340 scores[j] = F::from_f64(1e12).unwrap();
342 } else {
343 let n_minus_2 = F::from_usize(n_samples - 2).unwrap();
344 scores[j] = r2 * n_minus_2 / denom;
345 }
346 }
347
348 Ok(scores)
349}
350
351fn compute_variance<F: Float>(x: &Array2<F>) -> Array1<F> {
353 let n = F::from_usize(x.nrows()).unwrap();
354 let mean = x.sum_axis(Axis(0)) / n;
355 let n_features = x.ncols();
356
357 let mut variances = Array1::<F>::zeros(n_features);
358 for row in x.rows() {
359 for (j, (&val, &m)) in row.iter().zip(mean.iter()).enumerate() {
360 let diff = val - m;
361 variances[j] += diff * diff;
362 }
363 }
364 variances.mapv_inplace(|v| v / n);
365 variances
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use ndarray::array;
372
373 #[test]
374 fn test_f_classif_selects_discriminative_feature() {
375 let x = array![
377 [0.0, 0.5],
378 [0.0, 0.8],
379 [0.0, 0.2],
380 [0.0, 0.9],
381 [1.0, 0.3],
382 [1.0, 0.7],
383 [1.0, 0.1],
384 [1.0, 0.6],
385 ];
386 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
387
388 let selector = SelectKBest::new(1, ScoringFunction::FClassif);
389 let fitted = selector.fit(&x, &y).unwrap();
390
391 assert_eq!(fitted.selected_indices(), &[0]);
392 assert!(
393 fitted.scores()[0] > fitted.scores()[1],
394 "discriminative feature score ({}) should exceed noise ({})",
395 fitted.scores()[0],
396 fitted.scores()[1]
397 );
398 }
399
400 #[test]
401 fn test_f_regression_selects_correlated_feature() {
402 let x = array![[1.0, 5.0], [2.0, 5.0], [3.0, 5.0], [4.0, 5.0], [5.0, 5.0],];
405 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
406
407 let selector = SelectKBest::new(1, ScoringFunction::FRegression);
408 let fitted = selector.fit(&x, &y).unwrap();
409
410 assert_eq!(fitted.selected_indices(), &[0]);
411 assert!(fitted.scores()[0] > 100.0_f64);
413 assert!(fitted.scores()[1].abs() < 1e-10_f64);
415 }
416
417 #[test]
418 fn test_variance_scoring_selects_high_variance_feature() {
419 let x = array![
421 [1.0, 10.0, 5.0],
422 [1.1, 20.0, 5.0],
423 [0.9, 30.0, 5.0],
424 [1.0, 40.0, 5.0],
425 ];
426 let y = array![0.0, 0.0, 0.0, 0.0]; let selector = SelectKBest::new(1, ScoringFunction::Variance);
429 let fitted = selector.fit(&x, &y).unwrap();
430
431 assert_eq!(fitted.selected_indices(), &[1]);
432 }
433
434 #[test]
435 fn test_transform_outputs_correct_columns() {
436 let x = array![[10.0, 20.0, 30.0], [40.0, 50.0, 60.0], [70.0, 80.0, 90.0],];
437 let y = array![1.0, 2.0, 3.0];
438
439 let selector = SelectKBest::new(2, ScoringFunction::FRegression);
440 let fitted = selector.fit(&x, &y).unwrap();
441 let result = fitted.transform(&x).unwrap();
442
443 assert_eq!(result.nrows(), 3);
444 assert_eq!(result.ncols(), 2);
445
446 for &idx in fitted.selected_indices() {
448 let original_col: Vec<f64> = x.column(idx).to_vec();
449 let out_pos = fitted
450 .selected_indices()
451 .iter()
452 .position(|&i| i == idx)
453 .unwrap();
454 let result_col: Vec<f64> = result.column(out_pos).to_vec();
455 assert_eq!(original_col, result_col);
456 }
457 }
458
459 #[test]
460 fn test_error_k_zero() {
461 let x = array![[1.0, 2.0], [3.0, 4.0]];
462 let y = array![0.0, 1.0];
463
464 let selector = SelectKBest::new(0, ScoringFunction::FClassif);
465 let result = selector.fit(&x, &y);
466 assert!(result.is_err());
467 }
468
469 #[test]
470 fn test_error_k_exceeds_features() {
471 let x = array![[1.0, 2.0], [3.0, 4.0]];
472 let y = array![0.0, 1.0];
473
474 let selector = SelectKBest::new(5, ScoringFunction::FClassif);
475 let result = selector.fit(&x, &y);
476 assert!(result.is_err());
477 match result.unwrap_err() {
478 RustMlError::InvalidParameter(msg) => {
479 assert!(msg.contains("exceeds"), "unexpected message: {}", msg);
480 }
481 other => panic!("expected InvalidParameter, got {:?}", other),
482 }
483 }
484
485 #[test]
486 fn test_error_shape_mismatch_x_y() {
487 let x = array![[1.0, 2.0], [3.0, 4.0]];
488 let y = array![0.0, 1.0, 2.0]; let selector = SelectKBest::new(1, ScoringFunction::FClassif);
491 let result = selector.fit(&x, &y);
492 assert!(result.is_err());
493 match result.unwrap_err() {
494 RustMlError::ShapeMismatch(msg) => {
495 assert!(msg.contains("samples"), "unexpected message: {}", msg);
496 }
497 other => panic!("expected ShapeMismatch, got {:?}", other),
498 }
499 }
500
501 #[test]
502 fn test_error_on_empty_input() {
503 let x = Array2::<f64>::zeros((0, 3));
504 let y = Array1::<f64>::zeros(0);
505
506 let selector = SelectKBest::new(1, ScoringFunction::FRegression);
507 let result = selector.fit(&x, &y);
508 assert!(result.is_err());
509 }
510
511 #[test]
512 fn test_shape_mismatch_on_transform() {
513 let x = array![
514 [1.0, 2.0, 3.0],
515 [4.0, 5.0, 6.0],
516 [7.0, 8.0, 9.0],
517 [10.0, 11.0, 12.0],
518 ];
519 let y = array![0.0, 0.0, 1.0, 1.0];
520
521 let selector = SelectKBest::new(1, ScoringFunction::FClassif);
522 let fitted = selector.fit(&x, &y).unwrap();
523
524 let wrong = array![[1.0, 2.0]]; assert!(fitted.transform(&wrong).is_err());
526 }
527
528 #[test]
529 fn test_selects_all_when_k_equals_n_features() {
530 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
531 let y = array![1.0, 2.0, 3.0];
532
533 let selector = SelectKBest::new(2, ScoringFunction::FRegression);
534 let fitted = selector.fit(&x, &y).unwrap();
535
536 assert_eq!(fitted.selected_indices().len(), 2);
537 assert_eq!(fitted.selected_indices(), &[0, 1]);
538 }
539
540 #[test]
541 fn test_works_with_f32() {
542 let x: Array2<f32> = array![[0.0_f32, 0.5], [0.0, 0.8], [1.0, 0.3], [1.0, 0.7],];
543 let y: Array1<f32> = array![0.0_f32, 0.0, 1.0, 1.0];
544
545 let selector = SelectKBest::new(1, ScoringFunction::FClassif);
546 let fitted = selector.fit(&x, &y).unwrap();
547
548 assert_eq!(fitted.selected_indices().len(), 1);
549 let result = fitted.transform(&x).unwrap();
550 assert_eq!(result.ncols(), 1);
551 }
552}