anofox_ml_preprocessing/
cca.rs1use anofox_ml_core::{Result, RustMlError};
15use faer::linalg::solvers::Svd;
16use faer::Mat;
17use ndarray::{Array1, Array2};
18
19pub struct Cca {
20 pub n_components: usize,
21}
22
23impl Cca {
24 pub fn new(n_components: usize) -> Self {
25 Self { n_components }
26 }
27
28 pub fn fit(&self, x: &Array2<f64>, y: &Array2<f64>) -> Result<FittedCca> {
29 let n = x.nrows();
30 let dx = x.ncols();
31 let dy = y.ncols();
32 if y.nrows() != n {
33 return Err(RustMlError::ShapeMismatch(format!(
34 "X has {} rows but Y has {}",
35 n,
36 y.nrows()
37 )));
38 }
39 let k = self.n_components.min(dx).min(dy);
40 if k == 0 {
41 return Err(RustMlError::InvalidParameter("n_components >= 1".into()));
42 }
43 let n_f = n as f64;
44 let mut x_mean = Array1::<f64>::zeros(dx);
45 for j in 0..dx {
46 x_mean[j] = x.column(j).sum() / n_f;
47 }
48 let mut y_mean = Array1::<f64>::zeros(dy);
49 for j in 0..dy {
50 y_mean[j] = y.column(j).sum() / n_f;
51 }
52
53 let mut xc = x.clone();
54 let mut yc = y.clone();
55 for j in 0..dx {
56 for i in 0..n {
57 xc[[i, j]] -= x_mean[j];
58 }
59 }
60 for j in 0..dy {
61 for i in 0..n {
62 yc[[i, j]] -= y_mean[j];
63 }
64 }
65
66 let scale = (n as f64 - 1.0).sqrt();
69 let kx = whitening(&xc, scale)?;
70 let ky = whitening(&yc, scale)?;
71 let x_white = xc.dot(&kx);
72 let y_white = yc.dot(&ky);
73 let c = x_white.t().dot(&y_white) / (n_f - 1.0).max(1.0);
75 let nx = c.nrows();
76 let ny = c.ncols();
77 let cm = Mat::<f64>::from_fn(nx, ny, |i, j| c[[i, j]]);
78 let svd = Svd::new(cm.as_ref())
79 .map_err(|e| RustMlError::InvalidParameter(format!("SVD failed: {e:?}")))?;
80 let u = svd.U();
81 let s = svd.S();
82 let v = svd.V();
83 let k_real = k.min(nx).min(ny);
85 let mut u_top = Array2::<f64>::zeros((nx, k_real));
86 let mut v_top = Array2::<f64>::zeros((ny, k_real));
87 let mut corrs = Array1::<f64>::zeros(k_real);
88 for c_i in 0..k_real {
89 corrs[c_i] = s.column_vector()[c_i];
90 for i in 0..nx {
91 u_top[[i, c_i]] = u[(i, c_i)];
92 }
93 for i in 0..ny {
94 v_top[[i, c_i]] = v[(i, c_i)];
95 }
96 }
97 let x_weights = kx.dot(&u_top);
98 let y_weights = ky.dot(&v_top);
99 Ok(FittedCca {
100 x_mean,
101 y_mean,
102 x_weights,
103 y_weights,
104 canonical_correlations: corrs,
105 })
106 }
107}
108
109fn whitening(xc: &Array2<f64>, scale: f64) -> Result<Array2<f64>> {
111 let n = xc.nrows();
112 let d = xc.ncols();
113 let m = Mat::<f64>::from_fn(n, d, |i, j| xc[[i, j]]);
114 let svd = Svd::new(m.as_ref())
115 .map_err(|e| RustMlError::InvalidParameter(format!("SVD failed: {e:?}")))?;
116 let s = svd.S();
117 let v = svd.V();
118 let r = s.column_vector().nrows();
119 let mut w = Array2::<f64>::zeros((d, r));
120 for c in 0..r {
121 let sigma = s.column_vector()[c].max(1e-12);
122 for j in 0..d {
123 w[[j, c]] = v[(j, c)] * scale / sigma;
124 }
125 }
126 Ok(w)
127}
128
129#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
130pub struct FittedCca {
131 pub x_mean: Array1<f64>,
132 pub y_mean: Array1<f64>,
133 pub x_weights: Array2<f64>,
135 pub y_weights: Array2<f64>,
137 pub canonical_correlations: Array1<f64>,
139}
140
141impl FittedCca {
142 pub fn transform_x(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
144 if x.ncols() != self.x_mean.len() {
145 return Err(RustMlError::ShapeMismatch(format!(
146 "expected {} X-features, got {}",
147 self.x_mean.len(),
148 x.ncols()
149 )));
150 }
151 let mut xc = x.clone();
152 for j in 0..x.ncols() {
153 for i in 0..x.nrows() {
154 xc[[i, j]] -= self.x_mean[j];
155 }
156 }
157 Ok(xc.dot(&self.x_weights))
158 }
159
160 pub fn transform_y(&self, y: &Array2<f64>) -> Result<Array2<f64>> {
162 if y.ncols() != self.y_mean.len() {
163 return Err(RustMlError::ShapeMismatch(format!(
164 "expected {} Y-features, got {}",
165 self.y_mean.len(),
166 y.ncols()
167 )));
168 }
169 let mut yc = y.clone();
170 for j in 0..y.ncols() {
171 for i in 0..y.nrows() {
172 yc[[i, j]] -= self.y_mean[j];
173 }
174 }
175 Ok(yc.dot(&self.y_weights))
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use ndarray::array;
183
184 #[test]
185 fn test_cca_finds_high_correlation() {
186 let n = 100;
189 let mut x = Array2::<f64>::zeros((n, 3));
190 let mut y = Array2::<f64>::zeros((n, 2));
191 for i in 0..n {
192 let t = (i as f64) * 0.1;
193 x[[i, 0]] = t.sin();
194 x[[i, 1]] = (i as f64) - 50.0;
195 x[[i, 2]] = ((i * 7) % 13) as f64;
196 y[[i, 0]] = t.sin() + 0.01;
197 y[[i, 1]] = -2.0 * t.sin();
198 }
199 let fitted = Cca::new(1).fit(&x, &y).unwrap();
200 assert!(
201 fitted.canonical_correlations[0] > 0.9,
202 "first canonical correlation = {}",
203 fitted.canonical_correlations[0]
204 );
205 let _ = array![1.0_f64];
206 }
207
208 #[test]
209 fn test_cca_transform_shapes() {
210 let x = array![[1.0_f64, 0.0], [0.0, 1.0], [2.0, 1.0], [1.0, 2.0]];
211 let y = array![[1.0_f64, 0.0], [0.5, 0.5], [1.5, 0.5], [1.0, 1.0]];
212 let fitted = Cca::new(1).fit(&x, &y).unwrap();
213 let xt = fitted.transform_x(&x).unwrap();
214 let yt = fitted.transform_y(&y).unwrap();
215 assert_eq!(xt.shape(), &[4, 1]);
216 assert_eq!(yt.shape(), &[4, 1]);
217 }
218}