use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::random::{Rng, RngExt};
use crate::error::{OptimizeError, OptimizeResult};
pub trait SurrogateKernel: Send + Sync {
fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64;
fn covariance_matrix(&self, x: &Array2<f64>) -> Array2<f64> {
let n = x.nrows();
let mut k = Array2::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let kij = self.eval(&x.row(i), &x.row(j));
k[[i, j]] = kij;
if i != j {
k[[j, i]] = kij;
}
}
}
k
}
fn cross_covariance(&self, x1: &Array2<f64>, x2: &Array2<f64>) -> Array2<f64> {
let n1 = x1.nrows();
let n2 = x2.nrows();
let mut k = Array2::zeros((n1, n2));
for i in 0..n1 {
for j in 0..n2 {
k[[i, j]] = self.eval(&x1.row(i), &x2.row(j));
}
}
k
}
fn get_log_params(&self) -> Vec<f64>;
fn set_log_params(&mut self, params: &[f64]);
fn n_params(&self) -> usize {
self.get_log_params().len()
}
fn clone_box(&self) -> Box<dyn SurrogateKernel>;
fn name(&self) -> &str;
}
impl Clone for Box<dyn SurrogateKernel> {
fn clone(&self) -> Self {
self.clone_box()
}
}
#[derive(Debug, Clone)]
pub struct RbfKernel {
pub length_scale: f64,
pub signal_variance: f64,
}
impl RbfKernel {
pub fn new(length_scale: f64, signal_variance: f64) -> Self {
Self {
length_scale: length_scale.max(1e-10),
signal_variance: signal_variance.max(1e-10),
}
}
}
impl Default for RbfKernel {
fn default() -> Self {
Self::new(1.0, 1.0)
}
}
impl SurrogateKernel for RbfKernel {
fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
let sq_dist = squared_distance(x1, x2);
self.signal_variance * (-0.5 * sq_dist / (self.length_scale * self.length_scale)).exp()
}
fn get_log_params(&self) -> Vec<f64> {
vec![self.length_scale.ln(), self.signal_variance.ln()]
}
fn set_log_params(&mut self, params: &[f64]) {
if params.len() >= 2 {
self.length_scale = params[0].exp().max(1e-10);
self.signal_variance = params[1].exp().max(1e-10);
}
}
fn clone_box(&self) -> Box<dyn SurrogateKernel> {
Box::new(self.clone())
}
fn name(&self) -> &str {
"RBF"
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MaternVariant {
OneHalf,
ThreeHalves,
FiveHalves,
}
#[derive(Debug, Clone)]
pub struct MaternKernel {
pub variant: MaternVariant,
pub length_scale: f64,
pub signal_variance: f64,
}
impl MaternKernel {
pub fn new(variant: MaternVariant, length_scale: f64, signal_variance: f64) -> Self {
Self {
variant,
length_scale: length_scale.max(1e-10),
signal_variance: signal_variance.max(1e-10),
}
}
pub fn one_half(length_scale: f64, signal_variance: f64) -> Self {
Self::new(MaternVariant::OneHalf, length_scale, signal_variance)
}
pub fn three_halves(length_scale: f64, signal_variance: f64) -> Self {
Self::new(MaternVariant::ThreeHalves, length_scale, signal_variance)
}
pub fn five_halves(length_scale: f64, signal_variance: f64) -> Self {
Self::new(MaternVariant::FiveHalves, length_scale, signal_variance)
}
}
impl Default for MaternKernel {
fn default() -> Self {
Self::new(MaternVariant::FiveHalves, 1.0, 1.0)
}
}
impl SurrogateKernel for MaternKernel {
fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
let r = squared_distance(x1, x2).sqrt();
let l = self.length_scale;
let sv = self.signal_variance;
match self.variant {
MaternVariant::OneHalf => sv * (-r / l).exp(),
MaternVariant::ThreeHalves => {
let sqrt3_r_l = 3.0_f64.sqrt() * r / l;
sv * (1.0 + sqrt3_r_l) * (-sqrt3_r_l).exp()
}
MaternVariant::FiveHalves => {
let sqrt5_r_l = 5.0_f64.sqrt() * r / l;
let r2_l2 = r * r / (l * l);
sv * (1.0 + sqrt5_r_l + 5.0 * r2_l2 / 3.0) * (-sqrt5_r_l).exp()
}
}
}
fn get_log_params(&self) -> Vec<f64> {
vec![self.length_scale.ln(), self.signal_variance.ln()]
}
fn set_log_params(&mut self, params: &[f64]) {
if params.len() >= 2 {
self.length_scale = params[0].exp().max(1e-10);
self.signal_variance = params[1].exp().max(1e-10);
}
}
fn clone_box(&self) -> Box<dyn SurrogateKernel> {
Box::new(self.clone())
}
fn name(&self) -> &str {
match self.variant {
MaternVariant::OneHalf => "Matern12",
MaternVariant::ThreeHalves => "Matern32",
MaternVariant::FiveHalves => "Matern52",
}
}
}
#[derive(Debug, Clone)]
pub struct RationalQuadraticKernel {
pub length_scale: f64,
pub signal_variance: f64,
pub alpha: f64,
}
impl RationalQuadraticKernel {
pub fn new(length_scale: f64, signal_variance: f64, alpha: f64) -> Self {
Self {
length_scale: length_scale.max(1e-10),
signal_variance: signal_variance.max(1e-10),
alpha: alpha.max(1e-10),
}
}
}
impl Default for RationalQuadraticKernel {
fn default() -> Self {
Self::new(1.0, 1.0, 1.0)
}
}
impl SurrogateKernel for RationalQuadraticKernel {
fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
let sq_dist = squared_distance(x1, x2);
let base = 1.0 + sq_dist / (2.0 * self.alpha * self.length_scale * self.length_scale);
self.signal_variance * base.powf(-self.alpha)
}
fn get_log_params(&self) -> Vec<f64> {
vec![
self.length_scale.ln(),
self.signal_variance.ln(),
self.alpha.ln(),
]
}
fn set_log_params(&mut self, params: &[f64]) {
if params.len() >= 3 {
self.length_scale = params[0].exp().max(1e-10);
self.signal_variance = params[1].exp().max(1e-10);
self.alpha = params[2].exp().max(1e-10);
}
}
fn clone_box(&self) -> Box<dyn SurrogateKernel> {
Box::new(self.clone())
}
fn name(&self) -> &str {
"RationalQuadratic"
}
}
#[derive(Clone)]
pub struct SumKernel {
pub kernel1: Box<dyn SurrogateKernel>,
pub kernel2: Box<dyn SurrogateKernel>,
}
impl SumKernel {
pub fn new(k1: Box<dyn SurrogateKernel>, k2: Box<dyn SurrogateKernel>) -> Self {
Self {
kernel1: k1,
kernel2: k2,
}
}
}
impl SurrogateKernel for SumKernel {
fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
self.kernel1.eval(x1, x2) + self.kernel2.eval(x1, x2)
}
fn get_log_params(&self) -> Vec<f64> {
let mut p = self.kernel1.get_log_params();
p.extend(self.kernel2.get_log_params());
p
}
fn set_log_params(&mut self, params: &[f64]) {
let n1 = self.kernel1.n_params();
if params.len() >= n1 {
self.kernel1.set_log_params(¶ms[..n1]);
}
if params.len() > n1 {
self.kernel2.set_log_params(¶ms[n1..]);
}
}
fn clone_box(&self) -> Box<dyn SurrogateKernel> {
Box::new(self.clone())
}
fn name(&self) -> &str {
"Sum"
}
}
#[derive(Clone)]
pub struct ProductKernel {
pub kernel1: Box<dyn SurrogateKernel>,
pub kernel2: Box<dyn SurrogateKernel>,
}
impl ProductKernel {
pub fn new(k1: Box<dyn SurrogateKernel>, k2: Box<dyn SurrogateKernel>) -> Self {
Self {
kernel1: k1,
kernel2: k2,
}
}
}
impl SurrogateKernel for ProductKernel {
fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
self.kernel1.eval(x1, x2) * self.kernel2.eval(x1, x2)
}
fn get_log_params(&self) -> Vec<f64> {
let mut p = self.kernel1.get_log_params();
p.extend(self.kernel2.get_log_params());
p
}
fn set_log_params(&mut self, params: &[f64]) {
let n1 = self.kernel1.n_params();
if params.len() >= n1 {
self.kernel1.set_log_params(¶ms[..n1]);
}
if params.len() > n1 {
self.kernel2.set_log_params(¶ms[n1..]);
}
}
fn clone_box(&self) -> Box<dyn SurrogateKernel> {
Box::new(self.clone())
}
fn name(&self) -> &str {
"Product"
}
}
#[derive(Clone)]
pub struct GpSurrogateConfig {
pub noise_variance: f64,
pub optimize_hyperparams: bool,
pub n_restarts: usize,
pub max_opt_iters: usize,
}
impl Default for GpSurrogateConfig {
fn default() -> Self {
Self {
noise_variance: 1e-6,
optimize_hyperparams: true,
n_restarts: 3,
max_opt_iters: 50,
}
}
}
pub struct GpSurrogate {
kernel: Box<dyn SurrogateKernel>,
config: GpSurrogateConfig,
x_train: Option<Array2<f64>>,
y_train: Option<Array1<f64>>,
y_mean: f64,
y_std: f64,
l_factor: Option<Array2<f64>>,
alpha: Option<Array1<f64>>,
}
impl GpSurrogate {
pub fn new(kernel: Box<dyn SurrogateKernel>, config: GpSurrogateConfig) -> Self {
Self {
kernel,
config,
x_train: None,
y_train: None,
y_mean: 0.0,
y_std: 1.0,
l_factor: None,
alpha: None,
}
}
pub fn default_rbf() -> Self {
Self::new(Box::new(RbfKernel::default()), GpSurrogateConfig::default())
}
pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> OptimizeResult<()> {
if x.nrows() != y.len() {
return Err(OptimizeError::InvalidInput(format!(
"x has {} rows but y has {} elements",
x.nrows(),
y.len()
)));
}
if x.nrows() == 0 {
return Err(OptimizeError::InvalidInput(
"Cannot fit GP with zero training samples".to_string(),
));
}
self.x_train = Some(x.clone());
self.y_train = Some(y.clone());
self.y_mean = y.iter().sum::<f64>() / y.len() as f64;
let variance = y.iter().map(|&v| (v - self.y_mean).powi(2)).sum::<f64>() / y.len() as f64;
self.y_std = if variance > 1e-12 {
variance.sqrt()
} else {
1.0
};
if self.config.optimize_hyperparams && x.nrows() >= 3 {
self.optimize_hyperparameters()?;
}
self.update_model()
}
pub fn update(&mut self, x_new: &Array2<f64>, y_new: &Array1<f64>) -> OptimizeResult<()> {
if x_new.nrows() != y_new.len() {
return Err(OptimizeError::InvalidInput(
"x_new and y_new must have same number of rows".to_string(),
));
}
match (&self.x_train, &self.y_train) {
(Some(xt), Some(yt)) => {
let mut x_all = Array2::zeros((xt.nrows() + x_new.nrows(), xt.ncols()));
for i in 0..xt.nrows() {
for j in 0..xt.ncols() {
x_all[[i, j]] = xt[[i, j]];
}
}
for i in 0..x_new.nrows() {
for j in 0..x_new.ncols() {
x_all[[xt.nrows() + i, j]] = x_new[[i, j]];
}
}
let mut y_all = Array1::zeros(yt.len() + y_new.len());
for i in 0..yt.len() {
y_all[i] = yt[i];
}
for i in 0..y_new.len() {
y_all[yt.len() + i] = y_new[i];
}
self.fit(&x_all, &y_all)
}
_ => self.fit(x_new, y_new),
}
}
pub fn predict_mean(&self, x_test: &Array2<f64>) -> OptimizeResult<Array1<f64>> {
let (mean, _) = self.predict(x_test)?;
Ok(mean)
}
pub fn predict_variance(&self, x_test: &Array2<f64>) -> OptimizeResult<Array1<f64>> {
let (_, var) = self.predict(x_test)?;
Ok(var)
}
pub fn predict(&self, x_test: &Array2<f64>) -> OptimizeResult<(Array1<f64>, Array1<f64>)> {
let x_train = self.x_train.as_ref().ok_or_else(|| {
OptimizeError::ComputationError("GP must be fitted before prediction".to_string())
})?;
let alpha = self
.alpha
.as_ref()
.ok_or_else(|| OptimizeError::ComputationError("GP model not fitted".to_string()))?;
let l_factor = self
.l_factor
.as_ref()
.ok_or_else(|| OptimizeError::ComputationError("GP model not fitted".to_string()))?;
let k_star = self.kernel.cross_covariance(x_test, x_train);
let mean_std = k_star.dot(alpha);
let mean = mean_std.mapv(|v| v * self.y_std + self.y_mean);
let n_test = x_test.nrows();
let mut variance = Array1::zeros(n_test);
for i in 0..n_test {
let k_self = self.kernel.eval(&x_test.row(i), &x_test.row(i));
let k_col = k_star.row(i).to_owned();
let v = forward_solve(l_factor, &k_col)?;
let v_sq_sum: f64 = v.iter().map(|&vi| vi * vi).sum();
let var = (k_self - v_sq_sum).max(0.0);
variance[i] = var * self.y_std * self.y_std;
}
Ok((mean, variance))
}
pub fn predict_single(&self, x: &ArrayView1<f64>) -> OptimizeResult<(f64, f64)> {
let x_mat = x
.to_owned()
.into_shape_with_order((1, x.len()))
.map_err(|e| OptimizeError::ComputationError(format!("Shape error: {}", e)))?;
let (mean, var) = self.predict(&x_mat)?;
Ok((mean[0], var[0].max(0.0).sqrt()))
}
pub fn log_marginal_likelihood(&self) -> OptimizeResult<f64> {
let y_train = self
.y_train
.as_ref()
.ok_or_else(|| OptimizeError::ComputationError("GP must be fitted".to_string()))?;
let l_factor = self
.l_factor
.as_ref()
.ok_or_else(|| OptimizeError::ComputationError("GP model not fitted".to_string()))?;
let alpha = self
.alpha
.as_ref()
.ok_or_else(|| OptimizeError::ComputationError("GP model not fitted".to_string()))?;
let y_std = &self.standardize_y(y_train);
let n = y_std.len() as f64;
let data_fit = -0.5 * y_std.dot(alpha);
let log_det: f64 = l_factor.diag().iter().map(|&v| v.abs().ln()).sum();
let norm = -0.5 * n * (2.0 * std::f64::consts::PI).ln();
Ok(data_fit - log_det + norm)
}
pub fn kernel(&self) -> &dyn SurrogateKernel {
self.kernel.as_ref()
}
pub fn kernel_mut(&mut self) -> &mut dyn SurrogateKernel {
self.kernel.as_mut()
}
pub fn n_train(&self) -> usize {
self.x_train.as_ref().map_or(0, |x| x.nrows())
}
fn standardize_y(&self, y: &Array1<f64>) -> Array1<f64> {
y.mapv(|v| (v - self.y_mean) / self.y_std)
}
fn update_model(&mut self) -> OptimizeResult<()> {
let x_train = self
.x_train
.as_ref()
.ok_or_else(|| OptimizeError::ComputationError("No training data".to_string()))?;
let y_train = self
.y_train
.as_ref()
.ok_or_else(|| OptimizeError::ComputationError("No training data".to_string()))?;
let y_std = self.standardize_y(y_train);
let mut k = self.kernel.covariance_matrix(x_train);
let n = k.nrows();
for i in 0..n {
k[[i, i]] += self.config.noise_variance;
}
let l = match cholesky(&k) {
Ok(l) => l,
Err(_) => {
let jitters = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2];
let mut result = Err(OptimizeError::ComputationError(
"Cholesky failed with all jitter levels".to_string(),
));
for &jitter in &jitters {
for i in 0..n {
k[[i, i]] += jitter;
}
match cholesky(&k) {
Ok(l) => {
result = Ok(l);
break;
}
Err(_) => continue,
}
}
result?
}
};
let alpha1 = forward_solve(&l, &y_std)?;
let alpha = backward_solve_transpose(&l, &alpha1)?;
self.l_factor = Some(l);
self.alpha = Some(alpha);
Ok(())
}
fn optimize_hyperparameters(&mut self) -> OptimizeResult<()> {
let x_train = self
.x_train
.as_ref()
.ok_or_else(|| OptimizeError::ComputationError("No training data".to_string()))?
.clone();
let y_train = self
.y_train
.as_ref()
.ok_or_else(|| OptimizeError::ComputationError("No training data".to_string()))?
.clone();
let y_std = self.standardize_y(&y_train);
let n_params = self.kernel.n_params();
if n_params == 0 {
return Ok(());
}
let mut best_params = self.kernel.get_log_params();
let mut best_lml = f64::NEG_INFINITY;
if let Ok(lml) = self.eval_lml_at_params(&best_params, &x_train, &y_std) {
best_lml = lml;
}
let mut rng = scirs2_core::random::rng();
for restart in 0..self.config.n_restarts {
let init_params: Vec<f64> = if restart == 0 {
best_params.clone()
} else {
(0..n_params).map(|_| rng.random_range(-2.0..2.0)).collect()
};
let mut current_params = init_params;
for _iter in 0..self.config.max_opt_iters {
let mut improved = false;
for p in 0..n_params {
let original = current_params[p];
let steps = [0.1, 0.3, 1.0, -0.1, -0.3, -1.0];
let mut best_step_lml =
match self.eval_lml_at_params(¤t_params, &x_train, &y_std) {
Ok(v) => v,
Err(_) => f64::NEG_INFINITY,
};
let mut best_step_val = original;
for &step in &steps {
current_params[p] = original + step;
current_params[p] = current_params[p].clamp(-5.0, 5.0);
if let Ok(lml) = self.eval_lml_at_params(¤t_params, &x_train, &y_std)
{
if lml > best_step_lml {
best_step_lml = lml;
best_step_val = current_params[p];
improved = true;
}
}
}
current_params[p] = best_step_val;
}
if !improved {
break;
}
}
if let Ok(lml) = self.eval_lml_at_params(¤t_params, &x_train, &y_std) {
if lml > best_lml {
best_lml = lml;
best_params = current_params;
}
}
}
self.kernel.set_log_params(&best_params);
Ok(())
}
fn eval_lml_at_params(
&self,
log_params: &[f64],
x_train: &Array2<f64>,
y_std: &Array1<f64>,
) -> OptimizeResult<f64> {
let mut kernel = self.kernel.clone();
kernel.set_log_params(log_params);
let mut k = kernel.covariance_matrix(x_train);
let n = k.nrows();
for i in 0..n {
k[[i, i]] += self.config.noise_variance;
}
let l = cholesky(&k)?;
let alpha1 = forward_solve(&l, y_std)?;
let alpha = backward_solve_transpose(&l, &alpha1)?;
let n_f = n as f64;
let data_fit = -0.5 * y_std.dot(&alpha);
let log_det: f64 = l.diag().iter().map(|&v| v.abs().ln()).sum();
let norm = -0.5 * n_f * (2.0 * std::f64::consts::PI).ln();
Ok(data_fit - log_det + norm)
}
}
fn squared_distance(x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
let mut s = 0.0;
for i in 0..x1.len() {
let d = x1[i] - x2[i];
s += d * d;
}
s
}
fn cholesky(a: &Array2<f64>) -> OptimizeResult<Array2<f64>> {
let n = a.nrows();
if n != a.ncols() {
return Err(OptimizeError::ComputationError(
"Cholesky: matrix must be square".to_string(),
));
}
let mut l = Array2::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut s = 0.0;
for k in 0..j {
s += l[[i, k]] * l[[j, k]];
}
if i == j {
let diag = a[[i, i]] - s;
if diag <= 0.0 {
return Err(OptimizeError::ComputationError(format!(
"Cholesky: matrix not positive-definite (diag[{}] = {:.6e})",
i, diag
)));
}
l[[i, j]] = diag.sqrt();
} else {
l[[i, j]] = (a[[i, j]] - s) / l[[j, j]];
}
}
}
Ok(l)
}
fn forward_solve(l: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
let n = l.nrows();
let mut x = Array1::zeros(n);
for i in 0..n {
let mut s = 0.0;
for j in 0..i {
s += l[[i, j]] * x[j];
}
let diag = l[[i, i]];
if diag.abs() < 1e-15 {
return Err(OptimizeError::ComputationError(
"Forward solve: near-zero diagonal".to_string(),
));
}
x[i] = (b[i] - s) / diag;
}
Ok(x)
}
fn backward_solve_transpose(l: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
let n = l.nrows();
let mut x = Array1::zeros(n);
for i in (0..n).rev() {
let mut s = 0.0;
for j in (i + 1)..n {
s += l[[j, i]] * x[j]; }
let diag = l[[i, i]];
if diag.abs() < 1e-15 {
return Err(OptimizeError::ComputationError(
"Backward solve: near-zero diagonal".to_string(),
));
}
x[i] = (b[i] - s) / diag;
}
Ok(x)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn make_train_data() -> (Array2<f64>, Array1<f64>) {
let x = Array2::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0]).expect("shape ok");
let y = array![0.0, 0.841, 0.909, 0.141, -0.757];
(x, y)
}
#[test]
fn test_rbf_kernel_symmetry() {
let k = RbfKernel::default();
let a = array![1.0, 2.0];
let b = array![3.0, 4.0];
assert!((k.eval(&a.view(), &b.view()) - k.eval(&b.view(), &a.view())).abs() < 1e-14);
}
#[test]
fn test_rbf_kernel_self_covariance() {
let k = RbfKernel::new(1.0, 2.0);
let a = array![1.0, 2.0];
assert!((k.eval(&a.view(), &a.view()) - 2.0).abs() < 1e-14);
}
#[test]
fn test_matern_variants() {
let a = array![0.0];
let b = array![1.0];
for variant in &[
MaternVariant::OneHalf,
MaternVariant::ThreeHalves,
MaternVariant::FiveHalves,
] {
let k = MaternKernel::new(*variant, 1.0, 1.0);
let val = k.eval(&a.view(), &b.view());
assert!(val > 0.0 && val < 1.0, "Matern({:?}) = {}", variant, val);
assert!((k.eval(&a.view(), &a.view()) - 1.0).abs() < 1e-14);
}
}
#[test]
fn test_rational_quadratic_kernel() {
let k = RationalQuadraticKernel::new(1.0, 1.0, 1.0);
let a = array![0.0];
let b = array![1.0];
let val = k.eval(&a.view(), &b.view());
assert!((val - 2.0 / 3.0).abs() < 1e-10);
}
#[test]
fn test_rational_quadratic_approaches_rbf() {
let rbf = RbfKernel::new(1.0, 1.0);
let rq = RationalQuadraticKernel::new(1.0, 1.0, 1e6);
let a = array![0.0, 1.0];
let b = array![2.0, 3.0];
let rbf_val = rbf.eval(&a.view(), &b.view());
let rq_val = rq.eval(&a.view(), &b.view());
assert!(
(rbf_val - rq_val).abs() < 1e-4,
"RBF={}, RQ(alpha=1e6)={}",
rbf_val,
rq_val
);
}
#[test]
fn test_sum_kernel() {
let k1 = Box::new(RbfKernel::new(1.0, 1.0));
let k2 = Box::new(MaternKernel::five_halves(1.0, 0.5));
let sum = SumKernel::new(k1.clone(), k2.clone());
let a = array![1.0];
let b = array![2.0];
let expected = k1.eval(&a.view(), &b.view()) + k2.eval(&a.view(), &b.view());
assert!((sum.eval(&a.view(), &b.view()) - expected).abs() < 1e-14);
}
#[test]
fn test_product_kernel() {
let k1 = Box::new(RbfKernel::new(1.0, 1.0));
let k2 = Box::new(MaternKernel::five_halves(1.0, 0.5));
let prod = ProductKernel::new(k1.clone(), k2.clone());
let a = array![1.0];
let b = array![2.0];
let expected = k1.eval(&a.view(), &b.view()) * k2.eval(&a.view(), &b.view());
assert!((prod.eval(&a.view(), &b.view()) - expected).abs() < 1e-14);
}
#[test]
fn test_gp_fit_predict_basic() {
let (x, y) = make_train_data();
let mut gp = GpSurrogate::new(
Box::new(RbfKernel::default()),
GpSurrogateConfig {
optimize_hyperparams: false,
noise_variance: 1e-4,
..Default::default()
},
);
gp.fit(&x, &y).expect("fit should succeed");
let (mean, var) = gp.predict(&x).expect("predict should succeed");
for i in 0..y.len() {
assert!(
(mean[i] - y[i]).abs() < 0.15,
"mean[{}]={:.4} vs y[{}]={:.4}",
i,
mean[i],
i,
y[i]
);
assert!(
var[i] < 0.5,
"var[{}]={:.4} should be small at training point",
i,
var[i]
);
}
}
#[test]
fn test_gp_uncertainty_away_from_data() {
let (x, y) = make_train_data();
let mut gp = GpSurrogate::new(
Box::new(RbfKernel::default()),
GpSurrogateConfig {
optimize_hyperparams: false,
noise_variance: 1e-4,
..Default::default()
},
);
gp.fit(&x, &y).expect("fit should succeed");
let x_far = Array2::from_shape_vec((1, 1), vec![10.0]).expect("shape ok");
let (_, var_far) = gp.predict(&x_far).expect("predict ok");
let x_near = Array2::from_shape_vec((1, 1), vec![2.0]).expect("shape ok");
let (_, var_near) = gp.predict(&x_near).expect("predict ok");
assert!(
var_far[0] > var_near[0],
"var_far={:.4} should be > var_near={:.4}",
var_far[0],
var_near[0]
);
}
#[test]
fn test_gp_predict_single() {
let (x, y) = make_train_data();
let mut gp = GpSurrogate::default_rbf();
gp.config.optimize_hyperparams = false;
gp.config.noise_variance = 1e-4;
gp.fit(&x, &y).expect("fit ok");
let point = array![1.5];
let (mean, std) = gp.predict_single(&point.view()).expect("predict_single ok");
assert!(mean.is_finite());
assert!(std >= 0.0);
}
#[test]
fn test_gp_log_marginal_likelihood() {
let (x, y) = make_train_data();
let mut gp = GpSurrogate::new(
Box::new(RbfKernel::default()),
GpSurrogateConfig {
optimize_hyperparams: false,
noise_variance: 1e-4,
..Default::default()
},
);
gp.fit(&x, &y).expect("fit ok");
let lml = gp.log_marginal_likelihood().expect("lml ok");
assert!(lml.is_finite(), "LML should be finite, got {}", lml);
}
#[test]
fn test_gp_update_incremental() {
let (x, y) = make_train_data();
let mut gp = GpSurrogate::new(
Box::new(RbfKernel::default()),
GpSurrogateConfig {
optimize_hyperparams: false,
noise_variance: 1e-4,
..Default::default()
},
);
gp.fit(&x, &y).expect("fit ok");
assert_eq!(gp.n_train(), 5);
let x_new = Array2::from_shape_vec((1, 1), vec![5.0]).expect("shape ok");
let y_new = array![-0.959];
gp.update(&x_new, &y_new).expect("update ok");
assert_eq!(gp.n_train(), 6);
}
#[test]
fn test_gp_hyperparameter_optimization() {
let (x, y) = make_train_data();
let mut gp = GpSurrogate::new(
Box::new(RbfKernel::default()),
GpSurrogateConfig {
optimize_hyperparams: true,
n_restarts: 2,
max_opt_iters: 20,
noise_variance: 1e-4,
},
);
gp.fit(&x, &y).expect("fit with optimization ok");
let x_test = Array2::from_shape_vec((1, 1), vec![1.5]).expect("shape ok");
let (mean, var) = gp.predict(&x_test).expect("predict ok");
assert!(mean[0].is_finite());
assert!(var[0].is_finite());
}
#[test]
fn test_cholesky_positive_definite() {
let a = Array2::from_shape_vec((2, 2), vec![4.0, 2.0, 2.0, 3.0]).expect("shape ok");
let l = cholesky(&a).expect("should succeed");
let reconstructed = l.dot(&l.t());
for i in 0..2 {
for j in 0..2 {
assert!(
(reconstructed[[i, j]] - a[[i, j]]).abs() < 1e-10,
"Mismatch at [{},{}]",
i,
j
);
}
}
}
#[test]
fn test_cholesky_non_pd_fails() {
let a = Array2::from_shape_vec((2, 2), vec![1.0, 10.0, 10.0, 1.0]).expect("shape ok");
assert!(cholesky(&a).is_err());
}
#[test]
fn test_kernel_log_params_roundtrip() {
let mut k = RbfKernel::new(2.5, 0.3);
let params = k.get_log_params();
k.set_log_params(¶ms);
assert!((k.length_scale - 2.5).abs() < 1e-10);
assert!((k.signal_variance - 0.3).abs() < 1e-10);
}
#[test]
fn test_matern_kernel_covariance_matrix() {
let k = MaternKernel::three_halves(1.0, 1.0);
let x = Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("shape ok");
let cov = k.covariance_matrix(&x);
assert_eq!(cov.nrows(), 3);
assert_eq!(cov.ncols(), 3);
for i in 0..3 {
for j in 0..3 {
assert!(
(cov[[i, j]] - cov[[j, i]]).abs() < 1e-14,
"Not symmetric at [{},{}]",
i,
j
);
}
}
for i in 0..3 {
assert!(cov[[i, i]] > 0.0);
}
}
#[test]
fn test_gp_multidimensional() {
let x = Array2::from_shape_vec(
(6, 2),
vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, -1.0, 0.5, 0.5],
)
.expect("shape ok");
let y = array![0.0, 1.0, 1.0, 1.0, 1.0, 0.5];
let mut gp = GpSurrogate::new(
Box::new(RbfKernel::default()),
GpSurrogateConfig {
optimize_hyperparams: false,
noise_variance: 1e-4,
..Default::default()
},
);
gp.fit(&x, &y).expect("fit ok");
let x_test = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).expect("shape ok");
let (mean, _) = gp.predict(&x_test).expect("predict ok");
assert!(
mean[0].abs() < 0.3,
"Prediction at origin should be close to 0, got {}",
mean[0]
);
}
}