use crate::error::{KernelError, Result};
use crate::types::Kernel;
use std::fmt;
use std::sync::Arc;
#[derive(Clone)]
pub enum KernelExpr {
Base(Arc<dyn Kernel>),
Scaled { kernel: Box<KernelExpr>, scale: f64 },
Sum {
left: Box<KernelExpr>,
right: Box<KernelExpr>,
},
Product {
left: Box<KernelExpr>,
right: Box<KernelExpr>,
},
Power {
kernel: Box<KernelExpr>,
exponent: u32,
},
}
#[allow(clippy::should_implement_trait)]
impl KernelExpr {
pub fn base(kernel: Arc<dyn Kernel>) -> Self {
KernelExpr::Base(kernel)
}
pub fn scale(self, scale: f64) -> Result<Self> {
if !scale.is_finite() || scale < 0.0 {
return Err(KernelError::InvalidParameter {
parameter: "scale".to_string(),
value: scale.to_string(),
reason: "must be finite and non-negative".to_string(),
});
}
Ok(KernelExpr::Scaled {
kernel: Box::new(self),
scale,
})
}
pub fn add(self, other: KernelExpr) -> Self {
KernelExpr::Sum {
left: Box::new(self),
right: Box::new(other),
}
}
pub fn multiply(self, other: KernelExpr) -> Self {
KernelExpr::Product {
left: Box::new(self),
right: Box::new(other),
}
}
pub fn power(self, exponent: u32) -> Result<Self> {
if exponent == 0 {
return Err(KernelError::InvalidParameter {
parameter: "exponent".to_string(),
value: exponent.to_string(),
reason: "must be positive".to_string(),
});
}
Ok(KernelExpr::Power {
kernel: Box::new(self),
exponent,
})
}
pub fn build(self) -> Box<dyn Kernel> {
Box::new(SymbolicKernel { expr: self })
}
#[allow(clippy::redundant_guards)]
pub fn simplify(self) -> Self {
match self {
KernelExpr::Scaled { kernel, scale } if (scale - 1.0).abs() < 1e-10 => {
kernel.simplify()
}
KernelExpr::Scaled { kernel, scale } if scale.abs() < 1e-10 => {
KernelExpr::Scaled {
kernel: Box::new(kernel.simplify()),
scale,
}
}
KernelExpr::Scaled { kernel, scale } => KernelExpr::Scaled {
kernel: Box::new(kernel.simplify()),
scale,
},
KernelExpr::Sum { left, right } => KernelExpr::Sum {
left: Box::new(left.simplify()),
right: Box::new(right.simplify()),
},
KernelExpr::Product { left, right } => KernelExpr::Product {
left: Box::new(left.simplify()),
right: Box::new(right.simplify()),
},
KernelExpr::Power { kernel, exponent } if exponent == 1 => kernel.simplify(),
KernelExpr::Power { kernel, exponent } => KernelExpr::Power {
kernel: Box::new(kernel.simplify()),
exponent,
},
base @ KernelExpr::Base(_) => base,
}
}
}
impl fmt::Debug for KernelExpr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
KernelExpr::Base(k) => write!(f, "{}", k.name()),
KernelExpr::Scaled { kernel, scale } => write!(f, "{:.2}*{:?}", scale, kernel),
KernelExpr::Sum { left, right } => write!(f, "({:?} + {:?})", left, right),
KernelExpr::Product { left, right } => write!(f, "({:?} * {:?})", left, right),
KernelExpr::Power { kernel, exponent } => write!(f, "{:?}^{}", kernel, exponent),
}
}
}
#[derive(Clone)]
pub struct SymbolicKernel {
expr: KernelExpr,
}
impl SymbolicKernel {
pub fn new(expr: KernelExpr) -> Self {
Self { expr }
}
pub fn expression(&self) -> &KernelExpr {
&self.expr
}
pub fn simplify(self) -> Self {
Self {
expr: self.expr.simplify(),
}
}
fn eval(&self, x: &[f64], y: &[f64]) -> Result<f64> {
self.eval_expr(&self.expr, x, y)
}
#[allow(clippy::only_used_in_recursion)]
fn eval_expr(&self, expr: &KernelExpr, x: &[f64], y: &[f64]) -> Result<f64> {
match expr {
KernelExpr::Base(kernel) => kernel.compute(x, y),
KernelExpr::Scaled { kernel, scale } => {
let value = self.eval_expr(kernel, x, y)?;
Ok(scale * value)
}
KernelExpr::Sum { left, right } => {
let left_val = self.eval_expr(left, x, y)?;
let right_val = self.eval_expr(right, x, y)?;
Ok(left_val + right_val)
}
KernelExpr::Product { left, right } => {
let left_val = self.eval_expr(left, x, y)?;
let right_val = self.eval_expr(right, x, y)?;
Ok(left_val * right_val)
}
KernelExpr::Power { kernel, exponent } => {
let value = self.eval_expr(kernel, x, y)?;
Ok(value.powi(*exponent as i32))
}
}
}
}
impl Kernel for SymbolicKernel {
fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
self.eval(x, y)
}
fn name(&self) -> &str {
"Symbolic"
}
fn is_psd(&self) -> bool {
check_psd(&self.expr)
}
}
fn check_psd(expr: &KernelExpr) -> bool {
match expr {
KernelExpr::Base(kernel) => kernel.is_psd(),
KernelExpr::Scaled { kernel, scale } => *scale >= 0.0 && check_psd(kernel),
KernelExpr::Sum { left, right } => check_psd(left) && check_psd(right),
KernelExpr::Product { left, right } => check_psd(left) && check_psd(right),
KernelExpr::Power { kernel, exponent } => {
*exponent == 1 && check_psd(kernel)
}
}
}
pub struct KernelBuilder {
expr: Option<KernelExpr>,
}
#[allow(clippy::should_implement_trait)]
impl KernelBuilder {
pub fn new() -> Self {
Self { expr: None }
}
pub fn add(mut self, kernel: Arc<dyn Kernel>) -> Self {
let new_expr = KernelExpr::base(kernel);
self.expr = Some(match self.expr {
Some(existing) => existing.add(new_expr),
None => new_expr,
});
self
}
pub fn add_scaled(mut self, kernel: Arc<dyn Kernel>, scale: f64) -> Self {
let scaled = KernelExpr::base(kernel)
.scale(scale)
.expect("scale value is valid f64");
self.expr = Some(match self.expr {
Some(existing) => existing.add(scaled),
None => scaled,
});
self
}
pub fn multiply(mut self, kernel: Arc<dyn Kernel>) -> Self {
let new_expr = KernelExpr::base(kernel);
self.expr = Some(match self.expr {
Some(existing) => existing.multiply(new_expr),
None => new_expr,
});
self
}
pub fn scale(mut self, scale: f64) -> Result<Self> {
if let Some(expr) = self.expr {
self.expr = Some(expr.scale(scale)?);
}
Ok(self)
}
pub fn power(mut self, exponent: u32) -> Result<Self> {
if let Some(expr) = self.expr {
self.expr = Some(expr.power(exponent)?);
}
Ok(self)
}
pub fn build(self) -> Box<dyn Kernel> {
match self.expr {
Some(expr) => expr.simplify().build(),
None => {
KernelExpr::base(Arc::new(crate::tensor_kernels::LinearKernel::new()))
.scale(0.0)
.expect("0.0 is a valid scale value")
.build()
}
}
}
}
impl Default for KernelBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_kernels::{LinearKernel, RbfKernel};
use crate::types::RbfKernelConfig;
#[test]
fn test_kernel_expr_scale() {
let linear = Arc::new(LinearKernel::new());
let expr = KernelExpr::base(linear).scale(2.0).expect("unwrap");
let kernel = SymbolicKernel::new(expr);
let x = vec![1.0, 2.0, 3.0];
let y = vec![4.0, 5.0, 6.0];
let result = kernel.compute(&x, &y).expect("unwrap");
assert!((result - 64.0).abs() < 1e-10);
}
#[test]
fn test_kernel_expr_add() {
let linear1 = Arc::new(LinearKernel::new());
let linear2 = Arc::new(LinearKernel::new());
let expr1 = KernelExpr::base(linear1);
let expr2 = KernelExpr::base(linear2);
let sum = expr1.add(expr2);
let kernel = SymbolicKernel::new(sum);
let x = vec![1.0, 2.0, 3.0];
let y = vec![4.0, 5.0, 6.0];
let result = kernel.compute(&x, &y).expect("unwrap");
assert!((result - 64.0).abs() < 1e-10);
}
#[test]
fn test_kernel_expr_multiply() {
let linear1 = Arc::new(LinearKernel::new());
let linear2 = Arc::new(LinearKernel::new());
let expr1 = KernelExpr::base(linear1);
let expr2 = KernelExpr::base(linear2);
let product = expr1.multiply(expr2);
let kernel = SymbolicKernel::new(product);
let x = vec![1.0, 2.0, 3.0];
let y = vec![4.0, 5.0, 6.0];
let result = kernel.compute(&x, &y).expect("unwrap");
assert!((result - 1024.0).abs() < 1e-10);
}
#[test]
fn test_kernel_expr_power() {
let linear = Arc::new(LinearKernel::new());
let expr = KernelExpr::base(linear).power(2).expect("unwrap");
let kernel = SymbolicKernel::new(expr);
let x = vec![1.0, 2.0, 3.0];
let y = vec![4.0, 5.0, 6.0];
let result = kernel.compute(&x, &y).expect("unwrap");
assert!((result - 1024.0).abs() < 1e-10);
}
#[test]
fn test_kernel_expr_complex() {
let linear = Arc::new(LinearKernel::new());
let rbf = Arc::new(RbfKernel::new(RbfKernelConfig::new(0.5)).expect("unwrap"));
let expr = KernelExpr::base(linear)
.scale(0.5)
.expect("unwrap")
.add(KernelExpr::base(rbf).scale(0.3).expect("unwrap"));
let kernel = SymbolicKernel::new(expr);
let x = vec![1.0, 2.0, 3.0];
let y = vec![1.0, 2.0, 3.0];
let result = kernel.compute(&x, &y).expect("unwrap");
assert!((result - 7.3).abs() < 1e-6);
}
#[test]
fn test_kernel_builder() {
let builder = KernelBuilder::new();
let kernel = builder
.add_scaled(Arc::new(LinearKernel::new()), 0.5)
.add_scaled(Arc::new(LinearKernel::new()), 0.3)
.build();
let x = vec![1.0, 2.0, 3.0];
let y = vec![4.0, 5.0, 6.0];
let result = kernel.compute(&x, &y).expect("unwrap");
assert!((result - 25.6).abs() < 1e-10);
}
#[test]
fn test_kernel_builder_multiply() {
let builder = KernelBuilder::new();
let kernel = builder
.add(Arc::new(LinearKernel::new()))
.multiply(Arc::new(LinearKernel::new()))
.build();
let x = vec![1.0, 2.0];
let y = vec![3.0, 4.0];
let result = kernel.compute(&x, &y).expect("unwrap");
assert!((result - 121.0).abs() < 1e-10);
}
#[test]
fn test_kernel_builder_power() {
let builder = KernelBuilder::new();
let kernel = builder
.add(Arc::new(LinearKernel::new()))
.power(3)
.expect("unwrap")
.build();
let x = vec![2.0];
let y = vec![3.0];
let result = kernel.compute(&x, &y).expect("unwrap");
assert!((result - 216.0).abs() < 1e-10);
}
#[test]
fn test_kernel_expr_simplify() {
let linear = Arc::new(LinearKernel::new());
let expr = KernelExpr::base(linear)
.scale(1.0)
.expect("unwrap")
.power(1)
.expect("unwrap");
let simplified = expr.simplify();
matches!(simplified, KernelExpr::Base(_));
}
#[test]
fn test_invalid_scale() {
let linear = Arc::new(LinearKernel::new());
let result = KernelExpr::base(linear).scale(-1.0);
assert!(result.is_err());
}
#[test]
fn test_invalid_power() {
let linear = Arc::new(LinearKernel::new());
let result = KernelExpr::base(linear).power(0);
assert!(result.is_err());
}
#[test]
fn test_psd_property() {
let expr = KernelExpr::base(Arc::new(LinearKernel::new()))
.add(KernelExpr::base(Arc::new(LinearKernel::new())));
let kernel = SymbolicKernel::new(expr);
assert!(kernel.is_psd());
let expr = KernelExpr::base(Arc::new(LinearKernel::new()))
.multiply(KernelExpr::base(Arc::new(LinearKernel::new())));
let kernel = SymbolicKernel::new(expr);
assert!(kernel.is_psd());
let expr = KernelExpr::base(Arc::new(LinearKernel::new()))
.scale(2.0)
.expect("unwrap");
let kernel = SymbolicKernel::new(expr);
assert!(kernel.is_psd());
}
#[test]
fn test_kernel_name() {
let linear = Arc::new(LinearKernel::new());
let expr = KernelExpr::base(linear);
let kernel = SymbolicKernel::new(expr);
assert_eq!(kernel.name(), "Symbolic");
}
#[test]
fn test_empty_builder() {
let builder = KernelBuilder::new();
let kernel = builder.build();
let x = vec![1.0, 2.0, 3.0];
let y = vec![4.0, 5.0, 6.0];
let result = kernel.compute(&x, &y).expect("unwrap");
assert!((result - 0.0).abs() < 1e-10);
}
}