#![allow(clippy::needless_range_loop)]
use crate::error::{KernelError, Result};
use crate::types::Kernel;
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct SpectralComponent {
pub weight: f64,
pub mean: Vec<f64>,
pub variance: Vec<f64>,
}
impl SpectralComponent {
pub fn new(weight: f64, mean: Vec<f64>, variance: Vec<f64>) -> Result<Self> {
if weight <= 0.0 {
return Err(KernelError::InvalidParameter {
parameter: "weight".to_string(),
value: weight.to_string(),
reason: "weight must be positive".to_string(),
});
}
if mean.len() != variance.len() {
return Err(KernelError::InvalidParameter {
parameter: "mean/variance".to_string(),
value: format!(
"mean.len()={}, variance.len()={}",
mean.len(),
variance.len()
),
reason: "mean and variance must have same length".to_string(),
});
}
if mean.is_empty() {
return Err(KernelError::InvalidParameter {
parameter: "mean".to_string(),
value: "[]".to_string(),
reason: "must have at least one dimension".to_string(),
});
}
for (i, &v) in variance.iter().enumerate() {
if v <= 0.0 {
return Err(KernelError::InvalidParameter {
parameter: format!("variance[{}]", i),
value: v.to_string(),
reason: "variance must be positive".to_string(),
});
}
}
Ok(Self {
weight,
mean,
variance,
})
}
pub fn new_1d(weight: f64, mean: f64, variance: f64) -> Result<Self> {
Self::new(weight, vec![mean], vec![variance])
}
pub fn ndim(&self) -> usize {
self.mean.len()
}
}
#[derive(Debug, Clone)]
pub struct SpectralMixtureKernel {
components: Vec<SpectralComponent>,
ndim: usize,
}
impl SpectralMixtureKernel {
pub fn new(components: Vec<SpectralComponent>) -> Result<Self> {
if components.is_empty() {
return Err(KernelError::InvalidParameter {
parameter: "components".to_string(),
value: "[]".to_string(),
reason: "must have at least one component".to_string(),
});
}
let ndim = components[0].ndim();
for (i, comp) in components.iter().enumerate() {
if comp.ndim() != ndim {
return Err(KernelError::InvalidParameter {
parameter: format!("components[{}]", i),
value: format!("ndim={}", comp.ndim()),
reason: format!("all components must have {} dimensions", ndim),
});
}
}
Ok(Self { components, ndim })
}
pub fn new_1d(frequencies: Vec<(f64, f64, f64)>) -> Result<Self> {
let components: Result<Vec<_>> = frequencies
.into_iter()
.map(|(w, m, v)| SpectralComponent::new_1d(w, m, v))
.collect();
Self::new(components?)
}
pub fn components(&self) -> &[SpectralComponent] {
&self.components
}
pub fn num_components(&self) -> usize {
self.components.len()
}
pub fn ndim(&self) -> usize {
self.ndim
}
fn compute_component(&self, comp: &SpectralComponent, tau: &[f64]) -> f64 {
let mut exp_term = 0.0;
let mut cos_term = 0.0;
for d in 0..self.ndim {
let tau_d = tau[d];
exp_term += tau_d * tau_d * comp.variance[d];
cos_term += tau_d * comp.mean[d];
}
comp.weight * (-2.0 * PI * PI * exp_term).exp() * (2.0 * PI * cos_term).cos()
}
}
impl Kernel for SpectralMixtureKernel {
fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
if x.len() != self.ndim {
return Err(KernelError::DimensionMismatch {
expected: vec![self.ndim],
got: vec![x.len()],
context: "Spectral Mixture kernel".to_string(),
});
}
if y.len() != self.ndim {
return Err(KernelError::DimensionMismatch {
expected: vec![self.ndim],
got: vec![y.len()],
context: "Spectral Mixture kernel".to_string(),
});
}
let tau: Vec<f64> = x.iter().zip(y.iter()).map(|(a, b)| a - b).collect();
let mut result = 0.0;
for comp in &self.components {
result += self.compute_component(comp, &tau);
}
Ok(result)
}
fn name(&self) -> &str {
"SpectralMixture"
}
}
#[derive(Debug, Clone)]
pub struct ExpSineSquaredKernel {
period: f64,
length_scale: f64,
}
impl ExpSineSquaredKernel {
pub fn new(period: f64, length_scale: f64) -> Result<Self> {
if period <= 0.0 {
return Err(KernelError::InvalidParameter {
parameter: "period".to_string(),
value: period.to_string(),
reason: "period must be positive".to_string(),
});
}
if length_scale <= 0.0 {
return Err(KernelError::InvalidParameter {
parameter: "length_scale".to_string(),
value: length_scale.to_string(),
reason: "length_scale must be positive".to_string(),
});
}
Ok(Self {
period,
length_scale,
})
}
pub fn period(&self) -> f64 {
self.period
}
pub fn length_scale(&self) -> f64 {
self.length_scale
}
}
impl Kernel for ExpSineSquaredKernel {
fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
if x.len() != y.len() {
return Err(KernelError::DimensionMismatch {
expected: vec![x.len()],
got: vec![y.len()],
context: "ExpSineSquared kernel".to_string(),
});
}
let dist: f64 = x
.iter()
.zip(y.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>()
.sqrt();
let sin_term = (PI * dist / self.period).sin();
let result = (-2.0 * sin_term * sin_term / (self.length_scale * self.length_scale)).exp();
Ok(result)
}
fn name(&self) -> &str {
"ExpSineSquared"
}
}
#[derive(Debug, Clone)]
pub struct LocallyPeriodicKernel {
period: f64,
periodic_length_scale: f64,
rbf_length_scale: f64,
}
impl LocallyPeriodicKernel {
pub fn new(period: f64, periodic_length_scale: f64, rbf_length_scale: f64) -> Result<Self> {
if period <= 0.0 {
return Err(KernelError::InvalidParameter {
parameter: "period".to_string(),
value: period.to_string(),
reason: "period must be positive".to_string(),
});
}
if periodic_length_scale <= 0.0 {
return Err(KernelError::InvalidParameter {
parameter: "periodic_length_scale".to_string(),
value: periodic_length_scale.to_string(),
reason: "periodic_length_scale must be positive".to_string(),
});
}
if rbf_length_scale <= 0.0 {
return Err(KernelError::InvalidParameter {
parameter: "rbf_length_scale".to_string(),
value: rbf_length_scale.to_string(),
reason: "rbf_length_scale must be positive".to_string(),
});
}
Ok(Self {
period,
periodic_length_scale,
rbf_length_scale,
})
}
pub fn period(&self) -> f64 {
self.period
}
pub fn periodic_length_scale(&self) -> f64 {
self.periodic_length_scale
}
pub fn rbf_length_scale(&self) -> f64 {
self.rbf_length_scale
}
}
impl Kernel for LocallyPeriodicKernel {
fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
if x.len() != y.len() {
return Err(KernelError::DimensionMismatch {
expected: vec![x.len()],
got: vec![y.len()],
context: "Locally Periodic kernel".to_string(),
});
}
let sq_dist: f64 = x.iter().zip(y.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
let dist = sq_dist.sqrt();
let rbf = (-0.5 * sq_dist / (self.rbf_length_scale * self.rbf_length_scale)).exp();
let sin_term = (PI * dist / self.period).sin();
let periodic = (-2.0 * sin_term * sin_term
/ (self.periodic_length_scale * self.periodic_length_scale))
.exp();
Ok(rbf * periodic)
}
fn name(&self) -> &str {
"LocallyPeriodic"
}
}
#[derive(Debug, Clone)]
pub struct RbfLinearKernel {
length_scale: f64,
variance: f64,
}
impl RbfLinearKernel {
pub fn new(length_scale: f64, variance: f64) -> Result<Self> {
if length_scale <= 0.0 {
return Err(KernelError::InvalidParameter {
parameter: "length_scale".to_string(),
value: length_scale.to_string(),
reason: "length_scale must be positive".to_string(),
});
}
if variance <= 0.0 {
return Err(KernelError::InvalidParameter {
parameter: "variance".to_string(),
value: variance.to_string(),
reason: "variance must be positive".to_string(),
});
}
Ok(Self {
length_scale,
variance,
})
}
pub fn length_scale(&self) -> f64 {
self.length_scale
}
pub fn variance(&self) -> f64 {
self.variance
}
}
impl Kernel for RbfLinearKernel {
fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
if x.len() != y.len() {
return Err(KernelError::DimensionMismatch {
expected: vec![x.len()],
got: vec![y.len()],
context: "RBF-Linear kernel".to_string(),
});
}
let sq_dist: f64 = x.iter().zip(y.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
let rbf = (-0.5 * sq_dist / (self.length_scale * self.length_scale)).exp();
let dot: f64 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
let linear = self.variance * dot;
Ok(rbf * linear)
}
fn name(&self) -> &str {
"RBF-Linear"
}
fn is_psd(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spectral_component_1d() {
let comp = SpectralComponent::new_1d(1.0, 0.5, 0.1).expect("unwrap");
assert!((comp.weight - 1.0).abs() < 1e-10);
assert_eq!(comp.ndim(), 1);
}
#[test]
fn test_spectral_component_multidim() {
let comp = SpectralComponent::new(1.0, vec![0.1, 0.2], vec![0.01, 0.02]).expect("unwrap");
assert_eq!(comp.ndim(), 2);
}
#[test]
fn test_spectral_component_invalid_weight() {
assert!(SpectralComponent::new_1d(0.0, 0.5, 0.1).is_err());
assert!(SpectralComponent::new_1d(-1.0, 0.5, 0.1).is_err());
}
#[test]
fn test_spectral_component_invalid_variance() {
assert!(SpectralComponent::new_1d(1.0, 0.5, 0.0).is_err());
assert!(SpectralComponent::new_1d(1.0, 0.5, -0.1).is_err());
}
#[test]
fn test_spectral_component_mismatched_dims() {
assert!(SpectralComponent::new(1.0, vec![0.1, 0.2], vec![0.01]).is_err());
}
#[test]
fn test_spectral_mixture_kernel_single_component() {
let components = vec![SpectralComponent::new_1d(1.0, 0.0, 0.1).expect("unwrap")];
let kernel = SpectralMixtureKernel::new(components).expect("unwrap");
assert_eq!(kernel.name(), "SpectralMixture");
assert_eq!(kernel.num_components(), 1);
let x = vec![0.0];
let y = vec![0.0];
let sim = kernel.compute(&x, &y).expect("unwrap");
assert!((sim - 1.0).abs() < 1e-10);
}
#[test]
fn test_spectral_mixture_kernel_multiple_components() {
let components = vec![
SpectralComponent::new_1d(0.5, 0.1, 0.01).expect("unwrap"),
SpectralComponent::new_1d(0.5, 1.0, 0.1).expect("unwrap"),
];
let kernel = SpectralMixtureKernel::new(components).expect("unwrap");
assert_eq!(kernel.num_components(), 2);
let x = vec![0.0];
let y = vec![0.0];
let sim = kernel.compute(&x, &y).expect("unwrap");
assert!((sim - 1.0).abs() < 1e-10);
}
#[test]
fn test_spectral_mixture_kernel_1d_convenience() {
let kernel =
SpectralMixtureKernel::new_1d(vec![(1.0, 0.5, 0.1), (0.5, 1.0, 0.05)]).expect("unwrap");
assert_eq!(kernel.num_components(), 2);
assert_eq!(kernel.ndim(), 1);
}
#[test]
fn test_spectral_mixture_kernel_periodicity() {
let freq = 0.25; let components = vec![SpectralComponent::new_1d(1.0, freq, 0.0001).expect("unwrap")];
let kernel = SpectralMixtureKernel::new(components).expect("unwrap");
let x = vec![0.0];
let y_period = vec![4.0]; let y_half = vec![2.0];
let sim_period = kernel.compute(&x, &y_period).expect("unwrap");
let sim_half = kernel.compute(&x, &y_half).expect("unwrap");
assert!(
sim_period > sim_half,
"Period value {} should exceed half-period value {}",
sim_period,
sim_half
);
assert!(
sim_period > 0.5,
"Period value {} should be > 0.5",
sim_period
);
}
#[test]
fn test_spectral_mixture_kernel_symmetry() {
let components = vec![SpectralComponent::new_1d(1.0, 0.5, 0.1).expect("unwrap")];
let kernel = SpectralMixtureKernel::new(components).expect("unwrap");
let x = vec![1.0];
let y = vec![2.0];
let k_xy = kernel.compute(&x, &y).expect("unwrap");
let k_yx = kernel.compute(&y, &x).expect("unwrap");
assert!((k_xy - k_yx).abs() < 1e-10);
}
#[test]
fn test_spectral_mixture_kernel_empty_components() {
let result = SpectralMixtureKernel::new(vec![]);
assert!(result.is_err());
}
#[test]
fn test_spectral_mixture_kernel_dimension_mismatch() {
let components = vec![SpectralComponent::new_1d(1.0, 0.5, 0.1).expect("unwrap")];
let kernel = SpectralMixtureKernel::new(components).expect("unwrap");
let x = vec![0.0, 0.0]; let y = vec![0.0];
assert!(kernel.compute(&x, &y).is_err());
}
#[test]
fn test_exp_sine_squared_kernel_basic() {
let kernel = ExpSineSquaredKernel::new(10.0, 1.0).expect("unwrap");
assert_eq!(kernel.name(), "ExpSineSquared");
let x = vec![0.0];
let y = vec![0.0];
let sim = kernel.compute(&x, &y).expect("unwrap");
assert!((sim - 1.0).abs() < 1e-10);
}
#[test]
fn test_exp_sine_squared_kernel_periodicity() {
let period = 10.0;
let kernel = ExpSineSquaredKernel::new(period, 1.0).expect("unwrap");
let x = vec![0.0];
let y1 = vec![period]; let y2 = vec![2.0 * period];
let sim1 = kernel.compute(&x, &y1).expect("unwrap");
let sim2 = kernel.compute(&x, &y2).expect("unwrap");
assert!(sim1 > 0.99);
assert!(sim2 > 0.99);
}
#[test]
fn test_exp_sine_squared_kernel_invalid() {
assert!(ExpSineSquaredKernel::new(0.0, 1.0).is_err());
assert!(ExpSineSquaredKernel::new(10.0, 0.0).is_err());
}
#[test]
fn test_locally_periodic_kernel_basic() {
let kernel = LocallyPeriodicKernel::new(10.0, 1.0, 100.0).expect("unwrap");
assert_eq!(kernel.name(), "LocallyPeriodic");
let x = vec![0.0];
let sim = kernel.compute(&x, &x).expect("unwrap");
assert!((sim - 1.0).abs() < 1e-10);
}
#[test]
fn test_locally_periodic_kernel_decay() {
let kernel = LocallyPeriodicKernel::new(10.0, 1.0, 5.0).expect("unwrap");
let x = vec![0.0];
let y_near = vec![10.0]; let y_far = vec![100.0];
let sim_near = kernel.compute(&x, &y_near).expect("unwrap");
let sim_far = kernel.compute(&x, &y_far).expect("unwrap");
assert!(sim_near > sim_far);
}
#[test]
fn test_locally_periodic_kernel_invalid() {
assert!(LocallyPeriodicKernel::new(0.0, 1.0, 1.0).is_err());
assert!(LocallyPeriodicKernel::new(10.0, 0.0, 1.0).is_err());
assert!(LocallyPeriodicKernel::new(10.0, 1.0, 0.0).is_err());
}
#[test]
fn test_rbf_linear_kernel_basic() {
let kernel = RbfLinearKernel::new(1.0, 1.0).expect("unwrap");
assert_eq!(kernel.name(), "RBF-Linear");
assert!(kernel.is_psd());
let x = vec![1.0, 2.0];
let y = vec![1.0, 2.0];
let sim = kernel.compute(&x, &y).expect("unwrap");
assert!((sim - 5.0).abs() < 1e-10);
}
#[test]
fn test_rbf_linear_kernel_symmetry() {
let kernel = RbfLinearKernel::new(1.0, 1.0).expect("unwrap");
let x = vec![1.0, 2.0];
let y = vec![3.0, 4.0];
let k_xy = kernel.compute(&x, &y).expect("unwrap");
let k_yx = kernel.compute(&y, &x).expect("unwrap");
assert!((k_xy - k_yx).abs() < 1e-10);
}
#[test]
fn test_rbf_linear_kernel_invalid() {
assert!(RbfLinearKernel::new(0.0, 1.0).is_err());
assert!(RbfLinearKernel::new(1.0, 0.0).is_err());
}
#[test]
fn test_spectral_kernels_symmetry() {
let kernels: Vec<Box<dyn Kernel>> = vec![
Box::new(
SpectralMixtureKernel::new(vec![
SpectralComponent::new_1d(1.0, 0.5, 0.1).expect("unwrap")
])
.expect("unwrap"),
),
Box::new(ExpSineSquaredKernel::new(10.0, 1.0).expect("unwrap")),
Box::new(LocallyPeriodicKernel::new(10.0, 1.0, 10.0).expect("unwrap")),
Box::new(RbfLinearKernel::new(1.0, 1.0).expect("unwrap")),
];
let x = vec![1.0];
let y = vec![2.0];
for kernel in kernels {
let k_xy = kernel.compute(&x, &y).expect("unwrap");
let k_yx = kernel.compute(&y, &x).expect("unwrap");
assert!(
(k_xy - k_yx).abs() < 1e-10,
"{} not symmetric",
kernel.name()
);
}
}
}