1use anofox_ml_core::{FitUnsupervised, InverseTransform, Result, RustMlError, Transform};
8use faer::linalg::solvers::SelfAdjointEigen;
9use faer::{Mat, Side};
10use ndarray::{Array1, Array2};
11
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub enum KpcaKernel {
14 Linear,
15 Rbf {
16 gamma: f64,
17 },
18 Polynomial {
19 degree: usize,
20 coef0: f64,
21 gamma: f64,
22 },
23}
24
25impl KpcaKernel {
26 fn compute(&self, a: &[f64], b: &[f64]) -> f64 {
27 match self {
28 KpcaKernel::Linear => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
29 KpcaKernel::Rbf { gamma } => {
30 let sd: f64 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
31 (-gamma * sd).exp()
32 }
33 KpcaKernel::Polynomial {
34 degree,
35 coef0,
36 gamma,
37 } => {
38 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
39 (gamma * dot + coef0).powi(*degree as i32)
40 }
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
46pub struct KernelPca {
47 pub n_components: usize,
48 pub kernel: KpcaKernel,
49}
50
51impl KernelPca {
52 pub fn new(n_components: usize, kernel: KpcaKernel) -> Self {
53 Self {
54 n_components,
55 kernel,
56 }
57 }
58}
59
60#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
61pub struct FittedKernelPca {
62 pub x_train: Array2<f64>,
63 pub alphas: Array2<f64>, pub eigenvalues: Array1<f64>, pub row_means: Array1<f64>, pub global_mean: f64,
67 pub kernel: KpcaKernel,
68}
69
70fn build_kernel(x_a: &Array2<f64>, x_b: &Array2<f64>, k: &KpcaKernel) -> Array2<f64> {
71 let na = x_a.nrows();
72 let nb = x_b.nrows();
73 let mut out = Array2::<f64>::zeros((na, nb));
74 for i in 0..na {
75 let ai = x_a.row(i).to_owned();
76 for j in 0..nb {
77 let bj = x_b.row(j).to_owned();
78 out[[i, j]] = k.compute(ai.as_slice().unwrap(), bj.as_slice().unwrap());
79 }
80 }
81 out
82}
83
84impl FitUnsupervised<f64> for KernelPca {
85 type Fitted = FittedKernelPca;
86
87 fn fit(&self, x: &Array2<f64>) -> Result<Self::Fitted> {
88 let n = x.nrows();
89 if n == 0 {
90 return Err(RustMlError::EmptyInput("empty input".into()));
91 }
92 let k_target = self.n_components.min(n);
93 if k_target == 0 {
94 return Err(RustMlError::InvalidParameter(
95 "n_components must be >= 1".into(),
96 ));
97 }
98
99 let mut k = build_kernel(x, x, &self.kernel);
100 let row_means: Array1<f64> =
102 Array1::from_vec((0..n).map(|i| k.row(i).sum() / n as f64).collect());
103 let col_means: Array1<f64> = row_means.clone(); let global_mean: f64 = k.iter().copied().sum::<f64>() / (n as f64).powi(2);
105 for i in 0..n {
106 for j in 0..n {
107 k[[i, j]] += global_mean - row_means[i] - col_means[j];
108 }
109 }
110
111 let m = Mat::<f64>::from_fn(n, n, |i, j| k[[i, j]]);
113 let eig = SelfAdjointEigen::new(m.as_ref(), Side::Lower)
114 .map_err(|e| RustMlError::InvalidParameter(format!("eigen failed: {e:?}")))?;
115 let s = eig.S(); let v = eig.U();
117
118 let mut alphas = Array2::<f64>::zeros((n, k_target));
120 let mut eigenvalues = Array1::<f64>::zeros(k_target);
121 for c in 0..k_target {
122 let src = n - 1 - c; let val = s.column_vector()[src];
124 eigenvalues[c] = val;
125 for i in 0..n {
128 alphas[[i, c]] = v[(i, src)];
129 }
130 }
131
132 Ok(FittedKernelPca {
133 x_train: x.clone(),
134 alphas,
135 eigenvalues,
136 row_means,
137 global_mean,
138 kernel: self.kernel.clone(),
139 })
140 }
141}
142
143impl Transform<f64> for FittedKernelPca {
144 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
145 let n_train = self.x_train.nrows();
146 if x.ncols() != self.x_train.ncols() {
147 return Err(RustMlError::ShapeMismatch(format!(
148 "expected {} features, got {}",
149 self.x_train.ncols(),
150 x.ncols()
151 )));
152 }
153 let n_new = x.nrows();
154 let mut k_new = build_kernel(x, &self.x_train, &self.kernel);
155 let new_row_means: Array1<f64> = Array1::from_vec(
158 (0..n_new)
159 .map(|i| k_new.row(i).sum() / n_train as f64)
160 .collect(),
161 );
162 for i in 0..n_new {
163 for j in 0..n_train {
164 k_new[[i, j]] += self.global_mean - new_row_means[i] - self.row_means[j];
165 }
166 }
167 let mut out = Array2::<f64>::zeros((n_new, self.alphas.ncols()));
173 for c in 0..self.alphas.ncols() {
174 let lam = self.eigenvalues[c];
175 let sqrt_lam = lam.abs().sqrt().max(1e-12);
176 for i in 0..n_new {
177 let mut s = 0.0;
178 for j in 0..n_train {
179 s += k_new[[i, j]] * self.alphas[[j, c]];
180 }
181 out[[i, c]] = s / sqrt_lam;
182 }
183 }
184 Ok(out)
185 }
186}
187
188impl InverseTransform<f64> for FittedKernelPca {
189 fn inverse_transform(&self, t: &Array2<f64>) -> Result<Array2<f64>> {
198 let k = self.alphas.ncols();
199 if t.ncols() != k {
200 return Err(RustMlError::ShapeMismatch(format!(
201 "expected {} components, got {}",
202 k,
203 t.ncols()
204 )));
205 }
206 let n_train = self.x_train.nrows();
207 let n_new = t.nrows();
208
209 let mut coef = Array2::<f64>::zeros((n_new, n_train));
212 for i in 0..n_new {
213 for j in 0..n_train {
214 let mut s = 0.0;
215 for c in 0..k {
216 let lam_sqrt = self.eigenvalues[c].abs().sqrt().max(1e-12);
217 s += t[[i, c]] * lam_sqrt * self.alphas[[j, c]];
218 }
219 coef[[i, j]] = s;
220 }
221 }
222 Ok(coef.dot(&self.x_train))
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use ndarray::array;
233
234 #[test]
235 fn test_kernel_pca_runs_rbf() {
236 let x = array![
237 [0.0_f64, 0.0],
238 [1.0, 1.0],
239 [2.0, 4.0],
240 [3.0, 9.0],
241 [4.0, 16.0],
242 ];
243 let kpca = KernelPca::new(2, KpcaKernel::Rbf { gamma: 0.1 });
244 let fitted = kpca.fit(&x).unwrap();
245 let t = fitted.transform(&x).unwrap();
246 assert_eq!(t.shape(), &[5, 2]);
247 assert!(fitted.eigenvalues[0] >= fitted.eigenvalues[1]);
248 }
249
250 #[test]
251 fn test_kernel_pca_inverse_transform_runs() {
252 let x = array![
253 [0.0_f64, 0.0, 1.0],
254 [1.0, 1.0, 0.0],
255 [2.0, 4.0, -1.0],
256 [3.0, 9.0, 2.0],
257 [4.0, 16.0, 0.5],
258 ];
259 let fitted = KernelPca::new(2, KpcaKernel::Linear).fit(&x).unwrap();
260 let t = fitted.transform(&x).unwrap();
261 let back = fitted.inverse_transform(&t).unwrap();
262 assert_eq!(back.shape(), &[5, 3]);
263 for v in back.iter() {
264 assert!(v.is_finite());
265 }
266 }
267}