Skip to main content

ferrolearn_linear/
lda.rs

1//! Linear Discriminant Analysis (LDA).
2//!
3//! LDA is both a supervised dimensionality reduction technique and a
4//! linear classifier. It finds the directions that maximise the separation
5//! between classes while minimising within-class scatter.
6//!
7//! # Algorithm
8//!
9//! 1. Compute class means `μ_c` and the overall mean `μ`.
10//! 2. Compute the within-class scatter matrix
11//!    `Sw = Σ_c Σ_{x ∈ c} (x - μ_c)(x - μ_c)^T`.
12//! 3. Compute the between-class scatter matrix
13//!    `Sb = Σ_c n_c (μ_c - μ)(μ_c - μ)^T`.
14//! 4. Solve the generalised eigenvalue problem `Sw⁻¹ Sb v = λ v`.
15//! 5. Project data onto the top-`k` eigenvectors.
16//!
17//! The number of discriminant directions is at most `min(n_classes - 1, n_features)`.
18//!
19//! # Examples
20//!
21//! ```
22//! use ferrolearn_linear::lda::LDA;
23//! use ferrolearn_core::{Fit, Predict};
24//! use ndarray::{array, Array1, Array2};
25//!
26//! let lda = LDA::new(Some(1));
27//! let x = Array2::from_shape_vec(
28//!     (6, 2),
29//!     vec![1.0, 1.0, 1.5, 1.2, 1.2, 0.8, 5.0, 5.0, 5.5, 4.8, 4.8, 5.2],
30//! ).unwrap();
31//! let y = array![0usize, 0, 0, 1, 1, 1];
32//! let fitted = lda.fit(&x, &y).unwrap();
33//! let preds = fitted.predict(&x).unwrap();
34//! assert_eq!(preds.len(), 6);
35//! ```
36
37use ferrolearn_core::error::FerroError;
38use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
39use ferrolearn_core::traits::{Fit, Predict, Transform};
40use ndarray::{Array1, Array2};
41use num_traits::{Float, NumCast};
42
43// ---------------------------------------------------------------------------
44// LDA (unfitted)
45// ---------------------------------------------------------------------------
46
47/// Linear Discriminant Analysis configuration.
48///
49/// Holds hyperparameters. Calling [`Fit::fit`] computes the discriminant
50/// directions and returns a [`FittedLDA`].
51///
52/// # Type Parameters
53///
54/// - `F`: The floating-point scalar type (`f32` or `f64`).
55#[derive(Debug, Clone)]
56pub struct LDA<F> {
57    /// Number of discriminant components to retain.
58    ///
59    /// If `None`, defaults to `min(n_classes - 1, n_features)` at fit time.
60    n_components: Option<usize>,
61    _marker: std::marker::PhantomData<F>,
62}
63
64impl<F: Float + Send + Sync + 'static> LDA<F> {
65    /// Create a new `LDA`.
66    ///
67    /// - `n_components`: number of discriminant directions to retain.
68    ///   Pass `None` to use `min(n_classes - 1, n_features)`.
69    #[must_use]
70    pub fn new(n_components: Option<usize>) -> Self {
71        Self {
72            n_components,
73            _marker: std::marker::PhantomData,
74        }
75    }
76
77    /// Return the configured number of components (may be `None`).
78    #[must_use]
79    pub fn n_components(&self) -> Option<usize> {
80        self.n_components
81    }
82}
83
84impl<F: Float + Send + Sync + 'static> Default for LDA<F> {
85    fn default() -> Self {
86        Self::new(None)
87    }
88}
89
90// ---------------------------------------------------------------------------
91// FittedLDA
92// ---------------------------------------------------------------------------
93
94/// A fitted LDA model.
95///
96/// Created by calling [`Fit::fit`] on an [`LDA`]. Implements:
97/// - [`Transform<Array2<F>>`] — project data onto discriminant axes.
98/// - [`Predict<Array2<F>>`] — classify by nearest centroid in projected space.
99#[derive(Debug, Clone)]
100pub struct FittedLDA<F> {
101    /// Projection matrix, shape `(n_features, n_components)`.
102    ///
103    /// New data is projected via `X @ scalings`.
104    scalings: Array2<F>,
105
106    /// Class means in the projected space, shape `(n_classes, n_components)`.
107    means: Array2<F>,
108
109    /// Ratio of explained variance per discriminant direction.
110    explained_variance_ratio: Array1<F>,
111
112    /// Class labels corresponding to rows of `means`.
113    classes: Vec<usize>,
114
115    /// Number of features seen during fitting.
116    n_features: usize,
117}
118
119impl<F: Float + Send + Sync + 'static> FittedLDA<F> {
120    /// Projection (scalings) matrix, shape `(n_features, n_components)`.
121    #[must_use]
122    pub fn scalings(&self) -> &Array2<F> {
123        &self.scalings
124    }
125
126    /// Class centroids in the projected space, shape `(n_classes, n_components)`.
127    #[must_use]
128    pub fn means(&self) -> &Array2<F> {
129        &self.means
130    }
131
132    /// Explained-variance ratio per discriminant direction.
133    #[must_use]
134    pub fn explained_variance_ratio(&self) -> &Array1<F> {
135        &self.explained_variance_ratio
136    }
137
138    /// Sorted class labels as seen during fitting.
139    #[must_use]
140    pub fn classes(&self) -> &[usize] {
141        &self.classes
142    }
143}
144
145// ---------------------------------------------------------------------------
146// Internal linear algebra helpers (generic over F)
147// ---------------------------------------------------------------------------
148
149/// Jacobi symmetric eigendecomposition.
150///
151/// Returns `(eigenvalues, eigenvectors_columns)` — column `i` is the
152/// eigenvector for `eigenvalues[i]`.  Eigenvalues are **not** sorted.
153fn jacobi_eigen_f<F: Float + Send + Sync + 'static>(
154    a: &Array2<F>,
155    max_iter: usize,
156) -> Result<(Array1<F>, Array2<F>), FerroError> {
157    let n = a.nrows();
158    let mut mat = a.to_owned();
159    let mut v = Array2::<F>::zeros((n, n));
160    for i in 0..n {
161        v[[i, i]] = F::one();
162    }
163    let tol = F::from(1e-12).unwrap_or(F::epsilon());
164
165    for _ in 0..max_iter {
166        // Find the largest off-diagonal entry.
167        let mut max_off = F::zero();
168        let mut p = 0usize;
169        let mut q = 1usize;
170        for i in 0..n {
171            for j in (i + 1)..n {
172                let val = mat[[i, j]].abs();
173                if val > max_off {
174                    max_off = val;
175                    p = i;
176                    q = j;
177                }
178            }
179        }
180        if max_off < tol {
181            let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
182            return Ok((eigenvalues, v));
183        }
184        let app = mat[[p, p]];
185        let aqq = mat[[q, q]];
186        let apq = mat[[p, q]];
187        let two = F::from(2.0).unwrap();
188        let theta = if (app - aqq).abs() < tol {
189            F::from(std::f64::consts::FRAC_PI_4).unwrap_or(F::one())
190        } else {
191            let tau = (aqq - app) / (two * apq);
192            let t = if tau >= F::zero() {
193                F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
194            } else {
195                -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
196            };
197            t.atan()
198        };
199        let c = theta.cos();
200        let s = theta.sin();
201        let mut new_mat = mat.clone();
202        for i in 0..n {
203            if i != p && i != q {
204                let mip = mat[[i, p]];
205                let miq = mat[[i, q]];
206                new_mat[[i, p]] = c * mip - s * miq;
207                new_mat[[p, i]] = new_mat[[i, p]];
208                new_mat[[i, q]] = s * mip + c * miq;
209                new_mat[[q, i]] = new_mat[[i, q]];
210            }
211        }
212        new_mat[[p, p]] = c * c * app - two * s * c * apq + s * s * aqq;
213        new_mat[[q, q]] = s * s * app + two * s * c * apq + c * c * aqq;
214        new_mat[[p, q]] = F::zero();
215        new_mat[[q, p]] = F::zero();
216        mat = new_mat;
217        for i in 0..n {
218            let vip = v[[i, p]];
219            let viq = v[[i, q]];
220            v[[i, p]] = c * vip - s * viq;
221            v[[i, q]] = s * vip + c * viq;
222        }
223    }
224    Err(FerroError::ConvergenceFailure {
225        iterations: max_iter,
226        message: "Jacobi eigendecomposition did not converge (LDA)".into(),
227    })
228}
229
230/// Gaussian elimination with partial pivoting to solve `A x = b`.
231fn gaussian_solve_f<F: Float>(
232    n: usize,
233    a: &Array2<F>,
234    b: &Array1<F>,
235) -> Result<Array1<F>, FerroError> {
236    let mut aug = Array2::<F>::zeros((n, n + 1));
237    for i in 0..n {
238        for j in 0..n {
239            aug[[i, j]] = a[[i, j]];
240        }
241        aug[[i, n]] = b[i];
242    }
243    for col in 0..n {
244        let mut max_val = aug[[col, col]].abs();
245        let mut max_row = col;
246        for row in (col + 1)..n {
247            let val = aug[[row, col]].abs();
248            if val > max_val {
249                max_val = val;
250                max_row = row;
251            }
252        }
253        if max_val < F::from(1e-12).unwrap_or(F::epsilon()) {
254            return Err(FerroError::NumericalInstability {
255                message: "singular matrix during LDA inversion".into(),
256            });
257        }
258        if max_row != col {
259            for j in 0..=n {
260                let tmp = aug[[col, j]];
261                aug[[col, j]] = aug[[max_row, j]];
262                aug[[max_row, j]] = tmp;
263            }
264        }
265        let pivot = aug[[col, col]];
266        for row in (col + 1)..n {
267            let factor = aug[[row, col]] / pivot;
268            for j in col..=n {
269                let above = aug[[col, j]];
270                aug[[row, j]] = aug[[row, j]] - factor * above;
271            }
272        }
273    }
274    let mut x = Array1::<F>::zeros(n);
275    for i in (0..n).rev() {
276        let mut sum = aug[[i, n]];
277        for j in (i + 1)..n {
278            sum = sum - aug[[i, j]] * x[j];
279        }
280        if aug[[i, i]].abs() < F::from(1e-12).unwrap_or(F::epsilon()) {
281            return Err(FerroError::NumericalInstability {
282                message: "near-zero pivot during LDA back substitution".into(),
283            });
284        }
285        x[i] = sum / aug[[i, i]];
286    }
287    Ok(x)
288}
289
290/// Compute `Sw⁻¹ @ Sb` column by column.
291///
292/// Returns the matrix `M = Sw⁻¹ Sb` of shape `(n, n)`.
293fn sw_inv_sb<F: Float + Send + Sync + 'static>(
294    sw: &Array2<F>,
295    sb: &Array2<F>,
296) -> Result<Array2<F>, FerroError> {
297    let n = sw.nrows();
298    let mut result = Array2::<F>::zeros((n, n));
299    for j in 0..n {
300        let col_sb = Array1::from_shape_fn(n, |i| sb[[i, j]]);
301        let col = gaussian_solve_f(n, sw, &col_sb)?;
302        for i in 0..n {
303            result[[i, j]] = col[i];
304        }
305    }
306    Ok(result)
307}
308
309// ---------------------------------------------------------------------------
310// Fit
311// ---------------------------------------------------------------------------
312
313impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for LDA<F> {
314    type Fitted = FittedLDA<F>;
315    type Error = FerroError;
316
317    /// Fit the LDA model.
318    ///
319    /// # Errors
320    ///
321    /// - [`FerroError::InsufficientSamples`] if fewer than 2 samples are provided.
322    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or exceeds the
323    ///   maximum allowed (`min(n_classes - 1, n_features)`).
324    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different numbers of rows.
325    /// - [`FerroError::NumericalInstability`] if `Sw` is singular.
326    /// - [`FerroError::ConvergenceFailure`] if eigendecomposition does not converge.
327    fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedLDA<F>, FerroError> {
328        let (n_samples, n_features) = x.dim();
329
330        if n_samples != y.len() {
331            return Err(FerroError::ShapeMismatch {
332                expected: vec![n_samples],
333                actual: vec![y.len()],
334                context: "LDA: y length must match number of rows in X".into(),
335            });
336        }
337        if n_samples < 2 {
338            return Err(FerroError::InsufficientSamples {
339                required: 2,
340                actual: n_samples,
341                context: "LDA requires at least 2 samples".into(),
342            });
343        }
344
345        // Gather sorted unique classes.
346        let mut classes: Vec<usize> = y.to_vec();
347        classes.sort_unstable();
348        classes.dedup();
349        let n_classes = classes.len();
350
351        if n_classes < 2 {
352            return Err(FerroError::InsufficientSamples {
353                required: 2,
354                actual: n_classes,
355                context: "LDA requires at least 2 distinct classes".into(),
356            });
357        }
358
359        // Determine effective n_components.
360        let max_components = (n_classes - 1).min(n_features);
361        let n_comp = match self.n_components {
362            None => max_components,
363            Some(0) => {
364                return Err(FerroError::InvalidParameter {
365                    name: "n_components".into(),
366                    reason: "must be at least 1".into(),
367                });
368            }
369            Some(k) if k > max_components => {
370                return Err(FerroError::InvalidParameter {
371                    name: "n_components".into(),
372                    reason: format!(
373                        "n_components ({k}) exceeds max allowed ({max_components} = min(n_classes-1, n_features))"
374                    ),
375                });
376            }
377            Some(k) => k,
378        };
379
380        // --- Step 1: compute overall mean and per-class means ----------------
381        let n_f = F::from(n_samples).unwrap();
382        let mut overall_mean = Array1::<F>::zeros(n_features);
383        for j in 0..n_features {
384            let col = x.column(j);
385            let s = col.iter().copied().fold(F::zero(), |a, b| a + b);
386            overall_mean[j] = s / n_f;
387        }
388
389        // class_means[c] = mean of samples in class c
390        let mut class_means: Vec<Array1<F>> = Vec::with_capacity(n_classes);
391        let mut class_counts: Vec<usize> = Vec::with_capacity(n_classes);
392        for &cls in &classes {
393            let mut mean = Array1::<F>::zeros(n_features);
394            let mut cnt = 0usize;
395            for (i, &label) in y.iter().enumerate() {
396                if label == cls {
397                    for j in 0..n_features {
398                        mean[j] = mean[j] + x[[i, j]];
399                    }
400                    cnt += 1;
401                }
402            }
403            if cnt == 0 {
404                return Err(FerroError::InsufficientSamples {
405                    required: 1,
406                    actual: 0,
407                    context: format!("LDA: class {cls} has no samples"),
408                });
409            }
410            let cnt_f = F::from(cnt).unwrap();
411            mean.mapv_inplace(|v| v / cnt_f);
412            class_means.push(mean);
413            class_counts.push(cnt);
414        }
415
416        // --- Step 2: within-class scatter Sw ----------------------------------
417        let mut sw = Array2::<F>::zeros((n_features, n_features));
418        for (ci, &cls) in classes.iter().enumerate() {
419            let mu_c = &class_means[ci];
420            for (i, &label) in y.iter().enumerate() {
421                if label == cls {
422                    // diff = x[i] - mu_c
423                    let diff: Vec<F> = (0..n_features).map(|j| x[[i, j]] - mu_c[j]).collect();
424                    for r in 0..n_features {
425                        for c in 0..n_features {
426                            sw[[r, c]] = sw[[r, c]] + diff[r] * diff[c];
427                        }
428                    }
429                }
430            }
431        }
432
433        // Add a small regularisation to Sw to avoid singularity.
434        let reg = F::from(1e-6).unwrap();
435        for i in 0..n_features {
436            sw[[i, i]] = sw[[i, i]] + reg;
437        }
438
439        // --- Step 3: between-class scatter Sb ---------------------------------
440        let mut sb = Array2::<F>::zeros((n_features, n_features));
441        for (ci, &nc) in class_counts.iter().enumerate() {
442            let nc_f = F::from(nc).unwrap();
443            let diff: Vec<F> = (0..n_features)
444                .map(|j| class_means[ci][j] - overall_mean[j])
445                .collect();
446            for r in 0..n_features {
447                for c in 0..n_features {
448                    sb[[r, c]] = sb[[r, c]] + nc_f * diff[r] * diff[c];
449                }
450            }
451        }
452
453        // --- Step 4: solve generalised eigenvalue problem Sw⁻¹ Sb v = λ v ----
454        let m = sw_inv_sb(&sw, &sb)?;
455        let max_jacobi = n_features * n_features * 100 + 1000;
456        let (eigenvalues, eigenvectors) = jacobi_eigen_f(&m, max_jacobi)?;
457
458        // Sort eigenvalues descending.
459        let mut indices: Vec<usize> = (0..n_features).collect();
460        indices.sort_by(|&a, &b| {
461            eigenvalues[b]
462                .partial_cmp(&eigenvalues[a])
463                .unwrap_or(std::cmp::Ordering::Equal)
464        });
465
466        // Clamp negative eigenvalues.
467        let total_ev: F = eigenvalues
468            .iter()
469            .copied()
470            .map(|v| if v > F::zero() { v } else { F::zero() })
471            .fold(F::zero(), |a, b| a + b);
472
473        // --- Step 5: build scalings matrix (n_features × n_comp) -------------
474        let mut scalings = Array2::<F>::zeros((n_features, n_comp));
475        let mut explained_variance_ratio = Array1::<F>::zeros(n_comp);
476        for (k, &idx) in indices.iter().take(n_comp).enumerate() {
477            let ev = eigenvalues[idx];
478            let ev_clamped = if ev > F::zero() { ev } else { F::zero() };
479            explained_variance_ratio[k] = if total_ev > F::zero() {
480                ev_clamped / total_ev
481            } else {
482                F::zero()
483            };
484            for j in 0..n_features {
485                scalings[[j, k]] = eigenvectors[[j, idx]];
486            }
487        }
488
489        // --- Project class means into the discriminant space -----------------
490        // means[c, k] = class_means[c] · scalings[:, k]
491        let mut means = Array2::<F>::zeros((n_classes, n_comp));
492        for ci in 0..n_classes {
493            let mu_row = class_means[ci].view();
494            for k in 0..n_comp {
495                let mut dot = F::zero();
496                for j in 0..n_features {
497                    dot = dot + mu_row[j] * scalings[[j, k]];
498                }
499                means[[ci, k]] = dot;
500            }
501        }
502
503        Ok(FittedLDA {
504            scalings,
505            means,
506            explained_variance_ratio,
507            classes,
508            n_features,
509        })
510    }
511}
512
513// ---------------------------------------------------------------------------
514// Transform
515// ---------------------------------------------------------------------------
516
517impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedLDA<F> {
518    type Output = Array2<F>;
519    type Error = FerroError;
520
521    /// Project `x` onto the discriminant axes: `X @ scalings`.
522    ///
523    /// # Errors
524    ///
525    /// Returns [`FerroError::ShapeMismatch`] if `x.ncols()` does not match the
526    /// number of features seen during fitting.
527    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
528        if x.ncols() != self.n_features {
529            return Err(FerroError::ShapeMismatch {
530                expected: vec![x.nrows(), self.n_features],
531                actual: vec![x.nrows(), x.ncols()],
532                context: "FittedLDA::transform".into(),
533            });
534        }
535        Ok(x.dot(&self.scalings))
536    }
537}
538
539// ---------------------------------------------------------------------------
540// Predict (nearest centroid in projected space)
541// ---------------------------------------------------------------------------
542
543impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedLDA<F> {
544    type Output = Array1<usize>;
545    type Error = FerroError;
546
547    /// Classify samples by nearest centroid in the projected space.
548    ///
549    /// # Errors
550    ///
551    /// Returns [`FerroError::ShapeMismatch`] if the number of features does not
552    /// match the model.
553    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
554        let projected = self.transform(x)?;
555        let n_samples = projected.nrows();
556        let n_comp = projected.ncols();
557        let n_classes = self.classes.len();
558
559        let mut predictions = Array1::<usize>::zeros(n_samples);
560        for i in 0..n_samples {
561            let mut best_class = 0usize;
562            let mut best_dist = F::infinity();
563            for ci in 0..n_classes {
564                let mut dist = F::zero();
565                for k in 0..n_comp {
566                    let d = projected[[i, k]] - self.means[[ci, k]];
567                    dist = dist + d * d;
568                }
569                if dist < best_dist {
570                    best_dist = dist;
571                    best_class = ci;
572                }
573            }
574            predictions[i] = self.classes[best_class];
575        }
576        Ok(predictions)
577    }
578}
579
580// ---------------------------------------------------------------------------
581// Pipeline integration (generic)
582// ---------------------------------------------------------------------------
583
584impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for LDA<F> {
585    /// Fit LDA using the pipeline interface.
586    ///
587    /// # Errors
588    ///
589    /// Propagates errors from [`Fit::fit`].
590    fn fit_pipeline(
591        &self,
592        x: &Array2<F>,
593        y: &Array1<F>,
594    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
595        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
596        let fitted = self.fit(x, &y_usize)?;
597        Ok(Box::new(FittedLDAPipeline(fitted)))
598    }
599}
600
601/// Wrapper for pipeline integration that converts predictions to float.
602struct FittedLDAPipeline<F>(FittedLDA<F>);
603
604impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedLDAPipeline<F> {
605    /// Predict via the pipeline interface, returning float class labels.
606    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
607        let preds = self.0.predict(x)?;
608        Ok(preds.mapv(|v| NumCast::from(v).unwrap_or(F::nan())))
609    }
610}
611
612// ---------------------------------------------------------------------------
613// Tests
614// ---------------------------------------------------------------------------
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619    use approx::assert_abs_diff_eq;
620    use ndarray::{Array2, array};
621
622    // ------------------------------------------------------------------
623    // Helpers
624    // ------------------------------------------------------------------
625
626    fn linearly_separable_2d() -> (Array2<f64>, Array1<usize>) {
627        // Two well-separated Gaussian clusters.
628        let x = Array2::from_shape_vec(
629            (8, 2),
630            vec![
631                1.0, 1.0, 1.5, 1.2, 0.8, 0.9, 1.1, 1.3, // class 0
632                6.0, 6.0, 6.2, 5.8, 5.9, 6.1, 6.3, 5.7, // class 1
633            ],
634        )
635        .unwrap();
636        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
637        (x, y)
638    }
639
640    fn three_class_data() -> (Array2<f64>, Array1<usize>) {
641        let x = Array2::from_shape_vec(
642            (9, 2),
643            vec![
644                0.0, 0.0, 0.5, 0.1, 0.1, 0.5, // class 0
645                5.0, 0.0, 5.2, 0.3, 4.8, 0.1, // class 1
646                0.0, 5.0, 0.1, 5.2, 0.3, 4.8, // class 2
647            ],
648        )
649        .unwrap();
650        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
651        (x, y)
652    }
653
654    // ------------------------------------------------------------------
655
656    #[test]
657    fn test_lda_fit_returns_fitted() {
658        let (x, y) = linearly_separable_2d();
659        let lda = LDA::<f64>::new(Some(1));
660        let fitted = lda.fit(&x, &y).unwrap();
661        assert_eq!(fitted.scalings().ncols(), 1);
662        assert_eq!(fitted.scalings().nrows(), 2);
663    }
664
665    #[test]
666    fn test_lda_default_n_components() {
667        // With 2 classes the default n_components = min(1, n_features) = 1.
668        let (x, y) = linearly_separable_2d();
669        let lda = LDA::<f64>::default();
670        let fitted = lda.fit(&x, &y).unwrap();
671        assert_eq!(fitted.scalings().ncols(), 1);
672    }
673
674    #[test]
675    fn test_lda_transform_shape() {
676        let (x, y) = linearly_separable_2d();
677        let lda = LDA::<f64>::new(Some(1));
678        let fitted = lda.fit(&x, &y).unwrap();
679        let proj = fitted.transform(&x).unwrap();
680        assert_eq!(proj.dim(), (8, 1));
681    }
682
683    #[test]
684    fn test_lda_predict_accuracy_binary() {
685        let (x, y) = linearly_separable_2d();
686        let lda = LDA::<f64>::new(Some(1));
687        let fitted = lda.fit(&x, &y).unwrap();
688        let preds = fitted.predict(&x).unwrap();
689        let correct = preds.iter().zip(y.iter()).filter(|(p, a)| *p == *a).count();
690        assert_eq!(correct, 8, "All 8 samples should be classified correctly");
691    }
692
693    #[test]
694    fn test_lda_predict_three_classes() {
695        let (x, y) = three_class_data();
696        let lda = LDA::<f64>::new(Some(2));
697        let fitted = lda.fit(&x, &y).unwrap();
698        let preds = fitted.predict(&x).unwrap();
699        let correct = preds.iter().zip(y.iter()).filter(|(p, a)| *p == *a).count();
700        assert!(correct >= 7, "Expected at least 7/9 correct, got {correct}");
701    }
702
703    #[test]
704    fn test_lda_explained_variance_ratio_positive() {
705        let (x, y) = linearly_separable_2d();
706        let lda = LDA::<f64>::new(Some(1));
707        let fitted = lda.fit(&x, &y).unwrap();
708        for &v in fitted.explained_variance_ratio().iter() {
709            assert!(v >= 0.0);
710        }
711    }
712
713    #[test]
714    fn test_lda_explained_variance_ratio_le_1() {
715        let (x, y) = three_class_data();
716        let lda = LDA::<f64>::new(Some(2));
717        let fitted = lda.fit(&x, &y).unwrap();
718        let total: f64 = fitted.explained_variance_ratio().iter().sum();
719        assert!(total <= 1.0 + 1e-9, "total={total}");
720    }
721
722    #[test]
723    fn test_lda_classes_accessor() {
724        let (x, y) = linearly_separable_2d();
725        let lda = LDA::<f64>::new(Some(1));
726        let fitted = lda.fit(&x, &y).unwrap();
727        assert_eq!(fitted.classes(), &[0usize, 1]);
728    }
729
730    #[test]
731    fn test_lda_means_shape() {
732        let (x, y) = three_class_data();
733        let lda = LDA::<f64>::new(Some(2));
734        let fitted = lda.fit(&x, &y).unwrap();
735        assert_eq!(fitted.means().dim(), (3, 2));
736    }
737
738    #[test]
739    fn test_lda_transform_shape_mismatch() {
740        let (x, y) = linearly_separable_2d();
741        let lda = LDA::<f64>::new(Some(1));
742        let fitted = lda.fit(&x, &y).unwrap();
743        let x_bad = Array2::<f64>::zeros((3, 5));
744        assert!(fitted.transform(&x_bad).is_err());
745    }
746
747    #[test]
748    fn test_lda_predict_shape_mismatch() {
749        let (x, y) = linearly_separable_2d();
750        let lda = LDA::<f64>::new(Some(1));
751        let fitted = lda.fit(&x, &y).unwrap();
752        let x_bad = Array2::<f64>::zeros((3, 5));
753        assert!(fitted.predict(&x_bad).is_err());
754    }
755
756    #[test]
757    fn test_lda_error_zero_n_components() {
758        let (x, y) = linearly_separable_2d();
759        let lda = LDA::<f64>::new(Some(0));
760        assert!(lda.fit(&x, &y).is_err());
761    }
762
763    #[test]
764    fn test_lda_error_n_components_too_large() {
765        let (x, y) = linearly_separable_2d(); // 2 classes → max 1 component
766        let lda = LDA::<f64>::new(Some(5));
767        assert!(lda.fit(&x, &y).is_err());
768    }
769
770    #[test]
771    fn test_lda_error_single_class() {
772        let x =
773            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
774        let y = array![0usize, 0, 0, 0];
775        let lda = LDA::<f64>::new(None);
776        assert!(lda.fit(&x, &y).is_err());
777    }
778
779    #[test]
780    fn test_lda_error_shape_mismatch_fit() {
781        let x = Array2::<f64>::zeros((4, 2));
782        let y = array![0usize, 1]; // wrong length
783        let lda = LDA::<f64>::new(None);
784        assert!(lda.fit(&x, &y).is_err());
785    }
786
787    #[test]
788    fn test_lda_error_insufficient_samples() {
789        let x = Array2::<f64>::zeros((1, 2));
790        let y = array![0usize];
791        let lda = LDA::<f64>::new(None);
792        assert!(lda.fit(&x, &y).is_err());
793    }
794
795    #[test]
796    fn test_lda_scalings_accessor() {
797        let (x, y) = linearly_separable_2d();
798        let lda = LDA::<f64>::new(Some(1));
799        let fitted = lda.fit(&x, &y).unwrap();
800        assert_eq!(fitted.scalings().dim(), (2, 1));
801    }
802
803    #[test]
804    fn test_lda_pipeline_estimator() {
805        use ferrolearn_core::pipeline::PipelineEstimator;
806
807        let (x, y_usize) = linearly_separable_2d();
808        let y_f64 = y_usize.mapv(|v| v as f64);
809        let lda = LDA::<f64>::new(Some(1));
810        let fitted = lda.fit_pipeline(&x, &y_f64).unwrap();
811        let preds = fitted.predict_pipeline(&x).unwrap();
812        assert_eq!(preds.len(), 8);
813    }
814
815    #[test]
816    fn test_lda_n_components_getter() {
817        let lda = LDA::<f64>::new(Some(2));
818        assert_eq!(lda.n_components(), Some(2));
819        let lda_none = LDA::<f64>::new(None);
820        assert_eq!(lda_none.n_components(), None);
821    }
822
823    #[test]
824    fn test_lda_transform_then_predict_consistent() {
825        let (x, y) = linearly_separable_2d();
826        let lda = LDA::<f64>::new(Some(1));
827        let fitted = lda.fit(&x, &y).unwrap();
828        // Manually compute nearest-centroid prediction from transform output.
829        let projected = fitted.transform(&x).unwrap();
830        let preds_predict = fitted.predict(&x).unwrap();
831        let n_samples = projected.nrows();
832        let n_comp = projected.ncols();
833        let n_classes = fitted.classes().len();
834        for i in 0..n_samples {
835            let mut best = 0;
836            let mut best_d = f64::INFINITY;
837            for ci in 0..n_classes {
838                let mut d = 0.0;
839                for k in 0..n_comp {
840                    let diff = projected[[i, k]] - fitted.means()[[ci, k]];
841                    d += diff * diff;
842                }
843                if d < best_d {
844                    best_d = d;
845                    best = ci;
846                }
847            }
848            assert_eq!(preds_predict[i], fitted.classes()[best]);
849        }
850    }
851
852    #[test]
853    fn test_lda_projected_class_separation() {
854        let (x, y) = linearly_separable_2d();
855        let lda = LDA::<f64>::new(Some(1));
856        let fitted = lda.fit(&x, &y).unwrap();
857        let projected = fitted.transform(&x).unwrap();
858
859        // Means of class 0 and class 1 in projected space should be far apart.
860        let mean0: f64 = projected
861            .rows()
862            .into_iter()
863            .zip(y.iter())
864            .filter(|&(_, label)| *label == 0)
865            .map(|(row, _)| row[0])
866            .sum::<f64>()
867            / 4.0;
868        let mean1: f64 = projected
869            .rows()
870            .into_iter()
871            .zip(y.iter())
872            .filter(|&(_, label)| *label == 1)
873            .map(|(row, _)| row[0])
874            .sum::<f64>()
875            / 4.0;
876
877        assert!(
878            (mean0 - mean1).abs() > 0.5,
879            "Projected means should differ, got {mean0} vs {mean1}"
880        );
881    }
882
883    #[test]
884    fn test_lda_transform_known_data() {
885        // With perfectly separated data the transform should yield two clearly
886        // distinct groups.
887        let x = Array2::from_shape_vec((4, 1), vec![-2.0, -1.0, 1.0, 2.0]).unwrap();
888        let y = array![0usize, 0, 1, 1];
889        let lda = LDA::<f64>::new(Some(1));
890        let fitted = lda.fit(&x, &y).unwrap();
891        let proj = fitted.transform(&x).unwrap();
892        // The first two samples should project to one side, the other two to the other side.
893        let sign0 = proj[[0, 0]].signum();
894        let sign1 = proj[[2, 0]].signum();
895        // They should be on opposite sides of the origin (or at least the split is correct).
896        assert_ne!(
897            sign0 as i32, sign1 as i32,
898            "Classes should be on opposite sides"
899        );
900    }
901
902    #[test]
903    fn test_lda_abs_diff_eq_means_dimensions() {
904        let (x, y) = linearly_separable_2d();
905        let lda = LDA::<f64>::new(Some(1));
906        let fitted = lda.fit(&x, &y).unwrap();
907        // Each class mean in projected space should be a 1-component vector.
908        assert_eq!(fitted.means().ncols(), 1);
909        let m0 = fitted.means()[[0, 0]];
910        let m1 = fitted.means()[[1, 0]];
911        // For well-separated data the projected means should differ by > 1.0.
912        assert!((m0 - m1).abs() > 0.5, "m0={m0}, m1={m1}");
913        let _ = assert_abs_diff_eq!(0.0_f64, 0.0_f64); // use the import
914    }
915}