1use anofox_ml_core::{Fit, Predict, PredictProba, Result, RustMlError};
16use faer::linalg::solvers::Solve;
17use faer::{Mat, Side};
18use ndarray::{Array1, Array2};
19
20use crate::{build_gram, GpKernel};
21
22pub struct GaussianProcessClassifier {
23 pub kernel: GpKernel,
24 pub max_iter: usize,
25 pub tol: f64,
26}
27
28impl GaussianProcessClassifier {
29 pub fn new(kernel: GpKernel) -> Self {
30 Self {
31 kernel,
32 max_iter: 100,
33 tol: 1e-6,
34 }
35 }
36 pub fn with_max_iter(mut self, m: usize) -> Self {
37 self.max_iter = m;
38 self
39 }
40 pub fn with_tol(mut self, t: f64) -> Self {
41 self.tol = t;
42 self
43 }
44}
45
46pub struct FittedGaussianProcessClassifier {
47 pub x_train: Array2<f64>,
48 pub alpha: Array1<f64>,
50 pub l_lower: Mat<f64>,
52 pub w_sqrt: Array1<f64>,
54 pub kernel: GpKernel,
55 pub classes: [f64; 2],
56}
57
58fn sigmoid(z: f64) -> f64 {
59 if z >= 0.0 {
60 1.0 / (1.0 + (-z).exp())
61 } else {
62 let e = z.exp();
63 e / (1.0 + e)
64 }
65}
66
67fn clone_kernel(k: &GpKernel) -> GpKernel {
68 match k {
69 GpKernel::Rbf {
70 length_scale,
71 signal_var,
72 } => GpKernel::Rbf {
73 length_scale: *length_scale,
74 signal_var: *signal_var,
75 },
76 GpKernel::Matern {
77 length_scale,
78 signal_var,
79 nu,
80 } => GpKernel::Matern {
81 length_scale: *length_scale,
82 signal_var: *signal_var,
83 nu: *nu,
84 },
85 GpKernel::RationalQuadratic {
86 length_scale,
87 signal_var,
88 alpha,
89 } => GpKernel::RationalQuadratic {
90 length_scale: *length_scale,
91 signal_var: *signal_var,
92 alpha: *alpha,
93 },
94 GpKernel::White { noise_level } => GpKernel::White {
95 noise_level: *noise_level,
96 },
97 GpKernel::Constant { value } => GpKernel::Constant { value: *value },
98 GpKernel::Sum(a, b) => GpKernel::Sum(Box::new(clone_kernel(a)), Box::new(clone_kernel(b))),
99 GpKernel::Product(a, b) => {
100 GpKernel::Product(Box::new(clone_kernel(a)), Box::new(clone_kernel(b)))
101 }
102 }
103}
104
105impl Fit<f64> for GaussianProcessClassifier {
106 type Fitted = FittedGaussianProcessClassifier;
107
108 fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
109 let n = x.nrows();
110 if y.len() != n {
111 return Err(RustMlError::ShapeMismatch(format!(
112 "X has {} rows but y has {}",
113 n,
114 y.len()
115 )));
116 }
117 let mut classes: Vec<f64> = y.iter().copied().collect();
119 classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
120 classes.dedup();
121 if classes.len() != 2 {
122 return Err(RustMlError::InvalidParameter(format!(
123 "GPC expects 2 classes, found {}",
124 classes.len()
125 )));
126 }
127 let neg = classes[0];
128 let pos = classes[1];
129 let yb: Vec<f64> = y
131 .iter()
132 .map(|v| if *v == pos { 1.0 } else { 0.0 })
133 .collect();
134
135 let k = build_gram(x, x, &self.kernel);
136 let mut f = Array1::<f64>::zeros(n);
137
138 let mut prev_obj = f64::NEG_INFINITY;
140 let mut alpha = Array1::<f64>::zeros(n);
141 let mut l_lower = Mat::<f64>::zeros(n, n);
142 let mut w_sqrt = Array1::<f64>::zeros(n);
143
144 for _ in 0..self.max_iter {
145 let pi: Vec<f64> = f.iter().map(|&v| sigmoid(v)).collect();
147 let w: Vec<f64> = pi.iter().map(|&p| p * (1.0 - p)).collect();
148 let ws: Vec<f64> = w.iter().map(|&v| v.sqrt()).collect();
149
150 let mut b = Array2::<f64>::zeros((n, n));
152 for i in 0..n {
153 for j in 0..n {
154 b[[i, j]] = ws[i] * k[[i, j]] * ws[j];
155 }
156 b[[i, i]] += 1.0;
157 }
158 let bm = Mat::<f64>::from_fn(n, n, |i, j| b[[i, j]]);
159 let llt = faer::linalg::solvers::Llt::new(bm.as_ref(), Side::Lower)
160 .map_err(|e| RustMlError::InvalidParameter(format!("Cholesky failed: {e:?}")))?;
161 let lower = llt.L();
162 l_lower = Mat::<f64>::from_fn(n, n, |i, j| lower[(i, j)]);
163
164 let mut b_vec = Array1::<f64>::zeros(n);
166 for i in 0..n {
167 b_vec[i] = w[i] * f[i] + (yb[i] - pi[i]);
168 }
169 let mut k_b = Array1::<f64>::zeros(n);
172 for i in 0..n {
173 let mut s = 0.0;
174 for j in 0..n {
175 s += k[[i, j]] * b_vec[j];
176 }
177 k_b[i] = s;
178 }
179 let ws_kb: Vec<f64> = (0..n).map(|i| ws[i] * k_b[i]).collect();
180 let rhs = Mat::<f64>::from_fn(n, 1, |i, _| ws_kb[i]);
181 let v_mat = llt.solve(&rhs);
182 let mut a = Array1::<f64>::zeros(n);
183 for i in 0..n {
184 a[i] = b_vec[i] - ws[i] * v_mat[(i, 0)];
185 }
186 let mut new_f = Array1::<f64>::zeros(n);
188 for i in 0..n {
189 let mut s = 0.0;
190 for j in 0..n {
191 s += k[[i, j]] * a[j];
192 }
193 new_f[i] = s;
194 }
195
196 let mut obj = 0.0;
198 for i in 0..n {
199 obj -= 0.5 * new_f[i] * a[i];
200 let lp = if yb[i] > 0.5 {
202 -(-new_f[i]).ln_1p().min(0.0)
203 - if new_f[i] >= 0.0 {
204 (-new_f[i]).exp().ln_1p()
205 } else {
206 -new_f[i] + new_f[i].exp().ln_1p()
207 }
208 } else {
209 if new_f[i] >= 0.0 {
210 -new_f[i] - (-new_f[i]).exp().ln_1p()
211 } else {
212 -new_f[i].exp().ln_1p()
213 }
214 };
215 obj += lp;
216 }
217
218 f = new_f;
219 alpha = a;
220 for i in 0..n {
221 w_sqrt[i] = ws[i];
222 }
223
224 if (obj - prev_obj).abs() < self.tol {
225 break;
226 }
227 prev_obj = obj;
228 }
229
230 Ok(FittedGaussianProcessClassifier {
231 x_train: x.clone(),
232 alpha,
233 l_lower,
234 w_sqrt,
235 kernel: clone_kernel(&self.kernel),
236 classes: [neg, pos],
237 })
238 }
239}
240
241impl FittedGaussianProcessClassifier {
242 fn latent_predict(&self, x: &Array2<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
244 let n_train = self.x_train.nrows();
245 if x.ncols() != self.x_train.ncols() {
246 return Err(RustMlError::ShapeMismatch(format!(
247 "expected {} features, got {}",
248 self.x_train.ncols(),
249 x.ncols()
250 )));
251 }
252 let n_new = x.nrows();
253 let k_star = build_gram(x, &self.x_train, &self.kernel);
254 let mean = k_star.dot(&self.alpha);
255 let mut var = Array1::<f64>::zeros(n_new);
257 for i in 0..n_new {
258 let mut ws_k = vec![0.0_f64; n_train];
259 for j in 0..n_train {
260 ws_k[j] = self.w_sqrt[j] * k_star[[i, j]];
261 }
262 let mut v = vec![0.0_f64; n_train];
264 for r in 0..n_train {
265 let mut s = ws_k[r];
266 for c in 0..r {
267 s -= self.l_lower[(r, c)] * v[c];
268 }
269 v[r] = s / self.l_lower[(r, r)].max(1e-12);
270 }
271 let v_sq: f64 = v.iter().map(|x| x * x).sum();
272 let xi = x.row(i).to_owned();
273 let k_xx = self.kernel_compute(xi.as_slice().unwrap(), xi.as_slice().unwrap());
274 var[i] = (k_xx - v_sq).max(0.0);
275 }
276 Ok((mean, var))
277 }
278
279 fn kernel_compute(&self, a: &[f64], b: &[f64]) -> f64 {
280 let arr_a = Array2::from_shape_vec((1, a.len()), a.to_vec()).unwrap();
282 let arr_b = Array2::from_shape_vec((1, b.len()), b.to_vec()).unwrap();
283 build_gram(&arr_a, &arr_b, &self.kernel)[[0, 0]]
284 }
285}
286
287impl Predict<f64> for FittedGaussianProcessClassifier {
288 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
289 let proba = self.predict_proba(x)?;
290 let mut out = Array1::<f64>::zeros(x.nrows());
291 for i in 0..x.nrows() {
292 out[i] = if proba[[i, 1]] >= 0.5 {
293 self.classes[1]
294 } else {
295 self.classes[0]
296 };
297 }
298 Ok(out)
299 }
300}
301
302impl PredictProba<f64> for FittedGaussianProcessClassifier {
303 fn predict_proba(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
304 let (mean, var) = self.latent_predict(x)?;
305 let n = mean.len();
307 let mut out = Array2::<f64>::zeros((n, 2));
308 let pi8 = std::f64::consts::PI / 8.0;
309 for i in 0..n {
310 let denom = (1.0 + pi8 * var[i]).sqrt();
311 let p1 = sigmoid(mean[i] / denom);
312 out[[i, 0]] = 1.0 - p1;
313 out[[i, 1]] = p1;
314 }
315 Ok(out)
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use ndarray::array;
323
324 #[test]
325 fn test_gpc_separates_two_clusters() {
326 let mut x_data = Vec::new();
328 let mut y_data = Vec::new();
329 for i in 0..6 {
330 let f = i as f64 * 0.1;
331 x_data.extend([f, f + 0.1]);
332 y_data.push(0.0);
333 x_data.extend([5.0 + f, 5.0 - f]);
334 y_data.push(1.0);
335 }
336 let x = Array2::from_shape_vec((12, 2), x_data).unwrap();
337 let y = Array1::from_vec(y_data);
338 let kernel = GpKernel::Rbf {
339 length_scale: 2.0,
340 signal_var: 1.0,
341 };
342 let fitted = GaussianProcessClassifier::new(kernel)
343 .with_max_iter(50)
344 .fit(&x, &y)
345 .unwrap();
346 let preds = fitted.predict(&x).unwrap();
347 let correct = preds
349 .iter()
350 .zip(y.iter())
351 .filter(|(p, t)| (*p - *t).abs() < 0.5)
352 .count();
353 assert!(correct >= 11, "got {}/{} correct", correct, y.len());
354
355 let proba = fitted.predict_proba(&x).unwrap();
357 for i in 0..12 {
358 let s = proba[[i, 0]] + proba[[i, 1]];
359 assert!((s - 1.0).abs() < 1e-9, "row {} sum = {}", i, s);
360 }
361 let _ = array![1.0_f64];
362 }
363}
364
365impl anofox_ml_core::ClassifierScore<f64> for FittedGaussianProcessClassifier {}
366
367pub struct MulticlassGaussianProcessClassifier {
376 pub kernel: GpKernel,
377 pub max_iter: usize,
378 pub tol: f64,
379}
380
381impl MulticlassGaussianProcessClassifier {
382 pub fn new(kernel: GpKernel) -> Self {
383 Self {
384 kernel,
385 max_iter: 100,
386 tol: 1e-6,
387 }
388 }
389 pub fn with_max_iter(mut self, m: usize) -> Self {
390 self.max_iter = m;
391 self
392 }
393 pub fn with_tol(mut self, t: f64) -> Self {
394 self.tol = t;
395 self
396 }
397}
398
399pub struct FittedMulticlassGaussianProcessClassifier {
400 pub classes: Vec<f64>,
401 pub binary: Vec<FittedGaussianProcessClassifier>,
402}
403
404fn clone_kernel_local(k: &GpKernel) -> GpKernel {
405 match k {
406 GpKernel::Rbf {
407 length_scale,
408 signal_var,
409 } => GpKernel::Rbf {
410 length_scale: *length_scale,
411 signal_var: *signal_var,
412 },
413 GpKernel::Matern {
414 length_scale,
415 signal_var,
416 nu,
417 } => GpKernel::Matern {
418 length_scale: *length_scale,
419 signal_var: *signal_var,
420 nu: *nu,
421 },
422 GpKernel::RationalQuadratic {
423 length_scale,
424 signal_var,
425 alpha,
426 } => GpKernel::RationalQuadratic {
427 length_scale: *length_scale,
428 signal_var: *signal_var,
429 alpha: *alpha,
430 },
431 GpKernel::White { noise_level } => GpKernel::White {
432 noise_level: *noise_level,
433 },
434 GpKernel::Constant { value } => GpKernel::Constant { value: *value },
435 GpKernel::Sum(a, b) => GpKernel::Sum(
436 Box::new(clone_kernel_local(a)),
437 Box::new(clone_kernel_local(b)),
438 ),
439 GpKernel::Product(a, b) => GpKernel::Product(
440 Box::new(clone_kernel_local(a)),
441 Box::new(clone_kernel_local(b)),
442 ),
443 }
444}
445
446impl Fit<f64> for MulticlassGaussianProcessClassifier {
447 type Fitted = FittedMulticlassGaussianProcessClassifier;
448
449 fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
450 let mut classes: Vec<f64> = y.iter().copied().collect();
451 classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
452 classes.dedup();
453 if classes.len() < 2 {
454 return Err(RustMlError::InvalidParameter(format!(
455 "multi-class GPC needs ≥2 classes, found {}",
456 classes.len()
457 )));
458 }
459 let mut binary = Vec::with_capacity(classes.len());
460 for &c in &classes {
461 let y_bin: Array1<f64> = y.mapv(|v| if v == c { 1.0 } else { 0.0 });
462 let inner = GaussianProcessClassifier {
463 kernel: clone_kernel_local(&self.kernel),
464 max_iter: self.max_iter,
465 tol: self.tol,
466 };
467 binary.push(inner.fit(x, &y_bin)?);
468 }
469 Ok(FittedMulticlassGaussianProcessClassifier { classes, binary })
470 }
471}
472
473impl Predict<f64> for FittedMulticlassGaussianProcessClassifier {
474 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
475 let proba = self.predict_proba(x)?;
476 let n = x.nrows();
477 let mut out = Array1::<f64>::zeros(n);
478 for i in 0..n {
479 let mut best = f64::NEG_INFINITY;
480 let mut best_c = 0;
481 for c in 0..self.classes.len() {
482 if proba[[i, c]] > best {
483 best = proba[[i, c]];
484 best_c = c;
485 }
486 }
487 out[i] = self.classes[best_c];
488 }
489 Ok(out)
490 }
491}
492
493impl PredictProba<f64> for FittedMulticlassGaussianProcessClassifier {
494 fn predict_proba(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
495 let n = x.nrows();
496 let k = self.classes.len();
497 let mut p = Array2::<f64>::zeros((n, k));
498 for c in 0..k {
499 let pc = self.binary[c].predict_proba(x)?;
500 for i in 0..n {
502 p[[i, c]] = pc[[i, 1]];
503 }
504 }
505 for i in 0..n {
507 let s: f64 = (0..k).map(|c| p[[i, c]]).sum::<f64>().max(1e-12);
508 for c in 0..k {
509 p[[i, c]] /= s;
510 }
511 }
512 Ok(p)
513 }
514}
515
516impl anofox_ml_core::ClassifierScore<f64> for FittedMulticlassGaussianProcessClassifier {}
517
518#[cfg(test)]
519mod multiclass_tests {
520 use super::*;
521 use crate::GpKernel;
522 use ndarray::Array2;
523
524 #[test]
525 fn test_multiclass_gpc_three_classes() {
526 let n_per = 6;
528 let mut x_data = Vec::new();
529 let mut y_data = Vec::new();
530 for i in 0..n_per {
531 let f = i as f64 * 0.1;
532 x_data.extend([f, f]);
533 y_data.push(0.0);
534 x_data.extend([5.0 + f, f]);
535 y_data.push(1.0);
536 x_data.extend([f, 5.0 + f]);
537 y_data.push(2.0);
538 }
539 let x = Array2::from_shape_vec((n_per * 3, 2), x_data).unwrap();
540 let y = Array1::from_vec(y_data);
541 let mc = MulticlassGaussianProcessClassifier::new(GpKernel::Rbf {
542 length_scale: 2.0,
543 signal_var: 1.0,
544 })
545 .with_max_iter(50);
546 let fitted = mc.fit(&x, &y).unwrap();
547 let preds = fitted.predict(&x).unwrap();
548 let correct = preds
549 .iter()
550 .zip(y.iter())
551 .filter(|(p, t)| (*p - *t).abs() < 0.5)
552 .count();
553 assert!(
554 correct >= (n_per * 3) * 9 / 10,
555 "got {}/{} correct",
556 correct,
557 n_per * 3
558 );
559 let p = fitted.predict_proba(&x).unwrap();
560 for i in 0..(n_per * 3) {
561 let s: f64 = (0..3).map(|c| p[[i, c]]).sum();
562 assert!((s - 1.0).abs() < 1e-9);
563 }
564 }
565}