Skip to main content

anofox_ml_gaussian_process/
classifier.rs

1//! Gaussian Process binary classifier with Laplace approximation.
2//!
3//! Mirrors `sklearn.gaussian_process.GaussianProcessClassifier` for the
4//! binary case. Multi-class would wrap this in one-vs-rest (not yet done).
5//!
6//! Algorithm (Rasmussen & Williams §3.4, Algorithm 3.1):
7//!
8//! 1. Fix a kernel `k` and binary labels `y ∈ {0, 1}`.
9//! 2. Find the mode `f̂` of the latent posterior by Newton-Raphson with
10//!    Cholesky of `B = I + Wˢ K Wˢ` (`Wˢ = W^{1/2}`) for stability.
11//! 3. Posterior covariance `Σ = (K⁻¹ + W)⁻¹`.
12//! 4. Predict via probit approximation:
13//!      `p(y*=1|x*) ≈ σ(f̄* / √(1 + π/8 · V[f*]))`.
14
15use anofox_ml_core::{Fit, Predict, PredictProba, Result, RustMlError};
16use faer::linalg::solvers::Solve;
17use faer::{Mat, Side};
18use ndarray::{Array1, Array2};
19
20use crate::{build_gram, GpKernel};
21
22pub struct GaussianProcessClassifier {
23    pub kernel: GpKernel,
24    pub max_iter: usize,
25    pub tol: f64,
26}
27
28impl GaussianProcessClassifier {
29    pub fn new(kernel: GpKernel) -> Self {
30        Self {
31            kernel,
32            max_iter: 100,
33            tol: 1e-6,
34        }
35    }
36    pub fn with_max_iter(mut self, m: usize) -> Self {
37        self.max_iter = m;
38        self
39    }
40    pub fn with_tol(mut self, t: f64) -> Self {
41        self.tol = t;
42        self
43    }
44}
45
46pub struct FittedGaussianProcessClassifier {
47    pub x_train: Array2<f64>,
48    /// `y_train - π̂` (the dual coefficient vector at the posterior mode).
49    pub alpha: Array1<f64>,
50    /// Cholesky factor `L` of `I + Wˢ K Wˢ`.
51    pub l_lower: Mat<f64>,
52    /// `Wˢ = W^{1/2}` at the posterior mode.
53    pub w_sqrt: Array1<f64>,
54    pub kernel: GpKernel,
55    pub classes: [f64; 2],
56}
57
58fn sigmoid(z: f64) -> f64 {
59    if z >= 0.0 {
60        1.0 / (1.0 + (-z).exp())
61    } else {
62        let e = z.exp();
63        e / (1.0 + e)
64    }
65}
66
67fn clone_kernel(k: &GpKernel) -> GpKernel {
68    match k {
69        GpKernel::Rbf {
70            length_scale,
71            signal_var,
72        } => GpKernel::Rbf {
73            length_scale: *length_scale,
74            signal_var: *signal_var,
75        },
76        GpKernel::Matern {
77            length_scale,
78            signal_var,
79            nu,
80        } => GpKernel::Matern {
81            length_scale: *length_scale,
82            signal_var: *signal_var,
83            nu: *nu,
84        },
85        GpKernel::RationalQuadratic {
86            length_scale,
87            signal_var,
88            alpha,
89        } => GpKernel::RationalQuadratic {
90            length_scale: *length_scale,
91            signal_var: *signal_var,
92            alpha: *alpha,
93        },
94        GpKernel::White { noise_level } => GpKernel::White {
95            noise_level: *noise_level,
96        },
97        GpKernel::Constant { value } => GpKernel::Constant { value: *value },
98        GpKernel::Sum(a, b) => GpKernel::Sum(Box::new(clone_kernel(a)), Box::new(clone_kernel(b))),
99        GpKernel::Product(a, b) => {
100            GpKernel::Product(Box::new(clone_kernel(a)), Box::new(clone_kernel(b)))
101        }
102    }
103}
104
105impl Fit<f64> for GaussianProcessClassifier {
106    type Fitted = FittedGaussianProcessClassifier;
107
108    fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
109        let n = x.nrows();
110        if y.len() != n {
111            return Err(RustMlError::ShapeMismatch(format!(
112                "X has {} rows but y has {}",
113                n,
114                y.len()
115            )));
116        }
117        // Determine classes.
118        let mut classes: Vec<f64> = y.iter().copied().collect();
119        classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
120        classes.dedup();
121        if classes.len() != 2 {
122            return Err(RustMlError::InvalidParameter(format!(
123                "GPC expects 2 classes, found {}",
124                classes.len()
125            )));
126        }
127        let neg = classes[0];
128        let pos = classes[1];
129        // Encode labels as 0 / 1.
130        let yb: Vec<f64> = y
131            .iter()
132            .map(|v| if *v == pos { 1.0 } else { 0.0 })
133            .collect();
134
135        let k = build_gram(x, x, &self.kernel);
136        let mut f = Array1::<f64>::zeros(n);
137
138        // Newton-Raphson on the Laplace objective.
139        let mut prev_obj = f64::NEG_INFINITY;
140        let mut alpha = Array1::<f64>::zeros(n);
141        let mut l_lower = Mat::<f64>::zeros(n, n);
142        let mut w_sqrt = Array1::<f64>::zeros(n);
143
144        for _ in 0..self.max_iter {
145            // π = sigmoid(f); W = diag(π(1-π))
146            let pi: Vec<f64> = f.iter().map(|&v| sigmoid(v)).collect();
147            let w: Vec<f64> = pi.iter().map(|&p| p * (1.0 - p)).collect();
148            let ws: Vec<f64> = w.iter().map(|&v| v.sqrt()).collect();
149
150            // B = I + Wˢ K Wˢ
151            let mut b = Array2::<f64>::zeros((n, n));
152            for i in 0..n {
153                for j in 0..n {
154                    b[[i, j]] = ws[i] * k[[i, j]] * ws[j];
155                }
156                b[[i, i]] += 1.0;
157            }
158            let bm = Mat::<f64>::from_fn(n, n, |i, j| b[[i, j]]);
159            let llt = faer::linalg::solvers::Llt::new(bm.as_ref(), Side::Lower)
160                .map_err(|e| RustMlError::InvalidParameter(format!("Cholesky failed: {e:?}")))?;
161            let lower = llt.L();
162            l_lower = Mat::<f64>::from_fn(n, n, |i, j| lower[(i, j)]);
163
164            // b_vec = W f + (y - π)
165            let mut b_vec = Array1::<f64>::zeros(n);
166            for i in 0..n {
167                b_vec[i] = w[i] * f[i] + (yb[i] - pi[i]);
168            }
169            // a = b - Wˢ L'^{-1} L^{-1} Wˢ K b
170            // Equivalently: solve B v = Wˢ K b, then a = b - Wˢ v.
171            let mut k_b = Array1::<f64>::zeros(n);
172            for i in 0..n {
173                let mut s = 0.0;
174                for j in 0..n {
175                    s += k[[i, j]] * b_vec[j];
176                }
177                k_b[i] = s;
178            }
179            let ws_kb: Vec<f64> = (0..n).map(|i| ws[i] * k_b[i]).collect();
180            let rhs = Mat::<f64>::from_fn(n, 1, |i, _| ws_kb[i]);
181            let v_mat = llt.solve(&rhs);
182            let mut a = Array1::<f64>::zeros(n);
183            for i in 0..n {
184                a[i] = b_vec[i] - ws[i] * v_mat[(i, 0)];
185            }
186            // f = K a
187            let mut new_f = Array1::<f64>::zeros(n);
188            for i in 0..n {
189                let mut s = 0.0;
190                for j in 0..n {
191                    s += k[[i, j]] * a[j];
192                }
193                new_f[i] = s;
194            }
195
196            // Objective: Ψ(f) = -0.5 fᵀ a + Σ log p(y_i | f_i)
197            let mut obj = 0.0;
198            for i in 0..n {
199                obj -= 0.5 * new_f[i] * a[i];
200                // log p(y_i | f_i) = y log σ(f) + (1-y) log σ(-f)
201                let lp = if yb[i] > 0.5 {
202                    -(-new_f[i]).ln_1p().min(0.0)
203                        - if new_f[i] >= 0.0 {
204                            (-new_f[i]).exp().ln_1p()
205                        } else {
206                            -new_f[i] + new_f[i].exp().ln_1p()
207                        }
208                } else {
209                    if new_f[i] >= 0.0 {
210                        -new_f[i] - (-new_f[i]).exp().ln_1p()
211                    } else {
212                        -new_f[i].exp().ln_1p()
213                    }
214                };
215                obj += lp;
216            }
217
218            f = new_f;
219            alpha = a;
220            for i in 0..n {
221                w_sqrt[i] = ws[i];
222            }
223
224            if (obj - prev_obj).abs() < self.tol {
225                break;
226            }
227            prev_obj = obj;
228        }
229
230        Ok(FittedGaussianProcessClassifier {
231            x_train: x.clone(),
232            alpha,
233            l_lower,
234            w_sqrt,
235            kernel: clone_kernel(&self.kernel),
236            classes: [neg, pos],
237        })
238    }
239}
240
241impl FittedGaussianProcessClassifier {
242    /// Latent posterior mean and variance at query points.
243    fn latent_predict(&self, x: &Array2<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
244        let n_train = self.x_train.nrows();
245        if x.ncols() != self.x_train.ncols() {
246            return Err(RustMlError::ShapeMismatch(format!(
247                "expected {} features, got {}",
248                self.x_train.ncols(),
249                x.ncols()
250            )));
251        }
252        let n_new = x.nrows();
253        let k_star = build_gram(x, &self.x_train, &self.kernel);
254        let mean = k_star.dot(&self.alpha);
255        // Variance: k(x*, x*) - v' v where v = L^{-1} Wˢ k_star.
256        let mut var = Array1::<f64>::zeros(n_new);
257        for i in 0..n_new {
258            let mut ws_k = vec![0.0_f64; n_train];
259            for j in 0..n_train {
260                ws_k[j] = self.w_sqrt[j] * k_star[[i, j]];
261            }
262            // Forward solve L v = ws_k.
263            let mut v = vec![0.0_f64; n_train];
264            for r in 0..n_train {
265                let mut s = ws_k[r];
266                for c in 0..r {
267                    s -= self.l_lower[(r, c)] * v[c];
268                }
269                v[r] = s / self.l_lower[(r, r)].max(1e-12);
270            }
271            let v_sq: f64 = v.iter().map(|x| x * x).sum();
272            let xi = x.row(i).to_owned();
273            let k_xx = self.kernel_compute(xi.as_slice().unwrap(), xi.as_slice().unwrap());
274            var[i] = (k_xx - v_sq).max(0.0);
275        }
276        Ok((mean, var))
277    }
278
279    fn kernel_compute(&self, a: &[f64], b: &[f64]) -> f64 {
280        // Re-use the kernel's compute method via build_gram on 1×1 arrays.
281        let arr_a = Array2::from_shape_vec((1, a.len()), a.to_vec()).unwrap();
282        let arr_b = Array2::from_shape_vec((1, b.len()), b.to_vec()).unwrap();
283        build_gram(&arr_a, &arr_b, &self.kernel)[[0, 0]]
284    }
285}
286
287impl Predict<f64> for FittedGaussianProcessClassifier {
288    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
289        let proba = self.predict_proba(x)?;
290        let mut out = Array1::<f64>::zeros(x.nrows());
291        for i in 0..x.nrows() {
292            out[i] = if proba[[i, 1]] >= 0.5 {
293                self.classes[1]
294            } else {
295                self.classes[0]
296            };
297        }
298        Ok(out)
299    }
300}
301
302impl PredictProba<f64> for FittedGaussianProcessClassifier {
303    fn predict_proba(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
304        let (mean, var) = self.latent_predict(x)?;
305        // Probit approximation: p(y=1) ≈ σ(f̄ / sqrt(1 + π/8 · v)).
306        let n = mean.len();
307        let mut out = Array2::<f64>::zeros((n, 2));
308        let pi8 = std::f64::consts::PI / 8.0;
309        for i in 0..n {
310            let denom = (1.0 + pi8 * var[i]).sqrt();
311            let p1 = sigmoid(mean[i] / denom);
312            out[[i, 0]] = 1.0 - p1;
313            out[[i, 1]] = p1;
314        }
315        Ok(out)
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use ndarray::array;
323
324    #[test]
325    fn test_gpc_separates_two_clusters() {
326        // 2-class problem: clusters around (0,0) and (5,5).
327        let mut x_data = Vec::new();
328        let mut y_data = Vec::new();
329        for i in 0..6 {
330            let f = i as f64 * 0.1;
331            x_data.extend([f, f + 0.1]);
332            y_data.push(0.0);
333            x_data.extend([5.0 + f, 5.0 - f]);
334            y_data.push(1.0);
335        }
336        let x = Array2::from_shape_vec((12, 2), x_data).unwrap();
337        let y = Array1::from_vec(y_data);
338        let kernel = GpKernel::Rbf {
339            length_scale: 2.0,
340            signal_var: 1.0,
341        };
342        let fitted = GaussianProcessClassifier::new(kernel)
343            .with_max_iter(50)
344            .fit(&x, &y)
345            .unwrap();
346        let preds = fitted.predict(&x).unwrap();
347        // Should perfectly classify a well-separated problem.
348        let correct = preds
349            .iter()
350            .zip(y.iter())
351            .filter(|(p, t)| (*p - *t).abs() < 0.5)
352            .count();
353        assert!(correct >= 11, "got {}/{} correct", correct, y.len());
354
355        // predict_proba returns valid probabilities (rows sum to 1).
356        let proba = fitted.predict_proba(&x).unwrap();
357        for i in 0..12 {
358            let s = proba[[i, 0]] + proba[[i, 1]];
359            assert!((s - 1.0).abs() < 1e-9, "row {} sum = {}", i, s);
360        }
361        let _ = array![1.0_f64];
362    }
363}
364
365impl anofox_ml_core::ClassifierScore<f64> for FittedGaussianProcessClassifier {}
366
367// ---------------------------------------------------------------------------
368// Multi-class GPC via one-vs-rest.
369// ---------------------------------------------------------------------------
370
371/// Multi-class Gaussian Process Classifier built as a one-vs-rest stack of
372/// binary `GaussianProcessClassifier` instances. Mirrors sklearn's
373/// `GaussianProcessClassifier(multi_class='one_vs_rest')` for the case of
374/// arbitrary discrete class labels.
375pub struct MulticlassGaussianProcessClassifier {
376    pub kernel: GpKernel,
377    pub max_iter: usize,
378    pub tol: f64,
379}
380
381impl MulticlassGaussianProcessClassifier {
382    pub fn new(kernel: GpKernel) -> Self {
383        Self {
384            kernel,
385            max_iter: 100,
386            tol: 1e-6,
387        }
388    }
389    pub fn with_max_iter(mut self, m: usize) -> Self {
390        self.max_iter = m;
391        self
392    }
393    pub fn with_tol(mut self, t: f64) -> Self {
394        self.tol = t;
395        self
396    }
397}
398
399pub struct FittedMulticlassGaussianProcessClassifier {
400    pub classes: Vec<f64>,
401    pub binary: Vec<FittedGaussianProcessClassifier>,
402}
403
404fn clone_kernel_local(k: &GpKernel) -> GpKernel {
405    match k {
406        GpKernel::Rbf {
407            length_scale,
408            signal_var,
409        } => GpKernel::Rbf {
410            length_scale: *length_scale,
411            signal_var: *signal_var,
412        },
413        GpKernel::Matern {
414            length_scale,
415            signal_var,
416            nu,
417        } => GpKernel::Matern {
418            length_scale: *length_scale,
419            signal_var: *signal_var,
420            nu: *nu,
421        },
422        GpKernel::RationalQuadratic {
423            length_scale,
424            signal_var,
425            alpha,
426        } => GpKernel::RationalQuadratic {
427            length_scale: *length_scale,
428            signal_var: *signal_var,
429            alpha: *alpha,
430        },
431        GpKernel::White { noise_level } => GpKernel::White {
432            noise_level: *noise_level,
433        },
434        GpKernel::Constant { value } => GpKernel::Constant { value: *value },
435        GpKernel::Sum(a, b) => GpKernel::Sum(
436            Box::new(clone_kernel_local(a)),
437            Box::new(clone_kernel_local(b)),
438        ),
439        GpKernel::Product(a, b) => GpKernel::Product(
440            Box::new(clone_kernel_local(a)),
441            Box::new(clone_kernel_local(b)),
442        ),
443    }
444}
445
446impl Fit<f64> for MulticlassGaussianProcessClassifier {
447    type Fitted = FittedMulticlassGaussianProcessClassifier;
448
449    fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
450        let mut classes: Vec<f64> = y.iter().copied().collect();
451        classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
452        classes.dedup();
453        if classes.len() < 2 {
454            return Err(RustMlError::InvalidParameter(format!(
455                "multi-class GPC needs ≥2 classes, found {}",
456                classes.len()
457            )));
458        }
459        let mut binary = Vec::with_capacity(classes.len());
460        for &c in &classes {
461            let y_bin: Array1<f64> = y.mapv(|v| if v == c { 1.0 } else { 0.0 });
462            let inner = GaussianProcessClassifier {
463                kernel: clone_kernel_local(&self.kernel),
464                max_iter: self.max_iter,
465                tol: self.tol,
466            };
467            binary.push(inner.fit(x, &y_bin)?);
468        }
469        Ok(FittedMulticlassGaussianProcessClassifier { classes, binary })
470    }
471}
472
473impl Predict<f64> for FittedMulticlassGaussianProcessClassifier {
474    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
475        let proba = self.predict_proba(x)?;
476        let n = x.nrows();
477        let mut out = Array1::<f64>::zeros(n);
478        for i in 0..n {
479            let mut best = f64::NEG_INFINITY;
480            let mut best_c = 0;
481            for c in 0..self.classes.len() {
482                if proba[[i, c]] > best {
483                    best = proba[[i, c]];
484                    best_c = c;
485                }
486            }
487            out[i] = self.classes[best_c];
488        }
489        Ok(out)
490    }
491}
492
493impl PredictProba<f64> for FittedMulticlassGaussianProcessClassifier {
494    fn predict_proba(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
495        let n = x.nrows();
496        let k = self.classes.len();
497        let mut p = Array2::<f64>::zeros((n, k));
498        for c in 0..k {
499            let pc = self.binary[c].predict_proba(x)?;
500            // Take the "is-class-c" column (= column 1 since label 1.0 in fit).
501            for i in 0..n {
502                p[[i, c]] = pc[[i, 1]];
503            }
504        }
505        // Renormalise rows to sum to 1 (sklearn does the same for OvR).
506        for i in 0..n {
507            let s: f64 = (0..k).map(|c| p[[i, c]]).sum::<f64>().max(1e-12);
508            for c in 0..k {
509                p[[i, c]] /= s;
510            }
511        }
512        Ok(p)
513    }
514}
515
516impl anofox_ml_core::ClassifierScore<f64> for FittedMulticlassGaussianProcessClassifier {}
517
518#[cfg(test)]
519mod multiclass_tests {
520    use super::*;
521    use crate::GpKernel;
522    use ndarray::Array2;
523
524    #[test]
525    fn test_multiclass_gpc_three_classes() {
526        // Three clusters at (0,0), (5,0), (0,5).
527        let n_per = 6;
528        let mut x_data = Vec::new();
529        let mut y_data = Vec::new();
530        for i in 0..n_per {
531            let f = i as f64 * 0.1;
532            x_data.extend([f, f]);
533            y_data.push(0.0);
534            x_data.extend([5.0 + f, f]);
535            y_data.push(1.0);
536            x_data.extend([f, 5.0 + f]);
537            y_data.push(2.0);
538        }
539        let x = Array2::from_shape_vec((n_per * 3, 2), x_data).unwrap();
540        let y = Array1::from_vec(y_data);
541        let mc = MulticlassGaussianProcessClassifier::new(GpKernel::Rbf {
542            length_scale: 2.0,
543            signal_var: 1.0,
544        })
545        .with_max_iter(50);
546        let fitted = mc.fit(&x, &y).unwrap();
547        let preds = fitted.predict(&x).unwrap();
548        let correct = preds
549            .iter()
550            .zip(y.iter())
551            .filter(|(p, t)| (*p - *t).abs() < 0.5)
552            .count();
553        assert!(
554            correct >= (n_per * 3) * 9 / 10,
555            "got {}/{} correct",
556            correct,
557            n_per * 3
558        );
559        let p = fitted.predict_proba(&x).unwrap();
560        for i in 0..(n_per * 3) {
561            let s: f64 = (0..3).map(|c| p[[i, c]]).sum();
562            assert!((s - 1.0).abs() < 1e-9);
563        }
564    }
565}