Skip to main content

ferrolearn_linear/
lda.rs

1//! Linear Discriminant Analysis (LDA).
2//!
3//! LDA is both a supervised dimensionality reduction technique and a
4//! linear classifier. It finds the directions that maximise the separation
5//! between classes while minimising within-class scatter.
6//!
7//! # Algorithm
8//!
9//! 1. Compute class means `μ_c` and the overall mean `μ`.
10//! 2. Compute the within-class scatter matrix
11//!    `Sw = Σ_c Σ_{x ∈ c} (x - μ_c)(x - μ_c)^T`.
12//! 3. Compute the between-class scatter matrix
13//!    `Sb = Σ_c n_c (μ_c - μ)(μ_c - μ)^T`.
14//! 4. Solve the generalised eigenvalue problem `Sw⁻¹ Sb v = λ v`.
15//! 5. Project data onto the top-`k` eigenvectors.
16//!
17//! The number of discriminant directions is at most `min(n_classes - 1, n_features)`.
18//!
19//! # Examples
20//!
21//! ```
22//! use ferrolearn_linear::lda::LDA;
23//! use ferrolearn_core::{Fit, Predict};
24//! use ndarray::{array, Array1, Array2};
25//!
26//! let lda = LDA::new(Some(1));
27//! let x = Array2::from_shape_vec(
28//!     (6, 2),
29//!     vec![1.0, 1.0, 1.5, 1.2, 1.2, 0.8, 5.0, 5.0, 5.5, 4.8, 4.8, 5.2],
30//! ).unwrap();
31//! let y = array![0usize, 0, 0, 1, 1, 1];
32//! let fitted = lda.fit(&x, &y).unwrap();
33//! let preds = fitted.predict(&x).unwrap();
34//! assert_eq!(preds.len(), 6);
35//! ```
36
37use ferrolearn_core::error::FerroError;
38use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
39use ferrolearn_core::traits::{Fit, Predict, Transform};
40use ndarray::{Array1, Array2};
41use num_traits::Float;
42
43// ---------------------------------------------------------------------------
44// LDA (unfitted)
45// ---------------------------------------------------------------------------
46
47/// Linear Discriminant Analysis configuration.
48///
49/// Holds hyperparameters. Calling [`Fit::fit`] computes the discriminant
50/// directions and returns a [`FittedLDA`].
51///
52/// # Type Parameters
53///
54/// - `F`: The floating-point scalar type (`f32` or `f64`).
55#[derive(Debug, Clone)]
56pub struct LDA<F> {
57    /// Number of discriminant components to retain.
58    ///
59    /// If `None`, defaults to `min(n_classes - 1, n_features)` at fit time.
60    n_components: Option<usize>,
61    _marker: std::marker::PhantomData<F>,
62}
63
64impl<F: Float + Send + Sync + 'static> LDA<F> {
65    /// Create a new `LDA`.
66    ///
67    /// - `n_components`: number of discriminant directions to retain.
68    ///   Pass `None` to use `min(n_classes - 1, n_features)`.
69    #[must_use]
70    pub fn new(n_components: Option<usize>) -> Self {
71        Self {
72            n_components,
73            _marker: std::marker::PhantomData,
74        }
75    }
76
77    /// Return the configured number of components (may be `None`).
78    #[must_use]
79    pub fn n_components(&self) -> Option<usize> {
80        self.n_components
81    }
82}
83
84impl<F: Float + Send + Sync + 'static> Default for LDA<F> {
85    fn default() -> Self {
86        Self::new(None)
87    }
88}
89
90// ---------------------------------------------------------------------------
91// FittedLDA
92// ---------------------------------------------------------------------------
93
94/// A fitted LDA model.
95///
96/// Created by calling [`Fit::fit`] on an [`LDA`]. Implements:
97/// - [`Transform<Array2<F>>`] — project data onto discriminant axes.
98/// - [`Predict<Array2<F>>`] — classify by nearest centroid in projected space.
99#[derive(Debug, Clone)]
100pub struct FittedLDA<F> {
101    /// Projection matrix, shape `(n_features, n_components)`.
102    ///
103    /// New data is projected via `X @ scalings`.
104    scalings: Array2<F>,
105
106    /// Class means in the projected space, shape `(n_classes, n_components)`.
107    means: Array2<F>,
108
109    /// Ratio of explained variance per discriminant direction.
110    explained_variance_ratio: Array1<F>,
111
112    /// Class labels corresponding to rows of `means`.
113    classes: Vec<usize>,
114
115    /// Number of features seen during fitting.
116    n_features: usize,
117}
118
119impl<F: Float + Send + Sync + 'static> FittedLDA<F> {
120    /// Projection (scalings) matrix, shape `(n_features, n_components)`.
121    #[must_use]
122    pub fn scalings(&self) -> &Array2<F> {
123        &self.scalings
124    }
125
126    /// Class centroids in the projected space, shape `(n_classes, n_components)`.
127    #[must_use]
128    pub fn means(&self) -> &Array2<F> {
129        &self.means
130    }
131
132    /// Explained-variance ratio per discriminant direction.
133    #[must_use]
134    pub fn explained_variance_ratio(&self) -> &Array1<F> {
135        &self.explained_variance_ratio
136    }
137
138    /// Sorted class labels as seen during fitting.
139    #[must_use]
140    pub fn classes(&self) -> &[usize] {
141        &self.classes
142    }
143}
144
145// ---------------------------------------------------------------------------
146// Internal linear algebra helpers (generic over F)
147// ---------------------------------------------------------------------------
148
149/// Jacobi symmetric eigendecomposition.
150///
151/// Returns `(eigenvalues, eigenvectors_columns)` — column `i` is the
152/// eigenvector for `eigenvalues[i]`.  Eigenvalues are **not** sorted.
153fn jacobi_eigen_f<F: Float + Send + Sync + 'static>(
154    a: &Array2<F>,
155    max_iter: usize,
156) -> Result<(Array1<F>, Array2<F>), FerroError> {
157    let n = a.nrows();
158    let mut mat = a.to_owned();
159    let mut v = Array2::<F>::zeros((n, n));
160    for i in 0..n {
161        v[[i, i]] = F::one();
162    }
163    let tol = F::from(1e-12).unwrap_or(F::epsilon());
164
165    for _ in 0..max_iter {
166        // Find the largest off-diagonal entry.
167        let mut max_off = F::zero();
168        let mut p = 0usize;
169        let mut q = 1usize;
170        for i in 0..n {
171            for j in (i + 1)..n {
172                let val = mat[[i, j]].abs();
173                if val > max_off {
174                    max_off = val;
175                    p = i;
176                    q = j;
177                }
178            }
179        }
180        if max_off < tol {
181            let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
182            return Ok((eigenvalues, v));
183        }
184        let app = mat[[p, p]];
185        let aqq = mat[[q, q]];
186        let apq = mat[[p, q]];
187        let two = F::from(2.0).unwrap();
188        let theta = if (app - aqq).abs() < tol {
189            F::from(std::f64::consts::FRAC_PI_4).unwrap_or(F::one())
190        } else {
191            let tau = (aqq - app) / (two * apq);
192            let t = if tau >= F::zero() {
193                F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
194            } else {
195                -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
196            };
197            t.atan()
198        };
199        let c = theta.cos();
200        let s = theta.sin();
201        let mut new_mat = mat.clone();
202        for i in 0..n {
203            if i != p && i != q {
204                let mip = mat[[i, p]];
205                let miq = mat[[i, q]];
206                new_mat[[i, p]] = c * mip - s * miq;
207                new_mat[[p, i]] = new_mat[[i, p]];
208                new_mat[[i, q]] = s * mip + c * miq;
209                new_mat[[q, i]] = new_mat[[i, q]];
210            }
211        }
212        new_mat[[p, p]] = c * c * app - two * s * c * apq + s * s * aqq;
213        new_mat[[q, q]] = s * s * app + two * s * c * apq + c * c * aqq;
214        new_mat[[p, q]] = F::zero();
215        new_mat[[q, p]] = F::zero();
216        mat = new_mat;
217        for i in 0..n {
218            let vip = v[[i, p]];
219            let viq = v[[i, q]];
220            v[[i, p]] = c * vip - s * viq;
221            v[[i, q]] = s * vip + c * viq;
222        }
223    }
224    Err(FerroError::ConvergenceFailure {
225        iterations: max_iter,
226        message: "Jacobi eigendecomposition did not converge (LDA)".into(),
227    })
228}
229
230/// Gaussian elimination with partial pivoting to solve `A x = b`.
231fn gaussian_solve_f<F: Float>(
232    n: usize,
233    a: &Array2<F>,
234    b: &Array1<F>,
235) -> Result<Array1<F>, FerroError> {
236    let mut aug = Array2::<F>::zeros((n, n + 1));
237    for i in 0..n {
238        for j in 0..n {
239            aug[[i, j]] = a[[i, j]];
240        }
241        aug[[i, n]] = b[i];
242    }
243    for col in 0..n {
244        let mut max_val = aug[[col, col]].abs();
245        let mut max_row = col;
246        for row in (col + 1)..n {
247            let val = aug[[row, col]].abs();
248            if val > max_val {
249                max_val = val;
250                max_row = row;
251            }
252        }
253        if max_val < F::from(1e-12).unwrap_or(F::epsilon()) {
254            return Err(FerroError::NumericalInstability {
255                message: "singular matrix during LDA inversion".into(),
256            });
257        }
258        if max_row != col {
259            for j in 0..=n {
260                let tmp = aug[[col, j]];
261                aug[[col, j]] = aug[[max_row, j]];
262                aug[[max_row, j]] = tmp;
263            }
264        }
265        let pivot = aug[[col, col]];
266        for row in (col + 1)..n {
267            let factor = aug[[row, col]] / pivot;
268            for j in col..=n {
269                let above = aug[[col, j]];
270                aug[[row, j]] = aug[[row, j]] - factor * above;
271            }
272        }
273    }
274    let mut x = Array1::<F>::zeros(n);
275    for i in (0..n).rev() {
276        let mut sum = aug[[i, n]];
277        for j in (i + 1)..n {
278            sum = sum - aug[[i, j]] * x[j];
279        }
280        if aug[[i, i]].abs() < F::from(1e-12).unwrap_or(F::epsilon()) {
281            return Err(FerroError::NumericalInstability {
282                message: "near-zero pivot during LDA back substitution".into(),
283            });
284        }
285        x[i] = sum / aug[[i, i]];
286    }
287    Ok(x)
288}
289
290/// Compute `Sw⁻¹ @ Sb` column by column.
291///
292/// Returns the matrix `M = Sw⁻¹ Sb` of shape `(n, n)`.
293fn sw_inv_sb<F: Float + Send + Sync + 'static>(
294    sw: &Array2<F>,
295    sb: &Array2<F>,
296) -> Result<Array2<F>, FerroError> {
297    let n = sw.nrows();
298    let mut result = Array2::<F>::zeros((n, n));
299    for j in 0..n {
300        let col_sb = Array1::from_shape_fn(n, |i| sb[[i, j]]);
301        let col = gaussian_solve_f(n, sw, &col_sb)?;
302        for i in 0..n {
303            result[[i, j]] = col[i];
304        }
305    }
306    Ok(result)
307}
308
309// ---------------------------------------------------------------------------
310// Fit
311// ---------------------------------------------------------------------------
312
313impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for LDA<F> {
314    type Fitted = FittedLDA<F>;
315    type Error = FerroError;
316
317    /// Fit the LDA model.
318    ///
319    /// # Errors
320    ///
321    /// - [`FerroError::InsufficientSamples`] if fewer than 2 samples are provided.
322    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or exceeds the
323    ///   maximum allowed (`min(n_classes - 1, n_features)`).
324    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different numbers of rows.
325    /// - [`FerroError::NumericalInstability`] if `Sw` is singular.
326    /// - [`FerroError::ConvergenceFailure`] if eigendecomposition does not converge.
327    fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedLDA<F>, FerroError> {
328        let (n_samples, n_features) = x.dim();
329
330        if n_samples != y.len() {
331            return Err(FerroError::ShapeMismatch {
332                expected: vec![n_samples],
333                actual: vec![y.len()],
334                context: "LDA: y length must match number of rows in X".into(),
335            });
336        }
337        if n_samples < 2 {
338            return Err(FerroError::InsufficientSamples {
339                required: 2,
340                actual: n_samples,
341                context: "LDA requires at least 2 samples".into(),
342            });
343        }
344
345        // Gather sorted unique classes.
346        let mut classes: Vec<usize> = y.to_vec();
347        classes.sort_unstable();
348        classes.dedup();
349        let n_classes = classes.len();
350
351        if n_classes < 2 {
352            return Err(FerroError::InsufficientSamples {
353                required: 2,
354                actual: n_classes,
355                context: "LDA requires at least 2 distinct classes".into(),
356            });
357        }
358
359        // Determine effective n_components.
360        let max_components = (n_classes - 1).min(n_features);
361        let n_comp = match self.n_components {
362            None => max_components,
363            Some(0) => {
364                return Err(FerroError::InvalidParameter {
365                    name: "n_components".into(),
366                    reason: "must be at least 1".into(),
367                });
368            }
369            Some(k) if k > max_components => {
370                return Err(FerroError::InvalidParameter {
371                    name: "n_components".into(),
372                    reason: format!(
373                        "n_components ({k}) exceeds max allowed ({max_components} = min(n_classes-1, n_features))"
374                    ),
375                });
376            }
377            Some(k) => k,
378        };
379
380        // --- Step 1: compute overall mean and per-class means ----------------
381        let n_f = F::from(n_samples).unwrap();
382        let mut overall_mean = Array1::<F>::zeros(n_features);
383        for j in 0..n_features {
384            let col = x.column(j);
385            let s = col.iter().copied().fold(F::zero(), |a, b| a + b);
386            overall_mean[j] = s / n_f;
387        }
388
389        // class_means[c] = mean of samples in class c
390        let mut class_means: Vec<Array1<F>> = Vec::with_capacity(n_classes);
391        let mut class_counts: Vec<usize> = Vec::with_capacity(n_classes);
392        for &cls in &classes {
393            let mut mean = Array1::<F>::zeros(n_features);
394            let mut cnt = 0usize;
395            for (i, &label) in y.iter().enumerate() {
396                if label == cls {
397                    for j in 0..n_features {
398                        mean[j] = mean[j] + x[[i, j]];
399                    }
400                    cnt += 1;
401                }
402            }
403            if cnt == 0 {
404                return Err(FerroError::InsufficientSamples {
405                    required: 1,
406                    actual: 0,
407                    context: format!("LDA: class {cls} has no samples"),
408                });
409            }
410            let cnt_f = F::from(cnt).unwrap();
411            mean.mapv_inplace(|v| v / cnt_f);
412            class_means.push(mean);
413            class_counts.push(cnt);
414        }
415
416        // --- Step 2: within-class scatter Sw ----------------------------------
417        let mut sw = Array2::<F>::zeros((n_features, n_features));
418        for (ci, &cls) in classes.iter().enumerate() {
419            let mu_c = &class_means[ci];
420            for (i, &label) in y.iter().enumerate() {
421                if label == cls {
422                    // diff = x[i] - mu_c
423                    let diff: Vec<F> = (0..n_features).map(|j| x[[i, j]] - mu_c[j]).collect();
424                    for r in 0..n_features {
425                        for c in 0..n_features {
426                            sw[[r, c]] = sw[[r, c]] + diff[r] * diff[c];
427                        }
428                    }
429                }
430            }
431        }
432
433        // Add a small regularisation to Sw to avoid singularity.
434        let reg = F::from(1e-6).unwrap();
435        for i in 0..n_features {
436            sw[[i, i]] = sw[[i, i]] + reg;
437        }
438
439        // --- Step 3: between-class scatter Sb ---------------------------------
440        let mut sb = Array2::<F>::zeros((n_features, n_features));
441        for (ci, &nc) in class_counts.iter().enumerate() {
442            let nc_f = F::from(nc).unwrap();
443            let diff: Vec<F> = (0..n_features)
444                .map(|j| class_means[ci][j] - overall_mean[j])
445                .collect();
446            for r in 0..n_features {
447                for c in 0..n_features {
448                    sb[[r, c]] = sb[[r, c]] + nc_f * diff[r] * diff[c];
449                }
450            }
451        }
452
453        // --- Step 4: solve generalised eigenvalue problem Sw⁻¹ Sb v = λ v ----
454        let m = sw_inv_sb(&sw, &sb)?;
455        let max_jacobi = n_features * n_features * 100 + 1000;
456        let (eigenvalues, eigenvectors) = jacobi_eigen_f(&m, max_jacobi)?;
457
458        // Sort eigenvalues descending.
459        let mut indices: Vec<usize> = (0..n_features).collect();
460        indices.sort_by(|&a, &b| {
461            eigenvalues[b]
462                .partial_cmp(&eigenvalues[a])
463                .unwrap_or(std::cmp::Ordering::Equal)
464        });
465
466        // Clamp negative eigenvalues.
467        let total_ev: F = eigenvalues
468            .iter()
469            .copied()
470            .map(|v| if v > F::zero() { v } else { F::zero() })
471            .fold(F::zero(), |a, b| a + b);
472
473        // --- Step 5: build scalings matrix (n_features × n_comp) -------------
474        let mut scalings = Array2::<F>::zeros((n_features, n_comp));
475        let mut explained_variance_ratio = Array1::<F>::zeros(n_comp);
476        for (k, &idx) in indices.iter().take(n_comp).enumerate() {
477            let ev = eigenvalues[idx];
478            let ev_clamped = if ev > F::zero() { ev } else { F::zero() };
479            explained_variance_ratio[k] = if total_ev > F::zero() {
480                ev_clamped / total_ev
481            } else {
482                F::zero()
483            };
484            for j in 0..n_features {
485                scalings[[j, k]] = eigenvectors[[j, idx]];
486            }
487        }
488
489        // --- Project class means into the discriminant space -----------------
490        // means[c, k] = class_means[c] · scalings[:, k]
491        let mut means = Array2::<F>::zeros((n_classes, n_comp));
492        for ci in 0..n_classes {
493            let mu_row = class_means[ci].view();
494            for k in 0..n_comp {
495                let mut dot = F::zero();
496                for j in 0..n_features {
497                    dot = dot + mu_row[j] * scalings[[j, k]];
498                }
499                means[[ci, k]] = dot;
500            }
501        }
502
503        Ok(FittedLDA {
504            scalings,
505            means,
506            explained_variance_ratio,
507            classes,
508            n_features,
509        })
510    }
511}
512
513// ---------------------------------------------------------------------------
514// Transform
515// ---------------------------------------------------------------------------
516
517impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedLDA<F> {
518    type Output = Array2<F>;
519    type Error = FerroError;
520
521    /// Project `x` onto the discriminant axes: `X @ scalings`.
522    ///
523    /// # Errors
524    ///
525    /// Returns [`FerroError::ShapeMismatch`] if `x.ncols()` does not match the
526    /// number of features seen during fitting.
527    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
528        if x.ncols() != self.n_features {
529            return Err(FerroError::ShapeMismatch {
530                expected: vec![x.nrows(), self.n_features],
531                actual: vec![x.nrows(), x.ncols()],
532                context: "FittedLDA::transform".into(),
533            });
534        }
535        Ok(x.dot(&self.scalings))
536    }
537}
538
539// ---------------------------------------------------------------------------
540// Predict (nearest centroid in projected space)
541// ---------------------------------------------------------------------------
542
543impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedLDA<F> {
544    type Output = Array1<usize>;
545    type Error = FerroError;
546
547    /// Classify samples by nearest centroid in the projected space.
548    ///
549    /// # Errors
550    ///
551    /// Returns [`FerroError::ShapeMismatch`] if the number of features does not
552    /// match the model.
553    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
554        let projected = self.transform(x)?;
555        let n_samples = projected.nrows();
556        let n_comp = projected.ncols();
557        let n_classes = self.classes.len();
558
559        let mut predictions = Array1::<usize>::zeros(n_samples);
560        for i in 0..n_samples {
561            let mut best_class = 0usize;
562            let mut best_dist = F::infinity();
563            for ci in 0..n_classes {
564                let mut dist = F::zero();
565                for k in 0..n_comp {
566                    let d = projected[[i, k]] - self.means[[ci, k]];
567                    dist = dist + d * d;
568                }
569                if dist < best_dist {
570                    best_dist = dist;
571                    best_class = ci;
572                }
573            }
574            predictions[i] = self.classes[best_class];
575        }
576        Ok(predictions)
577    }
578}
579
580// ---------------------------------------------------------------------------
581// Pipeline integration (f64 specialisation)
582// ---------------------------------------------------------------------------
583
584impl PipelineEstimator<f64> for LDA<f64> {
585    /// Fit LDA using the pipeline interface.
586    ///
587    /// # Errors
588    ///
589    /// Propagates errors from [`Fit::fit`].
590    fn fit_pipeline(
591        &self,
592        x: &Array2<f64>,
593        y: &Array1<f64>,
594    ) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
595        let y_usize: Array1<usize> = y.mapv(|v| v as usize);
596        let fitted = self.fit(x, &y_usize)?;
597        Ok(Box::new(FittedLDAPipeline(fitted)))
598    }
599}
600
601/// Wrapper for pipeline integration that converts predictions to `f64`.
602struct FittedLDAPipeline(FittedLDA<f64>);
603
604// Safety: the inner type is Send + Sync.
605unsafe impl Send for FittedLDAPipeline {}
606unsafe impl Sync for FittedLDAPipeline {}
607
608impl FittedPipelineEstimator<f64> for FittedLDAPipeline {
609    /// Predict via the pipeline interface, returning `f64` class labels.
610    fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
611        let preds = self.0.predict(x)?;
612        Ok(preds.mapv(|v| v as f64))
613    }
614}
615
616// ---------------------------------------------------------------------------
617// Tests
618// ---------------------------------------------------------------------------
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623    use approx::assert_abs_diff_eq;
624    use ndarray::{Array2, array};
625
626    // ------------------------------------------------------------------
627    // Helpers
628    // ------------------------------------------------------------------
629
630    fn linearly_separable_2d() -> (Array2<f64>, Array1<usize>) {
631        // Two well-separated Gaussian clusters.
632        let x = Array2::from_shape_vec(
633            (8, 2),
634            vec![
635                1.0, 1.0, 1.5, 1.2, 0.8, 0.9, 1.1, 1.3, // class 0
636                6.0, 6.0, 6.2, 5.8, 5.9, 6.1, 6.3, 5.7, // class 1
637            ],
638        )
639        .unwrap();
640        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
641        (x, y)
642    }
643
644    fn three_class_data() -> (Array2<f64>, Array1<usize>) {
645        let x = Array2::from_shape_vec(
646            (9, 2),
647            vec![
648                0.0, 0.0, 0.5, 0.1, 0.1, 0.5, // class 0
649                5.0, 0.0, 5.2, 0.3, 4.8, 0.1, // class 1
650                0.0, 5.0, 0.1, 5.2, 0.3, 4.8, // class 2
651            ],
652        )
653        .unwrap();
654        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
655        (x, y)
656    }
657
658    // ------------------------------------------------------------------
659
660    #[test]
661    fn test_lda_fit_returns_fitted() {
662        let (x, y) = linearly_separable_2d();
663        let lda = LDA::<f64>::new(Some(1));
664        let fitted = lda.fit(&x, &y).unwrap();
665        assert_eq!(fitted.scalings().ncols(), 1);
666        assert_eq!(fitted.scalings().nrows(), 2);
667    }
668
669    #[test]
670    fn test_lda_default_n_components() {
671        // With 2 classes the default n_components = min(1, n_features) = 1.
672        let (x, y) = linearly_separable_2d();
673        let lda = LDA::<f64>::default();
674        let fitted = lda.fit(&x, &y).unwrap();
675        assert_eq!(fitted.scalings().ncols(), 1);
676    }
677
678    #[test]
679    fn test_lda_transform_shape() {
680        let (x, y) = linearly_separable_2d();
681        let lda = LDA::<f64>::new(Some(1));
682        let fitted = lda.fit(&x, &y).unwrap();
683        let proj = fitted.transform(&x).unwrap();
684        assert_eq!(proj.dim(), (8, 1));
685    }
686
687    #[test]
688    fn test_lda_predict_accuracy_binary() {
689        let (x, y) = linearly_separable_2d();
690        let lda = LDA::<f64>::new(Some(1));
691        let fitted = lda.fit(&x, &y).unwrap();
692        let preds = fitted.predict(&x).unwrap();
693        let correct = preds.iter().zip(y.iter()).filter(|(p, a)| *p == *a).count();
694        assert_eq!(correct, 8, "All 8 samples should be classified correctly");
695    }
696
697    #[test]
698    fn test_lda_predict_three_classes() {
699        let (x, y) = three_class_data();
700        let lda = LDA::<f64>::new(Some(2));
701        let fitted = lda.fit(&x, &y).unwrap();
702        let preds = fitted.predict(&x).unwrap();
703        let correct = preds.iter().zip(y.iter()).filter(|(p, a)| *p == *a).count();
704        assert!(correct >= 7, "Expected at least 7/9 correct, got {correct}");
705    }
706
707    #[test]
708    fn test_lda_explained_variance_ratio_positive() {
709        let (x, y) = linearly_separable_2d();
710        let lda = LDA::<f64>::new(Some(1));
711        let fitted = lda.fit(&x, &y).unwrap();
712        for &v in fitted.explained_variance_ratio().iter() {
713            assert!(v >= 0.0);
714        }
715    }
716
717    #[test]
718    fn test_lda_explained_variance_ratio_le_1() {
719        let (x, y) = three_class_data();
720        let lda = LDA::<f64>::new(Some(2));
721        let fitted = lda.fit(&x, &y).unwrap();
722        let total: f64 = fitted.explained_variance_ratio().iter().sum();
723        assert!(total <= 1.0 + 1e-9, "total={total}");
724    }
725
726    #[test]
727    fn test_lda_classes_accessor() {
728        let (x, y) = linearly_separable_2d();
729        let lda = LDA::<f64>::new(Some(1));
730        let fitted = lda.fit(&x, &y).unwrap();
731        assert_eq!(fitted.classes(), &[0usize, 1]);
732    }
733
734    #[test]
735    fn test_lda_means_shape() {
736        let (x, y) = three_class_data();
737        let lda = LDA::<f64>::new(Some(2));
738        let fitted = lda.fit(&x, &y).unwrap();
739        assert_eq!(fitted.means().dim(), (3, 2));
740    }
741
742    #[test]
743    fn test_lda_transform_shape_mismatch() {
744        let (x, y) = linearly_separable_2d();
745        let lda = LDA::<f64>::new(Some(1));
746        let fitted = lda.fit(&x, &y).unwrap();
747        let x_bad = Array2::<f64>::zeros((3, 5));
748        assert!(fitted.transform(&x_bad).is_err());
749    }
750
751    #[test]
752    fn test_lda_predict_shape_mismatch() {
753        let (x, y) = linearly_separable_2d();
754        let lda = LDA::<f64>::new(Some(1));
755        let fitted = lda.fit(&x, &y).unwrap();
756        let x_bad = Array2::<f64>::zeros((3, 5));
757        assert!(fitted.predict(&x_bad).is_err());
758    }
759
760    #[test]
761    fn test_lda_error_zero_n_components() {
762        let (x, y) = linearly_separable_2d();
763        let lda = LDA::<f64>::new(Some(0));
764        assert!(lda.fit(&x, &y).is_err());
765    }
766
767    #[test]
768    fn test_lda_error_n_components_too_large() {
769        let (x, y) = linearly_separable_2d(); // 2 classes → max 1 component
770        let lda = LDA::<f64>::new(Some(5));
771        assert!(lda.fit(&x, &y).is_err());
772    }
773
774    #[test]
775    fn test_lda_error_single_class() {
776        let x =
777            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
778        let y = array![0usize, 0, 0, 0];
779        let lda = LDA::<f64>::new(None);
780        assert!(lda.fit(&x, &y).is_err());
781    }
782
783    #[test]
784    fn test_lda_error_shape_mismatch_fit() {
785        let x = Array2::<f64>::zeros((4, 2));
786        let y = array![0usize, 1]; // wrong length
787        let lda = LDA::<f64>::new(None);
788        assert!(lda.fit(&x, &y).is_err());
789    }
790
791    #[test]
792    fn test_lda_error_insufficient_samples() {
793        let x = Array2::<f64>::zeros((1, 2));
794        let y = array![0usize];
795        let lda = LDA::<f64>::new(None);
796        assert!(lda.fit(&x, &y).is_err());
797    }
798
799    #[test]
800    fn test_lda_scalings_accessor() {
801        let (x, y) = linearly_separable_2d();
802        let lda = LDA::<f64>::new(Some(1));
803        let fitted = lda.fit(&x, &y).unwrap();
804        assert_eq!(fitted.scalings().dim(), (2, 1));
805    }
806
807    #[test]
808    fn test_lda_pipeline_estimator() {
809        use ferrolearn_core::pipeline::PipelineEstimator;
810
811        let (x, y_usize) = linearly_separable_2d();
812        let y_f64 = y_usize.mapv(|v| v as f64);
813        let lda = LDA::<f64>::new(Some(1));
814        let fitted = lda.fit_pipeline(&x, &y_f64).unwrap();
815        let preds = fitted.predict_pipeline(&x).unwrap();
816        assert_eq!(preds.len(), 8);
817    }
818
819    #[test]
820    fn test_lda_n_components_getter() {
821        let lda = LDA::<f64>::new(Some(2));
822        assert_eq!(lda.n_components(), Some(2));
823        let lda_none = LDA::<f64>::new(None);
824        assert_eq!(lda_none.n_components(), None);
825    }
826
827    #[test]
828    fn test_lda_transform_then_predict_consistent() {
829        let (x, y) = linearly_separable_2d();
830        let lda = LDA::<f64>::new(Some(1));
831        let fitted = lda.fit(&x, &y).unwrap();
832        // Manually compute nearest-centroid prediction from transform output.
833        let projected = fitted.transform(&x).unwrap();
834        let preds_predict = fitted.predict(&x).unwrap();
835        let n_samples = projected.nrows();
836        let n_comp = projected.ncols();
837        let n_classes = fitted.classes().len();
838        for i in 0..n_samples {
839            let mut best = 0;
840            let mut best_d = f64::INFINITY;
841            for ci in 0..n_classes {
842                let mut d = 0.0;
843                for k in 0..n_comp {
844                    let diff = projected[[i, k]] - fitted.means()[[ci, k]];
845                    d += diff * diff;
846                }
847                if d < best_d {
848                    best_d = d;
849                    best = ci;
850                }
851            }
852            assert_eq!(preds_predict[i], fitted.classes()[best]);
853        }
854    }
855
856    #[test]
857    fn test_lda_projected_class_separation() {
858        let (x, y) = linearly_separable_2d();
859        let lda = LDA::<f64>::new(Some(1));
860        let fitted = lda.fit(&x, &y).unwrap();
861        let projected = fitted.transform(&x).unwrap();
862
863        // Means of class 0 and class 1 in projected space should be far apart.
864        let mean0: f64 = projected
865            .rows()
866            .into_iter()
867            .zip(y.iter())
868            .filter(|&(_, label)| *label == 0)
869            .map(|(row, _)| row[0])
870            .sum::<f64>()
871            / 4.0;
872        let mean1: f64 = projected
873            .rows()
874            .into_iter()
875            .zip(y.iter())
876            .filter(|&(_, label)| *label == 1)
877            .map(|(row, _)| row[0])
878            .sum::<f64>()
879            / 4.0;
880
881        assert!(
882            (mean0 - mean1).abs() > 0.5,
883            "Projected means should differ, got {mean0} vs {mean1}"
884        );
885    }
886
887    #[test]
888    fn test_lda_transform_known_data() {
889        // With perfectly separated data the transform should yield two clearly
890        // distinct groups.
891        let x = Array2::from_shape_vec((4, 1), vec![-2.0, -1.0, 1.0, 2.0]).unwrap();
892        let y = array![0usize, 0, 1, 1];
893        let lda = LDA::<f64>::new(Some(1));
894        let fitted = lda.fit(&x, &y).unwrap();
895        let proj = fitted.transform(&x).unwrap();
896        // The first two samples should project to one side, the other two to the other side.
897        let sign0 = proj[[0, 0]].signum();
898        let sign1 = proj[[2, 0]].signum();
899        // They should be on opposite sides of the origin (or at least the split is correct).
900        assert_ne!(
901            sign0 as i32, sign1 as i32,
902            "Classes should be on opposite sides"
903        );
904    }
905
906    #[test]
907    fn test_lda_abs_diff_eq_means_dimensions() {
908        let (x, y) = linearly_separable_2d();
909        let lda = LDA::<f64>::new(Some(1));
910        let fitted = lda.fit(&x, &y).unwrap();
911        // Each class mean in projected space should be a 1-component vector.
912        assert_eq!(fitted.means().ncols(), 1);
913        let m0 = fitted.means()[[0, 0]];
914        let m1 = fitted.means()[[1, 0]];
915        // For well-separated data the projected means should differ by > 1.0.
916        assert!((m0 - m1).abs() > 0.5, "m0={m0}, m1={m1}");
917        let _ = assert_abs_diff_eq!(0.0_f64, 0.0_f64); // use the import
918    }
919}