Skip to main content

anofox_ml_discriminant/
lib.rs

1//! Linear and Quadratic Discriminant Analysis.
2//!
3//! Mirrors `sklearn.discriminant_analysis.{LinearDiscriminantAnalysis,
4//! QuadraticDiscriminantAnalysis}`.
5//!
6//! - **LDA** assumes all classes share a common covariance Σ. Decision
7//!   function is linear in `x`.
8//! - **QDA** estimates a separate covariance Σ_k per class.
9
10use anofox_ml_core::{Fit, Predict, PredictProba, Result, RustMlError, Transform};
11use faer::linalg::solvers::{SelfAdjointEigen, Solve};
12use faer::{Mat, Side};
13use ndarray::{Array1, Array2};
14
15/// Common helpers.
16fn class_indices(y: &Array1<f64>) -> (Vec<f64>, Vec<Vec<usize>>) {
17    let mut classes: Vec<f64> = y.iter().copied().collect();
18    classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
19    classes.dedup();
20    let groups: Vec<Vec<usize>> = classes
21        .iter()
22        .map(|&c| {
23            y.iter()
24                .enumerate()
25                .filter(|(_, &v)| v == c)
26                .map(|(i, _)| i)
27                .collect()
28        })
29        .collect();
30    (classes, groups)
31}
32
33fn class_mean(x: &Array2<f64>, idx: &[usize]) -> Array1<f64> {
34    let d = x.ncols();
35    let mut m = Array1::<f64>::zeros(d);
36    for &i in idx {
37        for j in 0..d {
38            m[j] += x[[i, j]];
39        }
40    }
41    let n = idx.len() as f64;
42    m.mapv(|v| v / n)
43}
44
45fn outer_subtract_accum(x: &Array2<f64>, mu: &Array1<f64>, idx: &[usize], accum: &mut Array2<f64>) {
46    let d = x.ncols();
47    for &i in idx {
48        let mut dv = vec![0.0; d];
49        for j in 0..d {
50            dv[j] = x[[i, j]] - mu[j];
51        }
52        for a in 0..d {
53            for b in 0..d {
54                accum[[a, b]] += dv[a] * dv[b];
55            }
56        }
57    }
58}
59
60fn solve_psd(a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>> {
61    let n = a.nrows();
62    let am = Mat::from_fn(n, n, |i, j| a[[i, j]]);
63    let llt = faer::linalg::solvers::Llt::new(am.as_ref(), Side::Lower)
64        .map_err(|e| RustMlError::InvalidParameter(format!("LLT failed: {e:?}")))?;
65    let bm = Mat::from_fn(n, 1, |i, _| b[i]);
66    let s = llt.solve(&bm);
67    Ok(Array1::from_vec((0..n).map(|i| s[(i, 0)]).collect()))
68}
69
70fn log_det_chol(a: &Array2<f64>) -> Result<f64> {
71    let n = a.nrows();
72    let am = Mat::from_fn(n, n, |i, j| a[[i, j]]);
73    let llt = faer::linalg::solvers::Llt::new(am.as_ref(), Side::Lower)
74        .map_err(|e| RustMlError::InvalidParameter(format!("LLT failed: {e:?}")))?;
75    let lower = llt.L();
76    let mut s = 0.0;
77    for i in 0..n {
78        s += lower[(i, i)].abs().ln();
79    }
80    Ok(2.0 * s)
81}
82
83// ---------------------------------------------------------------------------
84// LinearDiscriminantAnalysis (LDA)
85// ---------------------------------------------------------------------------
86
87#[derive(Debug, Clone)]
88pub struct LinearDiscriminantAnalysis {
89    /// Shrinkage on the within-class covariance toward `(tr(Σ)/d) I`.
90    /// 0.0 = no shrinkage (sklearn default).
91    pub shrinkage: f64,
92    /// Reg term added to the diagonal of Σ for numerical stability.
93    pub reg: f64,
94}
95
96impl LinearDiscriminantAnalysis {
97    pub fn new() -> Self {
98        Self {
99            shrinkage: 0.0,
100            reg: 1e-9,
101        }
102    }
103    pub fn with_shrinkage(mut self, s: f64) -> Self {
104        self.shrinkage = s;
105        self
106    }
107}
108
109impl Default for LinearDiscriminantAnalysis {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
116pub struct FittedLinearDiscriminantAnalysis {
117    pub classes: Vec<f64>,
118    pub means: Vec<Array1<f64>>,
119    pub priors: Vec<f64>,
120    pub coef: Vec<Array1<f64>>, // sigma_inv @ mu_k
121    pub intercept: Vec<f64>,    // -0.5 * mu_k^T sigma_inv mu_k + log(pi_k)
122    /// Projection matrix for dimensionality reduction; columns are the
123    /// generalised eigenvectors of `Σ_b v = λ Σ_w v`. Shape (d, n_classes-1).
124    pub scalings: Array2<f64>,
125    /// Global feature mean (used to center before projecting).
126    pub xbar: Array1<f64>,
127    pub n_features: usize,
128}
129
130impl Fit<f64> for LinearDiscriminantAnalysis {
131    type Fitted = FittedLinearDiscriminantAnalysis;
132
133    fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
134        if x.nrows() != y.len() {
135            return Err(RustMlError::ShapeMismatch(format!(
136                "X has {} rows but y has {} elements",
137                x.nrows(),
138                y.len()
139            )));
140        }
141        let (classes, groups) = class_indices(y);
142        if classes.len() < 2 {
143            return Err(RustMlError::InvalidParameter(
144                "need at least 2 classes for LDA".into(),
145            ));
146        }
147        let d = x.ncols();
148        let n = x.nrows();
149
150        let means: Vec<Array1<f64>> = groups.iter().map(|g| class_mean(x, g)).collect();
151        let priors: Vec<f64> = groups.iter().map(|g| g.len() as f64 / n as f64).collect();
152
153        // Pooled within-class scatter.
154        let mut sigma = Array2::<f64>::zeros((d, d));
155        for (mu, g) in means.iter().zip(groups.iter()) {
156            outer_subtract_accum(x, mu, g, &mut sigma);
157        }
158        // sklearn divides by (n - n_classes) (unbiased).
159        let denom = (n - classes.len()) as f64;
160        sigma.mapv_inplace(|v| v / denom.max(1.0));
161
162        // Optional shrinkage toward diagonal mean.
163        if self.shrinkage > 0.0 {
164            let trace = (0..d).map(|i| sigma[[i, i]]).sum::<f64>() / d as f64;
165            for i in 0..d {
166                for j in 0..d {
167                    if i == j {
168                        sigma[[i, j]] =
169                            (1.0 - self.shrinkage) * sigma[[i, j]] + self.shrinkage * trace;
170                    } else {
171                        sigma[[i, j]] *= 1.0 - self.shrinkage;
172                    }
173                }
174            }
175        }
176        for i in 0..d {
177            sigma[[i, i]] += self.reg;
178        }
179
180        // For each class compute sigma_inv @ mu_k as the linear coef.
181        let mut coef = Vec::with_capacity(classes.len());
182        let mut intercept = Vec::with_capacity(classes.len());
183        for (mu, pi) in means.iter().zip(priors.iter()) {
184            let s_inv_mu = solve_psd(&sigma, mu)?;
185            let q = mu.dot(&s_inv_mu); // mu^T sigma_inv mu
186            coef.push(s_inv_mu);
187            intercept.push(-0.5 * q + pi.ln());
188        }
189
190        // Build the LDA projection: solve the generalised eigenproblem
191        // Σ_b v = λ Σ_w v by Cholesky-whitening Σ_w.
192        let mut xbar = Array1::<f64>::zeros(d);
193        for (mu, g) in means.iter().zip(groups.iter()) {
194            let w = g.len() as f64 / n as f64;
195            for j in 0..d {
196                xbar[j] += w * mu[j];
197            }
198        }
199        let mut s_b = Array2::<f64>::zeros((d, d));
200        for (k_idx, mu) in means.iter().enumerate() {
201            let nk = groups[k_idx].len() as f64;
202            for a in 0..d {
203                for b in 0..d {
204                    s_b[[a, b]] += nk * (mu[a] - xbar[a]) * (mu[b] - xbar[b]);
205                }
206            }
207        }
208        // L Lᵀ = Σ_w; then symmetric matrix L⁻¹ Σ_b L⁻ᵀ.
209        let sw_mat = Mat::from_fn(d, d, |i, j| sigma[[i, j]]);
210        let llt = faer::linalg::solvers::Llt::new(sw_mat.as_ref(), Side::Lower)
211            .map_err(|e| RustMlError::InvalidParameter(format!("Σ_w Cholesky failed: {e:?}")))?;
212        // Solve L A = Σ_b for A (column-wise).
213        let sb_mat = Mat::from_fn(d, d, |i, j| s_b[[i, j]]);
214        let a_mat = llt.solve(&sb_mat);
215        // Then solve Lᵀ B = Aᵀ → B = (L⁻¹ Σ_b L⁻ᵀ) is symmetric.
216        // Equivalently: B = (L⁻¹ Σ_b)ᵀ first, then L⁻¹ on that.
217        let mut a_t = Mat::<f64>::zeros(d, d);
218        for i in 0..d {
219            for j in 0..d {
220                a_t[(i, j)] = a_mat[(j, i)];
221            }
222        }
223        let b_mat = llt.solve(&a_t);
224        // b_mat = L⁻¹ Σ_b L⁻ᵀ. Eigendecompose.
225        let eig = SelfAdjointEigen::new(b_mat.as_ref(), Side::Lower)
226            .map_err(|e| RustMlError::InvalidParameter(format!("eigen failed: {e:?}")))?;
227        let u = eig.U();
228        // Recover original-space eigenvectors V = L⁻ᵀ U; sklearn keeps the
229        // top n_classes-1 in descending eigenvalue order.
230        let n_proj = (classes.len() - 1).min(d);
231        let mut scalings = Array2::<f64>::zeros((d, n_proj));
232        for c in 0..n_proj {
233            let src = d - 1 - c;
234            let mut u_col = Mat::<f64>::zeros(d, 1);
235            for i in 0..d {
236                u_col[(i, 0)] = u[(i, src)];
237            }
238            // Solve Lᵀ v = u → v = L⁻ᵀ u; via two triangular solves on the
239            // same `llt`, we transpose-solve.
240            // faer's Llt::solve does L Lᵀ x = b, so we get the full inverse;
241            // for just Lᵀ x = u use an explicit back-sub on `lower.transpose()`.
242            let lower = llt.L();
243            // Lᵀ v = u; back-substitute from bottom.
244            let mut v = vec![0.0_f64; d];
245            for r in (0..d).rev() {
246                let mut s = u_col[(r, 0)];
247                for cc in (r + 1)..d {
248                    s -= lower[(cc, r)] * v[cc];
249                }
250                v[r] = s / lower[(r, r)].max(1e-12);
251            }
252            for r in 0..d {
253                scalings[[r, c]] = v[r];
254            }
255        }
256
257        Ok(FittedLinearDiscriminantAnalysis {
258            classes,
259            means,
260            priors,
261            coef,
262            intercept,
263            scalings,
264            xbar,
265            n_features: d,
266        })
267    }
268}
269
270impl Transform<f64> for FittedLinearDiscriminantAnalysis {
271    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
272        if x.ncols() != self.n_features {
273            return Err(RustMlError::ShapeMismatch(format!(
274                "expected {} features, got {}",
275                self.n_features,
276                x.ncols()
277            )));
278        }
279        let n = x.nrows();
280        let k = self.scalings.ncols();
281        let mut out = Array2::<f64>::zeros((n, k));
282        for i in 0..n {
283            for c in 0..k {
284                let mut s = 0.0;
285                for j in 0..self.n_features {
286                    s += (x[[i, j]] - self.xbar[j]) * self.scalings[[j, c]];
287                }
288                out[[i, c]] = s;
289            }
290        }
291        Ok(out)
292    }
293}
294
295impl Predict<f64> for FittedLinearDiscriminantAnalysis {
296    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
297        if x.ncols() != self.n_features {
298            return Err(RustMlError::ShapeMismatch(format!(
299                "expected {} features, got {}",
300                self.n_features,
301                x.ncols()
302            )));
303        }
304        let n = x.nrows();
305        let mut out = Array1::<f64>::zeros(n);
306        for i in 0..n {
307            let row = x.row(i);
308            let mut best = f64::NEG_INFINITY;
309            let mut best_k = 0usize;
310            for (k, (c, b)) in self.coef.iter().zip(self.intercept.iter()).enumerate() {
311                let score = row.dot(c) + b;
312                if score > best {
313                    best = score;
314                    best_k = k;
315                }
316            }
317            out[i] = self.classes[best_k];
318        }
319        Ok(out)
320    }
321}
322
323impl PredictProba<f64> for FittedLinearDiscriminantAnalysis {
324    fn predict_proba(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
325        if x.ncols() != self.n_features {
326            return Err(RustMlError::ShapeMismatch(format!(
327                "expected {} features, got {}",
328                self.n_features,
329                x.ncols()
330            )));
331        }
332        let n = x.nrows();
333        let k = self.classes.len();
334        let mut p = Array2::<f64>::zeros((n, k));
335        for i in 0..n {
336            let row = x.row(i);
337            let mut logits = vec![0.0_f64; k];
338            let mut max_l = f64::NEG_INFINITY;
339            for (c_i, (c, b)) in self.coef.iter().zip(self.intercept.iter()).enumerate() {
340                let s = row.dot(c) + b;
341                logits[c_i] = s;
342                if s > max_l {
343                    max_l = s;
344                }
345            }
346            let mut z = 0.0;
347            for c_i in 0..k {
348                let e = (logits[c_i] - max_l).exp();
349                p[[i, c_i]] = e;
350                z += e;
351            }
352            for c_i in 0..k {
353                p[[i, c_i]] /= z;
354            }
355        }
356        Ok(p)
357    }
358}
359
360// ---------------------------------------------------------------------------
361// QuadraticDiscriminantAnalysis (QDA)
362// ---------------------------------------------------------------------------
363
364#[derive(Debug, Clone)]
365pub struct QuadraticDiscriminantAnalysis {
366    pub reg: f64,
367}
368
369impl QuadraticDiscriminantAnalysis {
370    pub fn new() -> Self {
371        Self { reg: 1e-9 }
372    }
373    pub fn with_reg(mut self, r: f64) -> Self {
374        self.reg = r;
375        self
376    }
377}
378
379impl Default for QuadraticDiscriminantAnalysis {
380    fn default() -> Self {
381        Self::new()
382    }
383}
384
385#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
386pub struct FittedQuadraticDiscriminantAnalysis {
387    pub classes: Vec<f64>,
388    pub means: Vec<Array1<f64>>,
389    pub priors: Vec<f64>,
390    pub sigmas: Vec<Array2<f64>>,
391    pub log_det: Vec<f64>,
392    pub n_features: usize,
393}
394
395impl Fit<f64> for QuadraticDiscriminantAnalysis {
396    type Fitted = FittedQuadraticDiscriminantAnalysis;
397
398    fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
399        if x.nrows() != y.len() {
400            return Err(RustMlError::ShapeMismatch(format!(
401                "X has {} rows but y has {} elements",
402                x.nrows(),
403                y.len()
404            )));
405        }
406        let (classes, groups) = class_indices(y);
407        if classes.len() < 2 {
408            return Err(RustMlError::InvalidParameter(
409                "need at least 2 classes for QDA".into(),
410            ));
411        }
412        let d = x.ncols();
413        let n = x.nrows();
414
415        let means: Vec<Array1<f64>> = groups.iter().map(|g| class_mean(x, g)).collect();
416        let priors: Vec<f64> = groups.iter().map(|g| g.len() as f64 / n as f64).collect();
417
418        let mut sigmas = Vec::with_capacity(classes.len());
419        let mut log_det = Vec::with_capacity(classes.len());
420        for (k, g) in groups.iter().enumerate() {
421            let mut s = Array2::<f64>::zeros((d, d));
422            outer_subtract_accum(x, &means[k], g, &mut s);
423            let denom = (g.len() as f64 - 1.0).max(1.0);
424            s.mapv_inplace(|v| v / denom);
425            for i in 0..d {
426                s[[i, i]] += self.reg;
427            }
428            log_det.push(log_det_chol(&s)?);
429            sigmas.push(s);
430        }
431
432        Ok(FittedQuadraticDiscriminantAnalysis {
433            classes,
434            means,
435            priors,
436            sigmas,
437            log_det,
438            n_features: d,
439        })
440    }
441}
442
443impl Predict<f64> for FittedQuadraticDiscriminantAnalysis {
444    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
445        if x.ncols() != self.n_features {
446            return Err(RustMlError::ShapeMismatch(format!(
447                "expected {} features, got {}",
448                self.n_features,
449                x.ncols()
450            )));
451        }
452        let n = x.nrows();
453        let d = self.n_features;
454        let mut out = Array1::<f64>::zeros(n);
455        for i in 0..n {
456            let mut best = f64::NEG_INFINITY;
457            let mut best_k = 0usize;
458            for k in 0..self.classes.len() {
459                // discriminant_k(x) = -0.5 (x-mu)^T Σ_k^{-1} (x-mu) - 0.5 log|Σ_k| + log π_k
460                let mut diff = Array1::<f64>::zeros(d);
461                for j in 0..d {
462                    diff[j] = x[[i, j]] - self.means[k][j];
463                }
464                let s_inv_diff = solve_psd(&self.sigmas[k], &diff)?;
465                let m = diff.dot(&s_inv_diff);
466                let score = -0.5 * m - 0.5 * self.log_det[k] + self.priors[k].ln();
467                if score > best {
468                    best = score;
469                    best_k = k;
470                }
471            }
472            out[i] = self.classes[best_k];
473        }
474        Ok(out)
475    }
476}
477
478impl PredictProba<f64> for FittedQuadraticDiscriminantAnalysis {
479    fn predict_proba(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
480        if x.ncols() != self.n_features {
481            return Err(RustMlError::ShapeMismatch(format!(
482                "expected {} features, got {}",
483                self.n_features,
484                x.ncols()
485            )));
486        }
487        let n = x.nrows();
488        let k = self.classes.len();
489        let d = self.n_features;
490        let mut p = Array2::<f64>::zeros((n, k));
491        for i in 0..n {
492            let mut logits = vec![0.0_f64; k];
493            let mut max_l = f64::NEG_INFINITY;
494            for c_i in 0..k {
495                let mut diff = Array1::<f64>::zeros(d);
496                for j in 0..d {
497                    diff[j] = x[[i, j]] - self.means[c_i][j];
498                }
499                let s_inv_diff = solve_psd(&self.sigmas[c_i], &diff)?;
500                let m = diff.dot(&s_inv_diff);
501                let score = -0.5 * m - 0.5 * self.log_det[c_i] + self.priors[c_i].ln();
502                logits[c_i] = score;
503                if score > max_l {
504                    max_l = score;
505                }
506            }
507            let mut z = 0.0;
508            for c_i in 0..k {
509                let e = (logits[c_i] - max_l).exp();
510                p[[i, c_i]] = e;
511                z += e;
512            }
513            for c_i in 0..k {
514                p[[i, c_i]] /= z;
515            }
516        }
517        Ok(p)
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524    use ndarray::array;
525
526    #[test]
527    fn test_lda_two_well_separated_classes() {
528        let x = array![
529            [0.0, 0.0],
530            [0.5, 0.1],
531            [-0.3, -0.2],
532            [0.2, -0.1],
533            [5.0, 5.0],
534            [5.1, 4.9],
535            [4.7, 5.3],
536            [5.0, 5.2],
537        ];
538        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
539        let fitted = LinearDiscriminantAnalysis::new().fit(&x, &y).unwrap();
540        let preds = fitted.predict(&x).unwrap();
541        for (p, t) in preds.iter().zip(y.iter()) {
542            assert_eq!(*p, *t);
543        }
544    }
545
546    #[test]
547    fn test_lda_transform_separates() {
548        // 3-class problem in 2D — projects to 2 dims (n_classes-1=2).
549        let x = array![
550            [0.0, 0.0],
551            [0.5, 0.0],
552            [0.0, 0.3],
553            [4.0, 0.0],
554            [4.2, 0.1],
555            [4.0, 0.3],
556            [0.0, 4.0],
557            [0.1, 4.2],
558            [-0.1, 4.0],
559        ];
560        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
561        let fitted = LinearDiscriminantAnalysis::new().fit(&x, &y).unwrap();
562        let t = fitted.transform(&x).unwrap();
563        assert_eq!(t.shape(), &[9, 2]);
564        // After projection, within-class points should be much closer than
565        // between-class.
566        let d_within: f64 = (0..3)
567            .map(|c| {
568                let base = 3 * c;
569                ((t[[base, 0]] - t[[base + 1, 0]]).powi(2)
570                    + (t[[base, 1]] - t[[base + 1, 1]]).powi(2))
571                .sqrt()
572            })
573            .sum::<f64>()
574            / 3.0;
575        let d_between: f64 =
576            ((t[[0, 0]] - t[[3, 0]]).powi(2) + (t[[0, 1]] - t[[3, 1]]).powi(2)).sqrt();
577        assert!(
578            d_between > 5.0 * d_within,
579            "within={d_within}, between={d_between}"
580        );
581    }
582
583    #[test]
584    fn test_qda_two_well_separated_classes() {
585        let x = array![
586            [0.0, 0.0],
587            [0.5, 0.1],
588            [-0.3, -0.2],
589            [0.2, -0.1],
590            [0.1, 0.2],
591            [-0.1, 0.0],
592            [5.0, 5.0],
593            [5.1, 4.9],
594            [4.7, 5.3],
595            [5.0, 5.2],
596            [5.2, 5.1],
597            [4.8, 5.0],
598        ];
599        let y = array![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
600        let fitted = QuadraticDiscriminantAnalysis::new().fit(&x, &y).unwrap();
601        let preds = fitted.predict(&x).unwrap();
602        for (p, t) in preds.iter().zip(y.iter()) {
603            assert_eq!(*p, *t);
604        }
605    }
606}
607
608impl anofox_ml_core::ClassifierScore<f64> for FittedLinearDiscriminantAnalysis {}
609impl anofox_ml_core::ClassifierScore<f64> for FittedQuadraticDiscriminantAnalysis {}
610
611impl anofox_ml_core::PredictLogProba<f64> for FittedLinearDiscriminantAnalysis {}
612impl anofox_ml_core::PredictLogProba<f64> for FittedQuadraticDiscriminantAnalysis {}
613
614impl anofox_ml_core::DecisionFunction<f64> for FittedLinearDiscriminantAnalysis {
615    fn decision_function(
616        &self,
617        x: &ndarray::Array2<f64>,
618    ) -> anofox_ml_core::Result<ndarray::Array2<f64>> {
619        if x.ncols() != self.n_features {
620            return Err(anofox_ml_core::RustMlError::ShapeMismatch(format!(
621                "expected {} features, got {}",
622                self.n_features,
623                x.ncols()
624            )));
625        }
626        let n = x.nrows();
627        let k = self.classes.len();
628        let mut out = ndarray::Array2::<f64>::zeros((n, k));
629        for i in 0..n {
630            let row = x.row(i);
631            for (c_i, (c, b)) in self.coef.iter().zip(self.intercept.iter()).enumerate() {
632                out[[i, c_i]] = row.dot(c) + b;
633            }
634        }
635        Ok(out)
636    }
637}