pub mod kernel;
pub mod search;
pub use kernel::{BaseKernel, KernelExpr};
use crate::error::{InterpolateError, InterpolateResult};
use search::{build_cross_kernel, gp_fit, gp_predict, search_kernels};
const DEFAULT_PERIOD_GRID: &[f64] = &[0.1, 0.2, 0.5, 1.0, 2.0, 5.0];
#[derive(Debug, Clone)]
pub struct AutoKernelGpConfig {
pub max_depth: usize,
pub n_restarts: usize,
pub noise_variance: f64,
pub cv_folds: usize,
pub seed: u64,
}
impl Default for AutoKernelGpConfig {
fn default() -> Self {
Self {
max_depth: 2,
n_restarts: 3,
noise_variance: 0.01,
cv_folds: 5,
seed: 42,
}
}
}
pub struct AutoKernelGp {
best_kernel: KernelExpr,
best_cv_score: f64,
alpha: Vec<f64>,
train_x: Vec<f64>,
train_y: Vec<f64>,
config: AutoKernelGpConfig,
search_results: Vec<(String, f64)>,
is_fitted: bool,
}
impl AutoKernelGp {
pub fn new(config: AutoKernelGpConfig) -> Self {
Self {
best_kernel: KernelExpr::Base(BaseKernel::Rbf { length_scale: 1.0 }),
best_cv_score: f64::MAX,
alpha: Vec::new(),
train_x: Vec::new(),
train_y: Vec::new(),
config,
search_results: Vec::new(),
is_fitted: false,
}
}
pub fn fit(&mut self, x: &[f64], y: &[f64]) -> InterpolateResult<()> {
if x.len() != y.len() {
return Err(InterpolateError::DimensionMismatch(format!(
"x length {} ≠ y length {}",
x.len(),
y.len()
)));
}
if x.len() < 2 {
return Err(InterpolateError::InvalidInput {
message: "at least 2 training points are required".to_string(),
});
}
let (ranked, best_kernel) = search_kernels(
x,
y,
self.config.max_depth,
self.config.n_restarts,
self.config.noise_variance,
self.config.cv_folds,
DEFAULT_PERIOD_GRID,
);
self.search_results = ranked;
self.best_cv_score = self
.search_results
.first()
.map(|(_, s)| *s)
.unwrap_or(f64::MAX);
self.best_kernel = best_kernel;
self.alpha =
gp_fit(&self.best_kernel, x, y, self.config.noise_variance).ok_or_else(|| {
InterpolateError::ComputationError(
"Cholesky failed for selected kernel on full training set".to_string(),
)
})?;
self.train_x = x.to_vec();
self.train_y = y.to_vec();
self.is_fitted = true;
Ok(())
}
pub fn predict(&self, x_new: &[f64]) -> InterpolateResult<Vec<f64>> {
if !self.is_fitted {
return Err(InterpolateError::InvalidState(
"GP must be fitted before prediction".to_string(),
));
}
if x_new.is_empty() {
return Ok(Vec::new());
}
Ok(gp_predict(
&self.best_kernel,
&self.train_x,
&self.alpha,
x_new,
))
}
pub fn selected_kernel_description(&self) -> String {
self.best_kernel.description()
}
pub fn best_cv_score(&self) -> f64 {
self.best_cv_score
}
pub fn kernel_search_results(&self) -> &[(String, f64)] {
&self.search_results
}
pub fn kernel(&self) -> &KernelExpr {
&self.best_kernel
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sin_data(n: usize) -> (Vec<f64>, Vec<f64>) {
let x: Vec<f64> = (0..n)
.map(|i| i as f64 * 2.0 * std::f64::consts::PI / n as f64)
.collect();
let y: Vec<f64> = x.iter().map(|&xi| xi.sin()).collect();
(x, y)
}
#[test]
fn auto_kernel_gp_fits_sin_data() {
let (x, y) = sin_data(20);
let config = AutoKernelGpConfig {
max_depth: 1,
cv_folds: 3,
n_restarts: 1,
..Default::default()
};
let mut gp = AutoKernelGp::new(config);
gp.fit(&x, &y).expect("fit: should succeed on sin data");
let preds = gp.predict(&x).expect("predict: should succeed");
assert_eq!(preds.len(), x.len());
let mse: f64 = preds
.iter()
.zip(y.iter())
.map(|(p, t)| (p - t).powi(2))
.sum::<f64>()
/ x.len() as f64;
assert!(
mse < 0.5,
"MSE at training points should be small, got {mse}"
);
}
#[test]
fn auto_kernel_gp_predict_shape_correct() {
let (x, y) = sin_data(15);
let config = AutoKernelGpConfig {
max_depth: 0, cv_folds: 3,
n_restarts: 1,
..Default::default()
};
let mut gp = AutoKernelGp::new(config);
gp.fit(&x, &y).expect("fit ok");
let x_new = vec![0.1, 0.5, 1.0, 2.0, 4.0];
let preds = gp.predict(&x_new).expect("predict ok");
assert_eq!(
preds.len(),
x_new.len(),
"prediction shape must match query length"
);
}
#[test]
fn auto_kernel_gp_description_is_nonempty() {
let (x, y) = sin_data(12);
let config = AutoKernelGpConfig {
max_depth: 1,
cv_folds: 3,
n_restarts: 1,
..Default::default()
};
let mut gp = AutoKernelGp::new(config);
gp.fit(&x, &y).expect("fit ok");
let desc = gp.selected_kernel_description();
assert!(
!desc.is_empty(),
"kernel description must not be empty: '{desc}'"
);
}
#[test]
fn auto_kernel_gp_selects_periodic_kernel_for_sin() {
let (x, y) = sin_data(20);
let config = AutoKernelGpConfig {
max_depth: 1,
cv_folds: 3,
n_restarts: 2,
noise_variance: 1e-4,
..Default::default()
};
let mut gp = AutoKernelGp::new(config);
gp.fit(&x, &y).expect("fit ok");
assert!(
gp.best_cv_score().is_finite(),
"CV score must be finite, got {}",
gp.best_cv_score()
);
assert!(
!gp.kernel_search_results().is_empty(),
"search results must not be empty"
);
}
#[test]
fn auto_kernel_gp_cv_score_improves_with_depth() {
let (x, y) = sin_data(18);
let mut scores = Vec::new();
for max_depth in [0usize, 1, 2] {
let config = AutoKernelGpConfig {
max_depth,
cv_folds: 3,
n_restarts: 1,
noise_variance: 1e-3,
..Default::default()
};
let mut gp = AutoKernelGp::new(config);
gp.fit(&x, &y).expect("fit ok");
scores.push(gp.best_cv_score());
}
assert!(
scores[1] <= scores[0] * 1.1,
"depth-1 score {} should be ≤ depth-0 score {} (with 10% tolerance)",
scores[1],
scores[0]
);
assert!(
scores[2] <= scores[1] * 1.1,
"depth-2 score {} should be ≤ depth-1 score {} (with 10% tolerance)",
scores[2],
scores[1]
);
}
#[test]
fn auto_kernel_gp_predict_before_fit_errors() {
let gp = AutoKernelGp::new(AutoKernelGpConfig::default());
let result = gp.predict(&[0.5, 1.0]);
assert!(result.is_err(), "predict before fit should return an error");
}
}