Skip to main content

ferrolearn_decomp/
cross_decomposition.rs

1//! Cross-decomposition methods: PLS, CCA, and PLSSVD.
2//!
3//! This module provides Partial Least Squares (PLS) and Canonical Correlation
4//! Analysis (CCA) methods for modelling the relationship between two multivariate
5//! datasets X and Y.
6//!
7//! # Algorithms
8//!
9//! - [`PLSSVD`] — SVD of the cross-covariance matrix. The simplest PLS variant;
10//!   computes weight matrices from the leading singular vectors of `X^T Y`.
11//! - [`PLSRegression`] — PLS via the NIPALS algorithm. Maximises covariance
12//!   between X-scores and Y-scores, with asymmetric deflation suitable for
13//!   regression.
14//! - [`PLSCanonical`] — Canonical PLS via NIPALS. Symmetric deflation of both
15//!   X and Y.
16//! - [`CCA`] — Canonical Correlation Analysis via NIPALS. Maximises
17//!   *correlation* (not covariance) between X-scores and Y-scores by
18//!   normalising scores to unit variance.
19//!
20//! # Examples
21//!
22//! ```
23//! use ferrolearn_decomp::cross_decomposition::PLSRegression;
24//! use ferrolearn_core::traits::{Fit, Predict, Transform};
25//! use ndarray::array;
26//!
27//! let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
28//! let y = array![[1.0], [2.0], [3.0], [4.0]];
29//!
30//! let pls = PLSRegression::<f64>::new(1);
31//! let fitted = pls.fit(&x, &y).unwrap();
32//! let y_pred = fitted.predict(&x).unwrap();
33//! assert_eq!(y_pred.ncols(), 1);
34//! let x_scores = fitted.transform(&x).unwrap();
35//! assert_eq!(x_scores.ncols(), 1);
36//! ```
37
38use ferrolearn_core::backend::Backend;
39use ferrolearn_core::backend_faer::NdarrayFaerBackend;
40use ferrolearn_core::error::FerroError;
41use ferrolearn_core::traits::{Fit, Predict, Transform};
42use ndarray::{Array1, Array2};
43use num_traits::Float;
44use std::any::TypeId;
45
46/// Result type for SVD: `(U, S, Vt)`.
47type SvdResult<F> = Result<(Array2<F>, Array1<F>, Array2<F>), FerroError>;
48
49// ---------------------------------------------------------------------------
50// Helper: centre and optionally scale columns of a matrix
51// ---------------------------------------------------------------------------
52
53/// Centre (and optionally scale to unit variance) columns of a matrix.
54///
55/// Returns `(centred_matrix, mean, std)` where `std` is `None` when
56/// `scale` is `false`.
57fn centre_scale<F: Float + Send + Sync + 'static>(
58    x: &Array2<F>,
59    scale: bool,
60) -> (Array2<F>, Array1<F>, Option<Array1<F>>) {
61    let (n_samples, n_features) = x.dim();
62    let n_f = F::from(n_samples).unwrap();
63
64    // Compute column means.
65    let mean = Array1::from_shape_fn(n_features, |j| {
66        x.column(j).iter().copied().fold(F::zero(), |a, b| a + b) / n_f
67    });
68
69    // Centre.
70    let mut xc = x.to_owned();
71    for mut row in xc.rows_mut() {
72        for (v, &m) in row.iter_mut().zip(mean.iter()) {
73            *v = *v - m;
74        }
75    }
76
77    if scale {
78        let n_minus_1 = F::from(n_samples.saturating_sub(1).max(1)).unwrap();
79        let std_dev = Array1::from_shape_fn(n_features, |j| {
80            let var = xc
81                .column(j)
82                .iter()
83                .copied()
84                .fold(F::zero(), |a, b| a + b * b)
85                / n_minus_1;
86            let s = var.sqrt();
87            if s < F::epsilon() { F::one() } else { s }
88        });
89        for mut row in xc.rows_mut() {
90            for (v, &s) in row.iter_mut().zip(std_dev.iter()) {
91                *v = *v / s;
92            }
93        }
94        (xc, mean, Some(std_dev))
95    } else {
96        (xc, mean, None)
97    }
98}
99
100/// Apply centring (and optionally scaling) to new data using stored statistics.
101fn apply_centre_scale<F: Float + Send + Sync + 'static>(
102    x: &Array2<F>,
103    mean: &Array1<F>,
104    std_dev: &Option<Array1<F>>,
105    context: &str,
106) -> Result<Array2<F>, FerroError> {
107    if x.ncols() != mean.len() {
108        return Err(FerroError::ShapeMismatch {
109            expected: vec![x.nrows(), mean.len()],
110            actual: vec![x.nrows(), x.ncols()],
111            context: context.into(),
112        });
113    }
114    let mut xc = x.to_owned();
115    for mut row in xc.rows_mut() {
116        for (v, &m) in row.iter_mut().zip(mean.iter()) {
117            *v = *v - m;
118        }
119    }
120    if let Some(ref s) = *std_dev {
121        for mut row in xc.rows_mut() {
122            for (v, &sd) in row.iter_mut().zip(s.iter()) {
123                *v = *v / sd;
124            }
125        }
126    }
127    Ok(xc)
128}
129
130// ---------------------------------------------------------------------------
131// Helper: SVD dispatch (generic F via f64 fast-path or Jacobi fallback)
132// ---------------------------------------------------------------------------
133
134/// Compute the thin SVD of a general (m x n) matrix.
135///
136/// Returns `(U, S, Vt)` where:
137/// - `U` is `(m, min(m,n))`,
138/// - `S` is `(min(m,n),)`,
139/// - `Vt` is `(min(m,n), n)`.
140///
141/// For `f64` this delegates to `NdarrayFaerBackend::svd` (faer's optimised
142/// routine). For other float types it falls back to a power-iteration approach.
143fn svd_dispatch<F: Float + Send + Sync + 'static>(a: &Array2<F>) -> SvdResult<F> {
144    if TypeId::of::<F>() == TypeId::of::<f64>() {
145        // Cast to f64 and use faer.
146        let a_f64: &Array2<f64> = unsafe { &*(std::ptr::from_ref(a).cast::<Array2<f64>>()) };
147        let (u, s, vt) = NdarrayFaerBackend::svd(a_f64)?;
148        // Thin U and Vt.
149        let k = s.len();
150        let u_thin = u.slice(ndarray::s![.., ..k]).to_owned();
151        let vt_thin = vt.slice(ndarray::s![..k, ..]).to_owned();
152
153        // Cast back to F (which is f64).
154        let u_f: Array2<F> = unsafe { std::mem::transmute_copy::<Array2<f64>, Array2<F>>(&u_thin) };
155        let s_f: Array1<F> = unsafe { std::mem::transmute_copy::<Array1<f64>, Array1<F>>(&s) };
156        let vt_f: Array2<F> =
157            unsafe { std::mem::transmute_copy::<Array2<f64>, Array2<F>>(&vt_thin) };
158        std::mem::forget(u_thin);
159        std::mem::forget(s);
160        std::mem::forget(vt_thin);
161        Ok((u_f, s_f, vt_f))
162    } else if TypeId::of::<F>() == TypeId::of::<f32>() {
163        // Convert f32 -> f64, compute, convert back.
164        let (m, n) = a.dim();
165        let a_f64 =
166            Array2::<f64>::from_shape_fn((m, n), |(i, j)| a[[i, j]].to_f64().unwrap_or(0.0));
167        let (u64, s64, vt64) = NdarrayFaerBackend::svd(&a_f64)?;
168        let k = s64.len();
169        let u_thin = u64.slice(ndarray::s![.., ..k]).to_owned();
170        let vt_thin = vt64.slice(ndarray::s![..k, ..]).to_owned();
171
172        let u_f =
173            Array2::<F>::from_shape_fn(u_thin.dim(), |(i, j)| F::from(u_thin[[i, j]]).unwrap());
174        let s_f = Array1::<F>::from_shape_fn(s64.len(), |i| F::from(s64[i]).unwrap());
175        let vt_f =
176            Array2::<F>::from_shape_fn(vt_thin.dim(), |(i, j)| F::from(vt_thin[[i, j]]).unwrap());
177        Ok((u_f, s_f, vt_f))
178    } else {
179        // Fallback: compute via eigendecomposition of A^T A.
180        svd_via_eigen(a)
181    }
182}
183
184/// Compute SVD via eigendecomposition of `A^T A` (fallback for exotic float types).
185fn svd_via_eigen<F: Float + Send + Sync + 'static>(a: &Array2<F>) -> SvdResult<F> {
186    let (m, n) = a.dim();
187    let k = m.min(n);
188
189    // Compute A^T A.
190    let ata = a.t().dot(a);
191
192    // Jacobi eigendecomposition of A^T A.
193    let max_iter = n * n * 100 + 1000;
194    let (eigenvalues, eigenvectors) = jacobi_eigen_symmetric(&ata, max_iter)?;
195
196    // Sort eigenvalues descending.
197    let mut indices: Vec<usize> = (0..n).collect();
198    indices.sort_by(|&i, &j| {
199        eigenvalues[j]
200            .partial_cmp(&eigenvalues[i])
201            .unwrap_or(std::cmp::Ordering::Equal)
202    });
203
204    // Take top k.
205    let mut s = Array1::<F>::zeros(k);
206    let mut v = Array2::<F>::zeros((n, k));
207    for (col, &idx) in indices.iter().take(k).enumerate() {
208        let eval = eigenvalues[idx];
209        s[col] = if eval > F::zero() {
210            eval.sqrt()
211        } else {
212            F::zero()
213        };
214        for row in 0..n {
215            v[[row, col]] = eigenvectors[[row, idx]];
216        }
217    }
218
219    // U = A V S^{-1}
220    let av = a.dot(&v);
221    let mut u = Array2::<F>::zeros((m, k));
222    for col in 0..k {
223        if s[col] > F::epsilon() {
224            let inv_s = F::one() / s[col];
225            for row in 0..m {
226                u[[row, col]] = av[[row, col]] * inv_s;
227            }
228        }
229    }
230
231    // Vt = V^T
232    let mut vt = Array2::<F>::zeros((k, n));
233    for i in 0..k {
234        for j in 0..n {
235            vt[[i, j]] = v[[j, i]];
236        }
237    }
238
239    Ok((u, s, vt))
240}
241
242/// Jacobi eigendecomposition for symmetric matrices (generic F fallback).
243fn jacobi_eigen_symmetric<F: Float + Send + Sync + 'static>(
244    a: &Array2<F>,
245    max_iter: usize,
246) -> Result<(Array1<F>, Array2<F>), FerroError> {
247    let n = a.nrows();
248    let mut mat = a.to_owned();
249    let mut v = Array2::<F>::zeros((n, n));
250    for i in 0..n {
251        v[[i, i]] = F::one();
252    }
253
254    let tol = F::from(1e-12).unwrap_or_else(F::epsilon);
255
256    for _iteration in 0..max_iter {
257        let mut max_off = F::zero();
258        let mut p = 0;
259        let mut q = 1;
260        for i in 0..n {
261            for j in (i + 1)..n {
262                let val = mat[[i, j]].abs();
263                if val > max_off {
264                    max_off = val;
265                    p = i;
266                    q = j;
267                }
268            }
269        }
270
271        if max_off < tol {
272            let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
273            return Ok((eigenvalues, v));
274        }
275
276        let app = mat[[p, p]];
277        let aqq = mat[[q, q]];
278        let apq = mat[[p, q]];
279
280        let theta = if (app - aqq).abs() < tol {
281            F::from(std::f64::consts::FRAC_PI_4).unwrap_or_else(F::one)
282        } else {
283            let tau = (aqq - app) / (F::from(2.0).unwrap() * apq);
284            let t = if tau >= F::zero() {
285                F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
286            } else {
287                -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
288            };
289            t.atan()
290        };
291
292        let c = theta.cos();
293        let s = theta.sin();
294
295        let mut new_mat = mat.clone();
296        for i in 0..n {
297            if i != p && i != q {
298                let mip = mat[[i, p]];
299                let miq = mat[[i, q]];
300                new_mat[[i, p]] = c * mip - s * miq;
301                new_mat[[p, i]] = new_mat[[i, p]];
302                new_mat[[i, q]] = s * mip + c * miq;
303                new_mat[[q, i]] = new_mat[[i, q]];
304            }
305        }
306
307        new_mat[[p, p]] = c * c * app - F::from(2.0).unwrap() * s * c * apq + s * s * aqq;
308        new_mat[[q, q]] = s * s * app + F::from(2.0).unwrap() * s * c * apq + c * c * aqq;
309        new_mat[[p, q]] = F::zero();
310        new_mat[[q, p]] = F::zero();
311
312        mat = new_mat;
313
314        for i in 0..n {
315            let vip = v[[i, p]];
316            let viq = v[[i, q]];
317            v[[i, p]] = c * vip - s * viq;
318            v[[i, q]] = s * vip + c * viq;
319        }
320    }
321
322    Err(FerroError::ConvergenceFailure {
323        iterations: max_iter,
324        message: "Jacobi eigendecomposition did not converge in cross_decomposition SVD fallback"
325            .into(),
326    })
327}
328
329// ---------------------------------------------------------------------------
330// Helper: vector norm and dot
331// ---------------------------------------------------------------------------
332
333/// L2 norm of a 1-D array.
334fn norm<F: Float>(v: &Array1<F>) -> F {
335    v.iter().copied().fold(F::zero(), |a, b| a + b * b).sqrt()
336}
337
338/// Dot product of two 1-D arrays.
339fn dot<F: Float>(a: &Array1<F>, b: &Array1<F>) -> F {
340    a.iter()
341        .copied()
342        .zip(b.iter().copied())
343        .fold(F::zero(), |acc, (x, y)| acc + x * y)
344}
345
346// ---------------------------------------------------------------------------
347// Helper: solve (P^T W) inverse for PLSRegression predict
348// ---------------------------------------------------------------------------
349
350/// Solve `(P^T W)^{-1}` for a square matrix using Gaussian elimination
351/// with partial pivoting (generic float).
352///
353/// When a pivot is too small (near-singular), it is regularised with a
354/// small perturbation to avoid hard failures. This matches the behaviour
355/// of scikit-learn, which uses `pinv` for the rotation matrix.
356fn invert_square<F: Float + Send + Sync + 'static>(a: &Array2<F>) -> Result<Array2<F>, FerroError> {
357    let n = a.nrows();
358    if n != a.ncols() {
359        return Err(FerroError::ShapeMismatch {
360            expected: vec![n, n],
361            actual: vec![a.nrows(), a.ncols()],
362            context: "invert_square: matrix must be square".into(),
363        });
364    }
365
366    // Augmented matrix [A | I].
367    let mut aug = Array2::<F>::zeros((n, 2 * n));
368    for i in 0..n {
369        for j in 0..n {
370            aug[[i, j]] = a[[i, j]];
371        }
372        aug[[i, n + i]] = F::one();
373    }
374
375    // Compute a tolerance based on the matrix norm.
376    let max_abs = a.iter().copied().fold(F::zero(), |m, v| {
377        let abs = v.abs();
378        if abs > m { abs } else { m }
379    });
380    let regularise_tol = max_abs * F::from(1e-12).unwrap_or_else(F::epsilon)
381        + F::from(1e-15).unwrap_or_else(F::epsilon);
382
383    // Forward elimination with partial pivoting.
384    for col in 0..n {
385        // Find pivot.
386        let mut max_val = aug[[col, col]].abs();
387        let mut max_row = col;
388        for row in (col + 1)..n {
389            let val = aug[[row, col]].abs();
390            if val > max_val {
391                max_val = val;
392                max_row = row;
393            }
394        }
395
396        // Regularise if pivot is too small.
397        if max_val < regularise_tol {
398            aug[[col, col]] = regularise_tol;
399        } else {
400            // Swap rows.
401            if max_row != col {
402                for j in 0..(2 * n) {
403                    let tmp = aug[[col, j]];
404                    aug[[col, j]] = aug[[max_row, j]];
405                    aug[[max_row, j]] = tmp;
406                }
407            }
408        }
409
410        // Eliminate below.
411        let pivot = aug[[col, col]];
412        for row in (col + 1)..n {
413            let factor = aug[[row, col]] / pivot;
414            for j in col..(2 * n) {
415                let above = aug[[col, j]];
416                aug[[row, j]] = aug[[row, j]] - factor * above;
417            }
418        }
419    }
420
421    // Back substitution.
422    for col in (0..n).rev() {
423        let pivot = aug[[col, col]];
424        for j in 0..(2 * n) {
425            aug[[col, j]] = aug[[col, j]] / pivot;
426        }
427        for row in 0..col {
428            let factor = aug[[row, col]];
429            for j in 0..(2 * n) {
430                let below = aug[[col, j]];
431                aug[[row, j]] = aug[[row, j]] - factor * below;
432            }
433        }
434    }
435
436    // Extract inverse.
437    let mut inv = Array2::<F>::zeros((n, n));
438    for i in 0..n {
439        for j in 0..n {
440            inv[[i, j]] = aug[[i, n + j]];
441        }
442    }
443    Ok(inv)
444}
445
446// ===========================================================================
447// PLSSVD
448// ===========================================================================
449
450/// PLS via Singular Value Decomposition of the cross-covariance matrix.
451///
452/// This is the simplest PLS variant. It computes the weight matrices by
453/// taking the leading left and right singular vectors of `X^T Y` after
454/// optional centring and scaling.
455///
456/// Unlike [`PLSRegression`], PLSSVD does not iterate; it is a single
457/// matrix decomposition. It cannot predict Y from X — use
458/// [`PLSRegression`] if you need a `predict` method.
459///
460/// # Type Parameters
461///
462/// - `F`: The floating-point scalar type.
463///
464/// # Examples
465///
466/// ```
467/// use ferrolearn_decomp::cross_decomposition::PLSSVD;
468/// use ferrolearn_core::traits::{Fit, Transform};
469/// use ndarray::array;
470///
471/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
472/// let y = array![[1.0], [2.0], [3.0], [4.0]];
473/// let svd = PLSSVD::<f64>::new(1);
474/// let fitted = svd.fit(&x, &y).unwrap();
475/// let scores = fitted.transform(&x).unwrap();
476/// assert_eq!(scores.ncols(), 1);
477/// ```
478#[derive(Debug, Clone)]
479pub struct PLSSVD<F> {
480    /// Number of components to extract.
481    n_components: usize,
482    /// Whether to scale X and Y to unit variance before decomposition.
483    scale: bool,
484    _marker: std::marker::PhantomData<F>,
485}
486
487impl<F: Float + Send + Sync + 'static> PLSSVD<F> {
488    /// Create a new `PLSSVD` that extracts `n_components` components.
489    #[must_use]
490    pub fn new(n_components: usize) -> Self {
491        Self {
492            n_components,
493            scale: true,
494            _marker: std::marker::PhantomData,
495        }
496    }
497
498    /// Set whether to scale X and Y to unit variance (default: `true`).
499    #[must_use]
500    pub fn with_scale(mut self, scale: bool) -> Self {
501        self.scale = scale;
502        self
503    }
504
505    /// Return the number of components.
506    #[must_use]
507    pub fn n_components(&self) -> usize {
508        self.n_components
509    }
510}
511
512/// A fitted [`PLSSVD`] model.
513///
514/// Holds the learned weight matrices and centring/scaling statistics.
515/// Implements [`Transform`] to project X data onto the PLS score space.
516#[derive(Debug, Clone)]
517pub struct FittedPLSSVD<F> {
518    /// X-weights, shape `(n_features_x, n_components)`.
519    x_weights_: Array2<F>,
520    /// Y-weights, shape `(n_features_y, n_components)`.
521    y_weights_: Array2<F>,
522    /// Per-feature mean of X.
523    x_mean_: Array1<F>,
524    /// Per-feature mean of Y.
525    y_mean_: Array1<F>,
526    /// Per-feature standard deviation of X (None if not scaled).
527    x_std_: Option<Array1<F>>,
528    /// Per-feature standard deviation of Y (None if not scaled).
529    y_std_: Option<Array1<F>>,
530}
531
532impl<F: Float + Send + Sync + 'static> FittedPLSSVD<F> {
533    /// X-weights matrix, shape `(n_features_x, n_components)`.
534    #[must_use]
535    pub fn x_weights(&self) -> &Array2<F> {
536        &self.x_weights_
537    }
538
539    /// Y-weights matrix, shape `(n_features_y, n_components)`.
540    #[must_use]
541    pub fn y_weights(&self) -> &Array2<F> {
542        &self.y_weights_
543    }
544
545    /// Per-feature mean of X learned during fitting.
546    #[must_use]
547    pub fn x_mean(&self) -> &Array1<F> {
548        &self.x_mean_
549    }
550
551    /// Per-feature mean of Y learned during fitting.
552    #[must_use]
553    pub fn y_mean(&self) -> &Array1<F> {
554        &self.y_mean_
555    }
556
557    /// Transform Y data onto the Y-score space.
558    ///
559    /// # Errors
560    ///
561    /// Returns [`FerroError::ShapeMismatch`] if Y has the wrong number of columns.
562    pub fn transform_y(&self, y: &Array2<F>) -> Result<Array2<F>, FerroError> {
563        let yc = apply_centre_scale(y, &self.y_mean_, &self.y_std_, "FittedPLSSVD::transform_y")?;
564        Ok(yc.dot(&self.y_weights_))
565    }
566}
567
568impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for PLSSVD<F> {
569    type Fitted = FittedPLSSVD<F>;
570    type Error = FerroError;
571
572    /// Fit PLSSVD by computing the SVD of the cross-covariance matrix `X^T Y`.
573    ///
574    /// # Errors
575    ///
576    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or exceeds
577    ///   `min(n_features_x, n_features_y)`.
578    /// - [`FerroError::InsufficientSamples`] if there are fewer than 2 samples.
579    /// - [`FerroError::ShapeMismatch`] if X and Y have different numbers of rows.
580    fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedPLSSVD<F>, FerroError> {
581        let (n_samples_x, n_features_x) = x.dim();
582        let (n_samples_y, n_features_y) = y.dim();
583
584        if n_samples_x != n_samples_y {
585            return Err(FerroError::ShapeMismatch {
586                expected: vec![n_samples_x, n_features_y],
587                actual: vec![n_samples_y, n_features_y],
588                context: "PLSSVD::fit: X and Y must have the same number of rows".into(),
589            });
590        }
591
592        if self.n_components == 0 {
593            return Err(FerroError::InvalidParameter {
594                name: "n_components".into(),
595                reason: "must be at least 1".into(),
596            });
597        }
598
599        let max_components = n_features_x.min(n_features_y);
600        if self.n_components > max_components {
601            return Err(FerroError::InvalidParameter {
602                name: "n_components".into(),
603                reason: format!(
604                    "n_components ({}) exceeds min(n_features_x, n_features_y) ({})",
605                    self.n_components, max_components
606                ),
607            });
608        }
609
610        if n_samples_x < 2 {
611            return Err(FerroError::InsufficientSamples {
612                required: 2,
613                actual: n_samples_x,
614                context: "PLSSVD::fit requires at least 2 samples".into(),
615            });
616        }
617
618        // Centre and optionally scale.
619        let (xc, x_mean, x_std) = centre_scale(x, self.scale);
620        let (yc, y_mean, y_std) = centre_scale(y, self.scale);
621
622        // Cross-covariance: C = X^T Y.
623        let c = xc.t().dot(&yc);
624
625        // SVD of C.
626        let (u, _s, vt) = svd_dispatch(&c)?;
627
628        // Take first n_components columns of U, rows of Vt (= columns of V).
629        let nc = self.n_components;
630        let x_weights = u.slice(ndarray::s![.., ..nc]).to_owned();
631        // V = Vt^T, so columns of V = rows of Vt transposed.
632        let y_weights = vt.t().slice(ndarray::s![.., ..nc]).to_owned();
633
634        Ok(FittedPLSSVD {
635            x_weights_: x_weights,
636            y_weights_: y_weights,
637            x_mean_: x_mean,
638            y_mean_: y_mean,
639            x_std_: x_std,
640            y_std_: y_std,
641        })
642    }
643}
644
645impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedPLSSVD<F> {
646    type Output = Array2<F>;
647    type Error = FerroError;
648
649    /// Project X data onto the PLS score space: `(X - x_mean) / x_std @ x_weights`.
650    ///
651    /// # Errors
652    ///
653    /// Returns [`FerroError::ShapeMismatch`] if X has the wrong number of columns.
654    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
655        let xc = apply_centre_scale(x, &self.x_mean_, &self.x_std_, "FittedPLSSVD::transform")?;
656        Ok(xc.dot(&self.x_weights_))
657    }
658}
659
660// ===========================================================================
661// NIPALS mode enum (shared by PLSRegression, PLSCanonical, CCA)
662// ===========================================================================
663
664/// Internal NIPALS deflation mode.
665#[derive(Debug, Clone, Copy, PartialEq, Eq)]
666enum NipalsMode {
667    /// Regression: deflate Y with X-scores (`Y = Y - t q^T`).
668    Regression,
669    /// Canonical: deflate Y with its own scores (`Y = Y - u c^T`).
670    Canonical,
671}
672
673/// Internal flag: whether to normalise scores to unit variance (CCA).
674#[derive(Debug, Clone, Copy, PartialEq, Eq)]
675enum ScoreNorm {
676    /// Do not normalise scores (PLS).
677    None,
678    /// Normalise scores to unit variance (CCA).
679    UnitVariance,
680}
681
682// ---------------------------------------------------------------------------
683// NIPALS core algorithm
684// ---------------------------------------------------------------------------
685
686/// Result from a NIPALS fit.
687#[derive(Debug, Clone)]
688struct NipalsResult<F> {
689    /// X-weights W, shape `(n_features_x, n_components)` — columns are weight vectors.
690    x_weights: Array2<F>,
691    /// X-loadings P, shape `(n_features_x, n_components)`.
692    x_loadings: Array2<F>,
693    /// X-scores T, shape `(n_samples, n_components)`.
694    x_scores: Array2<F>,
695    /// Y-loadings Q, shape `(n_features_y, n_components)`.
696    y_loadings: Array2<F>,
697    /// Y-scores U, shape `(n_samples, n_components)`.
698    y_scores: Array2<F>,
699    /// Number of iterations per component.
700    n_iter: Vec<usize>,
701}
702
703/// Run the NIPALS algorithm.
704fn nipals<F: Float + Send + Sync + 'static>(
705    x: &Array2<F>,
706    y: &Array2<F>,
707    n_components: usize,
708    max_iter: usize,
709    tol: F,
710    mode: NipalsMode,
711    score_norm: ScoreNorm,
712) -> Result<NipalsResult<F>, FerroError> {
713    let (n_samples, n_features_x) = x.dim();
714    let n_features_y = y.ncols();
715
716    let mut xk = x.to_owned();
717    let mut yk = y.to_owned();
718
719    let mut x_weights = Array2::<F>::zeros((n_features_x, n_components));
720    let mut x_loadings = Array2::<F>::zeros((n_features_x, n_components));
721    let mut x_scores = Array2::<F>::zeros((n_samples, n_components));
722    let mut y_loadings = Array2::<F>::zeros((n_features_y, n_components));
723    let mut y_scores = Array2::<F>::zeros((n_samples, n_components));
724    let mut n_iter_vec = Vec::with_capacity(n_components);
725
726    for k in 0..n_components {
727        // Initialise u = column of Y with max variance.
728        let best_col = (0..n_features_y)
729            .max_by(|&a, &b| {
730                let var_a: F = yk
731                    .column(a)
732                    .iter()
733                    .copied()
734                    .fold(F::zero(), |s, v| s + v * v);
735                let var_b: F = yk
736                    .column(b)
737                    .iter()
738                    .copied()
739                    .fold(F::zero(), |s, v| s + v * v);
740                var_a
741                    .partial_cmp(&var_b)
742                    .unwrap_or(std::cmp::Ordering::Equal)
743            })
744            .unwrap_or(0);
745
746        let mut u = yk.column(best_col).to_owned();
747
748        let mut converged = false;
749        let mut iters = 0;
750
751        for iteration in 0..max_iter {
752            iters = iteration + 1;
753
754            // w = X^T u / (u^T u)
755            let utu = dot(&u, &u);
756            let mut w = xk.t().dot(&u);
757            if utu > F::epsilon() {
758                w.mapv_inplace(|v| v / utu);
759            }
760            // Normalise w.
761            let w_norm = norm(&w);
762            if w_norm < F::epsilon() {
763                // Degenerate: zero weight vector.
764                break;
765            }
766            w.mapv_inplace(|v| v / w_norm);
767
768            // t = X w
769            let t = xk.dot(&w);
770
771            // q = Y^T t / (t^T t)
772            let ttt = dot(&t, &t);
773            let mut q = yk.t().dot(&t);
774            if ttt > F::epsilon() {
775                q.mapv_inplace(|v| v / ttt);
776            }
777
778            // For CCA: normalise q.
779            if score_norm == ScoreNorm::UnitVariance {
780                let q_norm = norm(&q);
781                if q_norm > F::epsilon() {
782                    q.mapv_inplace(|v| v / q_norm);
783                }
784            }
785
786            // u_new = Y q / (q^T q)
787            let qtq = dot(&q, &q);
788            let mut u_new = yk.dot(&q);
789            if qtq > F::epsilon() {
790                u_new.mapv_inplace(|v| v / qtq);
791            }
792
793            // For CCA: normalise t and u to unit variance.
794            // (This is done after the loop for storing; here we just check convergence.)
795
796            // Convergence check: ||u_new - u|| / ||u_new||.
797            let diff_norm = {
798                let diff: Array1<F> = &u_new - &u;
799                norm(&diff)
800            };
801            let u_new_norm = norm(&u_new);
802
803            u = u_new;
804
805            if u_new_norm > F::epsilon() && diff_norm / u_new_norm < tol {
806                converged = true;
807                // Recompute final w, t, q with converged u.
808                // w = X^T u / (u^T u), normalised
809                let utu2 = dot(&u, &u);
810                w = xk.t().dot(&u);
811                if utu2 > F::epsilon() {
812                    w.mapv_inplace(|v| v / utu2);
813                }
814                let w_norm2 = norm(&w);
815                if w_norm2 > F::epsilon() {
816                    w.mapv_inplace(|v| v / w_norm2);
817                }
818                // t = X w
819                let t_final = xk.dot(&w);
820                let ttt2 = dot(&t_final, &t_final);
821                q = yk.t().dot(&t_final);
822                if ttt2 > F::epsilon() {
823                    q.mapv_inplace(|v| v / ttt2);
824                }
825                if score_norm == ScoreNorm::UnitVariance {
826                    let q_norm2 = norm(&q);
827                    if q_norm2 > F::epsilon() {
828                        q.mapv_inplace(|v| v / q_norm2);
829                    }
830                }
831                let qtq2 = dot(&q, &q);
832                u = yk.dot(&q);
833                if qtq2 > F::epsilon() {
834                    u.mapv_inplace(|v| v / qtq2);
835                }
836                break;
837            }
838        }
839
840        // Compute final scores and loadings with the converged weights.
841        let utu_final = dot(&u, &u);
842        let mut w_final = xk.t().dot(&u);
843        if utu_final > F::epsilon() {
844            w_final.mapv_inplace(|v| v / utu_final);
845        }
846        let w_norm_final = norm(&w_final);
847        if w_norm_final > F::epsilon() {
848            w_final.mapv_inplace(|v| v / w_norm_final);
849        }
850
851        let mut t_final = xk.dot(&w_final);
852        let ttt_final = dot(&t_final, &t_final);
853
854        // p = X^T t / (t^T t)
855        let mut p = xk.t().dot(&t_final);
856        if ttt_final > F::epsilon() {
857            p.mapv_inplace(|v| v / ttt_final);
858        }
859
860        // q = Y^T t / (t^T t)
861        let mut q_final = yk.t().dot(&t_final);
862        if ttt_final > F::epsilon() {
863            q_final.mapv_inplace(|v| v / ttt_final);
864        }
865
866        if score_norm == ScoreNorm::UnitVariance {
867            let q_norm = norm(&q_final);
868            if q_norm > F::epsilon() {
869                q_final.mapv_inplace(|v| v / q_norm);
870            }
871        }
872
873        let qtq_final = dot(&q_final, &q_final);
874        let mut u_final = yk.dot(&q_final);
875        if qtq_final > F::epsilon() {
876            u_final.mapv_inplace(|v| v / qtq_final);
877        }
878
879        // For CCA: normalise t and u to unit variance.
880        if score_norm == ScoreNorm::UnitVariance {
881            let t_std = {
882                let t_mean = t_final.iter().copied().fold(F::zero(), |a, b| a + b)
883                    / F::from(n_samples).unwrap();
884                let var = t_final
885                    .iter()
886                    .copied()
887                    .fold(F::zero(), |a, b| a + (b - t_mean) * (b - t_mean))
888                    / F::from(n_samples.saturating_sub(1).max(1)).unwrap();
889                var.sqrt()
890            };
891            if t_std > F::epsilon() {
892                t_final.mapv_inplace(|v| v / t_std);
893            }
894
895            let u_std = {
896                let u_mean = u_final.iter().copied().fold(F::zero(), |a, b| a + b)
897                    / F::from(n_samples).unwrap();
898                let var = u_final
899                    .iter()
900                    .copied()
901                    .fold(F::zero(), |a, b| a + (b - u_mean) * (b - u_mean))
902                    / F::from(n_samples.saturating_sub(1).max(1)).unwrap();
903                var.sqrt()
904            };
905            if u_std > F::epsilon() {
906                u_final.mapv_inplace(|v| v / u_std);
907            }
908        }
909
910        // Store component k.
911        x_weights.column_mut(k).assign(&w_final);
912        x_loadings.column_mut(k).assign(&p);
913        x_scores.column_mut(k).assign(&t_final);
914        y_loadings.column_mut(k).assign(&q_final);
915        y_scores.column_mut(k).assign(&u_final);
916
917        // Deflate X: X = X - t p^T.
918        for i in 0..n_samples {
919            let ti = t_final[i];
920            for j in 0..n_features_x {
921                xk[[i, j]] = xk[[i, j]] - ti * p[j];
922            }
923        }
924
925        // Deflate Y.
926        match mode {
927            NipalsMode::Regression => {
928                // Y = Y - t q^T (deflate with X-scores).
929                for i in 0..n_samples {
930                    let ti = t_final[i];
931                    for j in 0..n_features_y {
932                        yk[[i, j]] = yk[[i, j]] - ti * q_final[j];
933                    }
934                }
935            }
936            NipalsMode::Canonical => {
937                // Y = Y - u c^T where c = Y^T u / (u^T u).
938                let utu_c = dot(&u_final, &u_final);
939                let mut c = yk.t().dot(&u_final);
940                if utu_c > F::epsilon() {
941                    c.mapv_inplace(|v| v / utu_c);
942                }
943                for i in 0..n_samples {
944                    let ui = u_final[i];
945                    for j in 0..n_features_y {
946                        yk[[i, j]] = yk[[i, j]] - ui * c[j];
947                    }
948                }
949            }
950        }
951
952        n_iter_vec.push(iters);
953
954        if !converged && n_features_y > 1 && iters == max_iter {
955            return Err(FerroError::ConvergenceFailure {
956                iterations: max_iter,
957                message: format!("NIPALS did not converge for component {k}"),
958            });
959        }
960    }
961
962    Ok(NipalsResult {
963        x_weights,
964        x_loadings,
965        x_scores,
966        y_loadings,
967        y_scores,
968        n_iter: n_iter_vec,
969    })
970}
971
972// ===========================================================================
973// PLSRegression
974// ===========================================================================
975
976/// Partial Least Squares Regression via the NIPALS algorithm.
977///
978/// PLSRegression finds latent components that maximise the covariance
979/// between X-scores and Y-scores, with asymmetric deflation of Y using
980/// X-scores. This is the standard PLS2 algorithm for multi-target
981/// regression.
982///
983/// # Type Parameters
984///
985/// - `F`: The floating-point scalar type.
986///
987/// # Examples
988///
989/// ```
990/// use ferrolearn_decomp::cross_decomposition::PLSRegression;
991/// use ferrolearn_core::traits::{Fit, Predict, Transform};
992/// use ndarray::array;
993///
994/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
995/// let y = array![[1.0], [2.0], [3.0], [4.0]];
996///
997/// let pls = PLSRegression::<f64>::new(1);
998/// let fitted = pls.fit(&x, &y).unwrap();
999///
1000/// let y_pred = fitted.predict(&x).unwrap();
1001/// assert_eq!(y_pred.ncols(), 1);
1002///
1003/// let scores = fitted.transform(&x).unwrap();
1004/// assert_eq!(scores.ncols(), 1);
1005/// ```
1006#[derive(Debug, Clone)]
1007pub struct PLSRegression<F> {
1008    /// Number of PLS components to extract.
1009    n_components: usize,
1010    /// Maximum NIPALS iterations per component.
1011    max_iter: usize,
1012    /// Convergence tolerance for NIPALS.
1013    tol: F,
1014    /// Whether to scale X and Y to unit variance.
1015    scale: bool,
1016    _marker: std::marker::PhantomData<F>,
1017}
1018
1019impl<F: Float + Send + Sync + 'static> PLSRegression<F> {
1020    /// Create a new `PLSRegression` with `n_components` components.
1021    ///
1022    /// Defaults: `max_iter = 500`, `tol = 1e-6`, `scale = true`.
1023    #[must_use]
1024    pub fn new(n_components: usize) -> Self {
1025        Self {
1026            n_components,
1027            max_iter: 500,
1028            tol: F::from(1e-6).unwrap_or_else(F::epsilon),
1029            scale: true,
1030            _marker: std::marker::PhantomData,
1031        }
1032    }
1033
1034    /// Set the maximum number of NIPALS iterations per component.
1035    #[must_use]
1036    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1037        self.max_iter = max_iter;
1038        self
1039    }
1040
1041    /// Set the NIPALS convergence tolerance.
1042    #[must_use]
1043    pub fn with_tol(mut self, tol: F) -> Self {
1044        self.tol = tol;
1045        self
1046    }
1047
1048    /// Set whether to scale X and Y to unit variance (default: `true`).
1049    #[must_use]
1050    pub fn with_scale(mut self, scale: bool) -> Self {
1051        self.scale = scale;
1052        self
1053    }
1054
1055    /// Return the number of components.
1056    #[must_use]
1057    pub fn n_components(&self) -> usize {
1058        self.n_components
1059    }
1060}
1061
1062/// A fitted [`PLSRegression`] model.
1063///
1064/// Holds the learned weight, loading, and score matrices, plus the
1065/// regression coefficients for prediction. Implements [`Predict`] to
1066/// predict Y from X, and [`Transform`] to project X onto the score space.
1067#[derive(Debug, Clone)]
1068pub struct FittedPLSRegression<F> {
1069    /// X-weights W, shape `(n_features_x, n_components)`.
1070    x_weights_: Array2<F>,
1071    /// X-loadings P, shape `(n_features_x, n_components)`.
1072    x_loadings_: Array2<F>,
1073    /// Y-loadings Q, shape `(n_features_y, n_components)`.
1074    y_loadings_: Array2<F>,
1075    /// Regression coefficients B, shape `(n_features_x, n_features_y)`.
1076    /// B = W (P^T W)^{-1} Q^T.
1077    coefficients_: Array2<F>,
1078    /// X-scores T from training, shape `(n_samples, n_components)`.
1079    x_scores_: Array2<F>,
1080    /// Y-scores U from training, shape `(n_samples, n_components)`.
1081    y_scores_: Array2<F>,
1082    /// Number of iterations per component.
1083    n_iter_: Vec<usize>,
1084    /// Per-feature mean of X.
1085    x_mean_: Array1<F>,
1086    /// Per-feature mean of Y.
1087    y_mean_: Array1<F>,
1088    /// Per-feature std of X (None if not scaled).
1089    x_std_: Option<Array1<F>>,
1090    /// Per-feature std of Y (None if not scaled).
1091    y_std_: Option<Array1<F>>,
1092}
1093
1094impl<F: Float + Send + Sync + 'static> FittedPLSRegression<F> {
1095    /// X-weights matrix W, shape `(n_features_x, n_components)`.
1096    #[must_use]
1097    pub fn x_weights(&self) -> &Array2<F> {
1098        &self.x_weights_
1099    }
1100
1101    /// X-loadings matrix P, shape `(n_features_x, n_components)`.
1102    #[must_use]
1103    pub fn x_loadings(&self) -> &Array2<F> {
1104        &self.x_loadings_
1105    }
1106
1107    /// Y-loadings matrix Q, shape `(n_features_y, n_components)`.
1108    #[must_use]
1109    pub fn y_loadings(&self) -> &Array2<F> {
1110        &self.y_loadings_
1111    }
1112
1113    /// Regression coefficient matrix B, shape `(n_features_x, n_features_y)`.
1114    ///
1115    /// `Y_pred = X_centred @ B + y_mean`.
1116    #[must_use]
1117    pub fn coefficients(&self) -> &Array2<F> {
1118        &self.coefficients_
1119    }
1120
1121    /// X-scores T from training, shape `(n_samples, n_components)`.
1122    #[must_use]
1123    pub fn x_scores(&self) -> &Array2<F> {
1124        &self.x_scores_
1125    }
1126
1127    /// Y-scores U from training, shape `(n_samples, n_components)`.
1128    #[must_use]
1129    pub fn y_scores(&self) -> &Array2<F> {
1130        &self.y_scores_
1131    }
1132
1133    /// Number of NIPALS iterations for each component.
1134    #[must_use]
1135    pub fn n_iter(&self) -> &[usize] {
1136        &self.n_iter_
1137    }
1138}
1139
1140impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for PLSRegression<F> {
1141    type Fitted = FittedPLSRegression<F>;
1142    type Error = FerroError;
1143
1144    /// Fit PLSRegression using the NIPALS algorithm.
1145    ///
1146    /// # Errors
1147    ///
1148    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or too large.
1149    /// - [`FerroError::InsufficientSamples`] if fewer than 2 samples.
1150    /// - [`FerroError::ShapeMismatch`] if X and Y have different row counts.
1151    /// - [`FerroError::ConvergenceFailure`] if NIPALS does not converge.
1152    fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedPLSRegression<F>, FerroError> {
1153        let (n_samples_x, n_features_x) = x.dim();
1154        let (n_samples_y, n_features_y) = y.dim();
1155
1156        if n_samples_x != n_samples_y {
1157            return Err(FerroError::ShapeMismatch {
1158                expected: vec![n_samples_x, n_features_y],
1159                actual: vec![n_samples_y, n_features_y],
1160                context: "PLSRegression::fit: X and Y must have the same number of rows".into(),
1161            });
1162        }
1163
1164        if self.n_components == 0 {
1165            return Err(FerroError::InvalidParameter {
1166                name: "n_components".into(),
1167                reason: "must be at least 1".into(),
1168            });
1169        }
1170
1171        let max_components = n_features_x.min(n_features_y).min(n_samples_x);
1172        if self.n_components > max_components {
1173            return Err(FerroError::InvalidParameter {
1174                name: "n_components".into(),
1175                reason: format!(
1176                    "n_components ({}) exceeds min(n_features_x, n_features_y, n_samples) ({})",
1177                    self.n_components, max_components
1178                ),
1179            });
1180        }
1181
1182        if n_samples_x < 2 {
1183            return Err(FerroError::InsufficientSamples {
1184                required: 2,
1185                actual: n_samples_x,
1186                context: "PLSRegression::fit requires at least 2 samples".into(),
1187            });
1188        }
1189
1190        // Centre and optionally scale.
1191        let (xc, x_mean, x_std) = centre_scale(x, self.scale);
1192        let (yc, y_mean, y_std) = centre_scale(y, self.scale);
1193
1194        // Run NIPALS.
1195        let result = nipals(
1196            &xc,
1197            &yc,
1198            self.n_components,
1199            self.max_iter,
1200            self.tol,
1201            NipalsMode::Regression,
1202            ScoreNorm::None,
1203        )?;
1204
1205        // Compute regression coefficients: B = W (P^T W)^{-1} Q^T.
1206        let ptw = result.x_loadings.t().dot(&result.x_weights);
1207        let ptw_inv = invert_square(&ptw)?;
1208        let coefficients = result.x_weights.dot(&ptw_inv).dot(&result.y_loadings.t());
1209
1210        // If we scaled, adjust coefficients to work on the original scale.
1211        // The stored coefficients operate on centred (and scaled) X, producing
1212        // centred (and scaled) Y. We leave them in this internal space and
1213        // apply the scaling in predict/transform.
1214
1215        Ok(FittedPLSRegression {
1216            x_weights_: result.x_weights,
1217            x_loadings_: result.x_loadings,
1218            y_loadings_: result.y_loadings,
1219            coefficients_: coefficients,
1220            x_scores_: result.x_scores,
1221            y_scores_: result.y_scores,
1222            n_iter_: result.n_iter,
1223            x_mean_: x_mean,
1224            y_mean_: y_mean,
1225            x_std_: x_std,
1226            y_std_: y_std,
1227        })
1228    }
1229}
1230
1231impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedPLSRegression<F> {
1232    type Output = Array2<F>;
1233    type Error = FerroError;
1234
1235    /// Predict Y from X using the fitted PLS regression model.
1236    ///
1237    /// Computes `Y_pred = X_centred @ B`, then un-scales and un-centres.
1238    ///
1239    /// # Errors
1240    ///
1241    /// Returns [`FerroError::ShapeMismatch`] if X has the wrong number of columns.
1242    fn predict(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1243        let xc = apply_centre_scale(
1244            x,
1245            &self.x_mean_,
1246            &self.x_std_,
1247            "FittedPLSRegression::predict",
1248        )?;
1249
1250        let mut y_pred = xc.dot(&self.coefficients_);
1251
1252        // Un-scale Y.
1253        if let Some(ref ys) = self.y_std_ {
1254            for mut row in y_pred.rows_mut() {
1255                for (v, &s) in row.iter_mut().zip(ys.iter()) {
1256                    *v = *v * s;
1257                }
1258            }
1259        }
1260
1261        // Un-centre Y.
1262        for mut row in y_pred.rows_mut() {
1263            for (v, &m) in row.iter_mut().zip(self.y_mean_.iter()) {
1264                *v = *v + m;
1265            }
1266        }
1267
1268        Ok(y_pred)
1269    }
1270}
1271
1272impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedPLSRegression<F> {
1273    type Output = Array2<F>;
1274    type Error = FerroError;
1275
1276    /// Project X onto the PLS score space (X-scores).
1277    ///
1278    /// Computes `T = X_centred @ W (P^T W)^{-1}`.
1279    ///
1280    /// # Errors
1281    ///
1282    /// Returns [`FerroError::ShapeMismatch`] if X has the wrong number of columns.
1283    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1284        let xc = apply_centre_scale(
1285            x,
1286            &self.x_mean_,
1287            &self.x_std_,
1288            "FittedPLSRegression::transform",
1289        )?;
1290
1291        // T = X_centred @ W (P^T W)^{-1} = X_centred @ rotation
1292        // But simpler: T = X_centred @ W when W columns are the NIPALS weights.
1293        // Actually the correct transform for new data is via the rotation matrix.
1294        let ptw = self.x_loadings_.t().dot(&self.x_weights_);
1295        let ptw_inv = invert_square(&ptw)?;
1296        let rotation = self.x_weights_.dot(&ptw_inv);
1297        Ok(xc.dot(&rotation))
1298    }
1299}
1300
1301// ===========================================================================
1302// PLSCanonical
1303// ===========================================================================
1304
1305/// Canonical PLS via the NIPALS algorithm.
1306///
1307/// PLSCanonical performs a symmetric decomposition: both X and Y are
1308/// deflated with their own scores. This contrasts with [`PLSRegression`]
1309/// which deflates Y using X-scores.
1310///
1311/// PLSCanonical is appropriate when you want a symmetric analysis of the
1312/// relationship between X and Y, rather than a predictive model.
1313///
1314/// # Type Parameters
1315///
1316/// - `F`: The floating-point scalar type.
1317///
1318/// # Examples
1319///
1320/// ```
1321/// use ferrolearn_decomp::cross_decomposition::PLSCanonical;
1322/// use ferrolearn_core::traits::{Fit, Transform};
1323/// use ndarray::array;
1324///
1325/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1326/// let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1327///
1328/// let pls = PLSCanonical::<f64>::new(2);
1329/// let fitted = pls.fit(&x, &y).unwrap();
1330/// let scores = fitted.transform(&x).unwrap();
1331/// assert_eq!(scores.ncols(), 2);
1332/// ```
1333#[derive(Debug, Clone)]
1334pub struct PLSCanonical<F> {
1335    /// Number of PLS components to extract.
1336    n_components: usize,
1337    /// Maximum NIPALS iterations per component.
1338    max_iter: usize,
1339    /// Convergence tolerance for NIPALS.
1340    tol: F,
1341    /// Whether to scale X and Y to unit variance.
1342    scale: bool,
1343    _marker: std::marker::PhantomData<F>,
1344}
1345
1346impl<F: Float + Send + Sync + 'static> PLSCanonical<F> {
1347    /// Create a new `PLSCanonical` with `n_components` components.
1348    ///
1349    /// Defaults: `max_iter = 500`, `tol = 1e-6`, `scale = true`.
1350    #[must_use]
1351    pub fn new(n_components: usize) -> Self {
1352        Self {
1353            n_components,
1354            max_iter: 500,
1355            tol: F::from(1e-6).unwrap_or_else(F::epsilon),
1356            scale: true,
1357            _marker: std::marker::PhantomData,
1358        }
1359    }
1360
1361    /// Set the maximum number of NIPALS iterations per component.
1362    #[must_use]
1363    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1364        self.max_iter = max_iter;
1365        self
1366    }
1367
1368    /// Set the NIPALS convergence tolerance.
1369    #[must_use]
1370    pub fn with_tol(mut self, tol: F) -> Self {
1371        self.tol = tol;
1372        self
1373    }
1374
1375    /// Set whether to scale X and Y to unit variance (default: `true`).
1376    #[must_use]
1377    pub fn with_scale(mut self, scale: bool) -> Self {
1378        self.scale = scale;
1379        self
1380    }
1381
1382    /// Return the number of components.
1383    #[must_use]
1384    pub fn n_components(&self) -> usize {
1385        self.n_components
1386    }
1387}
1388
1389/// A fitted [`PLSCanonical`] model.
1390///
1391/// Holds the learned weight, loading, and score matrices from the
1392/// symmetric NIPALS decomposition. Implements [`Transform`] to project
1393/// X onto the score space.
1394#[derive(Debug, Clone)]
1395pub struct FittedPLSCanonical<F> {
1396    /// X-weights W, shape `(n_features_x, n_components)`.
1397    x_weights_: Array2<F>,
1398    /// X-loadings P, shape `(n_features_x, n_components)`.
1399    x_loadings_: Array2<F>,
1400    /// Y-loadings Q, shape `(n_features_y, n_components)`.
1401    y_loadings_: Array2<F>,
1402    /// X-scores T from training, shape `(n_samples, n_components)`.
1403    x_scores_: Array2<F>,
1404    /// Y-scores U from training, shape `(n_samples, n_components)`.
1405    y_scores_: Array2<F>,
1406    /// Number of iterations per component.
1407    n_iter_: Vec<usize>,
1408    /// Per-feature mean of X.
1409    x_mean_: Array1<F>,
1410    /// Per-feature mean of Y.
1411    y_mean_: Array1<F>,
1412    /// Per-feature std of X (None if not scaled).
1413    x_std_: Option<Array1<F>>,
1414    /// Per-feature std of Y (None if not scaled).
1415    y_std_: Option<Array1<F>>,
1416}
1417
1418impl<F: Float + Send + Sync + 'static> FittedPLSCanonical<F> {
1419    /// X-weights matrix W, shape `(n_features_x, n_components)`.
1420    #[must_use]
1421    pub fn x_weights(&self) -> &Array2<F> {
1422        &self.x_weights_
1423    }
1424
1425    /// X-loadings matrix P, shape `(n_features_x, n_components)`.
1426    #[must_use]
1427    pub fn x_loadings(&self) -> &Array2<F> {
1428        &self.x_loadings_
1429    }
1430
1431    /// Y-loadings matrix Q, shape `(n_features_y, n_components)`.
1432    #[must_use]
1433    pub fn y_loadings(&self) -> &Array2<F> {
1434        &self.y_loadings_
1435    }
1436
1437    /// X-scores T from training, shape `(n_samples, n_components)`.
1438    #[must_use]
1439    pub fn x_scores(&self) -> &Array2<F> {
1440        &self.x_scores_
1441    }
1442
1443    /// Y-scores U from training, shape `(n_samples, n_components)`.
1444    #[must_use]
1445    pub fn y_scores(&self) -> &Array2<F> {
1446        &self.y_scores_
1447    }
1448
1449    /// Number of NIPALS iterations for each component.
1450    #[must_use]
1451    pub fn n_iter(&self) -> &[usize] {
1452        &self.n_iter_
1453    }
1454
1455    /// Transform Y data onto the Y-score space.
1456    ///
1457    /// # Errors
1458    ///
1459    /// Returns [`FerroError::ShapeMismatch`] if Y has the wrong number of columns.
1460    pub fn transform_y(&self, y: &Array2<F>) -> Result<Array2<F>, FerroError> {
1461        let yc = apply_centre_scale(
1462            y,
1463            &self.y_mean_,
1464            &self.y_std_,
1465            "FittedPLSCanonical::transform_y",
1466        )?;
1467        Ok(yc.dot(&self.y_loadings_))
1468    }
1469}
1470
1471impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for PLSCanonical<F> {
1472    type Fitted = FittedPLSCanonical<F>;
1473    type Error = FerroError;
1474
1475    /// Fit PLSCanonical using the NIPALS algorithm with symmetric deflation.
1476    ///
1477    /// # Errors
1478    ///
1479    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or too large.
1480    /// - [`FerroError::InsufficientSamples`] if fewer than 2 samples.
1481    /// - [`FerroError::ShapeMismatch`] if X and Y have different row counts.
1482    /// - [`FerroError::ConvergenceFailure`] if NIPALS does not converge.
1483    fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedPLSCanonical<F>, FerroError> {
1484        let (n_samples_x, n_features_x) = x.dim();
1485        let (n_samples_y, n_features_y) = y.dim();
1486
1487        if n_samples_x != n_samples_y {
1488            return Err(FerroError::ShapeMismatch {
1489                expected: vec![n_samples_x, n_features_y],
1490                actual: vec![n_samples_y, n_features_y],
1491                context: "PLSCanonical::fit: X and Y must have the same number of rows".into(),
1492            });
1493        }
1494
1495        if self.n_components == 0 {
1496            return Err(FerroError::InvalidParameter {
1497                name: "n_components".into(),
1498                reason: "must be at least 1".into(),
1499            });
1500        }
1501
1502        let max_components = n_features_x.min(n_features_y).min(n_samples_x);
1503        if self.n_components > max_components {
1504            return Err(FerroError::InvalidParameter {
1505                name: "n_components".into(),
1506                reason: format!(
1507                    "n_components ({}) exceeds min(n_features_x, n_features_y, n_samples) ({})",
1508                    self.n_components, max_components
1509                ),
1510            });
1511        }
1512
1513        if n_samples_x < 2 {
1514            return Err(FerroError::InsufficientSamples {
1515                required: 2,
1516                actual: n_samples_x,
1517                context: "PLSCanonical::fit requires at least 2 samples".into(),
1518            });
1519        }
1520
1521        let (xc, x_mean, x_std) = centre_scale(x, self.scale);
1522        let (yc, y_mean, y_std) = centre_scale(y, self.scale);
1523
1524        let result = nipals(
1525            &xc,
1526            &yc,
1527            self.n_components,
1528            self.max_iter,
1529            self.tol,
1530            NipalsMode::Canonical,
1531            ScoreNorm::None,
1532        )?;
1533
1534        Ok(FittedPLSCanonical {
1535            x_weights_: result.x_weights,
1536            x_loadings_: result.x_loadings,
1537            y_loadings_: result.y_loadings,
1538            x_scores_: result.x_scores,
1539            y_scores_: result.y_scores,
1540            n_iter_: result.n_iter,
1541            x_mean_: x_mean,
1542            y_mean_: y_mean,
1543            x_std_: x_std,
1544            y_std_: y_std,
1545        })
1546    }
1547}
1548
1549impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedPLSCanonical<F> {
1550    type Output = Array2<F>;
1551    type Error = FerroError;
1552
1553    /// Project X onto the PLS score space (X-scores).
1554    ///
1555    /// # Errors
1556    ///
1557    /// Returns [`FerroError::ShapeMismatch`] if X has the wrong number of columns.
1558    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1559        let xc = apply_centre_scale(
1560            x,
1561            &self.x_mean_,
1562            &self.x_std_,
1563            "FittedPLSCanonical::transform",
1564        )?;
1565
1566        let ptw = self.x_loadings_.t().dot(&self.x_weights_);
1567        let ptw_inv = invert_square(&ptw)?;
1568        let rotation = self.x_weights_.dot(&ptw_inv);
1569        Ok(xc.dot(&rotation))
1570    }
1571}
1572
1573// ===========================================================================
1574// CCA (Canonical Correlation Analysis)
1575// ===========================================================================
1576
1577/// Canonical Correlation Analysis via the NIPALS algorithm.
1578///
1579/// CCA maximises the *correlation* (rather than covariance) between
1580/// X-scores and Y-scores by normalising scores to unit variance after
1581/// each NIPALS iteration. It uses symmetric (canonical) deflation.
1582///
1583/// # Type Parameters
1584///
1585/// - `F`: The floating-point scalar type.
1586///
1587/// # Examples
1588///
1589/// ```
1590/// use ferrolearn_decomp::cross_decomposition::CCA;
1591/// use ferrolearn_core::traits::{Fit, Transform};
1592/// use ndarray::array;
1593///
1594/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1595/// let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1596///
1597/// let cca = CCA::<f64>::new(2);
1598/// let fitted = cca.fit(&x, &y).unwrap();
1599/// let scores = fitted.transform(&x).unwrap();
1600/// assert_eq!(scores.ncols(), 2);
1601/// ```
1602#[derive(Debug, Clone)]
1603pub struct CCA<F> {
1604    /// Number of canonical components to extract.
1605    n_components: usize,
1606    /// Maximum NIPALS iterations per component.
1607    max_iter: usize,
1608    /// Convergence tolerance for NIPALS.
1609    tol: F,
1610    /// Whether to scale X and Y to unit variance.
1611    scale: bool,
1612    _marker: std::marker::PhantomData<F>,
1613}
1614
1615impl<F: Float + Send + Sync + 'static> CCA<F> {
1616    /// Create a new `CCA` with `n_components` components.
1617    ///
1618    /// Defaults: `max_iter = 500`, `tol = 1e-6`, `scale = true`.
1619    #[must_use]
1620    pub fn new(n_components: usize) -> Self {
1621        Self {
1622            n_components,
1623            max_iter: 500,
1624            tol: F::from(1e-6).unwrap_or_else(F::epsilon),
1625            scale: true,
1626            _marker: std::marker::PhantomData,
1627        }
1628    }
1629
1630    /// Set the maximum number of NIPALS iterations per component.
1631    #[must_use]
1632    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1633        self.max_iter = max_iter;
1634        self
1635    }
1636
1637    /// Set the NIPALS convergence tolerance.
1638    #[must_use]
1639    pub fn with_tol(mut self, tol: F) -> Self {
1640        self.tol = tol;
1641        self
1642    }
1643
1644    /// Set whether to scale X and Y to unit variance (default: `true`).
1645    #[must_use]
1646    pub fn with_scale(mut self, scale: bool) -> Self {
1647        self.scale = scale;
1648        self
1649    }
1650
1651    /// Return the number of components.
1652    #[must_use]
1653    pub fn n_components(&self) -> usize {
1654        self.n_components
1655    }
1656}
1657
1658/// A fitted [`CCA`] model.
1659///
1660/// Holds the learned weight, loading, and score matrices. Implements
1661/// [`Transform`] to project X onto the canonical score space.
1662#[derive(Debug, Clone)]
1663pub struct FittedCCA<F> {
1664    /// X-weights W, shape `(n_features_x, n_components)`.
1665    x_weights_: Array2<F>,
1666    /// X-loadings P, shape `(n_features_x, n_components)`.
1667    x_loadings_: Array2<F>,
1668    /// Y-loadings Q, shape `(n_features_y, n_components)`.
1669    y_loadings_: Array2<F>,
1670    /// X-scores T from training, shape `(n_samples, n_components)`.
1671    x_scores_: Array2<F>,
1672    /// Y-scores U from training, shape `(n_samples, n_components)`.
1673    y_scores_: Array2<F>,
1674    /// Number of iterations per component.
1675    n_iter_: Vec<usize>,
1676    /// Per-feature mean of X.
1677    x_mean_: Array1<F>,
1678    /// Per-feature mean of Y.
1679    y_mean_: Array1<F>,
1680    /// Per-feature std of X (None if not scaled).
1681    x_std_: Option<Array1<F>>,
1682    /// Per-feature std of Y (None if not scaled).
1683    y_std_: Option<Array1<F>>,
1684}
1685
1686impl<F: Float + Send + Sync + 'static> FittedCCA<F> {
1687    /// X-weights matrix W, shape `(n_features_x, n_components)`.
1688    #[must_use]
1689    pub fn x_weights(&self) -> &Array2<F> {
1690        &self.x_weights_
1691    }
1692
1693    /// X-loadings matrix P, shape `(n_features_x, n_components)`.
1694    #[must_use]
1695    pub fn x_loadings(&self) -> &Array2<F> {
1696        &self.x_loadings_
1697    }
1698
1699    /// Y-loadings matrix Q, shape `(n_features_y, n_components)`.
1700    #[must_use]
1701    pub fn y_loadings(&self) -> &Array2<F> {
1702        &self.y_loadings_
1703    }
1704
1705    /// X-scores T from training, shape `(n_samples, n_components)`.
1706    #[must_use]
1707    pub fn x_scores(&self) -> &Array2<F> {
1708        &self.x_scores_
1709    }
1710
1711    /// Y-scores U from training, shape `(n_samples, n_components)`.
1712    #[must_use]
1713    pub fn y_scores(&self) -> &Array2<F> {
1714        &self.y_scores_
1715    }
1716
1717    /// Number of NIPALS iterations for each component.
1718    #[must_use]
1719    pub fn n_iter(&self) -> &[usize] {
1720        &self.n_iter_
1721    }
1722
1723    /// Transform Y data onto the Y-score space.
1724    ///
1725    /// # Errors
1726    ///
1727    /// Returns [`FerroError::ShapeMismatch`] if Y has the wrong number of columns.
1728    pub fn transform_y(&self, y: &Array2<F>) -> Result<Array2<F>, FerroError> {
1729        let yc = apply_centre_scale(y, &self.y_mean_, &self.y_std_, "FittedCCA::transform_y")?;
1730        Ok(yc.dot(&self.y_loadings_))
1731    }
1732}
1733
1734impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for CCA<F> {
1735    type Fitted = FittedCCA<F>;
1736    type Error = FerroError;
1737
1738    /// Fit CCA using the NIPALS algorithm with score normalisation.
1739    ///
1740    /// # Errors
1741    ///
1742    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or too large.
1743    /// - [`FerroError::InsufficientSamples`] if fewer than 2 samples.
1744    /// - [`FerroError::ShapeMismatch`] if X and Y have different row counts.
1745    /// - [`FerroError::ConvergenceFailure`] if NIPALS does not converge.
1746    fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedCCA<F>, FerroError> {
1747        let (n_samples_x, n_features_x) = x.dim();
1748        let (n_samples_y, n_features_y) = y.dim();
1749
1750        if n_samples_x != n_samples_y {
1751            return Err(FerroError::ShapeMismatch {
1752                expected: vec![n_samples_x, n_features_y],
1753                actual: vec![n_samples_y, n_features_y],
1754                context: "CCA::fit: X and Y must have the same number of rows".into(),
1755            });
1756        }
1757
1758        if self.n_components == 0 {
1759            return Err(FerroError::InvalidParameter {
1760                name: "n_components".into(),
1761                reason: "must be at least 1".into(),
1762            });
1763        }
1764
1765        let max_components = n_features_x.min(n_features_y).min(n_samples_x);
1766        if self.n_components > max_components {
1767            return Err(FerroError::InvalidParameter {
1768                name: "n_components".into(),
1769                reason: format!(
1770                    "n_components ({}) exceeds min(n_features_x, n_features_y, n_samples) ({})",
1771                    self.n_components, max_components
1772                ),
1773            });
1774        }
1775
1776        if n_samples_x < 2 {
1777            return Err(FerroError::InsufficientSamples {
1778                required: 2,
1779                actual: n_samples_x,
1780                context: "CCA::fit requires at least 2 samples".into(),
1781            });
1782        }
1783
1784        let (xc, x_mean, x_std) = centre_scale(x, self.scale);
1785        let (yc, y_mean, y_std) = centre_scale(y, self.scale);
1786
1787        let result = nipals(
1788            &xc,
1789            &yc,
1790            self.n_components,
1791            self.max_iter,
1792            self.tol,
1793            NipalsMode::Canonical,
1794            ScoreNorm::UnitVariance,
1795        )?;
1796
1797        Ok(FittedCCA {
1798            x_weights_: result.x_weights,
1799            x_loadings_: result.x_loadings,
1800            y_loadings_: result.y_loadings,
1801            x_scores_: result.x_scores,
1802            y_scores_: result.y_scores,
1803            n_iter_: result.n_iter,
1804            x_mean_: x_mean,
1805            y_mean_: y_mean,
1806            x_std_: x_std,
1807            y_std_: y_std,
1808        })
1809    }
1810}
1811
1812impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedCCA<F> {
1813    type Output = Array2<F>;
1814    type Error = FerroError;
1815
1816    /// Project X onto the canonical score space.
1817    ///
1818    /// # Errors
1819    ///
1820    /// Returns [`FerroError::ShapeMismatch`] if X has the wrong number of columns.
1821    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1822        let xc = apply_centre_scale(x, &self.x_mean_, &self.x_std_, "FittedCCA::transform")?;
1823
1824        let ptw = self.x_loadings_.t().dot(&self.x_weights_);
1825        let ptw_inv = invert_square(&ptw)?;
1826        let rotation = self.x_weights_.dot(&ptw_inv);
1827        Ok(xc.dot(&rotation))
1828    }
1829}
1830
1831// ===========================================================================
1832// Tests
1833// ===========================================================================
1834
1835#[cfg(test)]
1836mod tests {
1837    use super::*;
1838    use approx::assert_abs_diff_eq;
1839    use ndarray::array;
1840
1841    // -----------------------------------------------------------------------
1842    // PLSSVD tests
1843    // -----------------------------------------------------------------------
1844
1845    #[test]
1846    fn test_plssvd_basic_fit_transform() {
1847        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1848        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
1849        let svd = PLSSVD::<f64>::new(1);
1850        let fitted = svd.fit(&x, &y).unwrap();
1851        let scores = fitted.transform(&x).unwrap();
1852        assert_eq!(scores.dim(), (5, 1));
1853    }
1854
1855    #[test]
1856    fn test_plssvd_two_components() {
1857        let x = array![
1858            [1.0, 2.0, 3.0],
1859            [4.0, 5.0, 6.0],
1860            [7.0, 8.0, 9.0],
1861            [10.0, 11.0, 12.0],
1862        ];
1863        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1864        let svd = PLSSVD::<f64>::new(2);
1865        let fitted = svd.fit(&x, &y).unwrap();
1866        let scores = fitted.transform(&x).unwrap();
1867        assert_eq!(scores.dim(), (4, 2));
1868    }
1869
1870    #[test]
1871    fn test_plssvd_transform_y() {
1872        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1873        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1874        let svd = PLSSVD::<f64>::new(1);
1875        let fitted = svd.fit(&x, &y).unwrap();
1876        let y_scores = fitted.transform_y(&y).unwrap();
1877        assert_eq!(y_scores.ncols(), 1);
1878    }
1879
1880    #[test]
1881    fn test_plssvd_no_scale() {
1882        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1883        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
1884        let svd = PLSSVD::<f64>::new(1).with_scale(false);
1885        let fitted = svd.fit(&x, &y).unwrap();
1886        let scores = fitted.transform(&x).unwrap();
1887        assert_eq!(scores.ncols(), 1);
1888    }
1889
1890    #[test]
1891    fn test_plssvd_x_weights_shape() {
1892        let x = array![
1893            [1.0, 2.0, 3.0],
1894            [4.0, 5.0, 6.0],
1895            [7.0, 8.0, 9.0],
1896            [10.0, 11.0, 12.0],
1897        ];
1898        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1899        let svd = PLSSVD::<f64>::new(2);
1900        let fitted = svd.fit(&x, &y).unwrap();
1901        assert_eq!(fitted.x_weights().dim(), (3, 2));
1902        assert_eq!(fitted.y_weights().dim(), (2, 2));
1903    }
1904
1905    #[test]
1906    fn test_plssvd_invalid_zero_components() {
1907        let x = array![[1.0, 2.0], [3.0, 4.0]];
1908        let y = array![[1.0], [2.0]];
1909        let svd = PLSSVD::<f64>::new(0);
1910        assert!(svd.fit(&x, &y).is_err());
1911    }
1912
1913    #[test]
1914    fn test_plssvd_too_many_components() {
1915        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1916        let y = array![[1.0], [2.0], [3.0]];
1917        // min(2, 1) = 1, asking for 2 is too many.
1918        let svd = PLSSVD::<f64>::new(2);
1919        assert!(svd.fit(&x, &y).is_err());
1920    }
1921
1922    #[test]
1923    fn test_plssvd_row_mismatch() {
1924        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1925        let y = array![[1.0], [2.0]];
1926        let svd = PLSSVD::<f64>::new(1);
1927        assert!(svd.fit(&x, &y).is_err());
1928    }
1929
1930    #[test]
1931    fn test_plssvd_insufficient_samples() {
1932        let x = array![[1.0, 2.0]];
1933        let y = array![[1.0]];
1934        let svd = PLSSVD::<f64>::new(1);
1935        assert!(svd.fit(&x, &y).is_err());
1936    }
1937
1938    #[test]
1939    fn test_plssvd_transform_shape_mismatch() {
1940        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1941        let y = array![[1.0], [2.0], [3.0]];
1942        let svd = PLSSVD::<f64>::new(1);
1943        let fitted = svd.fit(&x, &y).unwrap();
1944        let x_bad = array![[1.0, 2.0, 3.0]];
1945        assert!(fitted.transform(&x_bad).is_err());
1946    }
1947
1948    #[test]
1949    fn test_plssvd_n_components_getter() {
1950        let svd = PLSSVD::<f64>::new(3);
1951        assert_eq!(svd.n_components(), 3);
1952    }
1953
1954    #[test]
1955    fn test_plssvd_f32() {
1956        let x: Array2<f32> = array![
1957            [1.0f32, 2.0],
1958            [3.0, 4.0],
1959            [5.0, 6.0],
1960            [7.0, 8.0],
1961            [9.0, 10.0],
1962        ];
1963        let y: Array2<f32> = array![[1.0f32], [2.0], [3.0], [4.0], [5.0]];
1964        let svd = PLSSVD::<f32>::new(1);
1965        let fitted = svd.fit(&x, &y).unwrap();
1966        let scores = fitted.transform(&x).unwrap();
1967        assert_eq!(scores.ncols(), 1);
1968    }
1969
1970    // -----------------------------------------------------------------------
1971    // PLSRegression tests
1972    // -----------------------------------------------------------------------
1973
1974    #[test]
1975    fn test_plsregression_basic_fit_predict() {
1976        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1977        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
1978        let pls = PLSRegression::<f64>::new(1);
1979        let fitted = pls.fit(&x, &y).unwrap();
1980        let y_pred = fitted.predict(&x).unwrap();
1981        assert_eq!(y_pred.dim(), (5, 1));
1982    }
1983
1984    #[test]
1985    fn test_plsregression_prediction_quality() {
1986        // Perfect linear relationship: Y = X[:,0] + X[:,1]
1987        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1988        let y = array![[3.0], [7.0], [11.0], [15.0], [19.0]];
1989        let pls = PLSRegression::<f64>::new(1);
1990        let fitted = pls.fit(&x, &y).unwrap();
1991        let y_pred = fitted.predict(&x).unwrap();
1992
1993        // With a perfect linear relationship and 1 component, prediction
1994        // should be very close.
1995        for (pred, actual) in y_pred.column(0).iter().zip(y.column(0).iter()) {
1996            assert_abs_diff_eq!(pred, actual, epsilon = 1e-6);
1997        }
1998    }
1999
2000    #[test]
2001    fn test_plsregression_multi_target() {
2002        let x = array![
2003            [1.0, 2.0, 3.0],
2004            [4.0, 5.0, 6.0],
2005            [7.0, 8.0, 9.0],
2006            [10.0, 11.0, 12.0],
2007            [13.0, 14.0, 15.0],
2008        ];
2009        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2010        let pls = PLSRegression::<f64>::new(2);
2011        let fitted = pls.fit(&x, &y).unwrap();
2012        let y_pred = fitted.predict(&x).unwrap();
2013        assert_eq!(y_pred.dim(), (5, 2));
2014    }
2015
2016    #[test]
2017    fn test_plsregression_transform() {
2018        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2019        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2020        let pls = PLSRegression::<f64>::new(1);
2021        let fitted = pls.fit(&x, &y).unwrap();
2022        let scores = fitted.transform(&x).unwrap();
2023        assert_eq!(scores.dim(), (5, 1));
2024    }
2025
2026    #[test]
2027    fn test_plsregression_coefficients_shape() {
2028        let x = array![
2029            [1.0, 2.0, 3.0],
2030            [4.0, 5.0, 6.0],
2031            [7.0, 8.0, 9.0],
2032            [10.0, 11.0, 12.0],
2033        ];
2034        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
2035        let pls = PLSRegression::<f64>::new(2);
2036        let fitted = pls.fit(&x, &y).unwrap();
2037        // B shape: (n_features_x, n_features_y)
2038        assert_eq!(fitted.coefficients().dim(), (3, 2));
2039    }
2040
2041    #[test]
2042    fn test_plsregression_no_scale() {
2043        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2044        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2045        let pls = PLSRegression::<f64>::new(1).with_scale(false);
2046        let fitted = pls.fit(&x, &y).unwrap();
2047        let y_pred = fitted.predict(&x).unwrap();
2048        assert_eq!(y_pred.dim(), (5, 1));
2049    }
2050
2051    #[test]
2052    fn test_plsregression_builder() {
2053        let pls = PLSRegression::<f64>::new(2)
2054            .with_max_iter(1000)
2055            .with_tol(1e-8)
2056            .with_scale(false);
2057        assert_eq!(pls.n_components(), 2);
2058    }
2059
2060    #[test]
2061    fn test_plsregression_invalid_zero_components() {
2062        let x = array![[1.0, 2.0], [3.0, 4.0]];
2063        let y = array![[1.0], [2.0]];
2064        let pls = PLSRegression::<f64>::new(0);
2065        assert!(pls.fit(&x, &y).is_err());
2066    }
2067
2068    #[test]
2069    fn test_plsregression_too_many_components() {
2070        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2071        let y = array![[1.0], [2.0], [3.0]];
2072        // min(2, 1, 3) = 1, asking for 2 is too many.
2073        let pls = PLSRegression::<f64>::new(2);
2074        assert!(pls.fit(&x, &y).is_err());
2075    }
2076
2077    #[test]
2078    fn test_plsregression_row_mismatch() {
2079        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2080        let y = array![[1.0], [2.0]];
2081        let pls = PLSRegression::<f64>::new(1);
2082        assert!(pls.fit(&x, &y).is_err());
2083    }
2084
2085    #[test]
2086    fn test_plsregression_insufficient_samples() {
2087        let x = array![[1.0, 2.0]];
2088        let y = array![[1.0]];
2089        let pls = PLSRegression::<f64>::new(1);
2090        assert!(pls.fit(&x, &y).is_err());
2091    }
2092
2093    #[test]
2094    fn test_plsregression_predict_shape_mismatch() {
2095        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2096        let y = array![[1.0], [2.0], [3.0]];
2097        let pls = PLSRegression::<f64>::new(1);
2098        let fitted = pls.fit(&x, &y).unwrap();
2099        let x_bad = array![[1.0, 2.0, 3.0]];
2100        assert!(fitted.predict(&x_bad).is_err());
2101    }
2102
2103    #[test]
2104    fn test_plsregression_transform_shape_mismatch() {
2105        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2106        let y = array![[1.0], [2.0], [3.0]];
2107        let pls = PLSRegression::<f64>::new(1);
2108        let fitted = pls.fit(&x, &y).unwrap();
2109        let x_bad = array![[1.0, 2.0, 3.0]];
2110        assert!(fitted.transform(&x_bad).is_err());
2111    }
2112
2113    #[test]
2114    fn test_plsregression_x_scores_shape() {
2115        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
2116        let y = array![[1.0], [2.0], [3.0], [4.0]];
2117        let pls = PLSRegression::<f64>::new(1);
2118        let fitted = pls.fit(&x, &y).unwrap();
2119        assert_eq!(fitted.x_scores().dim(), (4, 1));
2120        assert_eq!(fitted.y_scores().dim(), (4, 1));
2121        assert_eq!(fitted.n_iter().len(), 1);
2122    }
2123
2124    #[test]
2125    fn test_plsregression_f32() {
2126        let x: Array2<f32> = array![
2127            [1.0f32, 2.0],
2128            [3.0, 4.0],
2129            [5.0, 6.0],
2130            [7.0, 8.0],
2131            [9.0, 10.0],
2132        ];
2133        let y: Array2<f32> = array![[1.0f32], [2.0], [3.0], [4.0], [5.0]];
2134        let pls = PLSRegression::<f32>::new(1);
2135        let fitted = pls.fit(&x, &y).unwrap();
2136        let y_pred = fitted.predict(&x).unwrap();
2137        assert_eq!(y_pred.ncols(), 1);
2138    }
2139
2140    // -----------------------------------------------------------------------
2141    // PLSCanonical tests
2142    // -----------------------------------------------------------------------
2143
2144    #[test]
2145    fn test_plscanonical_basic_fit_transform() {
2146        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2147        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2148        let pls = PLSCanonical::<f64>::new(2);
2149        let fitted = pls.fit(&x, &y).unwrap();
2150        let scores = fitted.transform(&x).unwrap();
2151        assert_eq!(scores.dim(), (5, 2));
2152    }
2153
2154    #[test]
2155    fn test_plscanonical_single_component() {
2156        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2157        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2158        let pls = PLSCanonical::<f64>::new(1);
2159        let fitted = pls.fit(&x, &y).unwrap();
2160        let scores = fitted.transform(&x).unwrap();
2161        assert_eq!(scores.ncols(), 1);
2162    }
2163
2164    #[test]
2165    fn test_plscanonical_scores_shape() {
2166        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
2167        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2168        let pls = PLSCanonical::<f64>::new(2);
2169        let fitted = pls.fit(&x, &y).unwrap();
2170        assert_eq!(fitted.x_scores().dim(), (3, 2));
2171        assert_eq!(fitted.y_scores().dim(), (3, 2));
2172        assert_eq!(fitted.x_weights().dim(), (3, 2));
2173        assert_eq!(fitted.x_loadings().dim(), (3, 2));
2174        assert_eq!(fitted.y_loadings().dim(), (2, 2));
2175    }
2176
2177    #[test]
2178    fn test_plscanonical_transform_y() {
2179        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
2180        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
2181        let pls = PLSCanonical::<f64>::new(1);
2182        let fitted = pls.fit(&x, &y).unwrap();
2183        let y_scores = fitted.transform_y(&y).unwrap();
2184        assert_eq!(y_scores.ncols(), 1);
2185    }
2186
2187    #[test]
2188    fn test_plscanonical_builder() {
2189        let pls = PLSCanonical::<f64>::new(2)
2190            .with_max_iter(1000)
2191            .with_tol(1e-8)
2192            .with_scale(false);
2193        assert_eq!(pls.n_components(), 2);
2194    }
2195
2196    #[test]
2197    fn test_plscanonical_invalid_zero_components() {
2198        let x = array![[1.0, 2.0], [3.0, 4.0]];
2199        let y = array![[1.0, 0.5], [2.0, 1.0]];
2200        let pls = PLSCanonical::<f64>::new(0);
2201        assert!(pls.fit(&x, &y).is_err());
2202    }
2203
2204    #[test]
2205    fn test_plscanonical_too_many_components() {
2206        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2207        let y = array![[1.0], [2.0], [3.0]];
2208        let pls = PLSCanonical::<f64>::new(2);
2209        assert!(pls.fit(&x, &y).is_err());
2210    }
2211
2212    #[test]
2213    fn test_plscanonical_row_mismatch() {
2214        let x = array![[1.0, 2.0], [3.0, 4.0]];
2215        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2216        let pls = PLSCanonical::<f64>::new(1);
2217        assert!(pls.fit(&x, &y).is_err());
2218    }
2219
2220    #[test]
2221    fn test_plscanonical_transform_shape_mismatch() {
2222        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2223        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2224        let pls = PLSCanonical::<f64>::new(1);
2225        let fitted = pls.fit(&x, &y).unwrap();
2226        let x_bad = array![[1.0, 2.0, 3.0]];
2227        assert!(fitted.transform(&x_bad).is_err());
2228    }
2229
2230    #[test]
2231    fn test_plscanonical_f32() {
2232        let x: Array2<f32> = array![
2233            [1.0f32, 2.0],
2234            [3.0, 4.0],
2235            [5.0, 6.0],
2236            [7.0, 8.0],
2237            [9.0, 10.0],
2238        ];
2239        let y: Array2<f32> = array![
2240            [1.0f32, 0.5],
2241            [2.0, 1.0],
2242            [3.0, 1.5],
2243            [4.0, 2.0],
2244            [5.0, 2.5],
2245        ];
2246        let pls = PLSCanonical::<f32>::new(1);
2247        let fitted = pls.fit(&x, &y).unwrap();
2248        let scores = fitted.transform(&x).unwrap();
2249        assert_eq!(scores.ncols(), 1);
2250    }
2251
2252    // -----------------------------------------------------------------------
2253    // CCA tests
2254    // -----------------------------------------------------------------------
2255
2256    #[test]
2257    fn test_cca_basic_fit_transform() {
2258        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2259        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2260        let cca = CCA::<f64>::new(2);
2261        let fitted = cca.fit(&x, &y).unwrap();
2262        let scores = fitted.transform(&x).unwrap();
2263        assert_eq!(scores.dim(), (5, 2));
2264    }
2265
2266    #[test]
2267    fn test_cca_single_component() {
2268        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2269        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2270        let cca = CCA::<f64>::new(1);
2271        let fitted = cca.fit(&x, &y).unwrap();
2272        let scores = fitted.transform(&x).unwrap();
2273        assert_eq!(scores.ncols(), 1);
2274    }
2275
2276    #[test]
2277    fn test_cca_scores_shape() {
2278        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
2279        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2280        let cca = CCA::<f64>::new(2);
2281        let fitted = cca.fit(&x, &y).unwrap();
2282        assert_eq!(fitted.x_scores().dim(), (3, 2));
2283        assert_eq!(fitted.y_scores().dim(), (3, 2));
2284        assert_eq!(fitted.x_weights().dim(), (3, 2));
2285        assert_eq!(fitted.x_loadings().dim(), (3, 2));
2286        assert_eq!(fitted.y_loadings().dim(), (2, 2));
2287    }
2288
2289    #[test]
2290    fn test_cca_transform_y() {
2291        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
2292        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
2293        let cca = CCA::<f64>::new(1);
2294        let fitted = cca.fit(&x, &y).unwrap();
2295        let y_scores = fitted.transform_y(&y).unwrap();
2296        assert_eq!(y_scores.ncols(), 1);
2297    }
2298
2299    #[test]
2300    fn test_cca_builder() {
2301        let cca = CCA::<f64>::new(2)
2302            .with_max_iter(1000)
2303            .with_tol(1e-8)
2304            .with_scale(false);
2305        assert_eq!(cca.n_components(), 2);
2306    }
2307
2308    #[test]
2309    fn test_cca_invalid_zero_components() {
2310        let x = array![[1.0, 2.0], [3.0, 4.0]];
2311        let y = array![[1.0, 0.5], [2.0, 1.0]];
2312        let cca = CCA::<f64>::new(0);
2313        assert!(cca.fit(&x, &y).is_err());
2314    }
2315
2316    #[test]
2317    fn test_cca_too_many_components() {
2318        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2319        let y = array![[1.0], [2.0], [3.0]];
2320        let cca = CCA::<f64>::new(2);
2321        assert!(cca.fit(&x, &y).is_err());
2322    }
2323
2324    #[test]
2325    fn test_cca_row_mismatch() {
2326        let x = array![[1.0, 2.0], [3.0, 4.0]];
2327        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2328        let cca = CCA::<f64>::new(1);
2329        assert!(cca.fit(&x, &y).is_err());
2330    }
2331
2332    #[test]
2333    fn test_cca_transform_shape_mismatch() {
2334        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2335        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2336        let cca = CCA::<f64>::new(1);
2337        let fitted = cca.fit(&x, &y).unwrap();
2338        let x_bad = array![[1.0, 2.0, 3.0]];
2339        assert!(fitted.transform(&x_bad).is_err());
2340    }
2341
2342    #[test]
2343    fn test_cca_f32() {
2344        let x: Array2<f32> = array![
2345            [1.0f32, 2.0],
2346            [3.0, 4.0],
2347            [5.0, 6.0],
2348            [7.0, 8.0],
2349            [9.0, 10.0],
2350        ];
2351        let y: Array2<f32> = array![
2352            [1.0f32, 0.5],
2353            [2.0, 1.0],
2354            [3.0, 1.5],
2355            [4.0, 2.0],
2356            [5.0, 2.5],
2357        ];
2358        let cca = CCA::<f32>::new(1);
2359        let fitted = cca.fit(&x, &y).unwrap();
2360        let scores = fitted.transform(&x).unwrap();
2361        assert_eq!(scores.ncols(), 1);
2362    }
2363
2364    // -----------------------------------------------------------------------
2365    // Cross-cutting tests
2366    // -----------------------------------------------------------------------
2367
2368    #[test]
2369    fn test_pls_regression_and_canonical_give_different_scores() {
2370        let x = array![
2371            [1.0, 2.0, 0.5],
2372            [3.0, 1.0, 2.5],
2373            [5.0, 6.0, 1.0],
2374            [7.0, 3.0, 4.5],
2375            [9.0, 10.0, 2.0],
2376        ];
2377        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2378
2379        let pls_reg = PLSRegression::<f64>::new(2);
2380        let fitted_reg = pls_reg.fit(&x, &y).unwrap();
2381        let scores_reg = fitted_reg.transform(&x).unwrap();
2382
2383        let pls_can = PLSCanonical::<f64>::new(2);
2384        let fitted_can = pls_can.fit(&x, &y).unwrap();
2385        let scores_can = fitted_can.transform(&x).unwrap();
2386
2387        // They should produce different results (different deflation).
2388        let diff: f64 = scores_reg
2389            .iter()
2390            .zip(scores_can.iter())
2391            .map(|(a, b)| (a - b).abs())
2392            .sum();
2393        // The scores should not be identical (unless the data is degenerate).
2394        // We just check they are both valid matrices.
2395        assert_eq!(scores_reg.dim(), scores_can.dim());
2396        // In practice diff may be zero for some data; just check no NaN.
2397        assert!(diff.is_finite());
2398    }
2399
2400    #[test]
2401    fn test_centre_scale_helper() {
2402        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2403        let (xc, mean, std_dev) = centre_scale(&x, true);
2404        assert_abs_diff_eq!(mean[0], 3.0, epsilon = 1e-10);
2405        assert_abs_diff_eq!(mean[1], 4.0, epsilon = 1e-10);
2406        assert!(std_dev.is_some());
2407
2408        // Centred data should have zero mean.
2409        let col_mean_0: f64 = xc.column(0).iter().sum::<f64>() / 3.0;
2410        assert_abs_diff_eq!(col_mean_0, 0.0, epsilon = 1e-10);
2411    }
2412
2413    #[test]
2414    fn test_centre_scale_no_scale() {
2415        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2416        let (_xc, _mean, std_dev) = centre_scale(&x, false);
2417        assert!(std_dev.is_none());
2418    }
2419
2420    #[test]
2421    fn test_invert_square_identity() {
2422        let eye = Array2::<f64>::from_shape_fn((3, 3), |(i, j)| if i == j { 1.0 } else { 0.0 });
2423        let inv = invert_square(&eye).unwrap();
2424        for i in 0..3 {
2425            for j in 0..3 {
2426                let expected = if i == j { 1.0 } else { 0.0 };
2427                assert_abs_diff_eq!(inv[[i, j]], expected, epsilon = 1e-10);
2428            }
2429        }
2430    }
2431
2432    #[test]
2433    fn test_invert_square_2x2() {
2434        let a = array![[4.0, 7.0], [2.0, 6.0]];
2435        let inv = invert_square(&a).unwrap();
2436        // A * A^{-1} should be identity.
2437        let prod = a.dot(&inv);
2438        for i in 0..2 {
2439            for j in 0..2 {
2440                let expected = if i == j { 1.0 } else { 0.0 };
2441                assert_abs_diff_eq!(prod[[i, j]], expected, epsilon = 1e-10);
2442            }
2443        }
2444    }
2445}