1use ferrolearn_core::error::FerroError;
38use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
39use ferrolearn_core::traits::{Fit, Predict, Transform};
40use ndarray::{Array1, Array2};
41use num_traits::Float;
42
43#[derive(Debug, Clone)]
56pub struct LDA<F> {
57 n_components: Option<usize>,
61 _marker: std::marker::PhantomData<F>,
62}
63
64impl<F: Float + Send + Sync + 'static> LDA<F> {
65 #[must_use]
70 pub fn new(n_components: Option<usize>) -> Self {
71 Self {
72 n_components,
73 _marker: std::marker::PhantomData,
74 }
75 }
76
77 #[must_use]
79 pub fn n_components(&self) -> Option<usize> {
80 self.n_components
81 }
82}
83
84impl<F: Float + Send + Sync + 'static> Default for LDA<F> {
85 fn default() -> Self {
86 Self::new(None)
87 }
88}
89
90#[derive(Debug, Clone)]
100pub struct FittedLDA<F> {
101 scalings: Array2<F>,
105
106 means: Array2<F>,
108
109 explained_variance_ratio: Array1<F>,
111
112 classes: Vec<usize>,
114
115 n_features: usize,
117}
118
119impl<F: Float + Send + Sync + 'static> FittedLDA<F> {
120 #[must_use]
122 pub fn scalings(&self) -> &Array2<F> {
123 &self.scalings
124 }
125
126 #[must_use]
128 pub fn means(&self) -> &Array2<F> {
129 &self.means
130 }
131
132 #[must_use]
134 pub fn explained_variance_ratio(&self) -> &Array1<F> {
135 &self.explained_variance_ratio
136 }
137
138 #[must_use]
140 pub fn classes(&self) -> &[usize] {
141 &self.classes
142 }
143}
144
145fn jacobi_eigen_f<F: Float + Send + Sync + 'static>(
154 a: &Array2<F>,
155 max_iter: usize,
156) -> Result<(Array1<F>, Array2<F>), FerroError> {
157 let n = a.nrows();
158 let mut mat = a.to_owned();
159 let mut v = Array2::<F>::zeros((n, n));
160 for i in 0..n {
161 v[[i, i]] = F::one();
162 }
163 let tol = F::from(1e-12).unwrap_or(F::epsilon());
164
165 for _ in 0..max_iter {
166 let mut max_off = F::zero();
168 let mut p = 0usize;
169 let mut q = 1usize;
170 for i in 0..n {
171 for j in (i + 1)..n {
172 let val = mat[[i, j]].abs();
173 if val > max_off {
174 max_off = val;
175 p = i;
176 q = j;
177 }
178 }
179 }
180 if max_off < tol {
181 let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
182 return Ok((eigenvalues, v));
183 }
184 let app = mat[[p, p]];
185 let aqq = mat[[q, q]];
186 let apq = mat[[p, q]];
187 let two = F::from(2.0).unwrap();
188 let theta = if (app - aqq).abs() < tol {
189 F::from(std::f64::consts::FRAC_PI_4).unwrap_or(F::one())
190 } else {
191 let tau = (aqq - app) / (two * apq);
192 let t = if tau >= F::zero() {
193 F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
194 } else {
195 -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
196 };
197 t.atan()
198 };
199 let c = theta.cos();
200 let s = theta.sin();
201 let mut new_mat = mat.clone();
202 for i in 0..n {
203 if i != p && i != q {
204 let mip = mat[[i, p]];
205 let miq = mat[[i, q]];
206 new_mat[[i, p]] = c * mip - s * miq;
207 new_mat[[p, i]] = new_mat[[i, p]];
208 new_mat[[i, q]] = s * mip + c * miq;
209 new_mat[[q, i]] = new_mat[[i, q]];
210 }
211 }
212 new_mat[[p, p]] = c * c * app - two * s * c * apq + s * s * aqq;
213 new_mat[[q, q]] = s * s * app + two * s * c * apq + c * c * aqq;
214 new_mat[[p, q]] = F::zero();
215 new_mat[[q, p]] = F::zero();
216 mat = new_mat;
217 for i in 0..n {
218 let vip = v[[i, p]];
219 let viq = v[[i, q]];
220 v[[i, p]] = c * vip - s * viq;
221 v[[i, q]] = s * vip + c * viq;
222 }
223 }
224 Err(FerroError::ConvergenceFailure {
225 iterations: max_iter,
226 message: "Jacobi eigendecomposition did not converge (LDA)".into(),
227 })
228}
229
230fn gaussian_solve_f<F: Float>(
232 n: usize,
233 a: &Array2<F>,
234 b: &Array1<F>,
235) -> Result<Array1<F>, FerroError> {
236 let mut aug = Array2::<F>::zeros((n, n + 1));
237 for i in 0..n {
238 for j in 0..n {
239 aug[[i, j]] = a[[i, j]];
240 }
241 aug[[i, n]] = b[i];
242 }
243 for col in 0..n {
244 let mut max_val = aug[[col, col]].abs();
245 let mut max_row = col;
246 for row in (col + 1)..n {
247 let val = aug[[row, col]].abs();
248 if val > max_val {
249 max_val = val;
250 max_row = row;
251 }
252 }
253 if max_val < F::from(1e-12).unwrap_or(F::epsilon()) {
254 return Err(FerroError::NumericalInstability {
255 message: "singular matrix during LDA inversion".into(),
256 });
257 }
258 if max_row != col {
259 for j in 0..=n {
260 let tmp = aug[[col, j]];
261 aug[[col, j]] = aug[[max_row, j]];
262 aug[[max_row, j]] = tmp;
263 }
264 }
265 let pivot = aug[[col, col]];
266 for row in (col + 1)..n {
267 let factor = aug[[row, col]] / pivot;
268 for j in col..=n {
269 let above = aug[[col, j]];
270 aug[[row, j]] = aug[[row, j]] - factor * above;
271 }
272 }
273 }
274 let mut x = Array1::<F>::zeros(n);
275 for i in (0..n).rev() {
276 let mut sum = aug[[i, n]];
277 for j in (i + 1)..n {
278 sum = sum - aug[[i, j]] * x[j];
279 }
280 if aug[[i, i]].abs() < F::from(1e-12).unwrap_or(F::epsilon()) {
281 return Err(FerroError::NumericalInstability {
282 message: "near-zero pivot during LDA back substitution".into(),
283 });
284 }
285 x[i] = sum / aug[[i, i]];
286 }
287 Ok(x)
288}
289
290fn sw_inv_sb<F: Float + Send + Sync + 'static>(
294 sw: &Array2<F>,
295 sb: &Array2<F>,
296) -> Result<Array2<F>, FerroError> {
297 let n = sw.nrows();
298 let mut result = Array2::<F>::zeros((n, n));
299 for j in 0..n {
300 let col_sb = Array1::from_shape_fn(n, |i| sb[[i, j]]);
301 let col = gaussian_solve_f(n, sw, &col_sb)?;
302 for i in 0..n {
303 result[[i, j]] = col[i];
304 }
305 }
306 Ok(result)
307}
308
309impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for LDA<F> {
314 type Fitted = FittedLDA<F>;
315 type Error = FerroError;
316
317 fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedLDA<F>, FerroError> {
328 let (n_samples, n_features) = x.dim();
329
330 if n_samples != y.len() {
331 return Err(FerroError::ShapeMismatch {
332 expected: vec![n_samples],
333 actual: vec![y.len()],
334 context: "LDA: y length must match number of rows in X".into(),
335 });
336 }
337 if n_samples < 2 {
338 return Err(FerroError::InsufficientSamples {
339 required: 2,
340 actual: n_samples,
341 context: "LDA requires at least 2 samples".into(),
342 });
343 }
344
345 let mut classes: Vec<usize> = y.to_vec();
347 classes.sort_unstable();
348 classes.dedup();
349 let n_classes = classes.len();
350
351 if n_classes < 2 {
352 return Err(FerroError::InsufficientSamples {
353 required: 2,
354 actual: n_classes,
355 context: "LDA requires at least 2 distinct classes".into(),
356 });
357 }
358
359 let max_components = (n_classes - 1).min(n_features);
361 let n_comp = match self.n_components {
362 None => max_components,
363 Some(0) => {
364 return Err(FerroError::InvalidParameter {
365 name: "n_components".into(),
366 reason: "must be at least 1".into(),
367 });
368 }
369 Some(k) if k > max_components => {
370 return Err(FerroError::InvalidParameter {
371 name: "n_components".into(),
372 reason: format!(
373 "n_components ({k}) exceeds max allowed ({max_components} = min(n_classes-1, n_features))"
374 ),
375 });
376 }
377 Some(k) => k,
378 };
379
380 let n_f = F::from(n_samples).unwrap();
382 let mut overall_mean = Array1::<F>::zeros(n_features);
383 for j in 0..n_features {
384 let col = x.column(j);
385 let s = col.iter().copied().fold(F::zero(), |a, b| a + b);
386 overall_mean[j] = s / n_f;
387 }
388
389 let mut class_means: Vec<Array1<F>> = Vec::with_capacity(n_classes);
391 let mut class_counts: Vec<usize> = Vec::with_capacity(n_classes);
392 for &cls in &classes {
393 let mut mean = Array1::<F>::zeros(n_features);
394 let mut cnt = 0usize;
395 for (i, &label) in y.iter().enumerate() {
396 if label == cls {
397 for j in 0..n_features {
398 mean[j] = mean[j] + x[[i, j]];
399 }
400 cnt += 1;
401 }
402 }
403 if cnt == 0 {
404 return Err(FerroError::InsufficientSamples {
405 required: 1,
406 actual: 0,
407 context: format!("LDA: class {cls} has no samples"),
408 });
409 }
410 let cnt_f = F::from(cnt).unwrap();
411 mean.mapv_inplace(|v| v / cnt_f);
412 class_means.push(mean);
413 class_counts.push(cnt);
414 }
415
416 let mut sw = Array2::<F>::zeros((n_features, n_features));
418 for (ci, &cls) in classes.iter().enumerate() {
419 let mu_c = &class_means[ci];
420 for (i, &label) in y.iter().enumerate() {
421 if label == cls {
422 let diff: Vec<F> = (0..n_features).map(|j| x[[i, j]] - mu_c[j]).collect();
424 for r in 0..n_features {
425 for c in 0..n_features {
426 sw[[r, c]] = sw[[r, c]] + diff[r] * diff[c];
427 }
428 }
429 }
430 }
431 }
432
433 let reg = F::from(1e-6).unwrap();
435 for i in 0..n_features {
436 sw[[i, i]] = sw[[i, i]] + reg;
437 }
438
439 let mut sb = Array2::<F>::zeros((n_features, n_features));
441 for (ci, &nc) in class_counts.iter().enumerate() {
442 let nc_f = F::from(nc).unwrap();
443 let diff: Vec<F> = (0..n_features)
444 .map(|j| class_means[ci][j] - overall_mean[j])
445 .collect();
446 for r in 0..n_features {
447 for c in 0..n_features {
448 sb[[r, c]] = sb[[r, c]] + nc_f * diff[r] * diff[c];
449 }
450 }
451 }
452
453 let m = sw_inv_sb(&sw, &sb)?;
455 let max_jacobi = n_features * n_features * 100 + 1000;
456 let (eigenvalues, eigenvectors) = jacobi_eigen_f(&m, max_jacobi)?;
457
458 let mut indices: Vec<usize> = (0..n_features).collect();
460 indices.sort_by(|&a, &b| {
461 eigenvalues[b]
462 .partial_cmp(&eigenvalues[a])
463 .unwrap_or(std::cmp::Ordering::Equal)
464 });
465
466 let total_ev: F = eigenvalues
468 .iter()
469 .copied()
470 .map(|v| if v > F::zero() { v } else { F::zero() })
471 .fold(F::zero(), |a, b| a + b);
472
473 let mut scalings = Array2::<F>::zeros((n_features, n_comp));
475 let mut explained_variance_ratio = Array1::<F>::zeros(n_comp);
476 for (k, &idx) in indices.iter().take(n_comp).enumerate() {
477 let ev = eigenvalues[idx];
478 let ev_clamped = if ev > F::zero() { ev } else { F::zero() };
479 explained_variance_ratio[k] = if total_ev > F::zero() {
480 ev_clamped / total_ev
481 } else {
482 F::zero()
483 };
484 for j in 0..n_features {
485 scalings[[j, k]] = eigenvectors[[j, idx]];
486 }
487 }
488
489 let mut means = Array2::<F>::zeros((n_classes, n_comp));
492 for ci in 0..n_classes {
493 let mu_row = class_means[ci].view();
494 for k in 0..n_comp {
495 let mut dot = F::zero();
496 for j in 0..n_features {
497 dot = dot + mu_row[j] * scalings[[j, k]];
498 }
499 means[[ci, k]] = dot;
500 }
501 }
502
503 Ok(FittedLDA {
504 scalings,
505 means,
506 explained_variance_ratio,
507 classes,
508 n_features,
509 })
510 }
511}
512
513impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedLDA<F> {
518 type Output = Array2<F>;
519 type Error = FerroError;
520
521 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
528 if x.ncols() != self.n_features {
529 return Err(FerroError::ShapeMismatch {
530 expected: vec![x.nrows(), self.n_features],
531 actual: vec![x.nrows(), x.ncols()],
532 context: "FittedLDA::transform".into(),
533 });
534 }
535 Ok(x.dot(&self.scalings))
536 }
537}
538
539impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedLDA<F> {
544 type Output = Array1<usize>;
545 type Error = FerroError;
546
547 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
554 let projected = self.transform(x)?;
555 let n_samples = projected.nrows();
556 let n_comp = projected.ncols();
557 let n_classes = self.classes.len();
558
559 let mut predictions = Array1::<usize>::zeros(n_samples);
560 for i in 0..n_samples {
561 let mut best_class = 0usize;
562 let mut best_dist = F::infinity();
563 for ci in 0..n_classes {
564 let mut dist = F::zero();
565 for k in 0..n_comp {
566 let d = projected[[i, k]] - self.means[[ci, k]];
567 dist = dist + d * d;
568 }
569 if dist < best_dist {
570 best_dist = dist;
571 best_class = ci;
572 }
573 }
574 predictions[i] = self.classes[best_class];
575 }
576 Ok(predictions)
577 }
578}
579
580impl PipelineEstimator<f64> for LDA<f64> {
585 fn fit_pipeline(
591 &self,
592 x: &Array2<f64>,
593 y: &Array1<f64>,
594 ) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
595 let y_usize: Array1<usize> = y.mapv(|v| v as usize);
596 let fitted = self.fit(x, &y_usize)?;
597 Ok(Box::new(FittedLDAPipeline(fitted)))
598 }
599}
600
601struct FittedLDAPipeline(FittedLDA<f64>);
603
604unsafe impl Send for FittedLDAPipeline {}
606unsafe impl Sync for FittedLDAPipeline {}
607
608impl FittedPipelineEstimator<f64> for FittedLDAPipeline {
609 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
611 let preds = self.0.predict(x)?;
612 Ok(preds.mapv(|v| v as f64))
613 }
614}
615
616#[cfg(test)]
621mod tests {
622 use super::*;
623 use approx::assert_abs_diff_eq;
624 use ndarray::{Array2, array};
625
626 fn linearly_separable_2d() -> (Array2<f64>, Array1<usize>) {
631 let x = Array2::from_shape_vec(
633 (8, 2),
634 vec![
635 1.0, 1.0, 1.5, 1.2, 0.8, 0.9, 1.1, 1.3, 6.0, 6.0, 6.2, 5.8, 5.9, 6.1, 6.3, 5.7, ],
638 )
639 .unwrap();
640 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
641 (x, y)
642 }
643
644 fn three_class_data() -> (Array2<f64>, Array1<usize>) {
645 let x = Array2::from_shape_vec(
646 (9, 2),
647 vec![
648 0.0, 0.0, 0.5, 0.1, 0.1, 0.5, 5.0, 0.0, 5.2, 0.3, 4.8, 0.1, 0.0, 5.0, 0.1, 5.2, 0.3, 4.8, ],
652 )
653 .unwrap();
654 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
655 (x, y)
656 }
657
658 #[test]
661 fn test_lda_fit_returns_fitted() {
662 let (x, y) = linearly_separable_2d();
663 let lda = LDA::<f64>::new(Some(1));
664 let fitted = lda.fit(&x, &y).unwrap();
665 assert_eq!(fitted.scalings().ncols(), 1);
666 assert_eq!(fitted.scalings().nrows(), 2);
667 }
668
669 #[test]
670 fn test_lda_default_n_components() {
671 let (x, y) = linearly_separable_2d();
673 let lda = LDA::<f64>::default();
674 let fitted = lda.fit(&x, &y).unwrap();
675 assert_eq!(fitted.scalings().ncols(), 1);
676 }
677
678 #[test]
679 fn test_lda_transform_shape() {
680 let (x, y) = linearly_separable_2d();
681 let lda = LDA::<f64>::new(Some(1));
682 let fitted = lda.fit(&x, &y).unwrap();
683 let proj = fitted.transform(&x).unwrap();
684 assert_eq!(proj.dim(), (8, 1));
685 }
686
687 #[test]
688 fn test_lda_predict_accuracy_binary() {
689 let (x, y) = linearly_separable_2d();
690 let lda = LDA::<f64>::new(Some(1));
691 let fitted = lda.fit(&x, &y).unwrap();
692 let preds = fitted.predict(&x).unwrap();
693 let correct = preds.iter().zip(y.iter()).filter(|(p, a)| *p == *a).count();
694 assert_eq!(correct, 8, "All 8 samples should be classified correctly");
695 }
696
697 #[test]
698 fn test_lda_predict_three_classes() {
699 let (x, y) = three_class_data();
700 let lda = LDA::<f64>::new(Some(2));
701 let fitted = lda.fit(&x, &y).unwrap();
702 let preds = fitted.predict(&x).unwrap();
703 let correct = preds.iter().zip(y.iter()).filter(|(p, a)| *p == *a).count();
704 assert!(correct >= 7, "Expected at least 7/9 correct, got {correct}");
705 }
706
707 #[test]
708 fn test_lda_explained_variance_ratio_positive() {
709 let (x, y) = linearly_separable_2d();
710 let lda = LDA::<f64>::new(Some(1));
711 let fitted = lda.fit(&x, &y).unwrap();
712 for &v in fitted.explained_variance_ratio().iter() {
713 assert!(v >= 0.0);
714 }
715 }
716
717 #[test]
718 fn test_lda_explained_variance_ratio_le_1() {
719 let (x, y) = three_class_data();
720 let lda = LDA::<f64>::new(Some(2));
721 let fitted = lda.fit(&x, &y).unwrap();
722 let total: f64 = fitted.explained_variance_ratio().iter().sum();
723 assert!(total <= 1.0 + 1e-9, "total={total}");
724 }
725
726 #[test]
727 fn test_lda_classes_accessor() {
728 let (x, y) = linearly_separable_2d();
729 let lda = LDA::<f64>::new(Some(1));
730 let fitted = lda.fit(&x, &y).unwrap();
731 assert_eq!(fitted.classes(), &[0usize, 1]);
732 }
733
734 #[test]
735 fn test_lda_means_shape() {
736 let (x, y) = three_class_data();
737 let lda = LDA::<f64>::new(Some(2));
738 let fitted = lda.fit(&x, &y).unwrap();
739 assert_eq!(fitted.means().dim(), (3, 2));
740 }
741
742 #[test]
743 fn test_lda_transform_shape_mismatch() {
744 let (x, y) = linearly_separable_2d();
745 let lda = LDA::<f64>::new(Some(1));
746 let fitted = lda.fit(&x, &y).unwrap();
747 let x_bad = Array2::<f64>::zeros((3, 5));
748 assert!(fitted.transform(&x_bad).is_err());
749 }
750
751 #[test]
752 fn test_lda_predict_shape_mismatch() {
753 let (x, y) = linearly_separable_2d();
754 let lda = LDA::<f64>::new(Some(1));
755 let fitted = lda.fit(&x, &y).unwrap();
756 let x_bad = Array2::<f64>::zeros((3, 5));
757 assert!(fitted.predict(&x_bad).is_err());
758 }
759
760 #[test]
761 fn test_lda_error_zero_n_components() {
762 let (x, y) = linearly_separable_2d();
763 let lda = LDA::<f64>::new(Some(0));
764 assert!(lda.fit(&x, &y).is_err());
765 }
766
767 #[test]
768 fn test_lda_error_n_components_too_large() {
769 let (x, y) = linearly_separable_2d(); let lda = LDA::<f64>::new(Some(5));
771 assert!(lda.fit(&x, &y).is_err());
772 }
773
774 #[test]
775 fn test_lda_error_single_class() {
776 let x =
777 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
778 let y = array![0usize, 0, 0, 0];
779 let lda = LDA::<f64>::new(None);
780 assert!(lda.fit(&x, &y).is_err());
781 }
782
783 #[test]
784 fn test_lda_error_shape_mismatch_fit() {
785 let x = Array2::<f64>::zeros((4, 2));
786 let y = array![0usize, 1]; let lda = LDA::<f64>::new(None);
788 assert!(lda.fit(&x, &y).is_err());
789 }
790
791 #[test]
792 fn test_lda_error_insufficient_samples() {
793 let x = Array2::<f64>::zeros((1, 2));
794 let y = array![0usize];
795 let lda = LDA::<f64>::new(None);
796 assert!(lda.fit(&x, &y).is_err());
797 }
798
799 #[test]
800 fn test_lda_scalings_accessor() {
801 let (x, y) = linearly_separable_2d();
802 let lda = LDA::<f64>::new(Some(1));
803 let fitted = lda.fit(&x, &y).unwrap();
804 assert_eq!(fitted.scalings().dim(), (2, 1));
805 }
806
807 #[test]
808 fn test_lda_pipeline_estimator() {
809 use ferrolearn_core::pipeline::PipelineEstimator;
810
811 let (x, y_usize) = linearly_separable_2d();
812 let y_f64 = y_usize.mapv(|v| v as f64);
813 let lda = LDA::<f64>::new(Some(1));
814 let fitted = lda.fit_pipeline(&x, &y_f64).unwrap();
815 let preds = fitted.predict_pipeline(&x).unwrap();
816 assert_eq!(preds.len(), 8);
817 }
818
819 #[test]
820 fn test_lda_n_components_getter() {
821 let lda = LDA::<f64>::new(Some(2));
822 assert_eq!(lda.n_components(), Some(2));
823 let lda_none = LDA::<f64>::new(None);
824 assert_eq!(lda_none.n_components(), None);
825 }
826
827 #[test]
828 fn test_lda_transform_then_predict_consistent() {
829 let (x, y) = linearly_separable_2d();
830 let lda = LDA::<f64>::new(Some(1));
831 let fitted = lda.fit(&x, &y).unwrap();
832 let projected = fitted.transform(&x).unwrap();
834 let preds_predict = fitted.predict(&x).unwrap();
835 let n_samples = projected.nrows();
836 let n_comp = projected.ncols();
837 let n_classes = fitted.classes().len();
838 for i in 0..n_samples {
839 let mut best = 0;
840 let mut best_d = f64::INFINITY;
841 for ci in 0..n_classes {
842 let mut d = 0.0;
843 for k in 0..n_comp {
844 let diff = projected[[i, k]] - fitted.means()[[ci, k]];
845 d += diff * diff;
846 }
847 if d < best_d {
848 best_d = d;
849 best = ci;
850 }
851 }
852 assert_eq!(preds_predict[i], fitted.classes()[best]);
853 }
854 }
855
856 #[test]
857 fn test_lda_projected_class_separation() {
858 let (x, y) = linearly_separable_2d();
859 let lda = LDA::<f64>::new(Some(1));
860 let fitted = lda.fit(&x, &y).unwrap();
861 let projected = fitted.transform(&x).unwrap();
862
863 let mean0: f64 = projected
865 .rows()
866 .into_iter()
867 .zip(y.iter())
868 .filter(|&(_, label)| *label == 0)
869 .map(|(row, _)| row[0])
870 .sum::<f64>()
871 / 4.0;
872 let mean1: f64 = projected
873 .rows()
874 .into_iter()
875 .zip(y.iter())
876 .filter(|&(_, label)| *label == 1)
877 .map(|(row, _)| row[0])
878 .sum::<f64>()
879 / 4.0;
880
881 assert!(
882 (mean0 - mean1).abs() > 0.5,
883 "Projected means should differ, got {mean0} vs {mean1}"
884 );
885 }
886
887 #[test]
888 fn test_lda_transform_known_data() {
889 let x = Array2::from_shape_vec((4, 1), vec![-2.0, -1.0, 1.0, 2.0]).unwrap();
892 let y = array![0usize, 0, 1, 1];
893 let lda = LDA::<f64>::new(Some(1));
894 let fitted = lda.fit(&x, &y).unwrap();
895 let proj = fitted.transform(&x).unwrap();
896 let sign0 = proj[[0, 0]].signum();
898 let sign1 = proj[[2, 0]].signum();
899 assert_ne!(
901 sign0 as i32, sign1 as i32,
902 "Classes should be on opposite sides"
903 );
904 }
905
906 #[test]
907 fn test_lda_abs_diff_eq_means_dimensions() {
908 let (x, y) = linearly_separable_2d();
909 let lda = LDA::<f64>::new(Some(1));
910 let fitted = lda.fit(&x, &y).unwrap();
911 assert_eq!(fitted.means().ncols(), 1);
913 let m0 = fitted.means()[[0, 0]];
914 let m1 = fitted.means()[[1, 0]];
915 assert!((m0 - m1).abs() > 0.5, "m0={m0}, m1={m1}");
917 let _ = assert_abs_diff_eq!(0.0_f64, 0.0_f64); }
919}