1use ferrolearn_core::error::FerroError;
38use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
39use ferrolearn_core::traits::{Fit, Predict, Transform};
40use ndarray::{Array1, Array2};
41use num_traits::{Float, NumCast};
42
43#[derive(Debug, Clone)]
56pub struct LDA<F> {
57 n_components: Option<usize>,
61 _marker: std::marker::PhantomData<F>,
62}
63
64impl<F: Float + Send + Sync + 'static> LDA<F> {
65 #[must_use]
70 pub fn new(n_components: Option<usize>) -> Self {
71 Self {
72 n_components,
73 _marker: std::marker::PhantomData,
74 }
75 }
76
77 #[must_use]
79 pub fn n_components(&self) -> Option<usize> {
80 self.n_components
81 }
82}
83
84impl<F: Float + Send + Sync + 'static> Default for LDA<F> {
85 fn default() -> Self {
86 Self::new(None)
87 }
88}
89
90#[derive(Debug, Clone)]
100pub struct FittedLDA<F> {
101 scalings: Array2<F>,
105
106 means: Array2<F>,
108
109 explained_variance_ratio: Array1<F>,
111
112 classes: Vec<usize>,
114
115 n_features: usize,
117}
118
119impl<F: Float + Send + Sync + 'static> FittedLDA<F> {
120 #[must_use]
122 pub fn scalings(&self) -> &Array2<F> {
123 &self.scalings
124 }
125
126 #[must_use]
128 pub fn means(&self) -> &Array2<F> {
129 &self.means
130 }
131
132 #[must_use]
134 pub fn explained_variance_ratio(&self) -> &Array1<F> {
135 &self.explained_variance_ratio
136 }
137
138 #[must_use]
140 pub fn classes(&self) -> &[usize] {
141 &self.classes
142 }
143}
144
145fn jacobi_eigen_f<F: Float + Send + Sync + 'static>(
154 a: &Array2<F>,
155 max_iter: usize,
156) -> Result<(Array1<F>, Array2<F>), FerroError> {
157 let n = a.nrows();
158 let mut mat = a.to_owned();
159 let mut v = Array2::<F>::zeros((n, n));
160 for i in 0..n {
161 v[[i, i]] = F::one();
162 }
163 let tol = F::from(1e-12).unwrap_or(F::epsilon());
164
165 for _ in 0..max_iter {
166 let mut max_off = F::zero();
168 let mut p = 0usize;
169 let mut q = 1usize;
170 for i in 0..n {
171 for j in (i + 1)..n {
172 let val = mat[[i, j]].abs();
173 if val > max_off {
174 max_off = val;
175 p = i;
176 q = j;
177 }
178 }
179 }
180 if max_off < tol {
181 let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
182 return Ok((eigenvalues, v));
183 }
184 let app = mat[[p, p]];
185 let aqq = mat[[q, q]];
186 let apq = mat[[p, q]];
187 let two = F::from(2.0).unwrap();
188 let theta = if (app - aqq).abs() < tol {
189 F::from(std::f64::consts::FRAC_PI_4).unwrap_or(F::one())
190 } else {
191 let tau = (aqq - app) / (two * apq);
192 let t = if tau >= F::zero() {
193 F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
194 } else {
195 -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
196 };
197 t.atan()
198 };
199 let c = theta.cos();
200 let s = theta.sin();
201 let mut new_mat = mat.clone();
202 for i in 0..n {
203 if i != p && i != q {
204 let mip = mat[[i, p]];
205 let miq = mat[[i, q]];
206 new_mat[[i, p]] = c * mip - s * miq;
207 new_mat[[p, i]] = new_mat[[i, p]];
208 new_mat[[i, q]] = s * mip + c * miq;
209 new_mat[[q, i]] = new_mat[[i, q]];
210 }
211 }
212 new_mat[[p, p]] = c * c * app - two * s * c * apq + s * s * aqq;
213 new_mat[[q, q]] = s * s * app + two * s * c * apq + c * c * aqq;
214 new_mat[[p, q]] = F::zero();
215 new_mat[[q, p]] = F::zero();
216 mat = new_mat;
217 for i in 0..n {
218 let vip = v[[i, p]];
219 let viq = v[[i, q]];
220 v[[i, p]] = c * vip - s * viq;
221 v[[i, q]] = s * vip + c * viq;
222 }
223 }
224 Err(FerroError::ConvergenceFailure {
225 iterations: max_iter,
226 message: "Jacobi eigendecomposition did not converge (LDA)".into(),
227 })
228}
229
230fn gaussian_solve_f<F: Float>(
232 n: usize,
233 a: &Array2<F>,
234 b: &Array1<F>,
235) -> Result<Array1<F>, FerroError> {
236 let mut aug = Array2::<F>::zeros((n, n + 1));
237 for i in 0..n {
238 for j in 0..n {
239 aug[[i, j]] = a[[i, j]];
240 }
241 aug[[i, n]] = b[i];
242 }
243 for col in 0..n {
244 let mut max_val = aug[[col, col]].abs();
245 let mut max_row = col;
246 for row in (col + 1)..n {
247 let val = aug[[row, col]].abs();
248 if val > max_val {
249 max_val = val;
250 max_row = row;
251 }
252 }
253 if max_val < F::from(1e-12).unwrap_or(F::epsilon()) {
254 return Err(FerroError::NumericalInstability {
255 message: "singular matrix during LDA inversion".into(),
256 });
257 }
258 if max_row != col {
259 for j in 0..=n {
260 let tmp = aug[[col, j]];
261 aug[[col, j]] = aug[[max_row, j]];
262 aug[[max_row, j]] = tmp;
263 }
264 }
265 let pivot = aug[[col, col]];
266 for row in (col + 1)..n {
267 let factor = aug[[row, col]] / pivot;
268 for j in col..=n {
269 let above = aug[[col, j]];
270 aug[[row, j]] = aug[[row, j]] - factor * above;
271 }
272 }
273 }
274 let mut x = Array1::<F>::zeros(n);
275 for i in (0..n).rev() {
276 let mut sum = aug[[i, n]];
277 for j in (i + 1)..n {
278 sum = sum - aug[[i, j]] * x[j];
279 }
280 if aug[[i, i]].abs() < F::from(1e-12).unwrap_or(F::epsilon()) {
281 return Err(FerroError::NumericalInstability {
282 message: "near-zero pivot during LDA back substitution".into(),
283 });
284 }
285 x[i] = sum / aug[[i, i]];
286 }
287 Ok(x)
288}
289
290fn sw_inv_sb<F: Float + Send + Sync + 'static>(
294 sw: &Array2<F>,
295 sb: &Array2<F>,
296) -> Result<Array2<F>, FerroError> {
297 let n = sw.nrows();
298 let mut result = Array2::<F>::zeros((n, n));
299 for j in 0..n {
300 let col_sb = Array1::from_shape_fn(n, |i| sb[[i, j]]);
301 let col = gaussian_solve_f(n, sw, &col_sb)?;
302 for i in 0..n {
303 result[[i, j]] = col[i];
304 }
305 }
306 Ok(result)
307}
308
309impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for LDA<F> {
314 type Fitted = FittedLDA<F>;
315 type Error = FerroError;
316
317 fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedLDA<F>, FerroError> {
328 let (n_samples, n_features) = x.dim();
329
330 if n_samples != y.len() {
331 return Err(FerroError::ShapeMismatch {
332 expected: vec![n_samples],
333 actual: vec![y.len()],
334 context: "LDA: y length must match number of rows in X".into(),
335 });
336 }
337 if n_samples < 2 {
338 return Err(FerroError::InsufficientSamples {
339 required: 2,
340 actual: n_samples,
341 context: "LDA requires at least 2 samples".into(),
342 });
343 }
344
345 let mut classes: Vec<usize> = y.to_vec();
347 classes.sort_unstable();
348 classes.dedup();
349 let n_classes = classes.len();
350
351 if n_classes < 2 {
352 return Err(FerroError::InsufficientSamples {
353 required: 2,
354 actual: n_classes,
355 context: "LDA requires at least 2 distinct classes".into(),
356 });
357 }
358
359 let max_components = (n_classes - 1).min(n_features);
361 let n_comp = match self.n_components {
362 None => max_components,
363 Some(0) => {
364 return Err(FerroError::InvalidParameter {
365 name: "n_components".into(),
366 reason: "must be at least 1".into(),
367 });
368 }
369 Some(k) if k > max_components => {
370 return Err(FerroError::InvalidParameter {
371 name: "n_components".into(),
372 reason: format!(
373 "n_components ({k}) exceeds max allowed ({max_components} = min(n_classes-1, n_features))"
374 ),
375 });
376 }
377 Some(k) => k,
378 };
379
380 let n_f = F::from(n_samples).unwrap();
382 let mut overall_mean = Array1::<F>::zeros(n_features);
383 for j in 0..n_features {
384 let col = x.column(j);
385 let s = col.iter().copied().fold(F::zero(), |a, b| a + b);
386 overall_mean[j] = s / n_f;
387 }
388
389 let mut class_means: Vec<Array1<F>> = Vec::with_capacity(n_classes);
391 let mut class_counts: Vec<usize> = Vec::with_capacity(n_classes);
392 for &cls in &classes {
393 let mut mean = Array1::<F>::zeros(n_features);
394 let mut cnt = 0usize;
395 for (i, &label) in y.iter().enumerate() {
396 if label == cls {
397 for j in 0..n_features {
398 mean[j] = mean[j] + x[[i, j]];
399 }
400 cnt += 1;
401 }
402 }
403 if cnt == 0 {
404 return Err(FerroError::InsufficientSamples {
405 required: 1,
406 actual: 0,
407 context: format!("LDA: class {cls} has no samples"),
408 });
409 }
410 let cnt_f = F::from(cnt).unwrap();
411 mean.mapv_inplace(|v| v / cnt_f);
412 class_means.push(mean);
413 class_counts.push(cnt);
414 }
415
416 let mut sw = Array2::<F>::zeros((n_features, n_features));
418 for (ci, &cls) in classes.iter().enumerate() {
419 let mu_c = &class_means[ci];
420 for (i, &label) in y.iter().enumerate() {
421 if label == cls {
422 let diff: Vec<F> = (0..n_features).map(|j| x[[i, j]] - mu_c[j]).collect();
424 for r in 0..n_features {
425 for c in 0..n_features {
426 sw[[r, c]] = sw[[r, c]] + diff[r] * diff[c];
427 }
428 }
429 }
430 }
431 }
432
433 let reg = F::from(1e-6).unwrap();
435 for i in 0..n_features {
436 sw[[i, i]] = sw[[i, i]] + reg;
437 }
438
439 let mut sb = Array2::<F>::zeros((n_features, n_features));
441 for (ci, &nc) in class_counts.iter().enumerate() {
442 let nc_f = F::from(nc).unwrap();
443 let diff: Vec<F> = (0..n_features)
444 .map(|j| class_means[ci][j] - overall_mean[j])
445 .collect();
446 for r in 0..n_features {
447 for c in 0..n_features {
448 sb[[r, c]] = sb[[r, c]] + nc_f * diff[r] * diff[c];
449 }
450 }
451 }
452
453 let m = sw_inv_sb(&sw, &sb)?;
455 let max_jacobi = n_features * n_features * 100 + 1000;
456 let (eigenvalues, eigenvectors) = jacobi_eigen_f(&m, max_jacobi)?;
457
458 let mut indices: Vec<usize> = (0..n_features).collect();
460 indices.sort_by(|&a, &b| {
461 eigenvalues[b]
462 .partial_cmp(&eigenvalues[a])
463 .unwrap_or(std::cmp::Ordering::Equal)
464 });
465
466 let total_ev: F = eigenvalues
468 .iter()
469 .copied()
470 .map(|v| if v > F::zero() { v } else { F::zero() })
471 .fold(F::zero(), |a, b| a + b);
472
473 let mut scalings = Array2::<F>::zeros((n_features, n_comp));
475 let mut explained_variance_ratio = Array1::<F>::zeros(n_comp);
476 for (k, &idx) in indices.iter().take(n_comp).enumerate() {
477 let ev = eigenvalues[idx];
478 let ev_clamped = if ev > F::zero() { ev } else { F::zero() };
479 explained_variance_ratio[k] = if total_ev > F::zero() {
480 ev_clamped / total_ev
481 } else {
482 F::zero()
483 };
484 for j in 0..n_features {
485 scalings[[j, k]] = eigenvectors[[j, idx]];
486 }
487 }
488
489 let mut means = Array2::<F>::zeros((n_classes, n_comp));
492 for ci in 0..n_classes {
493 let mu_row = class_means[ci].view();
494 for k in 0..n_comp {
495 let mut dot = F::zero();
496 for j in 0..n_features {
497 dot = dot + mu_row[j] * scalings[[j, k]];
498 }
499 means[[ci, k]] = dot;
500 }
501 }
502
503 Ok(FittedLDA {
504 scalings,
505 means,
506 explained_variance_ratio,
507 classes,
508 n_features,
509 })
510 }
511}
512
513impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedLDA<F> {
518 type Output = Array2<F>;
519 type Error = FerroError;
520
521 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
528 if x.ncols() != self.n_features {
529 return Err(FerroError::ShapeMismatch {
530 expected: vec![x.nrows(), self.n_features],
531 actual: vec![x.nrows(), x.ncols()],
532 context: "FittedLDA::transform".into(),
533 });
534 }
535 Ok(x.dot(&self.scalings))
536 }
537}
538
539impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedLDA<F> {
544 type Output = Array1<usize>;
545 type Error = FerroError;
546
547 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
554 let projected = self.transform(x)?;
555 let n_samples = projected.nrows();
556 let n_comp = projected.ncols();
557 let n_classes = self.classes.len();
558
559 let mut predictions = Array1::<usize>::zeros(n_samples);
560 for i in 0..n_samples {
561 let mut best_class = 0usize;
562 let mut best_dist = F::infinity();
563 for ci in 0..n_classes {
564 let mut dist = F::zero();
565 for k in 0..n_comp {
566 let d = projected[[i, k]] - self.means[[ci, k]];
567 dist = dist + d * d;
568 }
569 if dist < best_dist {
570 best_dist = dist;
571 best_class = ci;
572 }
573 }
574 predictions[i] = self.classes[best_class];
575 }
576 Ok(predictions)
577 }
578}
579
580impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for LDA<F> {
585 fn fit_pipeline(
591 &self,
592 x: &Array2<F>,
593 y: &Array1<F>,
594 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
595 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
596 let fitted = self.fit(x, &y_usize)?;
597 Ok(Box::new(FittedLDAPipeline(fitted)))
598 }
599}
600
601struct FittedLDAPipeline<F>(FittedLDA<F>);
603
604impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedLDAPipeline<F> {
605 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
607 let preds = self.0.predict(x)?;
608 Ok(preds.mapv(|v| NumCast::from(v).unwrap_or(F::nan())))
609 }
610}
611
612#[cfg(test)]
617mod tests {
618 use super::*;
619 use approx::assert_abs_diff_eq;
620 use ndarray::{Array2, array};
621
622 fn linearly_separable_2d() -> (Array2<f64>, Array1<usize>) {
627 let x = Array2::from_shape_vec(
629 (8, 2),
630 vec![
631 1.0, 1.0, 1.5, 1.2, 0.8, 0.9, 1.1, 1.3, 6.0, 6.0, 6.2, 5.8, 5.9, 6.1, 6.3, 5.7, ],
634 )
635 .unwrap();
636 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
637 (x, y)
638 }
639
640 fn three_class_data() -> (Array2<f64>, Array1<usize>) {
641 let x = Array2::from_shape_vec(
642 (9, 2),
643 vec![
644 0.0, 0.0, 0.5, 0.1, 0.1, 0.5, 5.0, 0.0, 5.2, 0.3, 4.8, 0.1, 0.0, 5.0, 0.1, 5.2, 0.3, 4.8, ],
648 )
649 .unwrap();
650 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
651 (x, y)
652 }
653
654 #[test]
657 fn test_lda_fit_returns_fitted() {
658 let (x, y) = linearly_separable_2d();
659 let lda = LDA::<f64>::new(Some(1));
660 let fitted = lda.fit(&x, &y).unwrap();
661 assert_eq!(fitted.scalings().ncols(), 1);
662 assert_eq!(fitted.scalings().nrows(), 2);
663 }
664
665 #[test]
666 fn test_lda_default_n_components() {
667 let (x, y) = linearly_separable_2d();
669 let lda = LDA::<f64>::default();
670 let fitted = lda.fit(&x, &y).unwrap();
671 assert_eq!(fitted.scalings().ncols(), 1);
672 }
673
674 #[test]
675 fn test_lda_transform_shape() {
676 let (x, y) = linearly_separable_2d();
677 let lda = LDA::<f64>::new(Some(1));
678 let fitted = lda.fit(&x, &y).unwrap();
679 let proj = fitted.transform(&x).unwrap();
680 assert_eq!(proj.dim(), (8, 1));
681 }
682
683 #[test]
684 fn test_lda_predict_accuracy_binary() {
685 let (x, y) = linearly_separable_2d();
686 let lda = LDA::<f64>::new(Some(1));
687 let fitted = lda.fit(&x, &y).unwrap();
688 let preds = fitted.predict(&x).unwrap();
689 let correct = preds.iter().zip(y.iter()).filter(|(p, a)| *p == *a).count();
690 assert_eq!(correct, 8, "All 8 samples should be classified correctly");
691 }
692
693 #[test]
694 fn test_lda_predict_three_classes() {
695 let (x, y) = three_class_data();
696 let lda = LDA::<f64>::new(Some(2));
697 let fitted = lda.fit(&x, &y).unwrap();
698 let preds = fitted.predict(&x).unwrap();
699 let correct = preds.iter().zip(y.iter()).filter(|(p, a)| *p == *a).count();
700 assert!(correct >= 7, "Expected at least 7/9 correct, got {correct}");
701 }
702
703 #[test]
704 fn test_lda_explained_variance_ratio_positive() {
705 let (x, y) = linearly_separable_2d();
706 let lda = LDA::<f64>::new(Some(1));
707 let fitted = lda.fit(&x, &y).unwrap();
708 for &v in fitted.explained_variance_ratio().iter() {
709 assert!(v >= 0.0);
710 }
711 }
712
713 #[test]
714 fn test_lda_explained_variance_ratio_le_1() {
715 let (x, y) = three_class_data();
716 let lda = LDA::<f64>::new(Some(2));
717 let fitted = lda.fit(&x, &y).unwrap();
718 let total: f64 = fitted.explained_variance_ratio().iter().sum();
719 assert!(total <= 1.0 + 1e-9, "total={total}");
720 }
721
722 #[test]
723 fn test_lda_classes_accessor() {
724 let (x, y) = linearly_separable_2d();
725 let lda = LDA::<f64>::new(Some(1));
726 let fitted = lda.fit(&x, &y).unwrap();
727 assert_eq!(fitted.classes(), &[0usize, 1]);
728 }
729
730 #[test]
731 fn test_lda_means_shape() {
732 let (x, y) = three_class_data();
733 let lda = LDA::<f64>::new(Some(2));
734 let fitted = lda.fit(&x, &y).unwrap();
735 assert_eq!(fitted.means().dim(), (3, 2));
736 }
737
738 #[test]
739 fn test_lda_transform_shape_mismatch() {
740 let (x, y) = linearly_separable_2d();
741 let lda = LDA::<f64>::new(Some(1));
742 let fitted = lda.fit(&x, &y).unwrap();
743 let x_bad = Array2::<f64>::zeros((3, 5));
744 assert!(fitted.transform(&x_bad).is_err());
745 }
746
747 #[test]
748 fn test_lda_predict_shape_mismatch() {
749 let (x, y) = linearly_separable_2d();
750 let lda = LDA::<f64>::new(Some(1));
751 let fitted = lda.fit(&x, &y).unwrap();
752 let x_bad = Array2::<f64>::zeros((3, 5));
753 assert!(fitted.predict(&x_bad).is_err());
754 }
755
756 #[test]
757 fn test_lda_error_zero_n_components() {
758 let (x, y) = linearly_separable_2d();
759 let lda = LDA::<f64>::new(Some(0));
760 assert!(lda.fit(&x, &y).is_err());
761 }
762
763 #[test]
764 fn test_lda_error_n_components_too_large() {
765 let (x, y) = linearly_separable_2d(); let lda = LDA::<f64>::new(Some(5));
767 assert!(lda.fit(&x, &y).is_err());
768 }
769
770 #[test]
771 fn test_lda_error_single_class() {
772 let x =
773 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
774 let y = array![0usize, 0, 0, 0];
775 let lda = LDA::<f64>::new(None);
776 assert!(lda.fit(&x, &y).is_err());
777 }
778
779 #[test]
780 fn test_lda_error_shape_mismatch_fit() {
781 let x = Array2::<f64>::zeros((4, 2));
782 let y = array![0usize, 1]; let lda = LDA::<f64>::new(None);
784 assert!(lda.fit(&x, &y).is_err());
785 }
786
787 #[test]
788 fn test_lda_error_insufficient_samples() {
789 let x = Array2::<f64>::zeros((1, 2));
790 let y = array![0usize];
791 let lda = LDA::<f64>::new(None);
792 assert!(lda.fit(&x, &y).is_err());
793 }
794
795 #[test]
796 fn test_lda_scalings_accessor() {
797 let (x, y) = linearly_separable_2d();
798 let lda = LDA::<f64>::new(Some(1));
799 let fitted = lda.fit(&x, &y).unwrap();
800 assert_eq!(fitted.scalings().dim(), (2, 1));
801 }
802
803 #[test]
804 fn test_lda_pipeline_estimator() {
805 use ferrolearn_core::pipeline::PipelineEstimator;
806
807 let (x, y_usize) = linearly_separable_2d();
808 let y_f64 = y_usize.mapv(|v| v as f64);
809 let lda = LDA::<f64>::new(Some(1));
810 let fitted = lda.fit_pipeline(&x, &y_f64).unwrap();
811 let preds = fitted.predict_pipeline(&x).unwrap();
812 assert_eq!(preds.len(), 8);
813 }
814
815 #[test]
816 fn test_lda_n_components_getter() {
817 let lda = LDA::<f64>::new(Some(2));
818 assert_eq!(lda.n_components(), Some(2));
819 let lda_none = LDA::<f64>::new(None);
820 assert_eq!(lda_none.n_components(), None);
821 }
822
823 #[test]
824 fn test_lda_transform_then_predict_consistent() {
825 let (x, y) = linearly_separable_2d();
826 let lda = LDA::<f64>::new(Some(1));
827 let fitted = lda.fit(&x, &y).unwrap();
828 let projected = fitted.transform(&x).unwrap();
830 let preds_predict = fitted.predict(&x).unwrap();
831 let n_samples = projected.nrows();
832 let n_comp = projected.ncols();
833 let n_classes = fitted.classes().len();
834 for i in 0..n_samples {
835 let mut best = 0;
836 let mut best_d = f64::INFINITY;
837 for ci in 0..n_classes {
838 let mut d = 0.0;
839 for k in 0..n_comp {
840 let diff = projected[[i, k]] - fitted.means()[[ci, k]];
841 d += diff * diff;
842 }
843 if d < best_d {
844 best_d = d;
845 best = ci;
846 }
847 }
848 assert_eq!(preds_predict[i], fitted.classes()[best]);
849 }
850 }
851
852 #[test]
853 fn test_lda_projected_class_separation() {
854 let (x, y) = linearly_separable_2d();
855 let lda = LDA::<f64>::new(Some(1));
856 let fitted = lda.fit(&x, &y).unwrap();
857 let projected = fitted.transform(&x).unwrap();
858
859 let mean0: f64 = projected
861 .rows()
862 .into_iter()
863 .zip(y.iter())
864 .filter(|&(_, label)| *label == 0)
865 .map(|(row, _)| row[0])
866 .sum::<f64>()
867 / 4.0;
868 let mean1: f64 = projected
869 .rows()
870 .into_iter()
871 .zip(y.iter())
872 .filter(|&(_, label)| *label == 1)
873 .map(|(row, _)| row[0])
874 .sum::<f64>()
875 / 4.0;
876
877 assert!(
878 (mean0 - mean1).abs() > 0.5,
879 "Projected means should differ, got {mean0} vs {mean1}"
880 );
881 }
882
883 #[test]
884 fn test_lda_transform_known_data() {
885 let x = Array2::from_shape_vec((4, 1), vec![-2.0, -1.0, 1.0, 2.0]).unwrap();
888 let y = array![0usize, 0, 1, 1];
889 let lda = LDA::<f64>::new(Some(1));
890 let fitted = lda.fit(&x, &y).unwrap();
891 let proj = fitted.transform(&x).unwrap();
892 let sign0 = proj[[0, 0]].signum();
894 let sign1 = proj[[2, 0]].signum();
895 assert_ne!(
897 sign0 as i32, sign1 as i32,
898 "Classes should be on opposite sides"
899 );
900 }
901
902 #[test]
903 fn test_lda_abs_diff_eq_means_dimensions() {
904 let (x, y) = linearly_separable_2d();
905 let lda = LDA::<f64>::new(Some(1));
906 let fitted = lda.fit(&x, &y).unwrap();
907 assert_eq!(fitted.means().ncols(), 1);
909 let m0 = fitted.means()[[0, 0]];
910 let m1 = fitted.means()[[1, 0]];
911 assert!((m0 - m1).abs() > 0.5, "m0={m0}, m1={m1}");
913 let _ = assert_abs_diff_eq!(0.0_f64, 0.0_f64); }
915}