Skip to main content

ferrolearn_linear/
lda.rs

1//! Linear Discriminant Analysis (LDA).
2//!
3//! LDA is both a supervised dimensionality reduction technique and a
4//! linear classifier. It finds the directions that maximise the separation
5//! between classes while minimising within-class scatter.
6//!
7//! # Algorithm
8//!
9//! 1. Compute class means `μ_c` and the overall mean `μ`.
10//! 2. Compute the within-class scatter matrix
11//!    `Sw = Σ_c Σ_{x ∈ c} (x - μ_c)(x - μ_c)^T`.
12//! 3. Compute the between-class scatter matrix
13//!    `Sb = Σ_c n_c (μ_c - μ)(μ_c - μ)^T`.
14//! 4. Solve the generalised eigenvalue problem `Sw⁻¹ Sb v = λ v`.
15//! 5. Project data onto the top-`k` eigenvectors.
16//!
17//! The number of discriminant directions is at most `min(n_classes - 1, n_features)`.
18//!
19//! # Examples
20//!
21//! ```
22//! use ferrolearn_linear::lda::LDA;
23//! use ferrolearn_core::{Fit, Predict};
24//! use ndarray::{array, Array1, Array2};
25//!
26//! let lda = LDA::new(Some(1));
27//! let x = Array2::from_shape_vec(
28//!     (6, 2),
29//!     vec![1.0, 1.0, 1.5, 1.2, 1.2, 0.8, 5.0, 5.0, 5.5, 4.8, 4.8, 5.2],
30//! ).unwrap();
31//! let y = array![0usize, 0, 0, 1, 1, 1];
32//! let fitted = lda.fit(&x, &y).unwrap();
33//! let preds = fitted.predict(&x).unwrap();
34//! assert_eq!(preds.len(), 6);
35//! ```
36
37use ferrolearn_core::error::FerroError;
38use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
39use ferrolearn_core::traits::{Fit, Predict, Transform};
40use ndarray::{Array1, Array2};
41use num_traits::{Float, NumCast};
42
43// ---------------------------------------------------------------------------
44// LDA (unfitted)
45// ---------------------------------------------------------------------------
46
47/// Linear Discriminant Analysis configuration.
48///
49/// Holds hyperparameters. Calling [`Fit::fit`] computes the discriminant
50/// directions and returns a [`FittedLDA`].
51///
52/// # Type Parameters
53///
54/// - `F`: The floating-point scalar type (`f32` or `f64`).
55#[derive(Debug, Clone)]
56pub struct LDA<F> {
57    /// Number of discriminant components to retain.
58    ///
59    /// If `None`, defaults to `min(n_classes - 1, n_features)` at fit time.
60    n_components: Option<usize>,
61    _marker: std::marker::PhantomData<F>,
62}
63
64impl<F: Float + Send + Sync + 'static> LDA<F> {
65    /// Create a new `LDA`.
66    ///
67    /// - `n_components`: number of discriminant directions to retain.
68    ///   Pass `None` to use `min(n_classes - 1, n_features)`.
69    #[must_use]
70    pub fn new(n_components: Option<usize>) -> Self {
71        Self {
72            n_components,
73            _marker: std::marker::PhantomData,
74        }
75    }
76
77    /// Return the configured number of components (may be `None`).
78    #[must_use]
79    pub fn n_components(&self) -> Option<usize> {
80        self.n_components
81    }
82}
83
84impl<F: Float + Send + Sync + 'static> Default for LDA<F> {
85    fn default() -> Self {
86        Self::new(None)
87    }
88}
89
90// ---------------------------------------------------------------------------
91// FittedLDA
92// ---------------------------------------------------------------------------
93
94/// A fitted LDA model.
95///
96/// Created by calling [`Fit::fit`] on an [`LDA`]. Implements:
97/// - [`Transform<Array2<F>>`] — project data onto discriminant axes.
98/// - [`Predict<Array2<F>>`] — classify by nearest centroid in projected space.
99#[derive(Debug, Clone)]
100pub struct FittedLDA<F> {
101    /// Projection matrix, shape `(n_features, n_components)`.
102    ///
103    /// New data is projected via `X @ scalings`.
104    scalings: Array2<F>,
105
106    /// Class means in the projected space, shape `(n_classes, n_components)`.
107    means: Array2<F>,
108
109    /// Ratio of explained variance per discriminant direction.
110    explained_variance_ratio: Array1<F>,
111
112    /// Class labels corresponding to rows of `means`.
113    classes: Vec<usize>,
114
115    /// Number of features seen during fitting.
116    n_features: usize,
117}
118
119impl<F: Float + Send + Sync + 'static> FittedLDA<F> {
120    /// Projection (scalings) matrix, shape `(n_features, n_components)`.
121    #[must_use]
122    pub fn scalings(&self) -> &Array2<F> {
123        &self.scalings
124    }
125
126    /// Class centroids in the projected space, shape `(n_classes, n_components)`.
127    #[must_use]
128    pub fn means(&self) -> &Array2<F> {
129        &self.means
130    }
131
132    /// Explained-variance ratio per discriminant direction.
133    #[must_use]
134    pub fn explained_variance_ratio(&self) -> &Array1<F> {
135        &self.explained_variance_ratio
136    }
137
138    /// Sorted class labels as seen during fitting.
139    #[must_use]
140    pub fn classes(&self) -> &[usize] {
141        &self.classes
142    }
143
144    /// Predict per-class probabilities. Mirrors sklearn
145    /// `LinearDiscriminantAnalysis.predict_proba`.
146    ///
147    /// Computes softmax over `-½ ‖z - μ_c‖²` in the projected space (an
148    /// equal-priors approximation; ferrolearn's FittedLDA does not store
149    /// the per-class priors, so the full sklearn formula reduces to this
150    /// when priors are uniform). Returns shape `(n_samples, n_classes)`;
151    /// rows sum to 1.
152    ///
153    /// # Errors
154    ///
155    /// Returns [`FerroError::ShapeMismatch`] if the number of features
156    /// does not match the model.
157    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
158        let projected = self.transform(x)?;
159        let n_samples = projected.nrows();
160        let n_comp = projected.ncols();
161        let n_classes = self.classes.len();
162        let neg_half = F::from(-0.5).unwrap();
163        let mut proba = Array2::<F>::zeros((n_samples, n_classes));
164        for i in 0..n_samples {
165            let mut logits = vec![F::zero(); n_classes];
166            for ci in 0..n_classes {
167                let mut dist_sq = F::zero();
168                for k in 0..n_comp {
169                    let d = projected[[i, k]] - self.means[[ci, k]];
170                    dist_sq = dist_sq + d * d;
171                }
172                logits[ci] = neg_half * dist_sq;
173            }
174            let max_l = logits
175                .iter()
176                .copied()
177                .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
178            let mut sum_exp = F::zero();
179            for ci in 0..n_classes {
180                let e = (logits[ci] - max_l).exp();
181                proba[[i, ci]] = e;
182                sum_exp = sum_exp + e;
183            }
184            for ci in 0..n_classes {
185                proba[[i, ci]] = proba[[i, ci]] / sum_exp;
186            }
187        }
188        Ok(proba)
189    }
190
191    /// Element-wise log of [`predict_proba`](Self::predict_proba).
192    ///
193    /// # Errors
194    ///
195    /// Forwards any error from [`predict_proba`](Self::predict_proba).
196    pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
197        let proba = self.predict_proba(x)?;
198        Ok(crate::log_proba(&proba))
199    }
200
201    /// Per-class discriminant scores. Mirrors sklearn
202    /// `LinearDiscriminantAnalysis.decision_function`.
203    ///
204    /// Returns shape `(n_samples, n_classes)` with `-½ ‖z - μ_c‖²` in
205    /// the projected space. argmax of each row agrees with [`Predict`].
206    ///
207    /// # Errors
208    ///
209    /// Returns [`FerroError::ShapeMismatch`] if the number of features
210    /// does not match the fitted model.
211    pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
212        let projected = self.transform(x)?;
213        let n_samples = projected.nrows();
214        let n_comp = projected.ncols();
215        let n_classes = self.classes.len();
216        let neg_half = F::from(-0.5).unwrap();
217        let mut out = Array2::<F>::zeros((n_samples, n_classes));
218        for i in 0..n_samples {
219            for ci in 0..n_classes {
220                let mut dist_sq = F::zero();
221                for k in 0..n_comp {
222                    let d = projected[[i, k]] - self.means[[ci, k]];
223                    dist_sq = dist_sq + d * d;
224                }
225                out[[i, ci]] = neg_half * dist_sq;
226            }
227        }
228        Ok(out)
229    }
230}
231
232// ---------------------------------------------------------------------------
233// Internal linear algebra helpers (generic over F)
234// ---------------------------------------------------------------------------
235
236/// Jacobi symmetric eigendecomposition.
237///
238/// Returns `(eigenvalues, eigenvectors_columns)` — column `i` is the
239/// eigenvector for `eigenvalues[i]`.  Eigenvalues are **not** sorted.
240fn jacobi_eigen_f<F: Float + Send + Sync + 'static>(
241    a: &Array2<F>,
242    max_iter: usize,
243) -> Result<(Array1<F>, Array2<F>), FerroError> {
244    let n = a.nrows();
245    let mut mat = a.to_owned();
246    let mut v = Array2::<F>::zeros((n, n));
247    for i in 0..n {
248        v[[i, i]] = F::one();
249    }
250    let tol = F::from(1e-12).unwrap_or_else(F::epsilon);
251
252    for _ in 0..max_iter {
253        // Find the largest off-diagonal entry.
254        let mut max_off = F::zero();
255        let mut p = 0usize;
256        let mut q = 1usize;
257        for i in 0..n {
258            for j in (i + 1)..n {
259                let val = mat[[i, j]].abs();
260                if val > max_off {
261                    max_off = val;
262                    p = i;
263                    q = j;
264                }
265            }
266        }
267        if max_off < tol {
268            let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
269            return Ok((eigenvalues, v));
270        }
271        let app = mat[[p, p]];
272        let aqq = mat[[q, q]];
273        let apq = mat[[p, q]];
274        let two = F::from(2.0).unwrap();
275        let theta = if (app - aqq).abs() < tol {
276            F::from(std::f64::consts::FRAC_PI_4).unwrap_or_else(F::one)
277        } else {
278            let tau = (aqq - app) / (two * apq);
279            let t = if tau >= F::zero() {
280                F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
281            } else {
282                -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
283            };
284            t.atan()
285        };
286        let c = theta.cos();
287        let s = theta.sin();
288        let mut new_mat = mat.clone();
289        for i in 0..n {
290            if i != p && i != q {
291                let mip = mat[[i, p]];
292                let miq = mat[[i, q]];
293                new_mat[[i, p]] = c * mip - s * miq;
294                new_mat[[p, i]] = new_mat[[i, p]];
295                new_mat[[i, q]] = s * mip + c * miq;
296                new_mat[[q, i]] = new_mat[[i, q]];
297            }
298        }
299        new_mat[[p, p]] = c * c * app - two * s * c * apq + s * s * aqq;
300        new_mat[[q, q]] = s * s * app + two * s * c * apq + c * c * aqq;
301        new_mat[[p, q]] = F::zero();
302        new_mat[[q, p]] = F::zero();
303        mat = new_mat;
304        for i in 0..n {
305            let vip = v[[i, p]];
306            let viq = v[[i, q]];
307            v[[i, p]] = c * vip - s * viq;
308            v[[i, q]] = s * vip + c * viq;
309        }
310    }
311    Err(FerroError::ConvergenceFailure {
312        iterations: max_iter,
313        message: "Jacobi eigendecomposition did not converge (LDA)".into(),
314    })
315}
316
317/// Gaussian elimination with partial pivoting to solve `A x = b`.
318fn gaussian_solve_f<F: Float>(
319    n: usize,
320    a: &Array2<F>,
321    b: &Array1<F>,
322) -> Result<Array1<F>, FerroError> {
323    let mut aug = Array2::<F>::zeros((n, n + 1));
324    for i in 0..n {
325        for j in 0..n {
326            aug[[i, j]] = a[[i, j]];
327        }
328        aug[[i, n]] = b[i];
329    }
330    for col in 0..n {
331        let mut max_val = aug[[col, col]].abs();
332        let mut max_row = col;
333        for row in (col + 1)..n {
334            let val = aug[[row, col]].abs();
335            if val > max_val {
336                max_val = val;
337                max_row = row;
338            }
339        }
340        if max_val < F::from(1e-12).unwrap_or_else(F::epsilon) {
341            return Err(FerroError::NumericalInstability {
342                message: "singular matrix during LDA inversion".into(),
343            });
344        }
345        if max_row != col {
346            for j in 0..=n {
347                let tmp = aug[[col, j]];
348                aug[[col, j]] = aug[[max_row, j]];
349                aug[[max_row, j]] = tmp;
350            }
351        }
352        let pivot = aug[[col, col]];
353        for row in (col + 1)..n {
354            let factor = aug[[row, col]] / pivot;
355            for j in col..=n {
356                let above = aug[[col, j]];
357                aug[[row, j]] = aug[[row, j]] - factor * above;
358            }
359        }
360    }
361    let mut x = Array1::<F>::zeros(n);
362    for i in (0..n).rev() {
363        let mut sum = aug[[i, n]];
364        for j in (i + 1)..n {
365            sum = sum - aug[[i, j]] * x[j];
366        }
367        if aug[[i, i]].abs() < F::from(1e-12).unwrap_or_else(F::epsilon) {
368            return Err(FerroError::NumericalInstability {
369                message: "near-zero pivot during LDA back substitution".into(),
370            });
371        }
372        x[i] = sum / aug[[i, i]];
373    }
374    Ok(x)
375}
376
377/// Compute `Sw⁻¹ @ Sb` column by column.
378///
379/// Returns the matrix `M = Sw⁻¹ Sb` of shape `(n, n)`.
380fn sw_inv_sb<F: Float + Send + Sync + 'static>(
381    sw: &Array2<F>,
382    sb: &Array2<F>,
383) -> Result<Array2<F>, FerroError> {
384    let n = sw.nrows();
385    let mut result = Array2::<F>::zeros((n, n));
386    for j in 0..n {
387        let col_sb = Array1::from_shape_fn(n, |i| sb[[i, j]]);
388        let col = gaussian_solve_f(n, sw, &col_sb)?;
389        for i in 0..n {
390            result[[i, j]] = col[i];
391        }
392    }
393    Ok(result)
394}
395
396// ---------------------------------------------------------------------------
397// Fit
398// ---------------------------------------------------------------------------
399
400impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for LDA<F> {
401    type Fitted = FittedLDA<F>;
402    type Error = FerroError;
403
404    /// Fit the LDA model.
405    ///
406    /// # Errors
407    ///
408    /// - [`FerroError::InsufficientSamples`] if fewer than 2 samples are provided.
409    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or exceeds the
410    ///   maximum allowed (`min(n_classes - 1, n_features)`).
411    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different numbers of rows.
412    /// - [`FerroError::NumericalInstability`] if `Sw` is singular.
413    /// - [`FerroError::ConvergenceFailure`] if eigendecomposition does not converge.
414    fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedLDA<F>, FerroError> {
415        let (n_samples, n_features) = x.dim();
416
417        if n_samples != y.len() {
418            return Err(FerroError::ShapeMismatch {
419                expected: vec![n_samples],
420                actual: vec![y.len()],
421                context: "LDA: y length must match number of rows in X".into(),
422            });
423        }
424        if n_samples < 2 {
425            return Err(FerroError::InsufficientSamples {
426                required: 2,
427                actual: n_samples,
428                context: "LDA requires at least 2 samples".into(),
429            });
430        }
431
432        // Gather sorted unique classes.
433        let mut classes: Vec<usize> = y.to_vec();
434        classes.sort_unstable();
435        classes.dedup();
436        let n_classes = classes.len();
437
438        if n_classes < 2 {
439            return Err(FerroError::InsufficientSamples {
440                required: 2,
441                actual: n_classes,
442                context: "LDA requires at least 2 distinct classes".into(),
443            });
444        }
445
446        // Determine effective n_components.
447        let max_components = (n_classes - 1).min(n_features);
448        let n_comp = match self.n_components {
449            None => max_components,
450            Some(0) => {
451                return Err(FerroError::InvalidParameter {
452                    name: "n_components".into(),
453                    reason: "must be at least 1".into(),
454                });
455            }
456            Some(k) if k > max_components => {
457                return Err(FerroError::InvalidParameter {
458                    name: "n_components".into(),
459                    reason: format!(
460                        "n_components ({k}) exceeds max allowed ({max_components} = min(n_classes-1, n_features))"
461                    ),
462                });
463            }
464            Some(k) => k,
465        };
466
467        // --- Step 1: compute overall mean and per-class means ----------------
468        let n_f = F::from(n_samples).unwrap();
469        let mut overall_mean = Array1::<F>::zeros(n_features);
470        for j in 0..n_features {
471            let col = x.column(j);
472            let s = col.iter().copied().fold(F::zero(), |a, b| a + b);
473            overall_mean[j] = s / n_f;
474        }
475
476        // class_means[c] = mean of samples in class c
477        let mut class_means: Vec<Array1<F>> = Vec::with_capacity(n_classes);
478        let mut class_counts: Vec<usize> = Vec::with_capacity(n_classes);
479        for &cls in &classes {
480            let mut mean = Array1::<F>::zeros(n_features);
481            let mut cnt = 0usize;
482            for (i, &label) in y.iter().enumerate() {
483                if label == cls {
484                    for j in 0..n_features {
485                        mean[j] = mean[j] + x[[i, j]];
486                    }
487                    cnt += 1;
488                }
489            }
490            if cnt == 0 {
491                return Err(FerroError::InsufficientSamples {
492                    required: 1,
493                    actual: 0,
494                    context: format!("LDA: class {cls} has no samples"),
495                });
496            }
497            let cnt_f = F::from(cnt).unwrap();
498            mean.mapv_inplace(|v| v / cnt_f);
499            class_means.push(mean);
500            class_counts.push(cnt);
501        }
502
503        // --- Step 2: within-class scatter Sw ----------------------------------
504        let mut sw = Array2::<F>::zeros((n_features, n_features));
505        for (ci, &cls) in classes.iter().enumerate() {
506            let mu_c = &class_means[ci];
507            for (i, &label) in y.iter().enumerate() {
508                if label == cls {
509                    // diff = x[i] - mu_c
510                    let diff: Vec<F> = (0..n_features).map(|j| x[[i, j]] - mu_c[j]).collect();
511                    for r in 0..n_features {
512                        for c in 0..n_features {
513                            sw[[r, c]] = sw[[r, c]] + diff[r] * diff[c];
514                        }
515                    }
516                }
517            }
518        }
519
520        // Add a small regularisation to Sw to avoid singularity.
521        let reg = F::from(1e-6).unwrap();
522        for i in 0..n_features {
523            sw[[i, i]] = sw[[i, i]] + reg;
524        }
525
526        // --- Step 3: between-class scatter Sb ---------------------------------
527        let mut sb = Array2::<F>::zeros((n_features, n_features));
528        for (ci, &nc) in class_counts.iter().enumerate() {
529            let nc_f = F::from(nc).unwrap();
530            let diff: Vec<F> = (0..n_features)
531                .map(|j| class_means[ci][j] - overall_mean[j])
532                .collect();
533            for r in 0..n_features {
534                for c in 0..n_features {
535                    sb[[r, c]] = sb[[r, c]] + nc_f * diff[r] * diff[c];
536                }
537            }
538        }
539
540        // --- Step 4: solve generalised eigenvalue problem Sw⁻¹ Sb v = λ v ----
541        let m = sw_inv_sb(&sw, &sb)?;
542        let max_jacobi = n_features * n_features * 100 + 1000;
543        let (eigenvalues, eigenvectors) = jacobi_eigen_f(&m, max_jacobi)?;
544
545        // Sort eigenvalues descending.
546        let mut indices: Vec<usize> = (0..n_features).collect();
547        indices.sort_by(|&a, &b| {
548            eigenvalues[b]
549                .partial_cmp(&eigenvalues[a])
550                .unwrap_or(std::cmp::Ordering::Equal)
551        });
552
553        // Clamp negative eigenvalues.
554        let total_ev: F = eigenvalues
555            .iter()
556            .copied()
557            .map(|v| if v > F::zero() { v } else { F::zero() })
558            .fold(F::zero(), |a, b| a + b);
559
560        // --- Step 5: build scalings matrix (n_features × n_comp) -------------
561        let mut scalings = Array2::<F>::zeros((n_features, n_comp));
562        let mut explained_variance_ratio = Array1::<F>::zeros(n_comp);
563        for (k, &idx) in indices.iter().take(n_comp).enumerate() {
564            let ev = eigenvalues[idx];
565            let ev_clamped = if ev > F::zero() { ev } else { F::zero() };
566            explained_variance_ratio[k] = if total_ev > F::zero() {
567                ev_clamped / total_ev
568            } else {
569                F::zero()
570            };
571            for j in 0..n_features {
572                scalings[[j, k]] = eigenvectors[[j, idx]];
573            }
574        }
575
576        // --- Project class means into the discriminant space -----------------
577        // means[c, k] = class_means[c] · scalings[:, k]
578        let mut means = Array2::<F>::zeros((n_classes, n_comp));
579        for ci in 0..n_classes {
580            let mu_row = class_means[ci].view();
581            for k in 0..n_comp {
582                let mut dot = F::zero();
583                for j in 0..n_features {
584                    dot = dot + mu_row[j] * scalings[[j, k]];
585                }
586                means[[ci, k]] = dot;
587            }
588        }
589
590        Ok(FittedLDA {
591            scalings,
592            means,
593            explained_variance_ratio,
594            classes,
595            n_features,
596        })
597    }
598}
599
600// ---------------------------------------------------------------------------
601// Transform
602// ---------------------------------------------------------------------------
603
604impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedLDA<F> {
605    type Output = Array2<F>;
606    type Error = FerroError;
607
608    /// Project `x` onto the discriminant axes: `X @ scalings`.
609    ///
610    /// # Errors
611    ///
612    /// Returns [`FerroError::ShapeMismatch`] if `x.ncols()` does not match the
613    /// number of features seen during fitting.
614    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
615        if x.ncols() != self.n_features {
616            return Err(FerroError::ShapeMismatch {
617                expected: vec![x.nrows(), self.n_features],
618                actual: vec![x.nrows(), x.ncols()],
619                context: "FittedLDA::transform".into(),
620            });
621        }
622        Ok(x.dot(&self.scalings))
623    }
624}
625
626// ---------------------------------------------------------------------------
627// Predict (nearest centroid in projected space)
628// ---------------------------------------------------------------------------
629
630impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedLDA<F> {
631    type Output = Array1<usize>;
632    type Error = FerroError;
633
634    /// Classify samples by nearest centroid in the projected space.
635    ///
636    /// # Errors
637    ///
638    /// Returns [`FerroError::ShapeMismatch`] if the number of features does not
639    /// match the model.
640    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
641        let projected = self.transform(x)?;
642        let n_samples = projected.nrows();
643        let n_comp = projected.ncols();
644        let n_classes = self.classes.len();
645
646        let mut predictions = Array1::<usize>::zeros(n_samples);
647        for i in 0..n_samples {
648            let mut best_class = 0usize;
649            let mut best_dist = F::infinity();
650            for ci in 0..n_classes {
651                let mut dist = F::zero();
652                for k in 0..n_comp {
653                    let d = projected[[i, k]] - self.means[[ci, k]];
654                    dist = dist + d * d;
655                }
656                if dist < best_dist {
657                    best_dist = dist;
658                    best_class = ci;
659                }
660            }
661            predictions[i] = self.classes[best_class];
662        }
663        Ok(predictions)
664    }
665}
666
667// ---------------------------------------------------------------------------
668// Pipeline integration (generic)
669// ---------------------------------------------------------------------------
670
671impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for LDA<F> {
672    /// Fit LDA using the pipeline interface.
673    ///
674    /// # Errors
675    ///
676    /// Propagates errors from [`Fit::fit`].
677    fn fit_pipeline(
678        &self,
679        x: &Array2<F>,
680        y: &Array1<F>,
681    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
682        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
683        let fitted = self.fit(x, &y_usize)?;
684        Ok(Box::new(FittedLDAPipeline(fitted)))
685    }
686}
687
688/// Wrapper for pipeline integration that converts predictions to float.
689struct FittedLDAPipeline<F>(FittedLDA<F>);
690
691impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedLDAPipeline<F> {
692    /// Predict via the pipeline interface, returning float class labels.
693    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
694        let preds = self.0.predict(x)?;
695        Ok(preds.mapv(|v| NumCast::from(v).unwrap_or_else(F::nan)))
696    }
697}
698
699// ---------------------------------------------------------------------------
700// Tests
701// ---------------------------------------------------------------------------
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706    use approx::assert_abs_diff_eq;
707    use ndarray::{Array2, array};
708
709    // ------------------------------------------------------------------
710    // Helpers
711    // ------------------------------------------------------------------
712
713    fn linearly_separable_2d() -> (Array2<f64>, Array1<usize>) {
714        // Two well-separated Gaussian clusters.
715        let x = Array2::from_shape_vec(
716            (8, 2),
717            vec![
718                1.0, 1.0, 1.5, 1.2, 0.8, 0.9, 1.1, 1.3, // class 0
719                6.0, 6.0, 6.2, 5.8, 5.9, 6.1, 6.3, 5.7, // class 1
720            ],
721        )
722        .unwrap();
723        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
724        (x, y)
725    }
726
727    fn three_class_data() -> (Array2<f64>, Array1<usize>) {
728        let x = Array2::from_shape_vec(
729            (9, 2),
730            vec![
731                0.0, 0.0, 0.5, 0.1, 0.1, 0.5, // class 0
732                5.0, 0.0, 5.2, 0.3, 4.8, 0.1, // class 1
733                0.0, 5.0, 0.1, 5.2, 0.3, 4.8, // class 2
734            ],
735        )
736        .unwrap();
737        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
738        (x, y)
739    }
740
741    // ------------------------------------------------------------------
742
743    #[test]
744    fn test_lda_fit_returns_fitted() {
745        let (x, y) = linearly_separable_2d();
746        let lda = LDA::<f64>::new(Some(1));
747        let fitted = lda.fit(&x, &y).unwrap();
748        assert_eq!(fitted.scalings().ncols(), 1);
749        assert_eq!(fitted.scalings().nrows(), 2);
750    }
751
752    #[test]
753    fn test_lda_default_n_components() {
754        // With 2 classes the default n_components = min(1, n_features) = 1.
755        let (x, y) = linearly_separable_2d();
756        let lda = LDA::<f64>::default();
757        let fitted = lda.fit(&x, &y).unwrap();
758        assert_eq!(fitted.scalings().ncols(), 1);
759    }
760
761    #[test]
762    fn test_lda_transform_shape() {
763        let (x, y) = linearly_separable_2d();
764        let lda = LDA::<f64>::new(Some(1));
765        let fitted = lda.fit(&x, &y).unwrap();
766        let proj = fitted.transform(&x).unwrap();
767        assert_eq!(proj.dim(), (8, 1));
768    }
769
770    #[test]
771    fn test_lda_predict_accuracy_binary() {
772        let (x, y) = linearly_separable_2d();
773        let lda = LDA::<f64>::new(Some(1));
774        let fitted = lda.fit(&x, &y).unwrap();
775        let preds = fitted.predict(&x).unwrap();
776        let correct = preds.iter().zip(y.iter()).filter(|(p, a)| *p == *a).count();
777        assert_eq!(correct, 8, "All 8 samples should be classified correctly");
778    }
779
780    #[test]
781    fn test_lda_predict_three_classes() {
782        let (x, y) = three_class_data();
783        let lda = LDA::<f64>::new(Some(2));
784        let fitted = lda.fit(&x, &y).unwrap();
785        let preds = fitted.predict(&x).unwrap();
786        let correct = preds.iter().zip(y.iter()).filter(|(p, a)| *p == *a).count();
787        assert!(correct >= 7, "Expected at least 7/9 correct, got {correct}");
788    }
789
790    #[test]
791    fn test_lda_explained_variance_ratio_positive() {
792        let (x, y) = linearly_separable_2d();
793        let lda = LDA::<f64>::new(Some(1));
794        let fitted = lda.fit(&x, &y).unwrap();
795        for &v in fitted.explained_variance_ratio() {
796            assert!(v >= 0.0);
797        }
798    }
799
800    #[test]
801    fn test_lda_explained_variance_ratio_le_1() {
802        let (x, y) = three_class_data();
803        let lda = LDA::<f64>::new(Some(2));
804        let fitted = lda.fit(&x, &y).unwrap();
805        let total: f64 = fitted.explained_variance_ratio().iter().sum();
806        assert!(total <= 1.0 + 1e-9, "total={total}");
807    }
808
809    #[test]
810    fn test_lda_classes_accessor() {
811        let (x, y) = linearly_separable_2d();
812        let lda = LDA::<f64>::new(Some(1));
813        let fitted = lda.fit(&x, &y).unwrap();
814        assert_eq!(fitted.classes(), &[0usize, 1]);
815    }
816
817    #[test]
818    fn test_lda_means_shape() {
819        let (x, y) = three_class_data();
820        let lda = LDA::<f64>::new(Some(2));
821        let fitted = lda.fit(&x, &y).unwrap();
822        assert_eq!(fitted.means().dim(), (3, 2));
823    }
824
825    #[test]
826    fn test_lda_transform_shape_mismatch() {
827        let (x, y) = linearly_separable_2d();
828        let lda = LDA::<f64>::new(Some(1));
829        let fitted = lda.fit(&x, &y).unwrap();
830        let x_bad = Array2::<f64>::zeros((3, 5));
831        assert!(fitted.transform(&x_bad).is_err());
832    }
833
834    #[test]
835    fn test_lda_predict_shape_mismatch() {
836        let (x, y) = linearly_separable_2d();
837        let lda = LDA::<f64>::new(Some(1));
838        let fitted = lda.fit(&x, &y).unwrap();
839        let x_bad = Array2::<f64>::zeros((3, 5));
840        assert!(fitted.predict(&x_bad).is_err());
841    }
842
843    #[test]
844    fn test_lda_error_zero_n_components() {
845        let (x, y) = linearly_separable_2d();
846        let lda = LDA::<f64>::new(Some(0));
847        assert!(lda.fit(&x, &y).is_err());
848    }
849
850    #[test]
851    fn test_lda_error_n_components_too_large() {
852        let (x, y) = linearly_separable_2d(); // 2 classes → max 1 component
853        let lda = LDA::<f64>::new(Some(5));
854        assert!(lda.fit(&x, &y).is_err());
855    }
856
857    #[test]
858    fn test_lda_error_single_class() {
859        let x =
860            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
861        let y = array![0usize, 0, 0, 0];
862        let lda = LDA::<f64>::new(None);
863        assert!(lda.fit(&x, &y).is_err());
864    }
865
866    #[test]
867    fn test_lda_error_shape_mismatch_fit() {
868        let x = Array2::<f64>::zeros((4, 2));
869        let y = array![0usize, 1]; // wrong length
870        let lda = LDA::<f64>::new(None);
871        assert!(lda.fit(&x, &y).is_err());
872    }
873
874    #[test]
875    fn test_lda_error_insufficient_samples() {
876        let x = Array2::<f64>::zeros((1, 2));
877        let y = array![0usize];
878        let lda = LDA::<f64>::new(None);
879        assert!(lda.fit(&x, &y).is_err());
880    }
881
882    #[test]
883    fn test_lda_scalings_accessor() {
884        let (x, y) = linearly_separable_2d();
885        let lda = LDA::<f64>::new(Some(1));
886        let fitted = lda.fit(&x, &y).unwrap();
887        assert_eq!(fitted.scalings().dim(), (2, 1));
888    }
889
890    #[test]
891    fn test_lda_pipeline_estimator() {
892        use ferrolearn_core::pipeline::PipelineEstimator;
893
894        let (x, y_usize) = linearly_separable_2d();
895        let y_f64 = y_usize.mapv(|v| v as f64);
896        let lda = LDA::<f64>::new(Some(1));
897        let fitted = lda.fit_pipeline(&x, &y_f64).unwrap();
898        let preds = fitted.predict_pipeline(&x).unwrap();
899        assert_eq!(preds.len(), 8);
900    }
901
902    #[test]
903    fn test_lda_n_components_getter() {
904        let lda = LDA::<f64>::new(Some(2));
905        assert_eq!(lda.n_components(), Some(2));
906        let lda_none = LDA::<f64>::new(None);
907        assert_eq!(lda_none.n_components(), None);
908    }
909
910    #[test]
911    fn test_lda_transform_then_predict_consistent() {
912        let (x, y) = linearly_separable_2d();
913        let lda = LDA::<f64>::new(Some(1));
914        let fitted = lda.fit(&x, &y).unwrap();
915        // Manually compute nearest-centroid prediction from transform output.
916        let projected = fitted.transform(&x).unwrap();
917        let preds_predict = fitted.predict(&x).unwrap();
918        let n_samples = projected.nrows();
919        let n_comp = projected.ncols();
920        let n_classes = fitted.classes().len();
921        for i in 0..n_samples {
922            let mut best = 0;
923            let mut best_d = f64::INFINITY;
924            for ci in 0..n_classes {
925                let mut d = 0.0;
926                for k in 0..n_comp {
927                    let diff = projected[[i, k]] - fitted.means()[[ci, k]];
928                    d += diff * diff;
929                }
930                if d < best_d {
931                    best_d = d;
932                    best = ci;
933                }
934            }
935            assert_eq!(preds_predict[i], fitted.classes()[best]);
936        }
937    }
938
939    #[test]
940    fn test_lda_projected_class_separation() {
941        let (x, y) = linearly_separable_2d();
942        let lda = LDA::<f64>::new(Some(1));
943        let fitted = lda.fit(&x, &y).unwrap();
944        let projected = fitted.transform(&x).unwrap();
945
946        // Means of class 0 and class 1 in projected space should be far apart.
947        let mean0: f64 = projected
948            .rows()
949            .into_iter()
950            .zip(y.iter())
951            .filter(|&(_, label)| *label == 0)
952            .map(|(row, _)| row[0])
953            .sum::<f64>()
954            / 4.0;
955        let mean1: f64 = projected
956            .rows()
957            .into_iter()
958            .zip(y.iter())
959            .filter(|&(_, label)| *label == 1)
960            .map(|(row, _)| row[0])
961            .sum::<f64>()
962            / 4.0;
963
964        assert!(
965            (mean0 - mean1).abs() > 0.5,
966            "Projected means should differ, got {mean0} vs {mean1}"
967        );
968    }
969
970    #[test]
971    fn test_lda_transform_known_data() {
972        // With perfectly separated data the transform should yield two clearly
973        // distinct groups.
974        let x = Array2::from_shape_vec((4, 1), vec![-2.0, -1.0, 1.0, 2.0]).unwrap();
975        let y = array![0usize, 0, 1, 1];
976        let lda = LDA::<f64>::new(Some(1));
977        let fitted = lda.fit(&x, &y).unwrap();
978        let proj = fitted.transform(&x).unwrap();
979        // The first two samples should project to one side, the other two to the other side.
980        let sign0 = proj[[0, 0]].signum();
981        let sign1 = proj[[2, 0]].signum();
982        // They should be on opposite sides of the origin (or at least the split is correct).
983        assert_ne!(
984            sign0 as i32, sign1 as i32,
985            "Classes should be on opposite sides"
986        );
987    }
988
989    #[test]
990    fn test_lda_abs_diff_eq_means_dimensions() {
991        let (x, y) = linearly_separable_2d();
992        let lda = LDA::<f64>::new(Some(1));
993        let fitted = lda.fit(&x, &y).unwrap();
994        // Each class mean in projected space should be a 1-component vector.
995        assert_eq!(fitted.means().ncols(), 1);
996        let m0 = fitted.means()[[0, 0]];
997        let m1 = fitted.means()[[1, 0]];
998        // For well-separated data the projected means should differ by > 1.0.
999        assert!((m0 - m1).abs() > 0.5, "m0={m0}, m1={m1}");
1000        assert_abs_diff_eq!(0.0_f64, 0.0_f64); // use the import
1001    }
1002}