Skip to main content

ferrolearn_preprocess/
feature_scoring.rs

1//! Feature scoring functions for feature selection.
2//!
3//! This module provides standalone univariate scoring functions that compute
4//! per-feature statistics and p-values:
5//!
6//! - [`f_classif`] — ANOVA F-statistic for classification.
7//! - [`f_regression`] — univariate F-statistic via Pearson correlation.
8//! - [`chi2`] — chi-squared statistic for non-negative features.
9//!
10//! These functions return `(F-statistics, p-values)` tuples and can be used
11//! directly or passed to [`SelectKBest`](crate::feature_selection::SelectKBest)
12//! / [`SelectPercentile`](crate::select_percentile::SelectPercentile) via the
13//! [`ScoreFunc`](crate::feature_selection::ScoreFunc) enum.
14
15use ferrolearn_core::error::FerroError;
16use ndarray::{Array1, Array2};
17use num_traits::Float;
18
19// ===========================================================================
20// f_classif — ANOVA F-statistic
21// ===========================================================================
22
23/// Compute the ANOVA F-statistic and approximate p-values for each feature.
24///
25/// For each feature column the between-class and within-class sum of squares
26/// are computed. The F-statistic is:
27///
28/// ```text
29/// F = (SSB / (k - 1)) / (SSW / (n - k))
30/// ```
31///
32/// where *k* is the number of distinct classes and *n* is the number of
33/// samples.
34///
35/// P-values are approximated using the regularized incomplete beta function
36/// from `ferrolearn-numerical`. If the numerical CDF is unavailable, `NaN`
37/// is returned for the p-value.
38///
39/// # Returns
40///
41/// `(f_statistics, p_values)` — two `Array1<F>` of length `n_features`.
42///
43/// # Errors
44///
45/// - [`FerroError::InsufficientSamples`] if `x` has zero rows.
46/// - [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()`.
47/// - [`FerroError::InvalidParameter`] if fewer than 2 classes are present.
48///
49/// # Examples
50///
51/// ```
52/// use ferrolearn_preprocess::feature_scoring::f_classif;
53/// use ndarray::{array, Array1};
54///
55/// let x = array![[1.0_f64, 100.0], [2.0, 200.0], [10.0, 100.0], [11.0, 200.0]];
56/// let y: Array1<usize> = array![0, 0, 1, 1];
57/// let (f_stats, p_vals) = f_classif(&x, &y).unwrap();
58/// assert_eq!(f_stats.len(), 2);
59/// // Feature 0 separates classes well → high F
60/// assert!(f_stats[0] > f_stats[1]);
61/// ```
62pub fn f_classif<F: Float + Send + Sync + 'static>(
63    x: &Array2<F>,
64    y: &Array1<usize>,
65) -> Result<(Array1<F>, Array1<F>), FerroError> {
66    let n_samples = x.nrows();
67    if n_samples == 0 {
68        return Err(FerroError::InsufficientSamples {
69            required: 1,
70            actual: 0,
71            context: "f_classif".into(),
72        });
73    }
74    if y.len() != n_samples {
75        return Err(FerroError::ShapeMismatch {
76            expected: vec![n_samples],
77            actual: vec![y.len()],
78            context: "f_classif — y must have same length as x rows".into(),
79        });
80    }
81
82    // Collect per-class row indices
83    let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
84        std::collections::HashMap::new();
85    for (i, &label) in y.iter().enumerate() {
86        class_indices.entry(label).or_default().push(i);
87    }
88    let n_classes = class_indices.len();
89    if n_classes < 2 {
90        return Err(FerroError::InvalidParameter {
91            name: "y".into(),
92            reason: format!("f_classif requires at least 2 classes, got {n_classes}"),
93        });
94    }
95
96    let n_features = x.ncols();
97    let n_f = F::from(n_samples).unwrap();
98
99    let df_between = n_classes - 1;
100    let df_within = n_samples - n_classes;
101    let df_b = F::from(df_between).unwrap();
102    let df_w = F::from(df_within).unwrap();
103
104    let mut f_stats = Array1::zeros(n_features);
105    let mut p_vals = Array1::zeros(n_features);
106
107    for j in 0..n_features {
108        let col = x.column(j);
109        let grand_mean = col.iter().copied().fold(F::zero(), |acc, v| acc + v) / n_f;
110
111        let mut ss_between = F::zero();
112        let mut ss_within = F::zero();
113
114        for rows in class_indices.values() {
115            let n_k = F::from(rows.len()).unwrap();
116            let class_mean = rows
117                .iter()
118                .map(|&i| col[i])
119                .fold(F::zero(), |acc, v| acc + v)
120                / n_k;
121            let diff = class_mean - grand_mean;
122            ss_between = ss_between + n_k * diff * diff;
123            for &i in rows {
124                let d = col[i] - class_mean;
125                ss_within = ss_within + d * d;
126            }
127        }
128
129        let f = if df_w == F::zero() {
130            F::zero()
131        } else {
132            let ms_between = ss_between / df_b;
133            let ms_within = ss_within / df_w;
134            if ms_within == F::zero() {
135                F::infinity()
136            } else {
137                ms_between / ms_within
138            }
139        };
140
141        f_stats[j] = f;
142        p_vals[j] = f_distribution_sf(f, df_between, df_within);
143    }
144
145    Ok((f_stats, p_vals))
146}
147
148// ===========================================================================
149// f_regression — Pearson correlation-based F-statistic
150// ===========================================================================
151
152/// Compute univariate F-statistics via Pearson correlation for regression.
153///
154/// For each feature the Pearson correlation coefficient *r* with the target
155/// is computed, then:
156///
157/// ```text
158/// F = r^2 * (n - 2) / (1 - r^2)
159/// ```
160///
161/// # Returns
162///
163/// `(f_statistics, p_values)` — two `Array1<F>` of length `n_features`.
164///
165/// # Errors
166///
167/// - [`FerroError::InsufficientSamples`] if `x` has fewer than 3 rows.
168/// - [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()`.
169///
170/// # Examples
171///
172/// ```
173/// use ferrolearn_preprocess::feature_scoring::f_regression;
174/// use ndarray::{array, Array1};
175///
176/// let x = array![[1.0_f64, 100.0], [2.0, 200.0], [3.0, 100.0], [4.0, 200.0]];
177/// let y: Array1<f64> = array![1.0, 2.0, 3.0, 4.0];
178/// let (f_stats, _p_vals) = f_regression(&x, &y).unwrap();
179/// assert_eq!(f_stats.len(), 2);
180/// ```
181pub fn f_regression<F: Float + Send + Sync + 'static>(
182    x: &Array2<F>,
183    y: &Array1<F>,
184) -> Result<(Array1<F>, Array1<F>), FerroError> {
185    let n_samples = x.nrows();
186    if n_samples < 3 {
187        return Err(FerroError::InsufficientSamples {
188            required: 3,
189            actual: n_samples,
190            context: "f_regression requires at least 3 samples".into(),
191        });
192    }
193    if y.len() != n_samples {
194        return Err(FerroError::ShapeMismatch {
195            expected: vec![n_samples],
196            actual: vec![y.len()],
197            context: "f_regression — y must have same length as x rows".into(),
198        });
199    }
200
201    let n_f = F::from(n_samples).unwrap();
202    let n_features = x.ncols();
203
204    // Precompute y stats
205    let y_mean = y.iter().copied().fold(F::zero(), |acc, v| acc + v) / n_f;
206    let y_var = y
207        .iter()
208        .copied()
209        .map(|v| (v - y_mean) * (v - y_mean))
210        .fold(F::zero(), |acc, v| acc + v);
211
212    let two = F::from(2.0).unwrap();
213
214    let mut f_stats = Array1::zeros(n_features);
215    let mut p_vals = Array1::zeros(n_features);
216
217    for j in 0..n_features {
218        let col = x.column(j);
219        let x_mean = col.iter().copied().fold(F::zero(), |acc, v| acc + v) / n_f;
220        let x_var = col
221            .iter()
222            .copied()
223            .map(|v| (v - x_mean) * (v - x_mean))
224            .fold(F::zero(), |acc, v| acc + v);
225
226        let cov = col
227            .iter()
228            .copied()
229            .zip(y.iter().copied())
230            .map(|(xi, yi)| (xi - x_mean) * (yi - y_mean))
231            .fold(F::zero(), |acc, v| acc + v);
232
233        let denom = x_var * y_var;
234        let r = if denom == F::zero() {
235            F::zero()
236        } else {
237            cov / denom.sqrt()
238        };
239
240        let r2 = r * r;
241        let f = if r2 >= F::one() {
242            F::infinity()
243        } else {
244            r2 * (n_f - two) / (F::one() - r2)
245        };
246
247        f_stats[j] = f;
248        // F-distribution with df1=1, df2=n-2
249        p_vals[j] = f_distribution_sf(f, 1, n_samples - 2);
250    }
251
252    Ok((f_stats, p_vals))
253}
254
255// ===========================================================================
256// chi2 — Chi-squared statistic
257// ===========================================================================
258
259/// Compute chi-squared statistics between each non-negative feature and the
260/// class labels.
261///
262/// For each feature the observed and expected frequencies per class are
263/// computed, then:
264///
265/// ```text
266/// chi2 = sum_class (observed - expected)^2 / expected
267/// ```
268///
269/// where `observed` is the sum of feature values for samples of that class,
270/// and `expected` is the expected sum under the null hypothesis (proportional
271/// to the class frequency and the overall feature sum).
272///
273/// # Returns
274///
275/// `(chi2_statistics, p_values)` — two `Array1<F>` of length `n_features`.
276///
277/// # Errors
278///
279/// - [`FerroError::InsufficientSamples`] if `x` has zero rows.
280/// - [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()`.
281/// - [`FerroError::InvalidParameter`] if any feature value is negative.
282///
283/// # Examples
284///
285/// ```
286/// use ferrolearn_preprocess::feature_scoring::chi2;
287/// use ndarray::{array, Array1};
288///
289/// let x = array![[1.0_f64, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
290/// let y: Array1<usize> = array![0, 1, 0, 1];
291/// let (chi2_stats, _p_vals) = chi2(&x, &y).unwrap();
292/// assert_eq!(chi2_stats.len(), 2);
293/// ```
294pub fn chi2<F: Float + Send + Sync + 'static>(
295    x: &Array2<F>,
296    y: &Array1<usize>,
297) -> Result<(Array1<F>, Array1<F>), FerroError> {
298    let n_samples = x.nrows();
299    if n_samples == 0 {
300        return Err(FerroError::InsufficientSamples {
301            required: 1,
302            actual: 0,
303            context: "chi2".into(),
304        });
305    }
306    if y.len() != n_samples {
307        return Err(FerroError::ShapeMismatch {
308            expected: vec![n_samples],
309            actual: vec![y.len()],
310            context: "chi2 — y must have same length as x rows".into(),
311        });
312    }
313
314    // Validate non-negative
315    for j in 0..x.ncols() {
316        for i in 0..n_samples {
317            if x[[i, j]] < F::zero() {
318                return Err(FerroError::InvalidParameter {
319                    name: "x".into(),
320                    reason: format!(
321                        "chi2 requires non-negative features, found negative value at ({i}, {j})"
322                    ),
323                });
324            }
325        }
326    }
327
328    // Collect per-class row indices
329    let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
330        std::collections::HashMap::new();
331    for (i, &label) in y.iter().enumerate() {
332        class_indices.entry(label).or_default().push(i);
333    }
334
335    let n_classes = class_indices.len();
336    let n_features = x.ncols();
337    let n_f = F::from(n_samples).unwrap();
338
339    let mut chi2_stats = Array1::zeros(n_features);
340    let mut p_vals = Array1::zeros(n_features);
341
342    for j in 0..n_features {
343        let col = x.column(j);
344        let total_sum = col.iter().copied().fold(F::zero(), |acc, v| acc + v);
345
346        if total_sum == F::zero() {
347            // All zero → chi2 = 0, p = 1
348            chi2_stats[j] = F::zero();
349            p_vals[j] = F::one();
350            continue;
351        }
352
353        let mut chi2_val = F::zero();
354
355        for rows in class_indices.values() {
356            let n_k = F::from(rows.len()).unwrap();
357            let observed = rows
358                .iter()
359                .map(|&i| col[i])
360                .fold(F::zero(), |acc, v| acc + v);
361            let expected = total_sum * n_k / n_f;
362
363            if expected > F::zero() {
364                let diff = observed - expected;
365                chi2_val = chi2_val + diff * diff / expected;
366            }
367        }
368
369        chi2_stats[j] = chi2_val;
370        // Chi-squared distribution with df = n_classes - 1
371        let df = n_classes.saturating_sub(1);
372        p_vals[j] = chi2_distribution_sf(chi2_val, df);
373    }
374
375    Ok((chi2_stats, p_vals))
376}
377
378// ===========================================================================
379// Distribution helper: F-distribution survival function (1 - CDF)
380// ===========================================================================
381
382/// Approximate the survival function (1 - CDF) of the F-distribution.
383///
384/// Uses the relationship between the F-distribution and the regularized
385/// incomplete beta function:
386///
387/// ```text
388/// P(F > x) = I_{d2/(d2 + d1*x)}(d2/2, d1/2)
389/// ```
390///
391/// Returns `NaN` if the computation cannot be performed.
392fn f_distribution_sf<F: Float>(x: F, df1: usize, df2: usize) -> F {
393    if x <= F::zero() {
394        return F::one();
395    }
396    if df1 == 0 || df2 == 0 {
397        return F::nan();
398    }
399
400    let d1 = F::from(df1).unwrap();
401    let d2 = F::from(df2).unwrap();
402
403    // I_{d2/(d2 + d1*x)}(d2/2, d1/2)
404    let z = d2 / (d2 + d1 * x);
405    let a = d2 / F::from(2.0).unwrap();
406    let b = d1 / F::from(2.0).unwrap();
407
408    regularized_incomplete_beta(z, a, b)
409}
410
411/// Approximate the survival function (1 - CDF) of the chi-squared distribution.
412///
413/// Uses the relationship: chi2 with k df = Gamma(k/2, 2), and
414/// P(X > x) = 1 - gamma_cdf = upper regularized gamma Q(k/2, x/2).
415///
416/// We use the relationship to the regularized incomplete beta function:
417/// Q(a, x) = I_{x/(x+a)}(... ) — but more simply, chi2 with k df is
418/// equivalent to F(k, inf) scaled. We use a direct series approximation.
419fn chi2_distribution_sf<F: Float>(x: F, df: usize) -> F {
420    if x <= F::zero() {
421        return F::one();
422    }
423    if df == 0 {
424        return F::nan();
425    }
426
427    // Use the upper regularized gamma function Q(k/2, x/2)
428    let a = F::from(df).unwrap() / F::from(2.0).unwrap();
429    let z = x / F::from(2.0).unwrap();
430
431    upper_regularized_gamma(a, z)
432}
433
434/// Upper regularized gamma function Q(a, x) = 1 - P(a, x).
435///
436/// Uses a continued fraction expansion for x >= a + 1, and the series
437/// expansion otherwise.
438fn upper_regularized_gamma<F: Float>(a: F, x: F) -> F {
439    if x <= F::zero() {
440        return F::one();
441    }
442
443    let one = F::one();
444    let two = F::from(2.0).unwrap();
445
446    // Use series for P(a, x) when x < a + 1, then Q = 1 - P
447    if x < a + one {
448        let p = lower_regularized_gamma_series(a, x);
449        return one - p;
450    }
451
452    // Continued fraction for Q(a, x) — Lentz's method
453    let eps = F::from(1.0e-12).unwrap();
454    let tiny = F::from(1.0e-30).unwrap();
455
456    let mut c = tiny;
457    let mut d = F::one() / (x + one - a);
458    let mut f = d;
459
460    for n_iter in 1..200 {
461        let n = F::from(n_iter).unwrap();
462        // Even term
463        let an_even = n * (a - n);
464        let bn_even = x + two * n + one - a;
465        d = F::one() / (bn_even + an_even * d);
466        c = bn_even + an_even / c;
467        let delta = c * d;
468        f = f * delta;
469
470        if (delta - one).abs() < eps {
471            break;
472        }
473    }
474
475    // Q(a, x) = e^(-x) * x^a / Gamma(a) * f
476    let log_prefix = a * x.ln() - x - ln_gamma(a);
477    let prefix = log_prefix.exp();
478    let result = prefix * f;
479
480    // Clamp to [0, 1]
481    if result < F::zero() {
482        F::zero()
483    } else if result > one {
484        one
485    } else {
486        result
487    }
488}
489
490/// Lower regularized gamma function P(a, x) via series expansion.
491fn lower_regularized_gamma_series<F: Float>(a: F, x: F) -> F {
492    let eps = F::from(1.0e-12).unwrap();
493    let one = F::one();
494
495    let mut sum = one / a;
496    let mut term = one / a;
497
498    for n in 1..200 {
499        let n_f = F::from(n).unwrap();
500        term = term * x / (a + n_f);
501        sum = sum + term;
502        if term.abs() < eps * sum.abs() {
503            break;
504        }
505    }
506
507    let log_prefix = a * x.ln() - x - ln_gamma(a);
508    let result = log_prefix.exp() * sum;
509
510    // Clamp to [0, 1]
511    if result < F::zero() {
512        F::zero()
513    } else if result > one {
514        one
515    } else {
516        result
517    }
518}
519
520/// Regularized incomplete beta function I_x(a, b) using a continued fraction
521/// (Lentz's method).
522fn regularized_incomplete_beta<F: Float>(x: F, a: F, b: F) -> F {
523    let one = F::one();
524    let two = F::from(2.0).unwrap();
525
526    if x <= F::zero() {
527        return F::zero();
528    }
529    if x >= one {
530        return one;
531    }
532
533    // Use the symmetry relation if x > (a+1)/(a+b+2) for better convergence
534    if x > (a + one) / (a + b + two) {
535        return one - regularized_incomplete_beta(one - x, b, a);
536    }
537
538    // Prefix: x^a * (1-x)^b / (a * Beta(a,b))
539    let log_prefix = a * x.ln() + b * (one - x).ln() - ln_beta(a, b) - a.ln();
540    let prefix = log_prefix.exp();
541
542    // Continued fraction (Lentz's algorithm)
543    let eps = F::from(1.0e-12).unwrap();
544    let tiny = F::from(1.0e-30).unwrap();
545
546    let mut f = tiny;
547    let mut c = tiny;
548    let mut d = one;
549
550    for m in 0..200 {
551        let m_f = F::from(m).unwrap();
552        let (a_m, b_m) = if m == 0 {
553            (one, one)
554        } else if m % 2 == 0 {
555            // Even: d_{2m} term
556            let k = m_f / two;
557            let num = k * (b - k) * x / ((a + two * k - one) * (a + two * k));
558            (num, one)
559        } else {
560            // Odd: d_{2m+1} term
561            let k = (m_f - one) / two;
562            let num =
563                -((a + k) * (a + b + k) * x) / ((a + two * k) * (a + two * k + one));
564            (num, one)
565        };
566
567        if m == 0 {
568            f = b_m;
569            c = b_m;
570            d = one / b_m;
571            continue;
572        }
573
574        d = b_m + a_m * d;
575        if d.abs() < tiny {
576            d = tiny;
577        }
578        d = one / d;
579
580        c = b_m + a_m / c;
581        if c.abs() < tiny {
582            c = tiny;
583        }
584
585        let delta = c * d;
586        f = f * delta;
587
588        if (delta - one).abs() < eps {
589            break;
590        }
591    }
592
593    let result = prefix * f;
594    if result < F::zero() {
595        F::zero()
596    } else if result > one {
597        one
598    } else {
599        result
600    }
601}
602
603/// Log of the beta function: ln(Beta(a, b)) = lnGamma(a) + lnGamma(b) - lnGamma(a+b).
604fn ln_beta<F: Float>(a: F, b: F) -> F {
605    ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b)
606}
607
608/// Lanczos approximation of ln(Gamma(x)) for x > 0.
609fn ln_gamma<F: Float>(x: F) -> F {
610    // Lanczos coefficients (g=7, n=9)
611    let coefs: [f64; 9] = [
612        0.999_999_999_999_809_9,
613        676.520_368_121_885_1,
614        -1_259.139_216_722_402_8,
615        771.323_428_777_653_1,
616        -176.615_029_162_140_6,
617        12.507_343_278_686_905,
618        -0.138_571_095_265_720_12,
619        9.984_369_578_019_572e-6,
620        1.505_632_735_149_311_6e-7,
621    ];
622
623    let one = F::one();
624    let half = F::from(0.5).unwrap();
625    let g = F::from(7.0).unwrap();
626
627    if x < half {
628        // Reflection formula
629        let pi = F::from(std::f64::consts::PI).unwrap();
630        return pi.ln() - (pi * x).sin().ln() - ln_gamma(one - x);
631    }
632
633    let z = x - one;
634    let mut sum = F::from(coefs[0]).unwrap();
635    for (i, &c) in coefs.iter().enumerate().skip(1) {
636        sum = sum + F::from(c).unwrap() / (z + F::from(i).unwrap());
637    }
638
639    let t = z + g + half;
640    let sqrt_2pi = F::from(2.506_628_274_631_000_5).unwrap();
641
642    sqrt_2pi.ln() + (z + half) * t.ln() - t + sum.ln()
643}
644
645// ===========================================================================
646// ScoreFunc integration
647// ===========================================================================
648
649/// Add `FRegression` and `Chi2` variants to `ScoreFunc`.
650///
651/// This cannot extend the existing enum directly, so we provide adapter
652/// functions that compute scores in the format expected by `SelectKBest`.
653///
654/// Compute scores for the given score function name, returning F-scores only.
655///
656/// This is a convenience dispatcher for integration with feature selection.
657pub fn compute_scores_classif<F: Float + Send + Sync + 'static>(
658    x: &Array2<F>,
659    y: &Array1<usize>,
660    func: &str,
661) -> Result<Vec<F>, FerroError> {
662    match func {
663        "f_classif" => {
664            let (f_stats, _) = f_classif(x, y)?;
665            Ok(f_stats.to_vec())
666        }
667        "chi2" => {
668            let (chi2_stats, _) = chi2(x, y)?;
669            Ok(chi2_stats.to_vec())
670        }
671        _ => Err(FerroError::InvalidParameter {
672            name: "func".into(),
673            reason: format!("unknown classification score function: {func}"),
674        }),
675    }
676}
677
678/// Compute regression scores, returning F-statistics only.
679pub fn compute_scores_regression<F: Float + Send + Sync + 'static>(
680    x: &Array2<F>,
681    y: &Array1<F>,
682) -> Result<Vec<F>, FerroError> {
683    let (f_stats, _) = f_regression(x, y)?;
684    Ok(f_stats.to_vec())
685}
686
687// ===========================================================================
688// Tests
689// ===========================================================================
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694    use ndarray::array;
695
696    // -----------------------------------------------------------------------
697    // f_classif tests
698    // -----------------------------------------------------------------------
699
700    #[test]
701    fn test_f_classif_basic() {
702        // Feature 0 separates classes well, feature 1 does not
703        let x = array![
704            [1.0_f64, 5.0],
705            [1.5, 5.5],
706            [2.0, 4.5],
707            [10.0, 5.0],
708            [10.5, 4.5],
709            [11.0, 5.5]
710        ];
711        let y: Array1<usize> = array![0, 0, 0, 1, 1, 1];
712        let (f_stats, p_vals) = f_classif(&x, &y).unwrap();
713        assert_eq!(f_stats.len(), 2);
714        assert_eq!(p_vals.len(), 2);
715        // Feature 0 should have much higher F than feature 1
716        assert!(f_stats[0] > f_stats[1]);
717        // p-value for feature 0 should be very small
718        assert!(p_vals[0] < 0.05);
719    }
720
721    #[test]
722    fn test_f_classif_empty_input() {
723        let x = Array2::<f64>::zeros((0, 2));
724        let y: Array1<usize> = Array1::zeros(0);
725        assert!(f_classif(&x, &y).is_err());
726    }
727
728    #[test]
729    fn test_f_classif_shape_mismatch() {
730        let x = array![[1.0_f64, 2.0], [3.0, 4.0]];
731        let y: Array1<usize> = array![0, 1, 2]; // wrong length
732        assert!(f_classif(&x, &y).is_err());
733    }
734
735    #[test]
736    fn test_f_classif_single_class_error() {
737        let x = array![[1.0_f64], [2.0], [3.0]];
738        let y: Array1<usize> = array![0, 0, 0];
739        assert!(f_classif(&x, &y).is_err());
740    }
741
742    #[test]
743    fn test_f_classif_perfect_separation() {
744        // Feature perfectly separates classes → infinite F
745        let x = array![[0.0_f64], [0.0], [1.0], [1.0]];
746        let y: Array1<usize> = array![0, 0, 1, 1];
747        let (f_stats, _) = f_classif(&x, &y).unwrap();
748        assert!(f_stats[0].is_infinite());
749    }
750
751    #[test]
752    fn test_f_classif_p_values_bounded() {
753        let x = array![
754            [1.0_f64, 10.0],
755            [2.0, 20.0],
756            [3.0, 10.0],
757            [4.0, 20.0],
758            [5.0, 10.0],
759            [6.0, 20.0]
760        ];
761        let y: Array1<usize> = array![0, 0, 0, 1, 1, 1];
762        let (_, p_vals) = f_classif(&x, &y).unwrap();
763        for &p in p_vals.iter() {
764            assert!(p >= 0.0 && p <= 1.0, "p-value {p} out of bounds");
765        }
766    }
767
768    // -----------------------------------------------------------------------
769    // f_regression tests
770    // -----------------------------------------------------------------------
771
772    #[test]
773    fn test_f_regression_perfect_correlation() {
774        // Feature 0 = target → r=1 → F=infinity
775        let x = array![[1.0_f64, 10.0], [2.0, 20.0], [3.0, 30.0], [4.0, 40.0]];
776        let y: Array1<f64> = array![1.0, 2.0, 3.0, 4.0];
777        let (f_stats, _) = f_regression(&x, &y).unwrap();
778        assert!(f_stats[0].is_infinite() || f_stats[0] > 1.0e6);
779    }
780
781    #[test]
782    fn test_f_regression_no_correlation() {
783        // Orthogonal feature → r≈0 → F≈0
784        let x = array![[1.0_f64], [-1.0], [1.0], [-1.0]];
785        let y: Array1<f64> = array![1.0, 1.0, -1.0, -1.0];
786        let (f_stats, _) = f_regression(&x, &y).unwrap();
787        assert!(f_stats[0].abs() < 1.0e-6);
788    }
789
790    #[test]
791    fn test_f_regression_too_few_samples() {
792        let x = array![[1.0_f64], [2.0]];
793        let y: Array1<f64> = array![1.0, 2.0];
794        assert!(f_regression(&x, &y).is_err());
795    }
796
797    #[test]
798    fn test_f_regression_shape_mismatch() {
799        let x = array![[1.0_f64], [2.0], [3.0]];
800        let y: Array1<f64> = array![1.0, 2.0]; // wrong length
801        assert!(f_regression(&x, &y).is_err());
802    }
803
804    #[test]
805    fn test_f_regression_p_values_bounded() {
806        let x = array![
807            [1.0_f64, 10.0],
808            [2.0, 20.0],
809            [3.0, 15.0],
810            [4.0, 25.0],
811            [5.0, 10.0]
812        ];
813        let y: Array1<f64> = array![1.0, 2.0, 3.0, 4.0, 5.0];
814        let (_, p_vals) = f_regression(&x, &y).unwrap();
815        for &p in p_vals.iter() {
816            assert!(p >= 0.0 && p <= 1.0, "p-value {p} out of bounds");
817        }
818    }
819
820    #[test]
821    fn test_f_regression_constant_feature() {
822        // Constant feature → r=0 → F=0
823        let x = array![[5.0_f64], [5.0], [5.0], [5.0]];
824        let y: Array1<f64> = array![1.0, 2.0, 3.0, 4.0];
825        let (f_stats, _) = f_regression(&x, &y).unwrap();
826        assert!(f_stats[0].abs() < 1.0e-6);
827    }
828
829    // -----------------------------------------------------------------------
830    // chi2 tests
831    // -----------------------------------------------------------------------
832
833    #[test]
834    fn test_chi2_basic() {
835        // Feature 0 correlates with class, feature 1 is random
836        let x = array![
837            [1.0_f64, 1.0],
838            [1.0, 0.0],
839            [0.0, 1.0],
840            [0.0, 0.0],
841            [1.0, 1.0],
842            [1.0, 0.0],
843            [0.0, 1.0],
844            [0.0, 0.0]
845        ];
846        let y: Array1<usize> = array![1, 1, 0, 0, 1, 1, 0, 0];
847        let (chi2_stats, p_vals) = chi2(&x, &y).unwrap();
848        assert_eq!(chi2_stats.len(), 2);
849        assert_eq!(p_vals.len(), 2);
850        // Feature 0 perfectly correlates → higher chi2
851        assert!(chi2_stats[0] > chi2_stats[1]);
852    }
853
854    #[test]
855    fn test_chi2_negative_value_error() {
856        let x = array![[1.0_f64, -1.0], [0.0, 1.0]];
857        let y: Array1<usize> = array![0, 1];
858        assert!(chi2(&x, &y).is_err());
859    }
860
861    #[test]
862    fn test_chi2_empty_input() {
863        let x = Array2::<f64>::zeros((0, 2));
864        let y: Array1<usize> = Array1::zeros(0);
865        assert!(chi2(&x, &y).is_err());
866    }
867
868    #[test]
869    fn test_chi2_shape_mismatch() {
870        let x = array![[1.0_f64], [2.0]];
871        let y: Array1<usize> = array![0]; // wrong length
872        assert!(chi2(&x, &y).is_err());
873    }
874
875    #[test]
876    fn test_chi2_all_zeros() {
877        let x = array![[0.0_f64, 0.0], [0.0, 0.0]];
878        let y: Array1<usize> = array![0, 1];
879        let (chi2_stats, p_vals) = chi2(&x, &y).unwrap();
880        assert_eq!(chi2_stats[0], 0.0);
881        assert_eq!(p_vals[0], 1.0);
882    }
883
884    #[test]
885    fn test_chi2_p_values_bounded() {
886        let x = array![
887            [1.0_f64, 0.0],
888            [0.0, 1.0],
889            [1.0, 1.0],
890            [0.0, 0.0],
891            [1.0, 0.0],
892            [0.0, 1.0]
893        ];
894        let y: Array1<usize> = array![0, 1, 0, 1, 0, 1];
895        let (_, p_vals) = chi2(&x, &y).unwrap();
896        for &p in p_vals.iter() {
897            assert!(p >= 0.0 && p <= 1.0, "p-value {p} out of bounds");
898        }
899    }
900
901    // -----------------------------------------------------------------------
902    // Distribution helper tests
903    // -----------------------------------------------------------------------
904
905    #[test]
906    fn test_ln_gamma_known_values() {
907        // Gamma(1) = 1 → ln = 0
908        let val: f64 = ln_gamma(1.0);
909        assert!((val).abs() < 1.0e-10);
910
911        // Gamma(2) = 1 → ln = 0
912        let val2: f64 = ln_gamma(2.0);
913        assert!((val2).abs() < 1.0e-10);
914
915        // Gamma(3) = 2 → ln = ln(2)
916        let val3: f64 = ln_gamma(3.0);
917        assert!((val3 - 2.0_f64.ln()).abs() < 1.0e-10);
918
919        // Gamma(0.5) = sqrt(pi) → ln = 0.5 * ln(pi)
920        let val4: f64 = ln_gamma(0.5);
921        let expected = 0.5 * std::f64::consts::PI.ln();
922        assert!((val4 - expected).abs() < 1.0e-8);
923    }
924
925    #[test]
926    fn test_regularized_incomplete_beta_boundaries() {
927        // I_0(a, b) = 0
928        let val: f64 = regularized_incomplete_beta(0.0, 1.0, 1.0);
929        assert!((val).abs() < 1.0e-10);
930
931        // I_1(a, b) = 1
932        let val2: f64 = regularized_incomplete_beta(1.0, 1.0, 1.0);
933        assert!((val2 - 1.0).abs() < 1.0e-10);
934    }
935
936    #[test]
937    fn test_f_distribution_sf_zero() {
938        // P(F > 0) = 1
939        let val: f64 = f_distribution_sf(0.0, 2, 10);
940        assert!((val - 1.0).abs() < 1.0e-10);
941    }
942
943    #[test]
944    fn test_f_distribution_sf_large_f() {
945        // Very large F → p ≈ 0
946        let val: f64 = f_distribution_sf(1000.0, 2, 100);
947        assert!(val < 0.001);
948    }
949
950    // -----------------------------------------------------------------------
951    // compute_scores_classif / compute_scores_regression
952    // -----------------------------------------------------------------------
953
954    #[test]
955    fn test_compute_scores_classif_f_classif() {
956        let x = array![
957            [1.0_f64, 5.0],
958            [1.5, 5.5],
959            [10.0, 5.0],
960            [10.5, 4.5]
961        ];
962        let y: Array1<usize> = array![0, 0, 1, 1];
963        let scores = compute_scores_classif(&x, &y, "f_classif").unwrap();
964        assert_eq!(scores.len(), 2);
965        assert!(scores[0] > scores[1]);
966    }
967
968    #[test]
969    fn test_compute_scores_classif_chi2() {
970        let x = array![[1.0_f64, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
971        let y: Array1<usize> = array![0, 1, 0, 1];
972        let scores = compute_scores_classif(&x, &y, "chi2").unwrap();
973        assert_eq!(scores.len(), 2);
974    }
975
976    #[test]
977    fn test_compute_scores_classif_unknown() {
978        let x = array![[1.0_f64]];
979        let y: Array1<usize> = array![0];
980        assert!(compute_scores_classif(&x, &y, "unknown").is_err());
981    }
982
983    #[test]
984    fn test_compute_scores_regression() {
985        let x = array![[1.0_f64, 10.0], [2.0, 20.0], [3.0, 30.0], [4.0, 40.0]];
986        let y: Array1<f64> = array![1.0, 2.0, 3.0, 4.0];
987        let scores = compute_scores_regression(&x, &y).unwrap();
988        assert_eq!(scores.len(), 2);
989    }
990}