Skip to main content

ferrolearn_decomp/
cross_decomposition.rs

1//! Cross-decomposition methods: PLS, CCA, and PLSSVD.
2//!
3//! This module provides Partial Least Squares (PLS) and Canonical Correlation
4//! Analysis (CCA) methods for modelling the relationship between two multivariate
5//! datasets X and Y.
6//!
7//! # Algorithms
8//!
9//! - [`PLSSVD`] — SVD of the cross-covariance matrix. The simplest PLS variant;
10//!   computes weight matrices from the leading singular vectors of `X^T Y`.
11//! - [`PLSRegression`] — PLS via the NIPALS algorithm. Maximises covariance
12//!   between X-scores and Y-scores, with asymmetric deflation suitable for
13//!   regression.
14//! - [`PLSCanonical`] — Canonical PLS via NIPALS. Symmetric deflation of both
15//!   X and Y.
16//! - [`CCA`] — Canonical Correlation Analysis via NIPALS. Maximises
17//!   *correlation* (not covariance) between X-scores and Y-scores by
18//!   normalising scores to unit variance.
19//!
20//! # Examples
21//!
22//! ```
23//! use ferrolearn_decomp::cross_decomposition::PLSRegression;
24//! use ferrolearn_core::traits::{Fit, Predict, Transform};
25//! use ndarray::array;
26//!
27//! let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
28//! let y = array![[1.0], [2.0], [3.0], [4.0]];
29//!
30//! let pls = PLSRegression::<f64>::new(1);
31//! let fitted = pls.fit(&x, &y).unwrap();
32//! let y_pred = fitted.predict(&x).unwrap();
33//! assert_eq!(y_pred.ncols(), 1);
34//! let x_scores = fitted.transform(&x).unwrap();
35//! assert_eq!(x_scores.ncols(), 1);
36//! ```
37
38use ferrolearn_core::backend::Backend;
39use ferrolearn_core::backend_faer::NdarrayFaerBackend;
40use ferrolearn_core::error::FerroError;
41use ferrolearn_core::traits::{Fit, Predict, Transform};
42use ndarray::{Array1, Array2};
43use num_traits::Float;
44use std::any::TypeId;
45
46/// Result type for SVD: `(U, S, Vt)`.
47type SvdResult<F> = Result<(Array2<F>, Array1<F>, Array2<F>), FerroError>;
48
49// ---------------------------------------------------------------------------
50// Helper: centre and optionally scale columns of a matrix
51// ---------------------------------------------------------------------------
52
53/// Centre (and optionally scale to unit variance) columns of a matrix.
54///
55/// Returns `(centred_matrix, mean, std)` where `std` is `None` when
56/// `scale` is `false`.
57fn centre_scale<F: Float + Send + Sync + 'static>(
58    x: &Array2<F>,
59    scale: bool,
60) -> (Array2<F>, Array1<F>, Option<Array1<F>>) {
61    let (n_samples, n_features) = x.dim();
62    let n_f = F::from(n_samples).unwrap();
63
64    // Compute column means.
65    let mean = Array1::from_shape_fn(n_features, |j| {
66        x.column(j).iter().copied().fold(F::zero(), |a, b| a + b) / n_f
67    });
68
69    // Centre.
70    let mut xc = x.to_owned();
71    for mut row in xc.rows_mut() {
72        for (v, &m) in row.iter_mut().zip(mean.iter()) {
73            *v = *v - m;
74        }
75    }
76
77    if scale {
78        let n_minus_1 = F::from(n_samples.saturating_sub(1).max(1)).unwrap();
79        let std_dev = Array1::from_shape_fn(n_features, |j| {
80            let var = xc
81                .column(j)
82                .iter()
83                .copied()
84                .fold(F::zero(), |a, b| a + b * b)
85                / n_minus_1;
86            let s = var.sqrt();
87            if s < F::epsilon() { F::one() } else { s }
88        });
89        for mut row in xc.rows_mut() {
90            for (v, &s) in row.iter_mut().zip(std_dev.iter()) {
91                *v = *v / s;
92            }
93        }
94        (xc, mean, Some(std_dev))
95    } else {
96        (xc, mean, None)
97    }
98}
99
100/// Apply centring (and optionally scaling) to new data using stored statistics.
101fn apply_centre_scale<F: Float + Send + Sync + 'static>(
102    x: &Array2<F>,
103    mean: &Array1<F>,
104    std_dev: &Option<Array1<F>>,
105    context: &str,
106) -> Result<Array2<F>, FerroError> {
107    if x.ncols() != mean.len() {
108        return Err(FerroError::ShapeMismatch {
109            expected: vec![x.nrows(), mean.len()],
110            actual: vec![x.nrows(), x.ncols()],
111            context: context.into(),
112        });
113    }
114    let mut xc = x.to_owned();
115    for mut row in xc.rows_mut() {
116        for (v, &m) in row.iter_mut().zip(mean.iter()) {
117            *v = *v - m;
118        }
119    }
120    if let Some(ref s) = *std_dev {
121        for mut row in xc.rows_mut() {
122            for (v, &sd) in row.iter_mut().zip(s.iter()) {
123                *v = *v / sd;
124            }
125        }
126    }
127    Ok(xc)
128}
129
130// ---------------------------------------------------------------------------
131// Helper: SVD dispatch (generic F via f64 fast-path or Jacobi fallback)
132// ---------------------------------------------------------------------------
133
134/// Compute the thin SVD of a general (m x n) matrix.
135///
136/// Returns `(U, S, Vt)` where:
137/// - `U` is `(m, min(m,n))`,
138/// - `S` is `(min(m,n),)`,
139/// - `Vt` is `(min(m,n), n)`.
140///
141/// For `f64` this delegates to `NdarrayFaerBackend::svd` (faer's optimised
142/// routine). For other float types it falls back to a power-iteration approach.
143fn svd_dispatch<F: Float + Send + Sync + 'static>(a: &Array2<F>) -> SvdResult<F> {
144    if TypeId::of::<F>() == TypeId::of::<f64>() {
145        // Cast to f64 and use faer.
146        let a_f64: &Array2<f64> = unsafe { &*(std::ptr::from_ref(a).cast::<Array2<f64>>()) };
147        let (u, s, vt) = NdarrayFaerBackend::svd(a_f64)?;
148        // Thin U and Vt.
149        let k = s.len();
150        let u_thin = u.slice(ndarray::s![.., ..k]).to_owned();
151        let vt_thin = vt.slice(ndarray::s![..k, ..]).to_owned();
152
153        // Cast back to F (which is f64).
154        let u_f: Array2<F> = unsafe { std::mem::transmute_copy::<Array2<f64>, Array2<F>>(&u_thin) };
155        let s_f: Array1<F> = unsafe { std::mem::transmute_copy::<Array1<f64>, Array1<F>>(&s) };
156        let vt_f: Array2<F> =
157            unsafe { std::mem::transmute_copy::<Array2<f64>, Array2<F>>(&vt_thin) };
158        std::mem::forget(u_thin);
159        std::mem::forget(s);
160        std::mem::forget(vt_thin);
161        Ok((u_f, s_f, vt_f))
162    } else if TypeId::of::<F>() == TypeId::of::<f32>() {
163        // Convert f32 -> f64, compute, convert back.
164        let (m, n) = a.dim();
165        let a_f64 =
166            Array2::<f64>::from_shape_fn((m, n), |(i, j)| a[[i, j]].to_f64().unwrap_or(0.0));
167        let (u64, s64, vt64) = NdarrayFaerBackend::svd(&a_f64)?;
168        let k = s64.len();
169        let u_thin = u64.slice(ndarray::s![.., ..k]).to_owned();
170        let vt_thin = vt64.slice(ndarray::s![..k, ..]).to_owned();
171
172        let u_f =
173            Array2::<F>::from_shape_fn(u_thin.dim(), |(i, j)| F::from(u_thin[[i, j]]).unwrap());
174        let s_f = Array1::<F>::from_shape_fn(s64.len(), |i| F::from(s64[i]).unwrap());
175        let vt_f =
176            Array2::<F>::from_shape_fn(vt_thin.dim(), |(i, j)| F::from(vt_thin[[i, j]]).unwrap());
177        Ok((u_f, s_f, vt_f))
178    } else {
179        // Fallback: compute via eigendecomposition of A^T A.
180        svd_via_eigen(a)
181    }
182}
183
184/// Compute SVD via eigendecomposition of `A^T A` (fallback for exotic float types).
185fn svd_via_eigen<F: Float + Send + Sync + 'static>(a: &Array2<F>) -> SvdResult<F> {
186    let (m, n) = a.dim();
187    let k = m.min(n);
188
189    // Compute A^T A.
190    let ata = a.t().dot(a);
191
192    // Jacobi eigendecomposition of A^T A.
193    let max_iter = n * n * 100 + 1000;
194    let (eigenvalues, eigenvectors) = jacobi_eigen_symmetric(&ata, max_iter)?;
195
196    // Sort eigenvalues descending.
197    let mut indices: Vec<usize> = (0..n).collect();
198    indices.sort_by(|&i, &j| {
199        eigenvalues[j]
200            .partial_cmp(&eigenvalues[i])
201            .unwrap_or(std::cmp::Ordering::Equal)
202    });
203
204    // Take top k.
205    let mut s = Array1::<F>::zeros(k);
206    let mut v = Array2::<F>::zeros((n, k));
207    for (col, &idx) in indices.iter().take(k).enumerate() {
208        let eval = eigenvalues[idx];
209        s[col] = if eval > F::zero() {
210            eval.sqrt()
211        } else {
212            F::zero()
213        };
214        for row in 0..n {
215            v[[row, col]] = eigenvectors[[row, idx]];
216        }
217    }
218
219    // U = A V S^{-1}
220    let av = a.dot(&v);
221    let mut u = Array2::<F>::zeros((m, k));
222    for col in 0..k {
223        if s[col] > F::epsilon() {
224            let inv_s = F::one() / s[col];
225            for row in 0..m {
226                u[[row, col]] = av[[row, col]] * inv_s;
227            }
228        }
229    }
230
231    // Vt = V^T
232    let mut vt = Array2::<F>::zeros((k, n));
233    for i in 0..k {
234        for j in 0..n {
235            vt[[i, j]] = v[[j, i]];
236        }
237    }
238
239    Ok((u, s, vt))
240}
241
242/// Jacobi eigendecomposition for symmetric matrices (generic F fallback).
243fn jacobi_eigen_symmetric<F: Float + Send + Sync + 'static>(
244    a: &Array2<F>,
245    max_iter: usize,
246) -> Result<(Array1<F>, Array2<F>), FerroError> {
247    let n = a.nrows();
248    let mut mat = a.to_owned();
249    let mut v = Array2::<F>::zeros((n, n));
250    for i in 0..n {
251        v[[i, i]] = F::one();
252    }
253
254    let tol = F::from(1e-12).unwrap_or(F::epsilon());
255
256    for _iteration in 0..max_iter {
257        let mut max_off = F::zero();
258        let mut p = 0;
259        let mut q = 1;
260        for i in 0..n {
261            for j in (i + 1)..n {
262                let val = mat[[i, j]].abs();
263                if val > max_off {
264                    max_off = val;
265                    p = i;
266                    q = j;
267                }
268            }
269        }
270
271        if max_off < tol {
272            let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
273            return Ok((eigenvalues, v));
274        }
275
276        let app = mat[[p, p]];
277        let aqq = mat[[q, q]];
278        let apq = mat[[p, q]];
279
280        let theta = if (app - aqq).abs() < tol {
281            F::from(std::f64::consts::FRAC_PI_4).unwrap_or(F::one())
282        } else {
283            let tau = (aqq - app) / (F::from(2.0).unwrap() * apq);
284            let t = if tau >= F::zero() {
285                F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
286            } else {
287                -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
288            };
289            t.atan()
290        };
291
292        let c = theta.cos();
293        let s = theta.sin();
294
295        let mut new_mat = mat.clone();
296        for i in 0..n {
297            if i != p && i != q {
298                let mip = mat[[i, p]];
299                let miq = mat[[i, q]];
300                new_mat[[i, p]] = c * mip - s * miq;
301                new_mat[[p, i]] = new_mat[[i, p]];
302                new_mat[[i, q]] = s * mip + c * miq;
303                new_mat[[q, i]] = new_mat[[i, q]];
304            }
305        }
306
307        new_mat[[p, p]] = c * c * app - F::from(2.0).unwrap() * s * c * apq + s * s * aqq;
308        new_mat[[q, q]] = s * s * app + F::from(2.0).unwrap() * s * c * apq + c * c * aqq;
309        new_mat[[p, q]] = F::zero();
310        new_mat[[q, p]] = F::zero();
311
312        mat = new_mat;
313
314        for i in 0..n {
315            let vip = v[[i, p]];
316            let viq = v[[i, q]];
317            v[[i, p]] = c * vip - s * viq;
318            v[[i, q]] = s * vip + c * viq;
319        }
320    }
321
322    Err(FerroError::ConvergenceFailure {
323        iterations: max_iter,
324        message: "Jacobi eigendecomposition did not converge in cross_decomposition SVD fallback"
325            .into(),
326    })
327}
328
329// ---------------------------------------------------------------------------
330// Helper: vector norm and dot
331// ---------------------------------------------------------------------------
332
333/// L2 norm of a 1-D array.
334fn norm<F: Float>(v: &Array1<F>) -> F {
335    v.iter().copied().fold(F::zero(), |a, b| a + b * b).sqrt()
336}
337
338/// Dot product of two 1-D arrays.
339fn dot<F: Float>(a: &Array1<F>, b: &Array1<F>) -> F {
340    a.iter()
341        .copied()
342        .zip(b.iter().copied())
343        .fold(F::zero(), |acc, (x, y)| acc + x * y)
344}
345
346// ---------------------------------------------------------------------------
347// Helper: solve (P^T W) inverse for PLSRegression predict
348// ---------------------------------------------------------------------------
349
350/// Solve `(P^T W)^{-1}` for a square matrix using Gaussian elimination
351/// with partial pivoting (generic float).
352///
353/// When a pivot is too small (near-singular), it is regularised with a
354/// small perturbation to avoid hard failures. This matches the behaviour
355/// of scikit-learn, which uses `pinv` for the rotation matrix.
356fn invert_square<F: Float + Send + Sync + 'static>(a: &Array2<F>) -> Result<Array2<F>, FerroError> {
357    let n = a.nrows();
358    if n != a.ncols() {
359        return Err(FerroError::ShapeMismatch {
360            expected: vec![n, n],
361            actual: vec![a.nrows(), a.ncols()],
362            context: "invert_square: matrix must be square".into(),
363        });
364    }
365
366    // Augmented matrix [A | I].
367    let mut aug = Array2::<F>::zeros((n, 2 * n));
368    for i in 0..n {
369        for j in 0..n {
370            aug[[i, j]] = a[[i, j]];
371        }
372        aug[[i, n + i]] = F::one();
373    }
374
375    // Compute a tolerance based on the matrix norm.
376    let max_abs = a.iter().copied().fold(F::zero(), |m, v| {
377        let abs = v.abs();
378        if abs > m { abs } else { m }
379    });
380    let regularise_tol =
381        max_abs * F::from(1e-12).unwrap_or(F::epsilon()) + F::from(1e-15).unwrap_or(F::epsilon());
382
383    // Forward elimination with partial pivoting.
384    for col in 0..n {
385        // Find pivot.
386        let mut max_val = aug[[col, col]].abs();
387        let mut max_row = col;
388        for row in (col + 1)..n {
389            let val = aug[[row, col]].abs();
390            if val > max_val {
391                max_val = val;
392                max_row = row;
393            }
394        }
395
396        // Regularise if pivot is too small.
397        if max_val < regularise_tol {
398            aug[[col, col]] = regularise_tol;
399        } else {
400            // Swap rows.
401            if max_row != col {
402                for j in 0..(2 * n) {
403                    let tmp = aug[[col, j]];
404                    aug[[col, j]] = aug[[max_row, j]];
405                    aug[[max_row, j]] = tmp;
406                }
407            }
408        }
409
410        // Eliminate below.
411        let pivot = aug[[col, col]];
412        for row in (col + 1)..n {
413            let factor = aug[[row, col]] / pivot;
414            for j in col..(2 * n) {
415                let above = aug[[col, j]];
416                aug[[row, j]] = aug[[row, j]] - factor * above;
417            }
418        }
419    }
420
421    // Back substitution.
422    for col in (0..n).rev() {
423        let pivot = aug[[col, col]];
424        for j in 0..(2 * n) {
425            aug[[col, j]] = aug[[col, j]] / pivot;
426        }
427        for row in 0..col {
428            let factor = aug[[row, col]];
429            for j in 0..(2 * n) {
430                let below = aug[[col, j]];
431                aug[[row, j]] = aug[[row, j]] - factor * below;
432            }
433        }
434    }
435
436    // Extract inverse.
437    let mut inv = Array2::<F>::zeros((n, n));
438    for i in 0..n {
439        for j in 0..n {
440            inv[[i, j]] = aug[[i, n + j]];
441        }
442    }
443    Ok(inv)
444}
445
446// ===========================================================================
447// PLSSVD
448// ===========================================================================
449
450/// PLS via Singular Value Decomposition of the cross-covariance matrix.
451///
452/// This is the simplest PLS variant. It computes the weight matrices by
453/// taking the leading left and right singular vectors of `X^T Y` after
454/// optional centring and scaling.
455///
456/// Unlike [`PLSRegression`], PLSSVD does not iterate; it is a single
457/// matrix decomposition. It cannot predict Y from X — use
458/// [`PLSRegression`] if you need a `predict` method.
459///
460/// # Type Parameters
461///
462/// - `F`: The floating-point scalar type.
463///
464/// # Examples
465///
466/// ```
467/// use ferrolearn_decomp::cross_decomposition::PLSSVD;
468/// use ferrolearn_core::traits::{Fit, Transform};
469/// use ndarray::array;
470///
471/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
472/// let y = array![[1.0], [2.0], [3.0], [4.0]];
473/// let svd = PLSSVD::<f64>::new(1);
474/// let fitted = svd.fit(&x, &y).unwrap();
475/// let scores = fitted.transform(&x).unwrap();
476/// assert_eq!(scores.ncols(), 1);
477/// ```
478#[derive(Debug, Clone)]
479pub struct PLSSVD<F> {
480    /// Number of components to extract.
481    n_components: usize,
482    /// Whether to scale X and Y to unit variance before decomposition.
483    scale: bool,
484    _marker: std::marker::PhantomData<F>,
485}
486
487impl<F: Float + Send + Sync + 'static> PLSSVD<F> {
488    /// Create a new `PLSSVD` that extracts `n_components` components.
489    #[must_use]
490    pub fn new(n_components: usize) -> Self {
491        Self {
492            n_components,
493            scale: true,
494            _marker: std::marker::PhantomData,
495        }
496    }
497
498    /// Set whether to scale X and Y to unit variance (default: `true`).
499    #[must_use]
500    pub fn with_scale(mut self, scale: bool) -> Self {
501        self.scale = scale;
502        self
503    }
504
505    /// Return the number of components.
506    #[must_use]
507    pub fn n_components(&self) -> usize {
508        self.n_components
509    }
510}
511
512/// A fitted [`PLSSVD`] model.
513///
514/// Holds the learned weight matrices and centring/scaling statistics.
515/// Implements [`Transform`] to project X data onto the PLS score space.
516#[derive(Debug, Clone)]
517pub struct FittedPLSSVD<F> {
518    /// X-weights, shape `(n_features_x, n_components)`.
519    x_weights_: Array2<F>,
520    /// Y-weights, shape `(n_features_y, n_components)`.
521    y_weights_: Array2<F>,
522    /// Per-feature mean of X.
523    x_mean_: Array1<F>,
524    /// Per-feature mean of Y.
525    y_mean_: Array1<F>,
526    /// Per-feature standard deviation of X (None if not scaled).
527    x_std_: Option<Array1<F>>,
528    /// Per-feature standard deviation of Y (None if not scaled).
529    y_std_: Option<Array1<F>>,
530}
531
532impl<F: Float + Send + Sync + 'static> FittedPLSSVD<F> {
533    /// X-weights matrix, shape `(n_features_x, n_components)`.
534    #[must_use]
535    pub fn x_weights(&self) -> &Array2<F> {
536        &self.x_weights_
537    }
538
539    /// Y-weights matrix, shape `(n_features_y, n_components)`.
540    #[must_use]
541    pub fn y_weights(&self) -> &Array2<F> {
542        &self.y_weights_
543    }
544
545    /// Per-feature mean of X learned during fitting.
546    #[must_use]
547    pub fn x_mean(&self) -> &Array1<F> {
548        &self.x_mean_
549    }
550
551    /// Per-feature mean of Y learned during fitting.
552    #[must_use]
553    pub fn y_mean(&self) -> &Array1<F> {
554        &self.y_mean_
555    }
556
557    /// Transform Y data onto the Y-score space.
558    ///
559    /// # Errors
560    ///
561    /// Returns [`FerroError::ShapeMismatch`] if Y has the wrong number of columns.
562    pub fn transform_y(&self, y: &Array2<F>) -> Result<Array2<F>, FerroError> {
563        let yc = apply_centre_scale(y, &self.y_mean_, &self.y_std_, "FittedPLSSVD::transform_y")?;
564        Ok(yc.dot(&self.y_weights_))
565    }
566}
567
568impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for PLSSVD<F> {
569    type Fitted = FittedPLSSVD<F>;
570    type Error = FerroError;
571
572    /// Fit PLSSVD by computing the SVD of the cross-covariance matrix `X^T Y`.
573    ///
574    /// # Errors
575    ///
576    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or exceeds
577    ///   `min(n_features_x, n_features_y)`.
578    /// - [`FerroError::InsufficientSamples`] if there are fewer than 2 samples.
579    /// - [`FerroError::ShapeMismatch`] if X and Y have different numbers of rows.
580    fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedPLSSVD<F>, FerroError> {
581        let (n_samples_x, n_features_x) = x.dim();
582        let (n_samples_y, n_features_y) = y.dim();
583
584        if n_samples_x != n_samples_y {
585            return Err(FerroError::ShapeMismatch {
586                expected: vec![n_samples_x, n_features_y],
587                actual: vec![n_samples_y, n_features_y],
588                context: "PLSSVD::fit: X and Y must have the same number of rows".into(),
589            });
590        }
591
592        if self.n_components == 0 {
593            return Err(FerroError::InvalidParameter {
594                name: "n_components".into(),
595                reason: "must be at least 1".into(),
596            });
597        }
598
599        let max_components = n_features_x.min(n_features_y);
600        if self.n_components > max_components {
601            return Err(FerroError::InvalidParameter {
602                name: "n_components".into(),
603                reason: format!(
604                    "n_components ({}) exceeds min(n_features_x, n_features_y) ({})",
605                    self.n_components, max_components
606                ),
607            });
608        }
609
610        if n_samples_x < 2 {
611            return Err(FerroError::InsufficientSamples {
612                required: 2,
613                actual: n_samples_x,
614                context: "PLSSVD::fit requires at least 2 samples".into(),
615            });
616        }
617
618        // Centre and optionally scale.
619        let (xc, x_mean, x_std) = centre_scale(x, self.scale);
620        let (yc, y_mean, y_std) = centre_scale(y, self.scale);
621
622        // Cross-covariance: C = X^T Y.
623        let c = xc.t().dot(&yc);
624
625        // SVD of C.
626        let (u, _s, vt) = svd_dispatch(&c)?;
627
628        // Take first n_components columns of U, rows of Vt (= columns of V).
629        let nc = self.n_components;
630        let x_weights = u.slice(ndarray::s![.., ..nc]).to_owned();
631        // V = Vt^T, so columns of V = rows of Vt transposed.
632        let y_weights = vt.t().slice(ndarray::s![.., ..nc]).to_owned();
633
634        Ok(FittedPLSSVD {
635            x_weights_: x_weights,
636            y_weights_: y_weights,
637            x_mean_: x_mean,
638            y_mean_: y_mean,
639            x_std_: x_std,
640            y_std_: y_std,
641        })
642    }
643}
644
645impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedPLSSVD<F> {
646    type Output = Array2<F>;
647    type Error = FerroError;
648
649    /// Project X data onto the PLS score space: `(X - x_mean) / x_std @ x_weights`.
650    ///
651    /// # Errors
652    ///
653    /// Returns [`FerroError::ShapeMismatch`] if X has the wrong number of columns.
654    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
655        let xc = apply_centre_scale(x, &self.x_mean_, &self.x_std_, "FittedPLSSVD::transform")?;
656        Ok(xc.dot(&self.x_weights_))
657    }
658}
659
660// ===========================================================================
661// NIPALS mode enum (shared by PLSRegression, PLSCanonical, CCA)
662// ===========================================================================
663
664/// Internal NIPALS deflation mode.
665#[derive(Debug, Clone, Copy, PartialEq, Eq)]
666enum NipalsMode {
667    /// Regression: deflate Y with X-scores (`Y = Y - t q^T`).
668    Regression,
669    /// Canonical: deflate Y with its own scores (`Y = Y - u c^T`).
670    Canonical,
671}
672
673/// Internal flag: whether to normalise scores to unit variance (CCA).
674#[derive(Debug, Clone, Copy, PartialEq, Eq)]
675enum ScoreNorm {
676    /// Do not normalise scores (PLS).
677    None,
678    /// Normalise scores to unit variance (CCA).
679    UnitVariance,
680}
681
682// ---------------------------------------------------------------------------
683// NIPALS core algorithm
684// ---------------------------------------------------------------------------
685
686/// Result from a NIPALS fit.
687#[derive(Debug, Clone)]
688struct NipalsResult<F> {
689    /// X-weights W, shape `(n_features_x, n_components)` — columns are weight vectors.
690    x_weights: Array2<F>,
691    /// X-loadings P, shape `(n_features_x, n_components)`.
692    x_loadings: Array2<F>,
693    /// X-scores T, shape `(n_samples, n_components)`.
694    x_scores: Array2<F>,
695    /// Y-loadings Q, shape `(n_features_y, n_components)`.
696    y_loadings: Array2<F>,
697    /// Y-scores U, shape `(n_samples, n_components)`.
698    y_scores: Array2<F>,
699    /// Number of iterations per component.
700    n_iter: Vec<usize>,
701}
702
703/// Run the NIPALS algorithm.
704fn nipals<F: Float + Send + Sync + 'static>(
705    x: &Array2<F>,
706    y: &Array2<F>,
707    n_components: usize,
708    max_iter: usize,
709    tol: F,
710    mode: NipalsMode,
711    score_norm: ScoreNorm,
712) -> Result<NipalsResult<F>, FerroError> {
713    let (n_samples, n_features_x) = x.dim();
714    let n_features_y = y.ncols();
715
716    let mut xk = x.to_owned();
717    let mut yk = y.to_owned();
718
719    let mut x_weights = Array2::<F>::zeros((n_features_x, n_components));
720    let mut x_loadings = Array2::<F>::zeros((n_features_x, n_components));
721    let mut x_scores = Array2::<F>::zeros((n_samples, n_components));
722    let mut y_loadings = Array2::<F>::zeros((n_features_y, n_components));
723    let mut y_scores = Array2::<F>::zeros((n_samples, n_components));
724    let mut n_iter_vec = Vec::with_capacity(n_components);
725
726    for k in 0..n_components {
727        // Initialise u = column of Y with max variance.
728        let best_col = (0..n_features_y)
729            .max_by(|&a, &b| {
730                let var_a: F = yk
731                    .column(a)
732                    .iter()
733                    .copied()
734                    .fold(F::zero(), |s, v| s + v * v);
735                let var_b: F = yk
736                    .column(b)
737                    .iter()
738                    .copied()
739                    .fold(F::zero(), |s, v| s + v * v);
740                var_a
741                    .partial_cmp(&var_b)
742                    .unwrap_or(std::cmp::Ordering::Equal)
743            })
744            .unwrap_or(0);
745
746        let mut u = yk.column(best_col).to_owned();
747
748        let mut converged = false;
749        let mut iters = 0;
750
751        for iteration in 0..max_iter {
752            iters = iteration + 1;
753
754            // w = X^T u / (u^T u)
755            let utu = dot(&u, &u);
756            let mut w = xk.t().dot(&u);
757            if utu > F::epsilon() {
758                w.mapv_inplace(|v| v / utu);
759            }
760            // Normalise w.
761            let w_norm = norm(&w);
762            if w_norm < F::epsilon() {
763                // Degenerate: zero weight vector.
764                break;
765            }
766            w.mapv_inplace(|v| v / w_norm);
767
768            // t = X w
769            let t = xk.dot(&w);
770
771            // q = Y^T t / (t^T t)
772            let ttt = dot(&t, &t);
773            let mut q = yk.t().dot(&t);
774            if ttt > F::epsilon() {
775                q.mapv_inplace(|v| v / ttt);
776            }
777
778            // For CCA: normalise q.
779            if score_norm == ScoreNorm::UnitVariance {
780                let q_norm = norm(&q);
781                if q_norm > F::epsilon() {
782                    q.mapv_inplace(|v| v / q_norm);
783                }
784            }
785
786            // u_new = Y q / (q^T q)
787            let qtq = dot(&q, &q);
788            let mut u_new = yk.dot(&q);
789            if qtq > F::epsilon() {
790                u_new.mapv_inplace(|v| v / qtq);
791            }
792
793            // For CCA: normalise t and u to unit variance.
794            // (This is done after the loop for storing; here we just check convergence.)
795
796            // Convergence check: ||u_new - u|| / ||u_new||.
797            let diff_norm = {
798                let diff: Array1<F> = &u_new - &u;
799                norm(&diff)
800            };
801            let u_new_norm = norm(&u_new);
802
803            u = u_new;
804
805            if u_new_norm > F::epsilon() && diff_norm / u_new_norm < tol {
806                converged = true;
807                // Recompute final w, t, q with converged u.
808                // w = X^T u / (u^T u), normalised
809                let utu2 = dot(&u, &u);
810                w = xk.t().dot(&u);
811                if utu2 > F::epsilon() {
812                    w.mapv_inplace(|v| v / utu2);
813                }
814                let w_norm2 = norm(&w);
815                if w_norm2 > F::epsilon() {
816                    w.mapv_inplace(|v| v / w_norm2);
817                }
818                // t = X w
819                let t_final = xk.dot(&w);
820                let ttt2 = dot(&t_final, &t_final);
821                q = yk.t().dot(&t_final);
822                if ttt2 > F::epsilon() {
823                    q.mapv_inplace(|v| v / ttt2);
824                }
825                if score_norm == ScoreNorm::UnitVariance {
826                    let q_norm2 = norm(&q);
827                    if q_norm2 > F::epsilon() {
828                        q.mapv_inplace(|v| v / q_norm2);
829                    }
830                }
831                let qtq2 = dot(&q, &q);
832                u = yk.dot(&q);
833                if qtq2 > F::epsilon() {
834                    u.mapv_inplace(|v| v / qtq2);
835                }
836                break;
837            }
838        }
839
840        // Compute final scores and loadings with the converged weights.
841        let utu_final = dot(&u, &u);
842        let mut w_final = xk.t().dot(&u);
843        if utu_final > F::epsilon() {
844            w_final.mapv_inplace(|v| v / utu_final);
845        }
846        let w_norm_final = norm(&w_final);
847        if w_norm_final > F::epsilon() {
848            w_final.mapv_inplace(|v| v / w_norm_final);
849        }
850
851        let mut t_final = xk.dot(&w_final);
852        let ttt_final = dot(&t_final, &t_final);
853
854        // p = X^T t / (t^T t)
855        let mut p = xk.t().dot(&t_final);
856        if ttt_final > F::epsilon() {
857            p.mapv_inplace(|v| v / ttt_final);
858        }
859
860        // q = Y^T t / (t^T t)
861        let mut q_final = yk.t().dot(&t_final);
862        if ttt_final > F::epsilon() {
863            q_final.mapv_inplace(|v| v / ttt_final);
864        }
865
866        if score_norm == ScoreNorm::UnitVariance {
867            let q_norm = norm(&q_final);
868            if q_norm > F::epsilon() {
869                q_final.mapv_inplace(|v| v / q_norm);
870            }
871        }
872
873        let qtq_final = dot(&q_final, &q_final);
874        let mut u_final = yk.dot(&q_final);
875        if qtq_final > F::epsilon() {
876            u_final.mapv_inplace(|v| v / qtq_final);
877        }
878
879        // For CCA: normalise t and u to unit variance.
880        if score_norm == ScoreNorm::UnitVariance {
881            let t_std = {
882                let t_mean = t_final.iter().copied().fold(F::zero(), |a, b| a + b)
883                    / F::from(n_samples).unwrap();
884                let var = t_final
885                    .iter()
886                    .copied()
887                    .fold(F::zero(), |a, b| a + (b - t_mean) * (b - t_mean))
888                    / F::from(n_samples.saturating_sub(1).max(1)).unwrap();
889                var.sqrt()
890            };
891            if t_std > F::epsilon() {
892                t_final.mapv_inplace(|v| v / t_std);
893            }
894
895            let u_std = {
896                let u_mean = u_final.iter().copied().fold(F::zero(), |a, b| a + b)
897                    / F::from(n_samples).unwrap();
898                let var = u_final
899                    .iter()
900                    .copied()
901                    .fold(F::zero(), |a, b| a + (b - u_mean) * (b - u_mean))
902                    / F::from(n_samples.saturating_sub(1).max(1)).unwrap();
903                var.sqrt()
904            };
905            if u_std > F::epsilon() {
906                u_final.mapv_inplace(|v| v / u_std);
907            }
908        }
909
910        // Store component k.
911        x_weights.column_mut(k).assign(&w_final);
912        x_loadings.column_mut(k).assign(&p);
913        x_scores.column_mut(k).assign(&t_final);
914        y_loadings.column_mut(k).assign(&q_final);
915        y_scores.column_mut(k).assign(&u_final);
916
917        // Deflate X: X = X - t p^T.
918        for i in 0..n_samples {
919            let ti = t_final[i];
920            for j in 0..n_features_x {
921                xk[[i, j]] = xk[[i, j]] - ti * p[j];
922            }
923        }
924
925        // Deflate Y.
926        match mode {
927            NipalsMode::Regression => {
928                // Y = Y - t q^T (deflate with X-scores).
929                for i in 0..n_samples {
930                    let ti = t_final[i];
931                    for j in 0..n_features_y {
932                        yk[[i, j]] = yk[[i, j]] - ti * q_final[j];
933                    }
934                }
935            }
936            NipalsMode::Canonical => {
937                // Y = Y - u c^T where c = Y^T u / (u^T u).
938                let utu_c = dot(&u_final, &u_final);
939                let mut c = yk.t().dot(&u_final);
940                if utu_c > F::epsilon() {
941                    c.mapv_inplace(|v| v / utu_c);
942                }
943                for i in 0..n_samples {
944                    let ui = u_final[i];
945                    for j in 0..n_features_y {
946                        yk[[i, j]] = yk[[i, j]] - ui * c[j];
947                    }
948                }
949            }
950        }
951
952        n_iter_vec.push(iters);
953
954        if !converged && n_features_y > 1 && iters == max_iter {
955            return Err(FerroError::ConvergenceFailure {
956                iterations: max_iter,
957                message: format!("NIPALS did not converge for component {k}"),
958            });
959        }
960    }
961
962    Ok(NipalsResult {
963        x_weights,
964        x_loadings,
965        x_scores,
966        y_loadings,
967        y_scores,
968        n_iter: n_iter_vec,
969    })
970}
971
972// ===========================================================================
973// PLSRegression
974// ===========================================================================
975
976/// Partial Least Squares Regression via the NIPALS algorithm.
977///
978/// PLSRegression finds latent components that maximise the covariance
979/// between X-scores and Y-scores, with asymmetric deflation of Y using
980/// X-scores. This is the standard PLS2 algorithm for multi-target
981/// regression.
982///
983/// # Type Parameters
984///
985/// - `F`: The floating-point scalar type.
986///
987/// # Examples
988///
989/// ```
990/// use ferrolearn_decomp::cross_decomposition::PLSRegression;
991/// use ferrolearn_core::traits::{Fit, Predict, Transform};
992/// use ndarray::array;
993///
994/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
995/// let y = array![[1.0], [2.0], [3.0], [4.0]];
996///
997/// let pls = PLSRegression::<f64>::new(1);
998/// let fitted = pls.fit(&x, &y).unwrap();
999///
1000/// let y_pred = fitted.predict(&x).unwrap();
1001/// assert_eq!(y_pred.ncols(), 1);
1002///
1003/// let scores = fitted.transform(&x).unwrap();
1004/// assert_eq!(scores.ncols(), 1);
1005/// ```
1006#[derive(Debug, Clone)]
1007pub struct PLSRegression<F> {
1008    /// Number of PLS components to extract.
1009    n_components: usize,
1010    /// Maximum NIPALS iterations per component.
1011    max_iter: usize,
1012    /// Convergence tolerance for NIPALS.
1013    tol: F,
1014    /// Whether to scale X and Y to unit variance.
1015    scale: bool,
1016    _marker: std::marker::PhantomData<F>,
1017}
1018
1019impl<F: Float + Send + Sync + 'static> PLSRegression<F> {
1020    /// Create a new `PLSRegression` with `n_components` components.
1021    ///
1022    /// Defaults: `max_iter = 500`, `tol = 1e-6`, `scale = true`.
1023    #[must_use]
1024    pub fn new(n_components: usize) -> Self {
1025        Self {
1026            n_components,
1027            max_iter: 500,
1028            tol: F::from(1e-6).unwrap_or(F::epsilon()),
1029            scale: true,
1030            _marker: std::marker::PhantomData,
1031        }
1032    }
1033
1034    /// Set the maximum number of NIPALS iterations per component.
1035    #[must_use]
1036    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1037        self.max_iter = max_iter;
1038        self
1039    }
1040
1041    /// Set the NIPALS convergence tolerance.
1042    #[must_use]
1043    pub fn with_tol(mut self, tol: F) -> Self {
1044        self.tol = tol;
1045        self
1046    }
1047
1048    /// Set whether to scale X and Y to unit variance (default: `true`).
1049    #[must_use]
1050    pub fn with_scale(mut self, scale: bool) -> Self {
1051        self.scale = scale;
1052        self
1053    }
1054
1055    /// Return the number of components.
1056    #[must_use]
1057    pub fn n_components(&self) -> usize {
1058        self.n_components
1059    }
1060}
1061
1062/// A fitted [`PLSRegression`] model.
1063///
1064/// Holds the learned weight, loading, and score matrices, plus the
1065/// regression coefficients for prediction. Implements [`Predict`] to
1066/// predict Y from X, and [`Transform`] to project X onto the score space.
1067#[derive(Debug, Clone)]
1068pub struct FittedPLSRegression<F> {
1069    /// X-weights W, shape `(n_features_x, n_components)`.
1070    x_weights_: Array2<F>,
1071    /// X-loadings P, shape `(n_features_x, n_components)`.
1072    x_loadings_: Array2<F>,
1073    /// Y-loadings Q, shape `(n_features_y, n_components)`.
1074    y_loadings_: Array2<F>,
1075    /// Regression coefficients B, shape `(n_features_x, n_features_y)`.
1076    /// B = W (P^T W)^{-1} Q^T.
1077    coefficients_: Array2<F>,
1078    /// X-scores T from training, shape `(n_samples, n_components)`.
1079    x_scores_: Array2<F>,
1080    /// Y-scores U from training, shape `(n_samples, n_components)`.
1081    y_scores_: Array2<F>,
1082    /// Number of iterations per component.
1083    n_iter_: Vec<usize>,
1084    /// Per-feature mean of X.
1085    x_mean_: Array1<F>,
1086    /// Per-feature mean of Y.
1087    y_mean_: Array1<F>,
1088    /// Per-feature std of X (None if not scaled).
1089    x_std_: Option<Array1<F>>,
1090    /// Per-feature std of Y (None if not scaled).
1091    y_std_: Option<Array1<F>>,
1092}
1093
1094impl<F: Float + Send + Sync + 'static> FittedPLSRegression<F> {
1095    /// X-weights matrix W, shape `(n_features_x, n_components)`.
1096    #[must_use]
1097    pub fn x_weights(&self) -> &Array2<F> {
1098        &self.x_weights_
1099    }
1100
1101    /// X-loadings matrix P, shape `(n_features_x, n_components)`.
1102    #[must_use]
1103    pub fn x_loadings(&self) -> &Array2<F> {
1104        &self.x_loadings_
1105    }
1106
1107    /// Y-loadings matrix Q, shape `(n_features_y, n_components)`.
1108    #[must_use]
1109    pub fn y_loadings(&self) -> &Array2<F> {
1110        &self.y_loadings_
1111    }
1112
1113    /// Regression coefficient matrix B, shape `(n_features_x, n_features_y)`.
1114    ///
1115    /// `Y_pred = X_centred @ B + y_mean`.
1116    #[must_use]
1117    pub fn coefficients(&self) -> &Array2<F> {
1118        &self.coefficients_
1119    }
1120
1121    /// X-scores T from training, shape `(n_samples, n_components)`.
1122    #[must_use]
1123    pub fn x_scores(&self) -> &Array2<F> {
1124        &self.x_scores_
1125    }
1126
1127    /// Y-scores U from training, shape `(n_samples, n_components)`.
1128    #[must_use]
1129    pub fn y_scores(&self) -> &Array2<F> {
1130        &self.y_scores_
1131    }
1132
1133    /// Number of NIPALS iterations for each component.
1134    #[must_use]
1135    pub fn n_iter(&self) -> &[usize] {
1136        &self.n_iter_
1137    }
1138}
1139
1140impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for PLSRegression<F> {
1141    type Fitted = FittedPLSRegression<F>;
1142    type Error = FerroError;
1143
1144    /// Fit PLSRegression using the NIPALS algorithm.
1145    ///
1146    /// # Errors
1147    ///
1148    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or too large.
1149    /// - [`FerroError::InsufficientSamples`] if fewer than 2 samples.
1150    /// - [`FerroError::ShapeMismatch`] if X and Y have different row counts.
1151    /// - [`FerroError::ConvergenceFailure`] if NIPALS does not converge.
1152    fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedPLSRegression<F>, FerroError> {
1153        let (n_samples_x, n_features_x) = x.dim();
1154        let (n_samples_y, n_features_y) = y.dim();
1155
1156        if n_samples_x != n_samples_y {
1157            return Err(FerroError::ShapeMismatch {
1158                expected: vec![n_samples_x, n_features_y],
1159                actual: vec![n_samples_y, n_features_y],
1160                context: "PLSRegression::fit: X and Y must have the same number of rows".into(),
1161            });
1162        }
1163
1164        if self.n_components == 0 {
1165            return Err(FerroError::InvalidParameter {
1166                name: "n_components".into(),
1167                reason: "must be at least 1".into(),
1168            });
1169        }
1170
1171        let max_components = n_features_x.min(n_features_y).min(n_samples_x);
1172        if self.n_components > max_components {
1173            return Err(FerroError::InvalidParameter {
1174                name: "n_components".into(),
1175                reason: format!(
1176                    "n_components ({}) exceeds min(n_features_x, n_features_y, n_samples) ({})",
1177                    self.n_components, max_components
1178                ),
1179            });
1180        }
1181
1182        if n_samples_x < 2 {
1183            return Err(FerroError::InsufficientSamples {
1184                required: 2,
1185                actual: n_samples_x,
1186                context: "PLSRegression::fit requires at least 2 samples".into(),
1187            });
1188        }
1189
1190        // Centre and optionally scale.
1191        let (xc, x_mean, x_std) = centre_scale(x, self.scale);
1192        let (yc, y_mean, y_std) = centre_scale(y, self.scale);
1193
1194        // Run NIPALS.
1195        let result = nipals(
1196            &xc,
1197            &yc,
1198            self.n_components,
1199            self.max_iter,
1200            self.tol,
1201            NipalsMode::Regression,
1202            ScoreNorm::None,
1203        )?;
1204
1205        // Compute regression coefficients: B = W (P^T W)^{-1} Q^T.
1206        let ptw = result.x_loadings.t().dot(&result.x_weights);
1207        let ptw_inv = invert_square(&ptw)?;
1208        let coefficients = result.x_weights.dot(&ptw_inv).dot(&result.y_loadings.t());
1209
1210        // If we scaled, adjust coefficients to work on the original scale.
1211        // The stored coefficients operate on centred (and scaled) X, producing
1212        // centred (and scaled) Y. We leave them in this internal space and
1213        // apply the scaling in predict/transform.
1214
1215        Ok(FittedPLSRegression {
1216            x_weights_: result.x_weights,
1217            x_loadings_: result.x_loadings,
1218            y_loadings_: result.y_loadings,
1219            coefficients_: coefficients,
1220            x_scores_: result.x_scores,
1221            y_scores_: result.y_scores,
1222            n_iter_: result.n_iter,
1223            x_mean_: x_mean,
1224            y_mean_: y_mean,
1225            x_std_: x_std,
1226            y_std_: y_std,
1227        })
1228    }
1229}
1230
1231impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedPLSRegression<F> {
1232    type Output = Array2<F>;
1233    type Error = FerroError;
1234
1235    /// Predict Y from X using the fitted PLS regression model.
1236    ///
1237    /// Computes `Y_pred = X_centred @ B`, then un-scales and un-centres.
1238    ///
1239    /// # Errors
1240    ///
1241    /// Returns [`FerroError::ShapeMismatch`] if X has the wrong number of columns.
1242    fn predict(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1243        let xc = apply_centre_scale(
1244            x,
1245            &self.x_mean_,
1246            &self.x_std_,
1247            "FittedPLSRegression::predict",
1248        )?;
1249
1250        let mut y_pred = xc.dot(&self.coefficients_);
1251
1252        // Un-scale Y.
1253        if let Some(ref ys) = self.y_std_ {
1254            for mut row in y_pred.rows_mut() {
1255                for (v, &s) in row.iter_mut().zip(ys.iter()) {
1256                    *v = *v * s;
1257                }
1258            }
1259        }
1260
1261        // Un-centre Y.
1262        for mut row in y_pred.rows_mut() {
1263            for (v, &m) in row.iter_mut().zip(self.y_mean_.iter()) {
1264                *v = *v + m;
1265            }
1266        }
1267
1268        Ok(y_pred)
1269    }
1270}
1271
1272impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedPLSRegression<F> {
1273    type Output = Array2<F>;
1274    type Error = FerroError;
1275
1276    /// Project X onto the PLS score space (X-scores).
1277    ///
1278    /// Computes `T = X_centred @ W (P^T W)^{-1}`.
1279    ///
1280    /// # Errors
1281    ///
1282    /// Returns [`FerroError::ShapeMismatch`] if X has the wrong number of columns.
1283    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1284        let xc = apply_centre_scale(
1285            x,
1286            &self.x_mean_,
1287            &self.x_std_,
1288            "FittedPLSRegression::transform",
1289        )?;
1290
1291        // T = X_centred @ W (P^T W)^{-1} = X_centred @ rotation
1292        // But simpler: T = X_centred @ W when W columns are the NIPALS weights.
1293        // Actually the correct transform for new data is via the rotation matrix.
1294        let ptw = self.x_loadings_.t().dot(&self.x_weights_);
1295        let ptw_inv = invert_square(&ptw)?;
1296        let rotation = self.x_weights_.dot(&ptw_inv);
1297        Ok(xc.dot(&rotation))
1298    }
1299}
1300
1301// ===========================================================================
1302// PLSCanonical
1303// ===========================================================================
1304
1305/// Canonical PLS via the NIPALS algorithm.
1306///
1307/// PLSCanonical performs a symmetric decomposition: both X and Y are
1308/// deflated with their own scores. This contrasts with [`PLSRegression`]
1309/// which deflates Y using X-scores.
1310///
1311/// PLSCanonical is appropriate when you want a symmetric analysis of the
1312/// relationship between X and Y, rather than a predictive model.
1313///
1314/// # Type Parameters
1315///
1316/// - `F`: The floating-point scalar type.
1317///
1318/// # Examples
1319///
1320/// ```
1321/// use ferrolearn_decomp::cross_decomposition::PLSCanonical;
1322/// use ferrolearn_core::traits::{Fit, Transform};
1323/// use ndarray::array;
1324///
1325/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1326/// let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1327///
1328/// let pls = PLSCanonical::<f64>::new(2);
1329/// let fitted = pls.fit(&x, &y).unwrap();
1330/// let scores = fitted.transform(&x).unwrap();
1331/// assert_eq!(scores.ncols(), 2);
1332/// ```
1333#[derive(Debug, Clone)]
1334pub struct PLSCanonical<F> {
1335    /// Number of PLS components to extract.
1336    n_components: usize,
1337    /// Maximum NIPALS iterations per component.
1338    max_iter: usize,
1339    /// Convergence tolerance for NIPALS.
1340    tol: F,
1341    /// Whether to scale X and Y to unit variance.
1342    scale: bool,
1343    _marker: std::marker::PhantomData<F>,
1344}
1345
1346impl<F: Float + Send + Sync + 'static> PLSCanonical<F> {
1347    /// Create a new `PLSCanonical` with `n_components` components.
1348    ///
1349    /// Defaults: `max_iter = 500`, `tol = 1e-6`, `scale = true`.
1350    #[must_use]
1351    pub fn new(n_components: usize) -> Self {
1352        Self {
1353            n_components,
1354            max_iter: 500,
1355            tol: F::from(1e-6).unwrap_or(F::epsilon()),
1356            scale: true,
1357            _marker: std::marker::PhantomData,
1358        }
1359    }
1360
1361    /// Set the maximum number of NIPALS iterations per component.
1362    #[must_use]
1363    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1364        self.max_iter = max_iter;
1365        self
1366    }
1367
1368    /// Set the NIPALS convergence tolerance.
1369    #[must_use]
1370    pub fn with_tol(mut self, tol: F) -> Self {
1371        self.tol = tol;
1372        self
1373    }
1374
1375    /// Set whether to scale X and Y to unit variance (default: `true`).
1376    #[must_use]
1377    pub fn with_scale(mut self, scale: bool) -> Self {
1378        self.scale = scale;
1379        self
1380    }
1381
1382    /// Return the number of components.
1383    #[must_use]
1384    pub fn n_components(&self) -> usize {
1385        self.n_components
1386    }
1387}
1388
1389/// A fitted [`PLSCanonical`] model.
1390///
1391/// Holds the learned weight, loading, and score matrices from the
1392/// symmetric NIPALS decomposition. Implements [`Transform`] to project
1393/// X onto the score space.
1394#[derive(Debug, Clone)]
1395pub struct FittedPLSCanonical<F> {
1396    /// X-weights W, shape `(n_features_x, n_components)`.
1397    x_weights_: Array2<F>,
1398    /// X-loadings P, shape `(n_features_x, n_components)`.
1399    x_loadings_: Array2<F>,
1400    /// Y-loadings Q, shape `(n_features_y, n_components)`.
1401    y_loadings_: Array2<F>,
1402    /// X-scores T from training, shape `(n_samples, n_components)`.
1403    x_scores_: Array2<F>,
1404    /// Y-scores U from training, shape `(n_samples, n_components)`.
1405    y_scores_: Array2<F>,
1406    /// Number of iterations per component.
1407    n_iter_: Vec<usize>,
1408    /// Per-feature mean of X.
1409    x_mean_: Array1<F>,
1410    /// Per-feature mean of Y.
1411    y_mean_: Array1<F>,
1412    /// Per-feature std of X (None if not scaled).
1413    x_std_: Option<Array1<F>>,
1414    /// Per-feature std of Y (None if not scaled).
1415    #[allow(dead_code)]
1416    y_std_: Option<Array1<F>>,
1417}
1418
1419impl<F: Float + Send + Sync + 'static> FittedPLSCanonical<F> {
1420    /// X-weights matrix W, shape `(n_features_x, n_components)`.
1421    #[must_use]
1422    pub fn x_weights(&self) -> &Array2<F> {
1423        &self.x_weights_
1424    }
1425
1426    /// X-loadings matrix P, shape `(n_features_x, n_components)`.
1427    #[must_use]
1428    pub fn x_loadings(&self) -> &Array2<F> {
1429        &self.x_loadings_
1430    }
1431
1432    /// Y-loadings matrix Q, shape `(n_features_y, n_components)`.
1433    #[must_use]
1434    pub fn y_loadings(&self) -> &Array2<F> {
1435        &self.y_loadings_
1436    }
1437
1438    /// X-scores T from training, shape `(n_samples, n_components)`.
1439    #[must_use]
1440    pub fn x_scores(&self) -> &Array2<F> {
1441        &self.x_scores_
1442    }
1443
1444    /// Y-scores U from training, shape `(n_samples, n_components)`.
1445    #[must_use]
1446    pub fn y_scores(&self) -> &Array2<F> {
1447        &self.y_scores_
1448    }
1449
1450    /// Number of NIPALS iterations for each component.
1451    #[must_use]
1452    pub fn n_iter(&self) -> &[usize] {
1453        &self.n_iter_
1454    }
1455
1456    /// Transform Y data onto the Y-score space.
1457    ///
1458    /// # Errors
1459    ///
1460    /// Returns [`FerroError::ShapeMismatch`] if Y has the wrong number of columns.
1461    pub fn transform_y(&self, y: &Array2<F>) -> Result<Array2<F>, FerroError> {
1462        let yc = apply_centre_scale(
1463            y,
1464            &self.y_mean_,
1465            &self.y_std_,
1466            "FittedPLSCanonical::transform_y",
1467        )?;
1468        Ok(yc.dot(&self.y_loadings_))
1469    }
1470}
1471
1472impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for PLSCanonical<F> {
1473    type Fitted = FittedPLSCanonical<F>;
1474    type Error = FerroError;
1475
1476    /// Fit PLSCanonical using the NIPALS algorithm with symmetric deflation.
1477    ///
1478    /// # Errors
1479    ///
1480    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or too large.
1481    /// - [`FerroError::InsufficientSamples`] if fewer than 2 samples.
1482    /// - [`FerroError::ShapeMismatch`] if X and Y have different row counts.
1483    /// - [`FerroError::ConvergenceFailure`] if NIPALS does not converge.
1484    fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedPLSCanonical<F>, FerroError> {
1485        let (n_samples_x, n_features_x) = x.dim();
1486        let (n_samples_y, n_features_y) = y.dim();
1487
1488        if n_samples_x != n_samples_y {
1489            return Err(FerroError::ShapeMismatch {
1490                expected: vec![n_samples_x, n_features_y],
1491                actual: vec![n_samples_y, n_features_y],
1492                context: "PLSCanonical::fit: X and Y must have the same number of rows".into(),
1493            });
1494        }
1495
1496        if self.n_components == 0 {
1497            return Err(FerroError::InvalidParameter {
1498                name: "n_components".into(),
1499                reason: "must be at least 1".into(),
1500            });
1501        }
1502
1503        let max_components = n_features_x.min(n_features_y).min(n_samples_x);
1504        if self.n_components > max_components {
1505            return Err(FerroError::InvalidParameter {
1506                name: "n_components".into(),
1507                reason: format!(
1508                    "n_components ({}) exceeds min(n_features_x, n_features_y, n_samples) ({})",
1509                    self.n_components, max_components
1510                ),
1511            });
1512        }
1513
1514        if n_samples_x < 2 {
1515            return Err(FerroError::InsufficientSamples {
1516                required: 2,
1517                actual: n_samples_x,
1518                context: "PLSCanonical::fit requires at least 2 samples".into(),
1519            });
1520        }
1521
1522        let (xc, x_mean, x_std) = centre_scale(x, self.scale);
1523        let (yc, y_mean, y_std) = centre_scale(y, self.scale);
1524
1525        let result = nipals(
1526            &xc,
1527            &yc,
1528            self.n_components,
1529            self.max_iter,
1530            self.tol,
1531            NipalsMode::Canonical,
1532            ScoreNorm::None,
1533        )?;
1534
1535        Ok(FittedPLSCanonical {
1536            x_weights_: result.x_weights,
1537            x_loadings_: result.x_loadings,
1538            y_loadings_: result.y_loadings,
1539            x_scores_: result.x_scores,
1540            y_scores_: result.y_scores,
1541            n_iter_: result.n_iter,
1542            x_mean_: x_mean,
1543            y_mean_: y_mean,
1544            x_std_: x_std,
1545            y_std_: y_std,
1546        })
1547    }
1548}
1549
1550impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedPLSCanonical<F> {
1551    type Output = Array2<F>;
1552    type Error = FerroError;
1553
1554    /// Project X onto the PLS score space (X-scores).
1555    ///
1556    /// # Errors
1557    ///
1558    /// Returns [`FerroError::ShapeMismatch`] if X has the wrong number of columns.
1559    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1560        let xc = apply_centre_scale(
1561            x,
1562            &self.x_mean_,
1563            &self.x_std_,
1564            "FittedPLSCanonical::transform",
1565        )?;
1566
1567        let ptw = self.x_loadings_.t().dot(&self.x_weights_);
1568        let ptw_inv = invert_square(&ptw)?;
1569        let rotation = self.x_weights_.dot(&ptw_inv);
1570        Ok(xc.dot(&rotation))
1571    }
1572}
1573
1574// ===========================================================================
1575// CCA (Canonical Correlation Analysis)
1576// ===========================================================================
1577
1578/// Canonical Correlation Analysis via the NIPALS algorithm.
1579///
1580/// CCA maximises the *correlation* (rather than covariance) between
1581/// X-scores and Y-scores by normalising scores to unit variance after
1582/// each NIPALS iteration. It uses symmetric (canonical) deflation.
1583///
1584/// # Type Parameters
1585///
1586/// - `F`: The floating-point scalar type.
1587///
1588/// # Examples
1589///
1590/// ```
1591/// use ferrolearn_decomp::cross_decomposition::CCA;
1592/// use ferrolearn_core::traits::{Fit, Transform};
1593/// use ndarray::array;
1594///
1595/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1596/// let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1597///
1598/// let cca = CCA::<f64>::new(2);
1599/// let fitted = cca.fit(&x, &y).unwrap();
1600/// let scores = fitted.transform(&x).unwrap();
1601/// assert_eq!(scores.ncols(), 2);
1602/// ```
1603#[derive(Debug, Clone)]
1604pub struct CCA<F> {
1605    /// Number of canonical components to extract.
1606    n_components: usize,
1607    /// Maximum NIPALS iterations per component.
1608    max_iter: usize,
1609    /// Convergence tolerance for NIPALS.
1610    tol: F,
1611    /// Whether to scale X and Y to unit variance.
1612    scale: bool,
1613    _marker: std::marker::PhantomData<F>,
1614}
1615
1616impl<F: Float + Send + Sync + 'static> CCA<F> {
1617    /// Create a new `CCA` with `n_components` components.
1618    ///
1619    /// Defaults: `max_iter = 500`, `tol = 1e-6`, `scale = true`.
1620    #[must_use]
1621    pub fn new(n_components: usize) -> Self {
1622        Self {
1623            n_components,
1624            max_iter: 500,
1625            tol: F::from(1e-6).unwrap_or(F::epsilon()),
1626            scale: true,
1627            _marker: std::marker::PhantomData,
1628        }
1629    }
1630
1631    /// Set the maximum number of NIPALS iterations per component.
1632    #[must_use]
1633    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1634        self.max_iter = max_iter;
1635        self
1636    }
1637
1638    /// Set the NIPALS convergence tolerance.
1639    #[must_use]
1640    pub fn with_tol(mut self, tol: F) -> Self {
1641        self.tol = tol;
1642        self
1643    }
1644
1645    /// Set whether to scale X and Y to unit variance (default: `true`).
1646    #[must_use]
1647    pub fn with_scale(mut self, scale: bool) -> Self {
1648        self.scale = scale;
1649        self
1650    }
1651
1652    /// Return the number of components.
1653    #[must_use]
1654    pub fn n_components(&self) -> usize {
1655        self.n_components
1656    }
1657}
1658
1659/// A fitted [`CCA`] model.
1660///
1661/// Holds the learned weight, loading, and score matrices. Implements
1662/// [`Transform`] to project X onto the canonical score space.
1663#[derive(Debug, Clone)]
1664pub struct FittedCCA<F> {
1665    /// X-weights W, shape `(n_features_x, n_components)`.
1666    x_weights_: Array2<F>,
1667    /// X-loadings P, shape `(n_features_x, n_components)`.
1668    x_loadings_: Array2<F>,
1669    /// Y-loadings Q, shape `(n_features_y, n_components)`.
1670    y_loadings_: Array2<F>,
1671    /// X-scores T from training, shape `(n_samples, n_components)`.
1672    x_scores_: Array2<F>,
1673    /// Y-scores U from training, shape `(n_samples, n_components)`.
1674    y_scores_: Array2<F>,
1675    /// Number of iterations per component.
1676    n_iter_: Vec<usize>,
1677    /// Per-feature mean of X.
1678    x_mean_: Array1<F>,
1679    /// Per-feature mean of Y.
1680    y_mean_: Array1<F>,
1681    /// Per-feature std of X (None if not scaled).
1682    x_std_: Option<Array1<F>>,
1683    /// Per-feature std of Y (None if not scaled).
1684    #[allow(dead_code)]
1685    y_std_: Option<Array1<F>>,
1686}
1687
1688impl<F: Float + Send + Sync + 'static> FittedCCA<F> {
1689    /// X-weights matrix W, shape `(n_features_x, n_components)`.
1690    #[must_use]
1691    pub fn x_weights(&self) -> &Array2<F> {
1692        &self.x_weights_
1693    }
1694
1695    /// X-loadings matrix P, shape `(n_features_x, n_components)`.
1696    #[must_use]
1697    pub fn x_loadings(&self) -> &Array2<F> {
1698        &self.x_loadings_
1699    }
1700
1701    /// Y-loadings matrix Q, shape `(n_features_y, n_components)`.
1702    #[must_use]
1703    pub fn y_loadings(&self) -> &Array2<F> {
1704        &self.y_loadings_
1705    }
1706
1707    /// X-scores T from training, shape `(n_samples, n_components)`.
1708    #[must_use]
1709    pub fn x_scores(&self) -> &Array2<F> {
1710        &self.x_scores_
1711    }
1712
1713    /// Y-scores U from training, shape `(n_samples, n_components)`.
1714    #[must_use]
1715    pub fn y_scores(&self) -> &Array2<F> {
1716        &self.y_scores_
1717    }
1718
1719    /// Number of NIPALS iterations for each component.
1720    #[must_use]
1721    pub fn n_iter(&self) -> &[usize] {
1722        &self.n_iter_
1723    }
1724
1725    /// Transform Y data onto the Y-score space.
1726    ///
1727    /// # Errors
1728    ///
1729    /// Returns [`FerroError::ShapeMismatch`] if Y has the wrong number of columns.
1730    pub fn transform_y(&self, y: &Array2<F>) -> Result<Array2<F>, FerroError> {
1731        let yc = apply_centre_scale(y, &self.y_mean_, &self.y_std_, "FittedCCA::transform_y")?;
1732        Ok(yc.dot(&self.y_loadings_))
1733    }
1734}
1735
1736impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array2<F>> for CCA<F> {
1737    type Fitted = FittedCCA<F>;
1738    type Error = FerroError;
1739
1740    /// Fit CCA using the NIPALS algorithm with score normalisation.
1741    ///
1742    /// # Errors
1743    ///
1744    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or too large.
1745    /// - [`FerroError::InsufficientSamples`] if fewer than 2 samples.
1746    /// - [`FerroError::ShapeMismatch`] if X and Y have different row counts.
1747    /// - [`FerroError::ConvergenceFailure`] if NIPALS does not converge.
1748    fn fit(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedCCA<F>, FerroError> {
1749        let (n_samples_x, n_features_x) = x.dim();
1750        let (n_samples_y, n_features_y) = y.dim();
1751
1752        if n_samples_x != n_samples_y {
1753            return Err(FerroError::ShapeMismatch {
1754                expected: vec![n_samples_x, n_features_y],
1755                actual: vec![n_samples_y, n_features_y],
1756                context: "CCA::fit: X and Y must have the same number of rows".into(),
1757            });
1758        }
1759
1760        if self.n_components == 0 {
1761            return Err(FerroError::InvalidParameter {
1762                name: "n_components".into(),
1763                reason: "must be at least 1".into(),
1764            });
1765        }
1766
1767        let max_components = n_features_x.min(n_features_y).min(n_samples_x);
1768        if self.n_components > max_components {
1769            return Err(FerroError::InvalidParameter {
1770                name: "n_components".into(),
1771                reason: format!(
1772                    "n_components ({}) exceeds min(n_features_x, n_features_y, n_samples) ({})",
1773                    self.n_components, max_components
1774                ),
1775            });
1776        }
1777
1778        if n_samples_x < 2 {
1779            return Err(FerroError::InsufficientSamples {
1780                required: 2,
1781                actual: n_samples_x,
1782                context: "CCA::fit requires at least 2 samples".into(),
1783            });
1784        }
1785
1786        let (xc, x_mean, x_std) = centre_scale(x, self.scale);
1787        let (yc, y_mean, y_std) = centre_scale(y, self.scale);
1788
1789        let result = nipals(
1790            &xc,
1791            &yc,
1792            self.n_components,
1793            self.max_iter,
1794            self.tol,
1795            NipalsMode::Canonical,
1796            ScoreNorm::UnitVariance,
1797        )?;
1798
1799        Ok(FittedCCA {
1800            x_weights_: result.x_weights,
1801            x_loadings_: result.x_loadings,
1802            y_loadings_: result.y_loadings,
1803            x_scores_: result.x_scores,
1804            y_scores_: result.y_scores,
1805            n_iter_: result.n_iter,
1806            x_mean_: x_mean,
1807            y_mean_: y_mean,
1808            x_std_: x_std,
1809            y_std_: y_std,
1810        })
1811    }
1812}
1813
1814impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedCCA<F> {
1815    type Output = Array2<F>;
1816    type Error = FerroError;
1817
1818    /// Project X onto the canonical score space.
1819    ///
1820    /// # Errors
1821    ///
1822    /// Returns [`FerroError::ShapeMismatch`] if X has the wrong number of columns.
1823    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1824        let xc = apply_centre_scale(x, &self.x_mean_, &self.x_std_, "FittedCCA::transform")?;
1825
1826        let ptw = self.x_loadings_.t().dot(&self.x_weights_);
1827        let ptw_inv = invert_square(&ptw)?;
1828        let rotation = self.x_weights_.dot(&ptw_inv);
1829        Ok(xc.dot(&rotation))
1830    }
1831}
1832
1833// ===========================================================================
1834// Tests
1835// ===========================================================================
1836
1837#[cfg(test)]
1838mod tests {
1839    use super::*;
1840    use approx::assert_abs_diff_eq;
1841    use ndarray::array;
1842
1843    // -----------------------------------------------------------------------
1844    // PLSSVD tests
1845    // -----------------------------------------------------------------------
1846
1847    #[test]
1848    fn test_plssvd_basic_fit_transform() {
1849        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1850        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
1851        let svd = PLSSVD::<f64>::new(1);
1852        let fitted = svd.fit(&x, &y).unwrap();
1853        let scores = fitted.transform(&x).unwrap();
1854        assert_eq!(scores.dim(), (5, 1));
1855    }
1856
1857    #[test]
1858    fn test_plssvd_two_components() {
1859        let x = array![
1860            [1.0, 2.0, 3.0],
1861            [4.0, 5.0, 6.0],
1862            [7.0, 8.0, 9.0],
1863            [10.0, 11.0, 12.0],
1864        ];
1865        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1866        let svd = PLSSVD::<f64>::new(2);
1867        let fitted = svd.fit(&x, &y).unwrap();
1868        let scores = fitted.transform(&x).unwrap();
1869        assert_eq!(scores.dim(), (4, 2));
1870    }
1871
1872    #[test]
1873    fn test_plssvd_transform_y() {
1874        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1875        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1876        let svd = PLSSVD::<f64>::new(1);
1877        let fitted = svd.fit(&x, &y).unwrap();
1878        let y_scores = fitted.transform_y(&y).unwrap();
1879        assert_eq!(y_scores.ncols(), 1);
1880    }
1881
1882    #[test]
1883    fn test_plssvd_no_scale() {
1884        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1885        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
1886        let svd = PLSSVD::<f64>::new(1).with_scale(false);
1887        let fitted = svd.fit(&x, &y).unwrap();
1888        let scores = fitted.transform(&x).unwrap();
1889        assert_eq!(scores.ncols(), 1);
1890    }
1891
1892    #[test]
1893    fn test_plssvd_x_weights_shape() {
1894        let x = array![
1895            [1.0, 2.0, 3.0],
1896            [4.0, 5.0, 6.0],
1897            [7.0, 8.0, 9.0],
1898            [10.0, 11.0, 12.0],
1899        ];
1900        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
1901        let svd = PLSSVD::<f64>::new(2);
1902        let fitted = svd.fit(&x, &y).unwrap();
1903        assert_eq!(fitted.x_weights().dim(), (3, 2));
1904        assert_eq!(fitted.y_weights().dim(), (2, 2));
1905    }
1906
1907    #[test]
1908    fn test_plssvd_invalid_zero_components() {
1909        let x = array![[1.0, 2.0], [3.0, 4.0]];
1910        let y = array![[1.0], [2.0]];
1911        let svd = PLSSVD::<f64>::new(0);
1912        assert!(svd.fit(&x, &y).is_err());
1913    }
1914
1915    #[test]
1916    fn test_plssvd_too_many_components() {
1917        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1918        let y = array![[1.0], [2.0], [3.0]];
1919        // min(2, 1) = 1, asking for 2 is too many.
1920        let svd = PLSSVD::<f64>::new(2);
1921        assert!(svd.fit(&x, &y).is_err());
1922    }
1923
1924    #[test]
1925    fn test_plssvd_row_mismatch() {
1926        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1927        let y = array![[1.0], [2.0]];
1928        let svd = PLSSVD::<f64>::new(1);
1929        assert!(svd.fit(&x, &y).is_err());
1930    }
1931
1932    #[test]
1933    fn test_plssvd_insufficient_samples() {
1934        let x = array![[1.0, 2.0]];
1935        let y = array![[1.0]];
1936        let svd = PLSSVD::<f64>::new(1);
1937        assert!(svd.fit(&x, &y).is_err());
1938    }
1939
1940    #[test]
1941    fn test_plssvd_transform_shape_mismatch() {
1942        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1943        let y = array![[1.0], [2.0], [3.0]];
1944        let svd = PLSSVD::<f64>::new(1);
1945        let fitted = svd.fit(&x, &y).unwrap();
1946        let x_bad = array![[1.0, 2.0, 3.0]];
1947        assert!(fitted.transform(&x_bad).is_err());
1948    }
1949
1950    #[test]
1951    fn test_plssvd_n_components_getter() {
1952        let svd = PLSSVD::<f64>::new(3);
1953        assert_eq!(svd.n_components(), 3);
1954    }
1955
1956    #[test]
1957    fn test_plssvd_f32() {
1958        let x: Array2<f32> = array![
1959            [1.0f32, 2.0],
1960            [3.0, 4.0],
1961            [5.0, 6.0],
1962            [7.0, 8.0],
1963            [9.0, 10.0],
1964        ];
1965        let y: Array2<f32> = array![[1.0f32], [2.0], [3.0], [4.0], [5.0]];
1966        let svd = PLSSVD::<f32>::new(1);
1967        let fitted = svd.fit(&x, &y).unwrap();
1968        let scores = fitted.transform(&x).unwrap();
1969        assert_eq!(scores.ncols(), 1);
1970    }
1971
1972    // -----------------------------------------------------------------------
1973    // PLSRegression tests
1974    // -----------------------------------------------------------------------
1975
1976    #[test]
1977    fn test_plsregression_basic_fit_predict() {
1978        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1979        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
1980        let pls = PLSRegression::<f64>::new(1);
1981        let fitted = pls.fit(&x, &y).unwrap();
1982        let y_pred = fitted.predict(&x).unwrap();
1983        assert_eq!(y_pred.dim(), (5, 1));
1984    }
1985
1986    #[test]
1987    fn test_plsregression_prediction_quality() {
1988        // Perfect linear relationship: Y = X[:,0] + X[:,1]
1989        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
1990        let y = array![[3.0], [7.0], [11.0], [15.0], [19.0]];
1991        let pls = PLSRegression::<f64>::new(1);
1992        let fitted = pls.fit(&x, &y).unwrap();
1993        let y_pred = fitted.predict(&x).unwrap();
1994
1995        // With a perfect linear relationship and 1 component, prediction
1996        // should be very close.
1997        for (pred, actual) in y_pred.column(0).iter().zip(y.column(0).iter()) {
1998            assert_abs_diff_eq!(pred, actual, epsilon = 1e-6);
1999        }
2000    }
2001
2002    #[test]
2003    fn test_plsregression_multi_target() {
2004        let x = array![
2005            [1.0, 2.0, 3.0],
2006            [4.0, 5.0, 6.0],
2007            [7.0, 8.0, 9.0],
2008            [10.0, 11.0, 12.0],
2009            [13.0, 14.0, 15.0],
2010        ];
2011        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2012        let pls = PLSRegression::<f64>::new(2);
2013        let fitted = pls.fit(&x, &y).unwrap();
2014        let y_pred = fitted.predict(&x).unwrap();
2015        assert_eq!(y_pred.dim(), (5, 2));
2016    }
2017
2018    #[test]
2019    fn test_plsregression_transform() {
2020        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2021        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2022        let pls = PLSRegression::<f64>::new(1);
2023        let fitted = pls.fit(&x, &y).unwrap();
2024        let scores = fitted.transform(&x).unwrap();
2025        assert_eq!(scores.dim(), (5, 1));
2026    }
2027
2028    #[test]
2029    fn test_plsregression_coefficients_shape() {
2030        let x = array![
2031            [1.0, 2.0, 3.0],
2032            [4.0, 5.0, 6.0],
2033            [7.0, 8.0, 9.0],
2034            [10.0, 11.0, 12.0],
2035        ];
2036        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
2037        let pls = PLSRegression::<f64>::new(2);
2038        let fitted = pls.fit(&x, &y).unwrap();
2039        // B shape: (n_features_x, n_features_y)
2040        assert_eq!(fitted.coefficients().dim(), (3, 2));
2041    }
2042
2043    #[test]
2044    fn test_plsregression_no_scale() {
2045        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2046        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2047        let pls = PLSRegression::<f64>::new(1).with_scale(false);
2048        let fitted = pls.fit(&x, &y).unwrap();
2049        let y_pred = fitted.predict(&x).unwrap();
2050        assert_eq!(y_pred.dim(), (5, 1));
2051    }
2052
2053    #[test]
2054    fn test_plsregression_builder() {
2055        let pls = PLSRegression::<f64>::new(2)
2056            .with_max_iter(1000)
2057            .with_tol(1e-8)
2058            .with_scale(false);
2059        assert_eq!(pls.n_components(), 2);
2060    }
2061
2062    #[test]
2063    fn test_plsregression_invalid_zero_components() {
2064        let x = array![[1.0, 2.0], [3.0, 4.0]];
2065        let y = array![[1.0], [2.0]];
2066        let pls = PLSRegression::<f64>::new(0);
2067        assert!(pls.fit(&x, &y).is_err());
2068    }
2069
2070    #[test]
2071    fn test_plsregression_too_many_components() {
2072        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2073        let y = array![[1.0], [2.0], [3.0]];
2074        // min(2, 1, 3) = 1, asking for 2 is too many.
2075        let pls = PLSRegression::<f64>::new(2);
2076        assert!(pls.fit(&x, &y).is_err());
2077    }
2078
2079    #[test]
2080    fn test_plsregression_row_mismatch() {
2081        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2082        let y = array![[1.0], [2.0]];
2083        let pls = PLSRegression::<f64>::new(1);
2084        assert!(pls.fit(&x, &y).is_err());
2085    }
2086
2087    #[test]
2088    fn test_plsregression_insufficient_samples() {
2089        let x = array![[1.0, 2.0]];
2090        let y = array![[1.0]];
2091        let pls = PLSRegression::<f64>::new(1);
2092        assert!(pls.fit(&x, &y).is_err());
2093    }
2094
2095    #[test]
2096    fn test_plsregression_predict_shape_mismatch() {
2097        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2098        let y = array![[1.0], [2.0], [3.0]];
2099        let pls = PLSRegression::<f64>::new(1);
2100        let fitted = pls.fit(&x, &y).unwrap();
2101        let x_bad = array![[1.0, 2.0, 3.0]];
2102        assert!(fitted.predict(&x_bad).is_err());
2103    }
2104
2105    #[test]
2106    fn test_plsregression_transform_shape_mismatch() {
2107        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2108        let y = array![[1.0], [2.0], [3.0]];
2109        let pls = PLSRegression::<f64>::new(1);
2110        let fitted = pls.fit(&x, &y).unwrap();
2111        let x_bad = array![[1.0, 2.0, 3.0]];
2112        assert!(fitted.transform(&x_bad).is_err());
2113    }
2114
2115    #[test]
2116    fn test_plsregression_x_scores_shape() {
2117        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
2118        let y = array![[1.0], [2.0], [3.0], [4.0]];
2119        let pls = PLSRegression::<f64>::new(1);
2120        let fitted = pls.fit(&x, &y).unwrap();
2121        assert_eq!(fitted.x_scores().dim(), (4, 1));
2122        assert_eq!(fitted.y_scores().dim(), (4, 1));
2123        assert_eq!(fitted.n_iter().len(), 1);
2124    }
2125
2126    #[test]
2127    fn test_plsregression_f32() {
2128        let x: Array2<f32> = array![
2129            [1.0f32, 2.0],
2130            [3.0, 4.0],
2131            [5.0, 6.0],
2132            [7.0, 8.0],
2133            [9.0, 10.0],
2134        ];
2135        let y: Array2<f32> = array![[1.0f32], [2.0], [3.0], [4.0], [5.0]];
2136        let pls = PLSRegression::<f32>::new(1);
2137        let fitted = pls.fit(&x, &y).unwrap();
2138        let y_pred = fitted.predict(&x).unwrap();
2139        assert_eq!(y_pred.ncols(), 1);
2140    }
2141
2142    // -----------------------------------------------------------------------
2143    // PLSCanonical tests
2144    // -----------------------------------------------------------------------
2145
2146    #[test]
2147    fn test_plscanonical_basic_fit_transform() {
2148        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2149        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2150        let pls = PLSCanonical::<f64>::new(2);
2151        let fitted = pls.fit(&x, &y).unwrap();
2152        let scores = fitted.transform(&x).unwrap();
2153        assert_eq!(scores.dim(), (5, 2));
2154    }
2155
2156    #[test]
2157    fn test_plscanonical_single_component() {
2158        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2159        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2160        let pls = PLSCanonical::<f64>::new(1);
2161        let fitted = pls.fit(&x, &y).unwrap();
2162        let scores = fitted.transform(&x).unwrap();
2163        assert_eq!(scores.ncols(), 1);
2164    }
2165
2166    #[test]
2167    fn test_plscanonical_scores_shape() {
2168        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
2169        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2170        let pls = PLSCanonical::<f64>::new(2);
2171        let fitted = pls.fit(&x, &y).unwrap();
2172        assert_eq!(fitted.x_scores().dim(), (3, 2));
2173        assert_eq!(fitted.y_scores().dim(), (3, 2));
2174        assert_eq!(fitted.x_weights().dim(), (3, 2));
2175        assert_eq!(fitted.x_loadings().dim(), (3, 2));
2176        assert_eq!(fitted.y_loadings().dim(), (2, 2));
2177    }
2178
2179    #[test]
2180    fn test_plscanonical_transform_y() {
2181        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
2182        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
2183        let pls = PLSCanonical::<f64>::new(1);
2184        let fitted = pls.fit(&x, &y).unwrap();
2185        let y_scores = fitted.transform_y(&y).unwrap();
2186        assert_eq!(y_scores.ncols(), 1);
2187    }
2188
2189    #[test]
2190    fn test_plscanonical_builder() {
2191        let pls = PLSCanonical::<f64>::new(2)
2192            .with_max_iter(1000)
2193            .with_tol(1e-8)
2194            .with_scale(false);
2195        assert_eq!(pls.n_components(), 2);
2196    }
2197
2198    #[test]
2199    fn test_plscanonical_invalid_zero_components() {
2200        let x = array![[1.0, 2.0], [3.0, 4.0]];
2201        let y = array![[1.0, 0.5], [2.0, 1.0]];
2202        let pls = PLSCanonical::<f64>::new(0);
2203        assert!(pls.fit(&x, &y).is_err());
2204    }
2205
2206    #[test]
2207    fn test_plscanonical_too_many_components() {
2208        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2209        let y = array![[1.0], [2.0], [3.0]];
2210        let pls = PLSCanonical::<f64>::new(2);
2211        assert!(pls.fit(&x, &y).is_err());
2212    }
2213
2214    #[test]
2215    fn test_plscanonical_row_mismatch() {
2216        let x = array![[1.0, 2.0], [3.0, 4.0]];
2217        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2218        let pls = PLSCanonical::<f64>::new(1);
2219        assert!(pls.fit(&x, &y).is_err());
2220    }
2221
2222    #[test]
2223    fn test_plscanonical_transform_shape_mismatch() {
2224        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2225        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2226        let pls = PLSCanonical::<f64>::new(1);
2227        let fitted = pls.fit(&x, &y).unwrap();
2228        let x_bad = array![[1.0, 2.0, 3.0]];
2229        assert!(fitted.transform(&x_bad).is_err());
2230    }
2231
2232    #[test]
2233    fn test_plscanonical_f32() {
2234        let x: Array2<f32> = array![
2235            [1.0f32, 2.0],
2236            [3.0, 4.0],
2237            [5.0, 6.0],
2238            [7.0, 8.0],
2239            [9.0, 10.0],
2240        ];
2241        let y: Array2<f32> = array![
2242            [1.0f32, 0.5],
2243            [2.0, 1.0],
2244            [3.0, 1.5],
2245            [4.0, 2.0],
2246            [5.0, 2.5],
2247        ];
2248        let pls = PLSCanonical::<f32>::new(1);
2249        let fitted = pls.fit(&x, &y).unwrap();
2250        let scores = fitted.transform(&x).unwrap();
2251        assert_eq!(scores.ncols(), 1);
2252    }
2253
2254    // -----------------------------------------------------------------------
2255    // CCA tests
2256    // -----------------------------------------------------------------------
2257
2258    #[test]
2259    fn test_cca_basic_fit_transform() {
2260        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2261        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2262        let cca = CCA::<f64>::new(2);
2263        let fitted = cca.fit(&x, &y).unwrap();
2264        let scores = fitted.transform(&x).unwrap();
2265        assert_eq!(scores.dim(), (5, 2));
2266    }
2267
2268    #[test]
2269    fn test_cca_single_component() {
2270        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
2271        let y = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
2272        let cca = CCA::<f64>::new(1);
2273        let fitted = cca.fit(&x, &y).unwrap();
2274        let scores = fitted.transform(&x).unwrap();
2275        assert_eq!(scores.ncols(), 1);
2276    }
2277
2278    #[test]
2279    fn test_cca_scores_shape() {
2280        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
2281        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2282        let cca = CCA::<f64>::new(2);
2283        let fitted = cca.fit(&x, &y).unwrap();
2284        assert_eq!(fitted.x_scores().dim(), (3, 2));
2285        assert_eq!(fitted.y_scores().dim(), (3, 2));
2286        assert_eq!(fitted.x_weights().dim(), (3, 2));
2287        assert_eq!(fitted.x_loadings().dim(), (3, 2));
2288        assert_eq!(fitted.y_loadings().dim(), (2, 2));
2289    }
2290
2291    #[test]
2292    fn test_cca_transform_y() {
2293        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
2294        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
2295        let cca = CCA::<f64>::new(1);
2296        let fitted = cca.fit(&x, &y).unwrap();
2297        let y_scores = fitted.transform_y(&y).unwrap();
2298        assert_eq!(y_scores.ncols(), 1);
2299    }
2300
2301    #[test]
2302    fn test_cca_builder() {
2303        let cca = CCA::<f64>::new(2)
2304            .with_max_iter(1000)
2305            .with_tol(1e-8)
2306            .with_scale(false);
2307        assert_eq!(cca.n_components(), 2);
2308    }
2309
2310    #[test]
2311    fn test_cca_invalid_zero_components() {
2312        let x = array![[1.0, 2.0], [3.0, 4.0]];
2313        let y = array![[1.0, 0.5], [2.0, 1.0]];
2314        let cca = CCA::<f64>::new(0);
2315        assert!(cca.fit(&x, &y).is_err());
2316    }
2317
2318    #[test]
2319    fn test_cca_too_many_components() {
2320        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2321        let y = array![[1.0], [2.0], [3.0]];
2322        let cca = CCA::<f64>::new(2);
2323        assert!(cca.fit(&x, &y).is_err());
2324    }
2325
2326    #[test]
2327    fn test_cca_row_mismatch() {
2328        let x = array![[1.0, 2.0], [3.0, 4.0]];
2329        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2330        let cca = CCA::<f64>::new(1);
2331        assert!(cca.fit(&x, &y).is_err());
2332    }
2333
2334    #[test]
2335    fn test_cca_transform_shape_mismatch() {
2336        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2337        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5]];
2338        let cca = CCA::<f64>::new(1);
2339        let fitted = cca.fit(&x, &y).unwrap();
2340        let x_bad = array![[1.0, 2.0, 3.0]];
2341        assert!(fitted.transform(&x_bad).is_err());
2342    }
2343
2344    #[test]
2345    fn test_cca_f32() {
2346        let x: Array2<f32> = array![
2347            [1.0f32, 2.0],
2348            [3.0, 4.0],
2349            [5.0, 6.0],
2350            [7.0, 8.0],
2351            [9.0, 10.0],
2352        ];
2353        let y: Array2<f32> = array![
2354            [1.0f32, 0.5],
2355            [2.0, 1.0],
2356            [3.0, 1.5],
2357            [4.0, 2.0],
2358            [5.0, 2.5],
2359        ];
2360        let cca = CCA::<f32>::new(1);
2361        let fitted = cca.fit(&x, &y).unwrap();
2362        let scores = fitted.transform(&x).unwrap();
2363        assert_eq!(scores.ncols(), 1);
2364    }
2365
2366    // -----------------------------------------------------------------------
2367    // Cross-cutting tests
2368    // -----------------------------------------------------------------------
2369
2370    #[test]
2371    fn test_pls_regression_and_canonical_give_different_scores() {
2372        let x = array![
2373            [1.0, 2.0, 0.5],
2374            [3.0, 1.0, 2.5],
2375            [5.0, 6.0, 1.0],
2376            [7.0, 3.0, 4.5],
2377            [9.0, 10.0, 2.0],
2378        ];
2379        let y = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0], [5.0, 2.5],];
2380
2381        let pls_reg = PLSRegression::<f64>::new(2);
2382        let fitted_reg = pls_reg.fit(&x, &y).unwrap();
2383        let scores_reg = fitted_reg.transform(&x).unwrap();
2384
2385        let pls_can = PLSCanonical::<f64>::new(2);
2386        let fitted_can = pls_can.fit(&x, &y).unwrap();
2387        let scores_can = fitted_can.transform(&x).unwrap();
2388
2389        // They should produce different results (different deflation).
2390        let diff: f64 = scores_reg
2391            .iter()
2392            .zip(scores_can.iter())
2393            .map(|(a, b)| (a - b).abs())
2394            .sum();
2395        // The scores should not be identical (unless the data is degenerate).
2396        // We just check they are both valid matrices.
2397        assert_eq!(scores_reg.dim(), scores_can.dim());
2398        // In practice diff may be zero for some data; just check no NaN.
2399        assert!(diff.is_finite());
2400    }
2401
2402    #[test]
2403    fn test_centre_scale_helper() {
2404        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2405        let (xc, mean, std_dev) = centre_scale(&x, true);
2406        assert_abs_diff_eq!(mean[0], 3.0, epsilon = 1e-10);
2407        assert_abs_diff_eq!(mean[1], 4.0, epsilon = 1e-10);
2408        assert!(std_dev.is_some());
2409
2410        // Centred data should have zero mean.
2411        let col_mean_0: f64 = xc.column(0).iter().sum::<f64>() / 3.0;
2412        assert_abs_diff_eq!(col_mean_0, 0.0, epsilon = 1e-10);
2413    }
2414
2415    #[test]
2416    fn test_centre_scale_no_scale() {
2417        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2418        let (_xc, _mean, std_dev) = centre_scale(&x, false);
2419        assert!(std_dev.is_none());
2420    }
2421
2422    #[test]
2423    fn test_invert_square_identity() {
2424        let eye = Array2::<f64>::from_shape_fn((3, 3), |(i, j)| if i == j { 1.0 } else { 0.0 });
2425        let inv = invert_square(&eye).unwrap();
2426        for i in 0..3 {
2427            for j in 0..3 {
2428                let expected = if i == j { 1.0 } else { 0.0 };
2429                assert_abs_diff_eq!(inv[[i, j]], expected, epsilon = 1e-10);
2430            }
2431        }
2432    }
2433
2434    #[test]
2435    fn test_invert_square_2x2() {
2436        let a = array![[4.0, 7.0], [2.0, 6.0]];
2437        let inv = invert_square(&a).unwrap();
2438        // A * A^{-1} should be identity.
2439        let prod = a.dot(&inv);
2440        for i in 0..2 {
2441            for j in 0..2 {
2442                let expected = if i == j { 1.0 } else { 0.0 };
2443                assert_abs_diff_eq!(prod[[i, j]], expected, epsilon = 1e-10);
2444            }
2445        }
2446    }
2447}