1use anofox_ml_core::{Fit, Predict, PredictProba, Result, RustMlError, Transform};
11use faer::linalg::solvers::{SelfAdjointEigen, Solve};
12use faer::{Mat, Side};
13use ndarray::{Array1, Array2};
14
15fn class_indices(y: &Array1<f64>) -> (Vec<f64>, Vec<Vec<usize>>) {
17 let mut classes: Vec<f64> = y.iter().copied().collect();
18 classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
19 classes.dedup();
20 let groups: Vec<Vec<usize>> = classes
21 .iter()
22 .map(|&c| {
23 y.iter()
24 .enumerate()
25 .filter(|(_, &v)| v == c)
26 .map(|(i, _)| i)
27 .collect()
28 })
29 .collect();
30 (classes, groups)
31}
32
33fn class_mean(x: &Array2<f64>, idx: &[usize]) -> Array1<f64> {
34 let d = x.ncols();
35 let mut m = Array1::<f64>::zeros(d);
36 for &i in idx {
37 for j in 0..d {
38 m[j] += x[[i, j]];
39 }
40 }
41 let n = idx.len() as f64;
42 m.mapv(|v| v / n)
43}
44
45fn outer_subtract_accum(x: &Array2<f64>, mu: &Array1<f64>, idx: &[usize], accum: &mut Array2<f64>) {
46 let d = x.ncols();
47 for &i in idx {
48 let mut dv = vec![0.0; d];
49 for j in 0..d {
50 dv[j] = x[[i, j]] - mu[j];
51 }
52 for a in 0..d {
53 for b in 0..d {
54 accum[[a, b]] += dv[a] * dv[b];
55 }
56 }
57 }
58}
59
60fn solve_psd(a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>> {
61 let n = a.nrows();
62 let am = Mat::from_fn(n, n, |i, j| a[[i, j]]);
63 let llt = faer::linalg::solvers::Llt::new(am.as_ref(), Side::Lower)
64 .map_err(|e| RustMlError::InvalidParameter(format!("LLT failed: {e:?}")))?;
65 let bm = Mat::from_fn(n, 1, |i, _| b[i]);
66 let s = llt.solve(&bm);
67 Ok(Array1::from_vec((0..n).map(|i| s[(i, 0)]).collect()))
68}
69
70fn log_det_chol(a: &Array2<f64>) -> Result<f64> {
71 let n = a.nrows();
72 let am = Mat::from_fn(n, n, |i, j| a[[i, j]]);
73 let llt = faer::linalg::solvers::Llt::new(am.as_ref(), Side::Lower)
74 .map_err(|e| RustMlError::InvalidParameter(format!("LLT failed: {e:?}")))?;
75 let lower = llt.L();
76 let mut s = 0.0;
77 for i in 0..n {
78 s += lower[(i, i)].abs().ln();
79 }
80 Ok(2.0 * s)
81}
82
83#[derive(Debug, Clone)]
88pub struct LinearDiscriminantAnalysis {
89 pub shrinkage: f64,
92 pub reg: f64,
94}
95
96impl LinearDiscriminantAnalysis {
97 pub fn new() -> Self {
98 Self {
99 shrinkage: 0.0,
100 reg: 1e-9,
101 }
102 }
103 pub fn with_shrinkage(mut self, s: f64) -> Self {
104 self.shrinkage = s;
105 self
106 }
107}
108
109impl Default for LinearDiscriminantAnalysis {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
116pub struct FittedLinearDiscriminantAnalysis {
117 pub classes: Vec<f64>,
118 pub means: Vec<Array1<f64>>,
119 pub priors: Vec<f64>,
120 pub coef: Vec<Array1<f64>>, pub intercept: Vec<f64>, pub scalings: Array2<f64>,
125 pub xbar: Array1<f64>,
127 pub n_features: usize,
128}
129
130impl Fit<f64> for LinearDiscriminantAnalysis {
131 type Fitted = FittedLinearDiscriminantAnalysis;
132
133 fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
134 if x.nrows() != y.len() {
135 return Err(RustMlError::ShapeMismatch(format!(
136 "X has {} rows but y has {} elements",
137 x.nrows(),
138 y.len()
139 )));
140 }
141 let (classes, groups) = class_indices(y);
142 if classes.len() < 2 {
143 return Err(RustMlError::InvalidParameter(
144 "need at least 2 classes for LDA".into(),
145 ));
146 }
147 let d = x.ncols();
148 let n = x.nrows();
149
150 let means: Vec<Array1<f64>> = groups.iter().map(|g| class_mean(x, g)).collect();
151 let priors: Vec<f64> = groups.iter().map(|g| g.len() as f64 / n as f64).collect();
152
153 let mut sigma = Array2::<f64>::zeros((d, d));
155 for (mu, g) in means.iter().zip(groups.iter()) {
156 outer_subtract_accum(x, mu, g, &mut sigma);
157 }
158 let denom = (n - classes.len()) as f64;
160 sigma.mapv_inplace(|v| v / denom.max(1.0));
161
162 if self.shrinkage > 0.0 {
164 let trace = (0..d).map(|i| sigma[[i, i]]).sum::<f64>() / d as f64;
165 for i in 0..d {
166 for j in 0..d {
167 if i == j {
168 sigma[[i, j]] =
169 (1.0 - self.shrinkage) * sigma[[i, j]] + self.shrinkage * trace;
170 } else {
171 sigma[[i, j]] *= 1.0 - self.shrinkage;
172 }
173 }
174 }
175 }
176 for i in 0..d {
177 sigma[[i, i]] += self.reg;
178 }
179
180 let mut coef = Vec::with_capacity(classes.len());
182 let mut intercept = Vec::with_capacity(classes.len());
183 for (mu, pi) in means.iter().zip(priors.iter()) {
184 let s_inv_mu = solve_psd(&sigma, mu)?;
185 let q = mu.dot(&s_inv_mu); coef.push(s_inv_mu);
187 intercept.push(-0.5 * q + pi.ln());
188 }
189
190 let mut xbar = Array1::<f64>::zeros(d);
193 for (mu, g) in means.iter().zip(groups.iter()) {
194 let w = g.len() as f64 / n as f64;
195 for j in 0..d {
196 xbar[j] += w * mu[j];
197 }
198 }
199 let mut s_b = Array2::<f64>::zeros((d, d));
200 for (k_idx, mu) in means.iter().enumerate() {
201 let nk = groups[k_idx].len() as f64;
202 for a in 0..d {
203 for b in 0..d {
204 s_b[[a, b]] += nk * (mu[a] - xbar[a]) * (mu[b] - xbar[b]);
205 }
206 }
207 }
208 let sw_mat = Mat::from_fn(d, d, |i, j| sigma[[i, j]]);
210 let llt = faer::linalg::solvers::Llt::new(sw_mat.as_ref(), Side::Lower)
211 .map_err(|e| RustMlError::InvalidParameter(format!("Σ_w Cholesky failed: {e:?}")))?;
212 let sb_mat = Mat::from_fn(d, d, |i, j| s_b[[i, j]]);
214 let a_mat = llt.solve(&sb_mat);
215 let mut a_t = Mat::<f64>::zeros(d, d);
218 for i in 0..d {
219 for j in 0..d {
220 a_t[(i, j)] = a_mat[(j, i)];
221 }
222 }
223 let b_mat = llt.solve(&a_t);
224 let eig = SelfAdjointEigen::new(b_mat.as_ref(), Side::Lower)
226 .map_err(|e| RustMlError::InvalidParameter(format!("eigen failed: {e:?}")))?;
227 let u = eig.U();
228 let n_proj = (classes.len() - 1).min(d);
231 let mut scalings = Array2::<f64>::zeros((d, n_proj));
232 for c in 0..n_proj {
233 let src = d - 1 - c;
234 let mut u_col = Mat::<f64>::zeros(d, 1);
235 for i in 0..d {
236 u_col[(i, 0)] = u[(i, src)];
237 }
238 let lower = llt.L();
243 let mut v = vec![0.0_f64; d];
245 for r in (0..d).rev() {
246 let mut s = u_col[(r, 0)];
247 for cc in (r + 1)..d {
248 s -= lower[(cc, r)] * v[cc];
249 }
250 v[r] = s / lower[(r, r)].max(1e-12);
251 }
252 for r in 0..d {
253 scalings[[r, c]] = v[r];
254 }
255 }
256
257 Ok(FittedLinearDiscriminantAnalysis {
258 classes,
259 means,
260 priors,
261 coef,
262 intercept,
263 scalings,
264 xbar,
265 n_features: d,
266 })
267 }
268}
269
270impl Transform<f64> for FittedLinearDiscriminantAnalysis {
271 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
272 if x.ncols() != self.n_features {
273 return Err(RustMlError::ShapeMismatch(format!(
274 "expected {} features, got {}",
275 self.n_features,
276 x.ncols()
277 )));
278 }
279 let n = x.nrows();
280 let k = self.scalings.ncols();
281 let mut out = Array2::<f64>::zeros((n, k));
282 for i in 0..n {
283 for c in 0..k {
284 let mut s = 0.0;
285 for j in 0..self.n_features {
286 s += (x[[i, j]] - self.xbar[j]) * self.scalings[[j, c]];
287 }
288 out[[i, c]] = s;
289 }
290 }
291 Ok(out)
292 }
293}
294
295impl Predict<f64> for FittedLinearDiscriminantAnalysis {
296 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
297 if x.ncols() != self.n_features {
298 return Err(RustMlError::ShapeMismatch(format!(
299 "expected {} features, got {}",
300 self.n_features,
301 x.ncols()
302 )));
303 }
304 let n = x.nrows();
305 let mut out = Array1::<f64>::zeros(n);
306 for i in 0..n {
307 let row = x.row(i);
308 let mut best = f64::NEG_INFINITY;
309 let mut best_k = 0usize;
310 for (k, (c, b)) in self.coef.iter().zip(self.intercept.iter()).enumerate() {
311 let score = row.dot(c) + b;
312 if score > best {
313 best = score;
314 best_k = k;
315 }
316 }
317 out[i] = self.classes[best_k];
318 }
319 Ok(out)
320 }
321}
322
323impl PredictProba<f64> for FittedLinearDiscriminantAnalysis {
324 fn predict_proba(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
325 if x.ncols() != self.n_features {
326 return Err(RustMlError::ShapeMismatch(format!(
327 "expected {} features, got {}",
328 self.n_features,
329 x.ncols()
330 )));
331 }
332 let n = x.nrows();
333 let k = self.classes.len();
334 let mut p = Array2::<f64>::zeros((n, k));
335 for i in 0..n {
336 let row = x.row(i);
337 let mut logits = vec![0.0_f64; k];
338 let mut max_l = f64::NEG_INFINITY;
339 for (c_i, (c, b)) in self.coef.iter().zip(self.intercept.iter()).enumerate() {
340 let s = row.dot(c) + b;
341 logits[c_i] = s;
342 if s > max_l {
343 max_l = s;
344 }
345 }
346 let mut z = 0.0;
347 for c_i in 0..k {
348 let e = (logits[c_i] - max_l).exp();
349 p[[i, c_i]] = e;
350 z += e;
351 }
352 for c_i in 0..k {
353 p[[i, c_i]] /= z;
354 }
355 }
356 Ok(p)
357 }
358}
359
360#[derive(Debug, Clone)]
365pub struct QuadraticDiscriminantAnalysis {
366 pub reg: f64,
367}
368
369impl QuadraticDiscriminantAnalysis {
370 pub fn new() -> Self {
371 Self { reg: 1e-9 }
372 }
373 pub fn with_reg(mut self, r: f64) -> Self {
374 self.reg = r;
375 self
376 }
377}
378
379impl Default for QuadraticDiscriminantAnalysis {
380 fn default() -> Self {
381 Self::new()
382 }
383}
384
385#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
386pub struct FittedQuadraticDiscriminantAnalysis {
387 pub classes: Vec<f64>,
388 pub means: Vec<Array1<f64>>,
389 pub priors: Vec<f64>,
390 pub sigmas: Vec<Array2<f64>>,
391 pub log_det: Vec<f64>,
392 pub n_features: usize,
393}
394
395impl Fit<f64> for QuadraticDiscriminantAnalysis {
396 type Fitted = FittedQuadraticDiscriminantAnalysis;
397
398 fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
399 if x.nrows() != y.len() {
400 return Err(RustMlError::ShapeMismatch(format!(
401 "X has {} rows but y has {} elements",
402 x.nrows(),
403 y.len()
404 )));
405 }
406 let (classes, groups) = class_indices(y);
407 if classes.len() < 2 {
408 return Err(RustMlError::InvalidParameter(
409 "need at least 2 classes for QDA".into(),
410 ));
411 }
412 let d = x.ncols();
413 let n = x.nrows();
414
415 let means: Vec<Array1<f64>> = groups.iter().map(|g| class_mean(x, g)).collect();
416 let priors: Vec<f64> = groups.iter().map(|g| g.len() as f64 / n as f64).collect();
417
418 let mut sigmas = Vec::with_capacity(classes.len());
419 let mut log_det = Vec::with_capacity(classes.len());
420 for (k, g) in groups.iter().enumerate() {
421 let mut s = Array2::<f64>::zeros((d, d));
422 outer_subtract_accum(x, &means[k], g, &mut s);
423 let denom = (g.len() as f64 - 1.0).max(1.0);
424 s.mapv_inplace(|v| v / denom);
425 for i in 0..d {
426 s[[i, i]] += self.reg;
427 }
428 log_det.push(log_det_chol(&s)?);
429 sigmas.push(s);
430 }
431
432 Ok(FittedQuadraticDiscriminantAnalysis {
433 classes,
434 means,
435 priors,
436 sigmas,
437 log_det,
438 n_features: d,
439 })
440 }
441}
442
443impl Predict<f64> for FittedQuadraticDiscriminantAnalysis {
444 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
445 if x.ncols() != self.n_features {
446 return Err(RustMlError::ShapeMismatch(format!(
447 "expected {} features, got {}",
448 self.n_features,
449 x.ncols()
450 )));
451 }
452 let n = x.nrows();
453 let d = self.n_features;
454 let mut out = Array1::<f64>::zeros(n);
455 for i in 0..n {
456 let mut best = f64::NEG_INFINITY;
457 let mut best_k = 0usize;
458 for k in 0..self.classes.len() {
459 let mut diff = Array1::<f64>::zeros(d);
461 for j in 0..d {
462 diff[j] = x[[i, j]] - self.means[k][j];
463 }
464 let s_inv_diff = solve_psd(&self.sigmas[k], &diff)?;
465 let m = diff.dot(&s_inv_diff);
466 let score = -0.5 * m - 0.5 * self.log_det[k] + self.priors[k].ln();
467 if score > best {
468 best = score;
469 best_k = k;
470 }
471 }
472 out[i] = self.classes[best_k];
473 }
474 Ok(out)
475 }
476}
477
478impl PredictProba<f64> for FittedQuadraticDiscriminantAnalysis {
479 fn predict_proba(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
480 if x.ncols() != self.n_features {
481 return Err(RustMlError::ShapeMismatch(format!(
482 "expected {} features, got {}",
483 self.n_features,
484 x.ncols()
485 )));
486 }
487 let n = x.nrows();
488 let k = self.classes.len();
489 let d = self.n_features;
490 let mut p = Array2::<f64>::zeros((n, k));
491 for i in 0..n {
492 let mut logits = vec![0.0_f64; k];
493 let mut max_l = f64::NEG_INFINITY;
494 for c_i in 0..k {
495 let mut diff = Array1::<f64>::zeros(d);
496 for j in 0..d {
497 diff[j] = x[[i, j]] - self.means[c_i][j];
498 }
499 let s_inv_diff = solve_psd(&self.sigmas[c_i], &diff)?;
500 let m = diff.dot(&s_inv_diff);
501 let score = -0.5 * m - 0.5 * self.log_det[c_i] + self.priors[c_i].ln();
502 logits[c_i] = score;
503 if score > max_l {
504 max_l = score;
505 }
506 }
507 let mut z = 0.0;
508 for c_i in 0..k {
509 let e = (logits[c_i] - max_l).exp();
510 p[[i, c_i]] = e;
511 z += e;
512 }
513 for c_i in 0..k {
514 p[[i, c_i]] /= z;
515 }
516 }
517 Ok(p)
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use ndarray::array;
525
526 #[test]
527 fn test_lda_two_well_separated_classes() {
528 let x = array![
529 [0.0, 0.0],
530 [0.5, 0.1],
531 [-0.3, -0.2],
532 [0.2, -0.1],
533 [5.0, 5.0],
534 [5.1, 4.9],
535 [4.7, 5.3],
536 [5.0, 5.2],
537 ];
538 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
539 let fitted = LinearDiscriminantAnalysis::new().fit(&x, &y).unwrap();
540 let preds = fitted.predict(&x).unwrap();
541 for (p, t) in preds.iter().zip(y.iter()) {
542 assert_eq!(*p, *t);
543 }
544 }
545
546 #[test]
547 fn test_lda_transform_separates() {
548 let x = array![
550 [0.0, 0.0],
551 [0.5, 0.0],
552 [0.0, 0.3],
553 [4.0, 0.0],
554 [4.2, 0.1],
555 [4.0, 0.3],
556 [0.0, 4.0],
557 [0.1, 4.2],
558 [-0.1, 4.0],
559 ];
560 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
561 let fitted = LinearDiscriminantAnalysis::new().fit(&x, &y).unwrap();
562 let t = fitted.transform(&x).unwrap();
563 assert_eq!(t.shape(), &[9, 2]);
564 let d_within: f64 = (0..3)
567 .map(|c| {
568 let base = 3 * c;
569 ((t[[base, 0]] - t[[base + 1, 0]]).powi(2)
570 + (t[[base, 1]] - t[[base + 1, 1]]).powi(2))
571 .sqrt()
572 })
573 .sum::<f64>()
574 / 3.0;
575 let d_between: f64 =
576 ((t[[0, 0]] - t[[3, 0]]).powi(2) + (t[[0, 1]] - t[[3, 1]]).powi(2)).sqrt();
577 assert!(
578 d_between > 5.0 * d_within,
579 "within={d_within}, between={d_between}"
580 );
581 }
582
583 #[test]
584 fn test_qda_two_well_separated_classes() {
585 let x = array![
586 [0.0, 0.0],
587 [0.5, 0.1],
588 [-0.3, -0.2],
589 [0.2, -0.1],
590 [0.1, 0.2],
591 [-0.1, 0.0],
592 [5.0, 5.0],
593 [5.1, 4.9],
594 [4.7, 5.3],
595 [5.0, 5.2],
596 [5.2, 5.1],
597 [4.8, 5.0],
598 ];
599 let y = array![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
600 let fitted = QuadraticDiscriminantAnalysis::new().fit(&x, &y).unwrap();
601 let preds = fitted.predict(&x).unwrap();
602 for (p, t) in preds.iter().zip(y.iter()) {
603 assert_eq!(*p, *t);
604 }
605 }
606}
607
608impl anofox_ml_core::ClassifierScore<f64> for FittedLinearDiscriminantAnalysis {}
609impl anofox_ml_core::ClassifierScore<f64> for FittedQuadraticDiscriminantAnalysis {}
610
611impl anofox_ml_core::PredictLogProba<f64> for FittedLinearDiscriminantAnalysis {}
612impl anofox_ml_core::PredictLogProba<f64> for FittedQuadraticDiscriminantAnalysis {}
613
614impl anofox_ml_core::DecisionFunction<f64> for FittedLinearDiscriminantAnalysis {
615 fn decision_function(
616 &self,
617 x: &ndarray::Array2<f64>,
618 ) -> anofox_ml_core::Result<ndarray::Array2<f64>> {
619 if x.ncols() != self.n_features {
620 return Err(anofox_ml_core::RustMlError::ShapeMismatch(format!(
621 "expected {} features, got {}",
622 self.n_features,
623 x.ncols()
624 )));
625 }
626 let n = x.nrows();
627 let k = self.classes.len();
628 let mut out = ndarray::Array2::<f64>::zeros((n, k));
629 for i in 0..n {
630 let row = x.row(i);
631 for (c_i, (c, b)) in self.coef.iter().zip(self.intercept.iter()).enumerate() {
632 out[[i, c_i]] = row.dot(c) + b;
633 }
634 }
635 Ok(out)
636 }
637}