Skip to main content

anofox_ml_preprocessing/
cca.rs

1//! Canonical Correlation Analysis.
2//!
3//! Mirrors `sklearn.cross_decomposition.CCA` (and `PLSCanonical` for the
4//! univariate case). Given paired (X, Y), CCA finds linear combinations
5//! `Xa` and `Yb` that are maximally correlated.
6//!
7//! Closed form via SVD:
8//! 1. Centre and whiten both X and Y (X_white = X · K_x, Y_white = Y · K_y).
9//! 2. Compute the cross-covariance `C = X_whiteᵀ · Y_white / (n - 1)`.
10//! 3. SVD `C = U Σ Vᵀ`. The first `k` columns of `K_x · U` and `K_y · V`
11//!    are the loadings `x_weights_` and `y_weights_` (sklearn's naming).
12//! 4. `Σ_ii` are the canonical correlations.
13
14use anofox_ml_core::{Result, RustMlError};
15use faer::linalg::solvers::Svd;
16use faer::Mat;
17use ndarray::{Array1, Array2};
18
19pub struct Cca {
20    pub n_components: usize,
21}
22
23impl Cca {
24    pub fn new(n_components: usize) -> Self {
25        Self { n_components }
26    }
27
28    pub fn fit(&self, x: &Array2<f64>, y: &Array2<f64>) -> Result<FittedCca> {
29        let n = x.nrows();
30        let dx = x.ncols();
31        let dy = y.ncols();
32        if y.nrows() != n {
33            return Err(RustMlError::ShapeMismatch(format!(
34                "X has {} rows but Y has {}",
35                n,
36                y.nrows()
37            )));
38        }
39        let k = self.n_components.min(dx).min(dy);
40        if k == 0 {
41            return Err(RustMlError::InvalidParameter("n_components >= 1".into()));
42        }
43        let n_f = n as f64;
44        let mut x_mean = Array1::<f64>::zeros(dx);
45        for j in 0..dx {
46            x_mean[j] = x.column(j).sum() / n_f;
47        }
48        let mut y_mean = Array1::<f64>::zeros(dy);
49        for j in 0..dy {
50            y_mean[j] = y.column(j).sum() / n_f;
51        }
52
53        let mut xc = x.clone();
54        let mut yc = y.clone();
55        for j in 0..dx {
56            for i in 0..n {
57                xc[[i, j]] -= x_mean[j];
58            }
59        }
60        for j in 0..dy {
61            for i in 0..n {
62                yc[[i, j]] -= y_mean[j];
63            }
64        }
65
66        // Whitening for X via SVD: X_centred = U_x Σ_x V_xᵀ.
67        // K_x = V_x Σ_x⁻¹ √(n-1).
68        let scale = (n as f64 - 1.0).sqrt();
69        let kx = whitening(&xc, scale)?;
70        let ky = whitening(&yc, scale)?;
71        let x_white = xc.dot(&kx);
72        let y_white = yc.dot(&ky);
73        // C = X_white' Y_white / (n - 1).
74        let c = x_white.t().dot(&y_white) / (n_f - 1.0).max(1.0);
75        let nx = c.nrows();
76        let ny = c.ncols();
77        let cm = Mat::<f64>::from_fn(nx, ny, |i, j| c[[i, j]]);
78        let svd = Svd::new(cm.as_ref())
79            .map_err(|e| RustMlError::InvalidParameter(format!("SVD failed: {e:?}")))?;
80        let u = svd.U();
81        let s = svd.S();
82        let v = svd.V();
83        // x_weights = K_x · U[:, :k]; y_weights = K_y · V[:, :k].
84        let k_real = k.min(nx).min(ny);
85        let mut u_top = Array2::<f64>::zeros((nx, k_real));
86        let mut v_top = Array2::<f64>::zeros((ny, k_real));
87        let mut corrs = Array1::<f64>::zeros(k_real);
88        for c_i in 0..k_real {
89            corrs[c_i] = s.column_vector()[c_i];
90            for i in 0..nx {
91                u_top[[i, c_i]] = u[(i, c_i)];
92            }
93            for i in 0..ny {
94                v_top[[i, c_i]] = v[(i, c_i)];
95            }
96        }
97        let x_weights = kx.dot(&u_top);
98        let y_weights = ky.dot(&v_top);
99        Ok(FittedCca {
100            x_mean,
101            y_mean,
102            x_weights,
103            y_weights,
104            canonical_correlations: corrs,
105        })
106    }
107}
108
109/// Returns whitening matrix `W` such that `X_centred · W` has identity covariance.
110fn whitening(xc: &Array2<f64>, scale: f64) -> Result<Array2<f64>> {
111    let n = xc.nrows();
112    let d = xc.ncols();
113    let m = Mat::<f64>::from_fn(n, d, |i, j| xc[[i, j]]);
114    let svd = Svd::new(m.as_ref())
115        .map_err(|e| RustMlError::InvalidParameter(format!("SVD failed: {e:?}")))?;
116    let s = svd.S();
117    let v = svd.V();
118    let r = s.column_vector().nrows();
119    let mut w = Array2::<f64>::zeros((d, r));
120    for c in 0..r {
121        let sigma = s.column_vector()[c].max(1e-12);
122        for j in 0..d {
123            w[[j, c]] = v[(j, c)] * scale / sigma;
124        }
125    }
126    Ok(w)
127}
128
129#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
130pub struct FittedCca {
131    pub x_mean: Array1<f64>,
132    pub y_mean: Array1<f64>,
133    /// X loadings, shape (n_features_x, n_components).
134    pub x_weights: Array2<f64>,
135    /// Y loadings, shape (n_features_y, n_components).
136    pub y_weights: Array2<f64>,
137    /// Canonical correlations (diagonal of the cross-covariance SVD).
138    pub canonical_correlations: Array1<f64>,
139}
140
141impl FittedCca {
142    /// Project X into canonical x-space: (X − x̄) · x_weights.
143    pub fn transform_x(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
144        if x.ncols() != self.x_mean.len() {
145            return Err(RustMlError::ShapeMismatch(format!(
146                "expected {} X-features, got {}",
147                self.x_mean.len(),
148                x.ncols()
149            )));
150        }
151        let mut xc = x.clone();
152        for j in 0..x.ncols() {
153            for i in 0..x.nrows() {
154                xc[[i, j]] -= self.x_mean[j];
155            }
156        }
157        Ok(xc.dot(&self.x_weights))
158    }
159
160    /// Project Y into canonical y-space: (Y − ȳ) · y_weights.
161    pub fn transform_y(&self, y: &Array2<f64>) -> Result<Array2<f64>> {
162        if y.ncols() != self.y_mean.len() {
163            return Err(RustMlError::ShapeMismatch(format!(
164                "expected {} Y-features, got {}",
165                self.y_mean.len(),
166                y.ncols()
167            )));
168        }
169        let mut yc = y.clone();
170        for j in 0..y.ncols() {
171            for i in 0..y.nrows() {
172                yc[[i, j]] -= self.y_mean[j];
173            }
174        }
175        Ok(yc.dot(&self.y_weights))
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use ndarray::array;
183
184    #[test]
185    fn test_cca_finds_high_correlation() {
186        // Construct X, Y where Y is essentially a noisy linear function of
187        // X[:, 0]. CCA should find the first canonical correlation near 1.
188        let n = 100;
189        let mut x = Array2::<f64>::zeros((n, 3));
190        let mut y = Array2::<f64>::zeros((n, 2));
191        for i in 0..n {
192            let t = (i as f64) * 0.1;
193            x[[i, 0]] = t.sin();
194            x[[i, 1]] = (i as f64) - 50.0;
195            x[[i, 2]] = ((i * 7) % 13) as f64;
196            y[[i, 0]] = t.sin() + 0.01;
197            y[[i, 1]] = -2.0 * t.sin();
198        }
199        let fitted = Cca::new(1).fit(&x, &y).unwrap();
200        assert!(
201            fitted.canonical_correlations[0] > 0.9,
202            "first canonical correlation = {}",
203            fitted.canonical_correlations[0]
204        );
205        let _ = array![1.0_f64];
206    }
207
208    #[test]
209    fn test_cca_transform_shapes() {
210        let x = array![[1.0_f64, 0.0], [0.0, 1.0], [2.0, 1.0], [1.0, 2.0]];
211        let y = array![[1.0_f64, 0.0], [0.5, 0.5], [1.5, 0.5], [1.0, 1.0]];
212        let fitted = Cca::new(1).fit(&x, &y).unwrap();
213        let xt = fitted.transform_x(&x).unwrap();
214        let yt = fitted.transform_y(&y).unwrap();
215        assert_eq!(xt.shape(), &[4, 1]);
216        assert_eq!(yt.shape(), &[4, 1]);
217    }
218}