#[derive(Debug, Clone, PartialEq)]
pub enum BaseKernel {
Rbf { length_scale: f64 },
Matern52 { length_scale: f64 },
Periodic { period: f64, length_scale: f64 },
Linear { variance: f64 },
WhiteNoise { variance: f64 },
}
impl BaseKernel {
pub fn eval(&self, x1: f64, x2: f64) -> f64 {
match self {
BaseKernel::Rbf { length_scale } => {
let ell = length_scale.max(1e-10);
let d = x1 - x2;
(-d * d / (2.0 * ell * ell)).exp()
}
BaseKernel::Matern52 { length_scale } => {
let ell = length_scale.max(1e-10);
let r = (x1 - x2).abs();
let s = 5.0_f64.sqrt() * r / ell;
(1.0 + s + s * s / 3.0) * (-s).exp()
}
BaseKernel::Periodic {
period,
length_scale,
} => {
let p = period.max(1e-10);
let ell = length_scale.max(1e-10);
let d = (x1 - x2).abs();
let sin_val = (std::f64::consts::PI * d / p).sin();
(-2.0 * sin_val * sin_val / (ell * ell)).exp()
}
BaseKernel::Linear { variance } => variance * x1 * x2,
BaseKernel::WhiteNoise { variance } => {
if (x1 - x2).abs() < 1e-12 {
*variance
} else {
0.0
}
}
}
}
pub fn name(&self) -> &'static str {
match self {
BaseKernel::Rbf { .. } => "RBF",
BaseKernel::Matern52 { .. } => "Matern52",
BaseKernel::Periodic { .. } => "Periodic",
BaseKernel::Linear { .. } => "Linear",
BaseKernel::WhiteNoise { .. } => "WhiteNoise",
}
}
pub fn with_length_scale(&self, new_ell: f64) -> Self {
match self {
BaseKernel::Rbf { .. } => BaseKernel::Rbf {
length_scale: new_ell,
},
BaseKernel::Matern52 { .. } => BaseKernel::Matern52 {
length_scale: new_ell,
},
BaseKernel::Periodic { period, .. } => BaseKernel::Periodic {
period: *period,
length_scale: new_ell,
},
BaseKernel::Linear { .. } => BaseKernel::Linear { variance: new_ell },
BaseKernel::WhiteNoise { .. } => BaseKernel::WhiteNoise { variance: new_ell },
}
}
}
#[derive(Debug, Clone)]
pub enum KernelExpr {
Base(BaseKernel),
Sum(Box<KernelExpr>, Box<KernelExpr>),
Product(Box<KernelExpr>, Box<KernelExpr>),
}
impl KernelExpr {
pub fn eval(&self, x1: f64, x2: f64) -> f64 {
match self {
KernelExpr::Base(k) => k.eval(x1, x2),
KernelExpr::Sum(a, b) => a.eval(x1, x2) + b.eval(x1, x2),
KernelExpr::Product(a, b) => a.eval(x1, x2) * b.eval(x1, x2),
}
}
pub fn description(&self) -> String {
match self {
KernelExpr::Base(k) => k.name().to_string(),
KernelExpr::Sum(a, b) => format!("{} + {}", a.description(), b.description()),
KernelExpr::Product(a, b) => {
let ad = a.description();
let bd = b.description();
let left = if matches!(a.as_ref(), KernelExpr::Sum(_, _)) {
format!("({ad})")
} else {
ad
};
let right = if matches!(b.as_ref(), KernelExpr::Sum(_, _)) {
format!("({bd})")
} else {
bd
};
format!("{left} × {right}")
}
}
}
pub fn depth(&self) -> usize {
match self {
KernelExpr::Base(_) => 0,
KernelExpr::Sum(a, b) | KernelExpr::Product(a, b) => 1 + a.depth().max(b.depth()),
}
}
}
pub fn base_kernels() -> Vec<BaseKernel> {
vec![
BaseKernel::Rbf { length_scale: 1.0 },
BaseKernel::Matern52 { length_scale: 1.0 },
BaseKernel::Periodic {
period: 1.0,
length_scale: 1.0,
},
BaseKernel::Linear { variance: 1.0 },
BaseKernel::WhiteNoise { variance: 0.1 },
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn kernel_expr_eval_rbf_is_symmetric() {
let k = KernelExpr::Base(BaseKernel::Rbf { length_scale: 1.0 });
let a = k.eval(0.5, 1.5);
let b = k.eval(1.5, 0.5);
assert!(
(a - b).abs() < 1e-12,
"RBF must be symmetric: k(0.5,1.5)={a}, k(1.5,0.5)={b}"
);
}
#[test]
fn kernel_expr_sum_is_additive() {
let k1 = KernelExpr::Base(BaseKernel::Rbf { length_scale: 1.0 });
let k2 = KernelExpr::Base(BaseKernel::Matern52 { length_scale: 1.0 });
let sum = KernelExpr::Sum(Box::new(k1.clone()), Box::new(k2.clone()));
let x = 0.3;
let y = 1.2;
let expected = k1.eval(x, y) + k2.eval(x, y);
let actual = sum.eval(x, y);
assert!(
(actual - expected).abs() < 1e-12,
"Sum kernel must equal sum of parts: {actual} vs {expected}"
);
}
#[test]
fn kernel_expr_product_is_multiplicative() {
let k1 = KernelExpr::Base(BaseKernel::Rbf { length_scale: 0.8 });
let k2 = KernelExpr::Base(BaseKernel::Linear { variance: 2.0 });
let prod = KernelExpr::Product(Box::new(k1.clone()), Box::new(k2.clone()));
let x = 1.5;
let y = 0.5;
let expected = k1.eval(x, y) * k2.eval(x, y);
let actual = prod.eval(x, y);
assert!(
(actual - expected).abs() < 1e-12,
"Product kernel must equal product of parts: {actual} vs {expected}"
);
}
#[test]
fn kernel_description_nonempty() {
let k = KernelExpr::Sum(
Box::new(KernelExpr::Base(BaseKernel::Rbf { length_scale: 1.0 })),
Box::new(KernelExpr::Product(
Box::new(KernelExpr::Base(BaseKernel::Periodic {
period: 1.0,
length_scale: 0.5,
})),
Box::new(KernelExpr::Base(BaseKernel::Linear { variance: 1.0 })),
)),
);
let desc = k.description();
assert!(!desc.is_empty(), "description must not be empty");
assert!(desc.contains("RBF"), "description should contain RBF");
assert!(
desc.contains("Periodic"),
"description should contain Periodic"
);
}
}