1use ferrolearn_core::error::FerroError;
38use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
39use ferrolearn_core::traits::{Fit, Predict, Transform};
40use ndarray::{Array1, Array2};
41use num_traits::{Float, NumCast};
42
43#[derive(Debug, Clone)]
56pub struct LDA<F> {
57 n_components: Option<usize>,
61 _marker: std::marker::PhantomData<F>,
62}
63
64impl<F: Float + Send + Sync + 'static> LDA<F> {
65 #[must_use]
70 pub fn new(n_components: Option<usize>) -> Self {
71 Self {
72 n_components,
73 _marker: std::marker::PhantomData,
74 }
75 }
76
77 #[must_use]
79 pub fn n_components(&self) -> Option<usize> {
80 self.n_components
81 }
82}
83
84impl<F: Float + Send + Sync + 'static> Default for LDA<F> {
85 fn default() -> Self {
86 Self::new(None)
87 }
88}
89
90#[derive(Debug, Clone)]
100pub struct FittedLDA<F> {
101 scalings: Array2<F>,
105
106 means: Array2<F>,
108
109 explained_variance_ratio: Array1<F>,
111
112 classes: Vec<usize>,
114
115 n_features: usize,
117}
118
119impl<F: Float + Send + Sync + 'static> FittedLDA<F> {
120 #[must_use]
122 pub fn scalings(&self) -> &Array2<F> {
123 &self.scalings
124 }
125
126 #[must_use]
128 pub fn means(&self) -> &Array2<F> {
129 &self.means
130 }
131
132 #[must_use]
134 pub fn explained_variance_ratio(&self) -> &Array1<F> {
135 &self.explained_variance_ratio
136 }
137
138 #[must_use]
140 pub fn classes(&self) -> &[usize] {
141 &self.classes
142 }
143
144 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
158 let projected = self.transform(x)?;
159 let n_samples = projected.nrows();
160 let n_comp = projected.ncols();
161 let n_classes = self.classes.len();
162 let neg_half = F::from(-0.5).unwrap();
163 let mut proba = Array2::<F>::zeros((n_samples, n_classes));
164 for i in 0..n_samples {
165 let mut logits = vec![F::zero(); n_classes];
166 for ci in 0..n_classes {
167 let mut dist_sq = F::zero();
168 for k in 0..n_comp {
169 let d = projected[[i, k]] - self.means[[ci, k]];
170 dist_sq = dist_sq + d * d;
171 }
172 logits[ci] = neg_half * dist_sq;
173 }
174 let max_l = logits
175 .iter()
176 .copied()
177 .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
178 let mut sum_exp = F::zero();
179 for ci in 0..n_classes {
180 let e = (logits[ci] - max_l).exp();
181 proba[[i, ci]] = e;
182 sum_exp = sum_exp + e;
183 }
184 for ci in 0..n_classes {
185 proba[[i, ci]] = proba[[i, ci]] / sum_exp;
186 }
187 }
188 Ok(proba)
189 }
190
191 pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
197 let proba = self.predict_proba(x)?;
198 Ok(crate::log_proba(&proba))
199 }
200
201 pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
212 let projected = self.transform(x)?;
213 let n_samples = projected.nrows();
214 let n_comp = projected.ncols();
215 let n_classes = self.classes.len();
216 let neg_half = F::from(-0.5).unwrap();
217 let mut out = Array2::<F>::zeros((n_samples, n_classes));
218 for i in 0..n_samples {
219 for ci in 0..n_classes {
220 let mut dist_sq = F::zero();
221 for k in 0..n_comp {
222 let d = projected[[i, k]] - self.means[[ci, k]];
223 dist_sq = dist_sq + d * d;
224 }
225 out[[i, ci]] = neg_half * dist_sq;
226 }
227 }
228 Ok(out)
229 }
230}
231
232fn jacobi_eigen_f<F: Float + Send + Sync + 'static>(
241 a: &Array2<F>,
242 max_iter: usize,
243) -> Result<(Array1<F>, Array2<F>), FerroError> {
244 let n = a.nrows();
245 let mut mat = a.to_owned();
246 let mut v = Array2::<F>::zeros((n, n));
247 for i in 0..n {
248 v[[i, i]] = F::one();
249 }
250 let tol = F::from(1e-12).unwrap_or_else(F::epsilon);
251
252 for _ in 0..max_iter {
253 let mut max_off = F::zero();
255 let mut p = 0usize;
256 let mut q = 1usize;
257 for i in 0..n {
258 for j in (i + 1)..n {
259 let val = mat[[i, j]].abs();
260 if val > max_off {
261 max_off = val;
262 p = i;
263 q = j;
264 }
265 }
266 }
267 if max_off < tol {
268 let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
269 return Ok((eigenvalues, v));
270 }
271 let app = mat[[p, p]];
272 let aqq = mat[[q, q]];
273 let apq = mat[[p, q]];
274 let two = F::from(2.0).unwrap();
275 let theta = if (app - aqq).abs() < tol {
276 F::from(std::f64::consts::FRAC_PI_4).unwrap_or_else(F::one)
277 } else {
278 let tau = (aqq - app) / (two * apq);
279 let t = if tau >= F::zero() {
280 F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
281 } else {
282 -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
283 };
284 t.atan()
285 };
286 let c = theta.cos();
287 let s = theta.sin();
288 let mut new_mat = mat.clone();
289 for i in 0..n {
290 if i != p && i != q {
291 let mip = mat[[i, p]];
292 let miq = mat[[i, q]];
293 new_mat[[i, p]] = c * mip - s * miq;
294 new_mat[[p, i]] = new_mat[[i, p]];
295 new_mat[[i, q]] = s * mip + c * miq;
296 new_mat[[q, i]] = new_mat[[i, q]];
297 }
298 }
299 new_mat[[p, p]] = c * c * app - two * s * c * apq + s * s * aqq;
300 new_mat[[q, q]] = s * s * app + two * s * c * apq + c * c * aqq;
301 new_mat[[p, q]] = F::zero();
302 new_mat[[q, p]] = F::zero();
303 mat = new_mat;
304 for i in 0..n {
305 let vip = v[[i, p]];
306 let viq = v[[i, q]];
307 v[[i, p]] = c * vip - s * viq;
308 v[[i, q]] = s * vip + c * viq;
309 }
310 }
311 Err(FerroError::ConvergenceFailure {
312 iterations: max_iter,
313 message: "Jacobi eigendecomposition did not converge (LDA)".into(),
314 })
315}
316
317fn gaussian_solve_f<F: Float>(
319 n: usize,
320 a: &Array2<F>,
321 b: &Array1<F>,
322) -> Result<Array1<F>, FerroError> {
323 let mut aug = Array2::<F>::zeros((n, n + 1));
324 for i in 0..n {
325 for j in 0..n {
326 aug[[i, j]] = a[[i, j]];
327 }
328 aug[[i, n]] = b[i];
329 }
330 for col in 0..n {
331 let mut max_val = aug[[col, col]].abs();
332 let mut max_row = col;
333 for row in (col + 1)..n {
334 let val = aug[[row, col]].abs();
335 if val > max_val {
336 max_val = val;
337 max_row = row;
338 }
339 }
340 if max_val < F::from(1e-12).unwrap_or_else(F::epsilon) {
341 return Err(FerroError::NumericalInstability {
342 message: "singular matrix during LDA inversion".into(),
343 });
344 }
345 if max_row != col {
346 for j in 0..=n {
347 let tmp = aug[[col, j]];
348 aug[[col, j]] = aug[[max_row, j]];
349 aug[[max_row, j]] = tmp;
350 }
351 }
352 let pivot = aug[[col, col]];
353 for row in (col + 1)..n {
354 let factor = aug[[row, col]] / pivot;
355 for j in col..=n {
356 let above = aug[[col, j]];
357 aug[[row, j]] = aug[[row, j]] - factor * above;
358 }
359 }
360 }
361 let mut x = Array1::<F>::zeros(n);
362 for i in (0..n).rev() {
363 let mut sum = aug[[i, n]];
364 for j in (i + 1)..n {
365 sum = sum - aug[[i, j]] * x[j];
366 }
367 if aug[[i, i]].abs() < F::from(1e-12).unwrap_or_else(F::epsilon) {
368 return Err(FerroError::NumericalInstability {
369 message: "near-zero pivot during LDA back substitution".into(),
370 });
371 }
372 x[i] = sum / aug[[i, i]];
373 }
374 Ok(x)
375}
376
377fn sw_inv_sb<F: Float + Send + Sync + 'static>(
381 sw: &Array2<F>,
382 sb: &Array2<F>,
383) -> Result<Array2<F>, FerroError> {
384 let n = sw.nrows();
385 let mut result = Array2::<F>::zeros((n, n));
386 for j in 0..n {
387 let col_sb = Array1::from_shape_fn(n, |i| sb[[i, j]]);
388 let col = gaussian_solve_f(n, sw, &col_sb)?;
389 for i in 0..n {
390 result[[i, j]] = col[i];
391 }
392 }
393 Ok(result)
394}
395
396impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for LDA<F> {
401 type Fitted = FittedLDA<F>;
402 type Error = FerroError;
403
404 fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedLDA<F>, FerroError> {
415 let (n_samples, n_features) = x.dim();
416
417 if n_samples != y.len() {
418 return Err(FerroError::ShapeMismatch {
419 expected: vec![n_samples],
420 actual: vec![y.len()],
421 context: "LDA: y length must match number of rows in X".into(),
422 });
423 }
424 if n_samples < 2 {
425 return Err(FerroError::InsufficientSamples {
426 required: 2,
427 actual: n_samples,
428 context: "LDA requires at least 2 samples".into(),
429 });
430 }
431
432 let mut classes: Vec<usize> = y.to_vec();
434 classes.sort_unstable();
435 classes.dedup();
436 let n_classes = classes.len();
437
438 if n_classes < 2 {
439 return Err(FerroError::InsufficientSamples {
440 required: 2,
441 actual: n_classes,
442 context: "LDA requires at least 2 distinct classes".into(),
443 });
444 }
445
446 let max_components = (n_classes - 1).min(n_features);
448 let n_comp = match self.n_components {
449 None => max_components,
450 Some(0) => {
451 return Err(FerroError::InvalidParameter {
452 name: "n_components".into(),
453 reason: "must be at least 1".into(),
454 });
455 }
456 Some(k) if k > max_components => {
457 return Err(FerroError::InvalidParameter {
458 name: "n_components".into(),
459 reason: format!(
460 "n_components ({k}) exceeds max allowed ({max_components} = min(n_classes-1, n_features))"
461 ),
462 });
463 }
464 Some(k) => k,
465 };
466
467 let n_f = F::from(n_samples).unwrap();
469 let mut overall_mean = Array1::<F>::zeros(n_features);
470 for j in 0..n_features {
471 let col = x.column(j);
472 let s = col.iter().copied().fold(F::zero(), |a, b| a + b);
473 overall_mean[j] = s / n_f;
474 }
475
476 let mut class_means: Vec<Array1<F>> = Vec::with_capacity(n_classes);
478 let mut class_counts: Vec<usize> = Vec::with_capacity(n_classes);
479 for &cls in &classes {
480 let mut mean = Array1::<F>::zeros(n_features);
481 let mut cnt = 0usize;
482 for (i, &label) in y.iter().enumerate() {
483 if label == cls {
484 for j in 0..n_features {
485 mean[j] = mean[j] + x[[i, j]];
486 }
487 cnt += 1;
488 }
489 }
490 if cnt == 0 {
491 return Err(FerroError::InsufficientSamples {
492 required: 1,
493 actual: 0,
494 context: format!("LDA: class {cls} has no samples"),
495 });
496 }
497 let cnt_f = F::from(cnt).unwrap();
498 mean.mapv_inplace(|v| v / cnt_f);
499 class_means.push(mean);
500 class_counts.push(cnt);
501 }
502
503 let mut sw = Array2::<F>::zeros((n_features, n_features));
505 for (ci, &cls) in classes.iter().enumerate() {
506 let mu_c = &class_means[ci];
507 for (i, &label) in y.iter().enumerate() {
508 if label == cls {
509 let diff: Vec<F> = (0..n_features).map(|j| x[[i, j]] - mu_c[j]).collect();
511 for r in 0..n_features {
512 for c in 0..n_features {
513 sw[[r, c]] = sw[[r, c]] + diff[r] * diff[c];
514 }
515 }
516 }
517 }
518 }
519
520 let reg = F::from(1e-6).unwrap();
522 for i in 0..n_features {
523 sw[[i, i]] = sw[[i, i]] + reg;
524 }
525
526 let mut sb = Array2::<F>::zeros((n_features, n_features));
528 for (ci, &nc) in class_counts.iter().enumerate() {
529 let nc_f = F::from(nc).unwrap();
530 let diff: Vec<F> = (0..n_features)
531 .map(|j| class_means[ci][j] - overall_mean[j])
532 .collect();
533 for r in 0..n_features {
534 for c in 0..n_features {
535 sb[[r, c]] = sb[[r, c]] + nc_f * diff[r] * diff[c];
536 }
537 }
538 }
539
540 let m = sw_inv_sb(&sw, &sb)?;
542 let max_jacobi = n_features * n_features * 100 + 1000;
543 let (eigenvalues, eigenvectors) = jacobi_eigen_f(&m, max_jacobi)?;
544
545 let mut indices: Vec<usize> = (0..n_features).collect();
547 indices.sort_by(|&a, &b| {
548 eigenvalues[b]
549 .partial_cmp(&eigenvalues[a])
550 .unwrap_or(std::cmp::Ordering::Equal)
551 });
552
553 let total_ev: F = eigenvalues
555 .iter()
556 .copied()
557 .map(|v| if v > F::zero() { v } else { F::zero() })
558 .fold(F::zero(), |a, b| a + b);
559
560 let mut scalings = Array2::<F>::zeros((n_features, n_comp));
562 let mut explained_variance_ratio = Array1::<F>::zeros(n_comp);
563 for (k, &idx) in indices.iter().take(n_comp).enumerate() {
564 let ev = eigenvalues[idx];
565 let ev_clamped = if ev > F::zero() { ev } else { F::zero() };
566 explained_variance_ratio[k] = if total_ev > F::zero() {
567 ev_clamped / total_ev
568 } else {
569 F::zero()
570 };
571 for j in 0..n_features {
572 scalings[[j, k]] = eigenvectors[[j, idx]];
573 }
574 }
575
576 let mut means = Array2::<F>::zeros((n_classes, n_comp));
579 for ci in 0..n_classes {
580 let mu_row = class_means[ci].view();
581 for k in 0..n_comp {
582 let mut dot = F::zero();
583 for j in 0..n_features {
584 dot = dot + mu_row[j] * scalings[[j, k]];
585 }
586 means[[ci, k]] = dot;
587 }
588 }
589
590 Ok(FittedLDA {
591 scalings,
592 means,
593 explained_variance_ratio,
594 classes,
595 n_features,
596 })
597 }
598}
599
600impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedLDA<F> {
605 type Output = Array2<F>;
606 type Error = FerroError;
607
608 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
615 if x.ncols() != self.n_features {
616 return Err(FerroError::ShapeMismatch {
617 expected: vec![x.nrows(), self.n_features],
618 actual: vec![x.nrows(), x.ncols()],
619 context: "FittedLDA::transform".into(),
620 });
621 }
622 Ok(x.dot(&self.scalings))
623 }
624}
625
626impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedLDA<F> {
631 type Output = Array1<usize>;
632 type Error = FerroError;
633
634 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
641 let projected = self.transform(x)?;
642 let n_samples = projected.nrows();
643 let n_comp = projected.ncols();
644 let n_classes = self.classes.len();
645
646 let mut predictions = Array1::<usize>::zeros(n_samples);
647 for i in 0..n_samples {
648 let mut best_class = 0usize;
649 let mut best_dist = F::infinity();
650 for ci in 0..n_classes {
651 let mut dist = F::zero();
652 for k in 0..n_comp {
653 let d = projected[[i, k]] - self.means[[ci, k]];
654 dist = dist + d * d;
655 }
656 if dist < best_dist {
657 best_dist = dist;
658 best_class = ci;
659 }
660 }
661 predictions[i] = self.classes[best_class];
662 }
663 Ok(predictions)
664 }
665}
666
667impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for LDA<F> {
672 fn fit_pipeline(
678 &self,
679 x: &Array2<F>,
680 y: &Array1<F>,
681 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
682 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
683 let fitted = self.fit(x, &y_usize)?;
684 Ok(Box::new(FittedLDAPipeline(fitted)))
685 }
686}
687
688struct FittedLDAPipeline<F>(FittedLDA<F>);
690
691impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedLDAPipeline<F> {
692 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
694 let preds = self.0.predict(x)?;
695 Ok(preds.mapv(|v| NumCast::from(v).unwrap_or_else(F::nan)))
696 }
697}
698
699#[cfg(test)]
704mod tests {
705 use super::*;
706 use approx::assert_abs_diff_eq;
707 use ndarray::{Array2, array};
708
709 fn linearly_separable_2d() -> (Array2<f64>, Array1<usize>) {
714 let x = Array2::from_shape_vec(
716 (8, 2),
717 vec![
718 1.0, 1.0, 1.5, 1.2, 0.8, 0.9, 1.1, 1.3, 6.0, 6.0, 6.2, 5.8, 5.9, 6.1, 6.3, 5.7, ],
721 )
722 .unwrap();
723 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
724 (x, y)
725 }
726
727 fn three_class_data() -> (Array2<f64>, Array1<usize>) {
728 let x = Array2::from_shape_vec(
729 (9, 2),
730 vec![
731 0.0, 0.0, 0.5, 0.1, 0.1, 0.5, 5.0, 0.0, 5.2, 0.3, 4.8, 0.1, 0.0, 5.0, 0.1, 5.2, 0.3, 4.8, ],
735 )
736 .unwrap();
737 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
738 (x, y)
739 }
740
741 #[test]
744 fn test_lda_fit_returns_fitted() {
745 let (x, y) = linearly_separable_2d();
746 let lda = LDA::<f64>::new(Some(1));
747 let fitted = lda.fit(&x, &y).unwrap();
748 assert_eq!(fitted.scalings().ncols(), 1);
749 assert_eq!(fitted.scalings().nrows(), 2);
750 }
751
752 #[test]
753 fn test_lda_default_n_components() {
754 let (x, y) = linearly_separable_2d();
756 let lda = LDA::<f64>::default();
757 let fitted = lda.fit(&x, &y).unwrap();
758 assert_eq!(fitted.scalings().ncols(), 1);
759 }
760
761 #[test]
762 fn test_lda_transform_shape() {
763 let (x, y) = linearly_separable_2d();
764 let lda = LDA::<f64>::new(Some(1));
765 let fitted = lda.fit(&x, &y).unwrap();
766 let proj = fitted.transform(&x).unwrap();
767 assert_eq!(proj.dim(), (8, 1));
768 }
769
770 #[test]
771 fn test_lda_predict_accuracy_binary() {
772 let (x, y) = linearly_separable_2d();
773 let lda = LDA::<f64>::new(Some(1));
774 let fitted = lda.fit(&x, &y).unwrap();
775 let preds = fitted.predict(&x).unwrap();
776 let correct = preds.iter().zip(y.iter()).filter(|(p, a)| *p == *a).count();
777 assert_eq!(correct, 8, "All 8 samples should be classified correctly");
778 }
779
780 #[test]
781 fn test_lda_predict_three_classes() {
782 let (x, y) = three_class_data();
783 let lda = LDA::<f64>::new(Some(2));
784 let fitted = lda.fit(&x, &y).unwrap();
785 let preds = fitted.predict(&x).unwrap();
786 let correct = preds.iter().zip(y.iter()).filter(|(p, a)| *p == *a).count();
787 assert!(correct >= 7, "Expected at least 7/9 correct, got {correct}");
788 }
789
790 #[test]
791 fn test_lda_explained_variance_ratio_positive() {
792 let (x, y) = linearly_separable_2d();
793 let lda = LDA::<f64>::new(Some(1));
794 let fitted = lda.fit(&x, &y).unwrap();
795 for &v in fitted.explained_variance_ratio() {
796 assert!(v >= 0.0);
797 }
798 }
799
800 #[test]
801 fn test_lda_explained_variance_ratio_le_1() {
802 let (x, y) = three_class_data();
803 let lda = LDA::<f64>::new(Some(2));
804 let fitted = lda.fit(&x, &y).unwrap();
805 let total: f64 = fitted.explained_variance_ratio().iter().sum();
806 assert!(total <= 1.0 + 1e-9, "total={total}");
807 }
808
809 #[test]
810 fn test_lda_classes_accessor() {
811 let (x, y) = linearly_separable_2d();
812 let lda = LDA::<f64>::new(Some(1));
813 let fitted = lda.fit(&x, &y).unwrap();
814 assert_eq!(fitted.classes(), &[0usize, 1]);
815 }
816
817 #[test]
818 fn test_lda_means_shape() {
819 let (x, y) = three_class_data();
820 let lda = LDA::<f64>::new(Some(2));
821 let fitted = lda.fit(&x, &y).unwrap();
822 assert_eq!(fitted.means().dim(), (3, 2));
823 }
824
825 #[test]
826 fn test_lda_transform_shape_mismatch() {
827 let (x, y) = linearly_separable_2d();
828 let lda = LDA::<f64>::new(Some(1));
829 let fitted = lda.fit(&x, &y).unwrap();
830 let x_bad = Array2::<f64>::zeros((3, 5));
831 assert!(fitted.transform(&x_bad).is_err());
832 }
833
834 #[test]
835 fn test_lda_predict_shape_mismatch() {
836 let (x, y) = linearly_separable_2d();
837 let lda = LDA::<f64>::new(Some(1));
838 let fitted = lda.fit(&x, &y).unwrap();
839 let x_bad = Array2::<f64>::zeros((3, 5));
840 assert!(fitted.predict(&x_bad).is_err());
841 }
842
843 #[test]
844 fn test_lda_error_zero_n_components() {
845 let (x, y) = linearly_separable_2d();
846 let lda = LDA::<f64>::new(Some(0));
847 assert!(lda.fit(&x, &y).is_err());
848 }
849
850 #[test]
851 fn test_lda_error_n_components_too_large() {
852 let (x, y) = linearly_separable_2d(); let lda = LDA::<f64>::new(Some(5));
854 assert!(lda.fit(&x, &y).is_err());
855 }
856
857 #[test]
858 fn test_lda_error_single_class() {
859 let x =
860 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
861 let y = array![0usize, 0, 0, 0];
862 let lda = LDA::<f64>::new(None);
863 assert!(lda.fit(&x, &y).is_err());
864 }
865
866 #[test]
867 fn test_lda_error_shape_mismatch_fit() {
868 let x = Array2::<f64>::zeros((4, 2));
869 let y = array![0usize, 1]; let lda = LDA::<f64>::new(None);
871 assert!(lda.fit(&x, &y).is_err());
872 }
873
874 #[test]
875 fn test_lda_error_insufficient_samples() {
876 let x = Array2::<f64>::zeros((1, 2));
877 let y = array![0usize];
878 let lda = LDA::<f64>::new(None);
879 assert!(lda.fit(&x, &y).is_err());
880 }
881
882 #[test]
883 fn test_lda_scalings_accessor() {
884 let (x, y) = linearly_separable_2d();
885 let lda = LDA::<f64>::new(Some(1));
886 let fitted = lda.fit(&x, &y).unwrap();
887 assert_eq!(fitted.scalings().dim(), (2, 1));
888 }
889
890 #[test]
891 fn test_lda_pipeline_estimator() {
892 use ferrolearn_core::pipeline::PipelineEstimator;
893
894 let (x, y_usize) = linearly_separable_2d();
895 let y_f64 = y_usize.mapv(|v| v as f64);
896 let lda = LDA::<f64>::new(Some(1));
897 let fitted = lda.fit_pipeline(&x, &y_f64).unwrap();
898 let preds = fitted.predict_pipeline(&x).unwrap();
899 assert_eq!(preds.len(), 8);
900 }
901
902 #[test]
903 fn test_lda_n_components_getter() {
904 let lda = LDA::<f64>::new(Some(2));
905 assert_eq!(lda.n_components(), Some(2));
906 let lda_none = LDA::<f64>::new(None);
907 assert_eq!(lda_none.n_components(), None);
908 }
909
910 #[test]
911 fn test_lda_transform_then_predict_consistent() {
912 let (x, y) = linearly_separable_2d();
913 let lda = LDA::<f64>::new(Some(1));
914 let fitted = lda.fit(&x, &y).unwrap();
915 let projected = fitted.transform(&x).unwrap();
917 let preds_predict = fitted.predict(&x).unwrap();
918 let n_samples = projected.nrows();
919 let n_comp = projected.ncols();
920 let n_classes = fitted.classes().len();
921 for i in 0..n_samples {
922 let mut best = 0;
923 let mut best_d = f64::INFINITY;
924 for ci in 0..n_classes {
925 let mut d = 0.0;
926 for k in 0..n_comp {
927 let diff = projected[[i, k]] - fitted.means()[[ci, k]];
928 d += diff * diff;
929 }
930 if d < best_d {
931 best_d = d;
932 best = ci;
933 }
934 }
935 assert_eq!(preds_predict[i], fitted.classes()[best]);
936 }
937 }
938
939 #[test]
940 fn test_lda_projected_class_separation() {
941 let (x, y) = linearly_separable_2d();
942 let lda = LDA::<f64>::new(Some(1));
943 let fitted = lda.fit(&x, &y).unwrap();
944 let projected = fitted.transform(&x).unwrap();
945
946 let mean0: f64 = projected
948 .rows()
949 .into_iter()
950 .zip(y.iter())
951 .filter(|&(_, label)| *label == 0)
952 .map(|(row, _)| row[0])
953 .sum::<f64>()
954 / 4.0;
955 let mean1: f64 = projected
956 .rows()
957 .into_iter()
958 .zip(y.iter())
959 .filter(|&(_, label)| *label == 1)
960 .map(|(row, _)| row[0])
961 .sum::<f64>()
962 / 4.0;
963
964 assert!(
965 (mean0 - mean1).abs() > 0.5,
966 "Projected means should differ, got {mean0} vs {mean1}"
967 );
968 }
969
970 #[test]
971 fn test_lda_transform_known_data() {
972 let x = Array2::from_shape_vec((4, 1), vec![-2.0, -1.0, 1.0, 2.0]).unwrap();
975 let y = array![0usize, 0, 1, 1];
976 let lda = LDA::<f64>::new(Some(1));
977 let fitted = lda.fit(&x, &y).unwrap();
978 let proj = fitted.transform(&x).unwrap();
979 let sign0 = proj[[0, 0]].signum();
981 let sign1 = proj[[2, 0]].signum();
982 assert_ne!(
984 sign0 as i32, sign1 as i32,
985 "Classes should be on opposite sides"
986 );
987 }
988
989 #[test]
990 fn test_lda_abs_diff_eq_means_dimensions() {
991 let (x, y) = linearly_separable_2d();
992 let lda = LDA::<f64>::new(Some(1));
993 let fitted = lda.fit(&x, &y).unwrap();
994 assert_eq!(fitted.means().ncols(), 1);
996 let m0 = fitted.means()[[0, 0]];
997 let m1 = fitted.means()[[1, 0]];
998 assert!((m0 - m1).abs() > 0.5, "m0={m0}, m1={m1}");
1000 assert_abs_diff_eq!(0.0_f64, 0.0_f64); }
1002}