use anofox_ml_core::{Fit, FitWeighted, Predict, Result, RustMlError};
use anofox_ml_svm::SvmKernel;
use faer::linalg::solvers::Solve;
use faer::{Mat, Side};
use ndarray::{Array1, Array2};
#[derive(Debug, Clone)]
pub struct KernelRidge {
pub alpha: f64,
pub kernel: SvmKernel,
}
impl KernelRidge {
pub fn new() -> Self {
Self {
alpha: 1.0,
kernel: SvmKernel::Linear,
}
}
pub fn with_alpha(mut self, alpha: f64) -> Self {
self.alpha = alpha;
self
}
pub fn with_kernel(mut self, kernel: SvmKernel) -> Self {
self.kernel = kernel;
self
}
}
impl Default for KernelRidge {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct FittedKernelRidge {
pub x_train: Array2<f64>,
pub dual_coef: Array1<f64>,
pub kernel: SvmKernel,
}
fn build_gram(x_a: &Array2<f64>, x_b: &Array2<f64>, kernel: &SvmKernel) -> Array2<f64> {
let na = x_a.nrows();
let nb = x_b.nrows();
let mut k = Array2::<f64>::zeros((na, nb));
for i in 0..na {
let ri = x_a.row(i);
for j in 0..nb {
let rj = x_b.row(j);
k[[i, j]] = kernel.compute(&ri, &rj);
}
}
k
}
impl Fit<f64> for KernelRidge {
type Fitted = FittedKernelRidge;
fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
if x.nrows() != y.len() {
return Err(RustMlError::ShapeMismatch(format!(
"X has {} rows but y has {} elements",
x.nrows(),
y.len()
)));
}
if x.is_empty() {
return Err(RustMlError::EmptyInput("training data is empty".into()));
}
if self.alpha < 0.0 {
return Err(RustMlError::InvalidParameter(
"alpha must be non-negative".into(),
));
}
let n = x.nrows();
let mut k = build_gram(x, x, &self.kernel);
for i in 0..n {
k[[i, i]] += self.alpha;
}
let k_mat = Mat::from_fn(n, n, |i, j| k[[i, j]]);
let llt = faer::linalg::solvers::Llt::new(k_mat.as_ref(), Side::Lower)
.map_err(|e| RustMlError::InvalidParameter(format!("Cholesky failed: {e:?}")))?;
let y_mat = Mat::from_fn(n, 1, |i, _| y[i]);
let sol = llt.solve(&y_mat);
let dual = Array1::from_vec((0..n).map(|i| sol[(i, 0)]).collect());
Ok(FittedKernelRidge {
x_train: x.clone(),
dual_coef: dual,
kernel: self.kernel.clone(),
})
}
}
impl FitWeighted<f64> for KernelRidge {
type Fitted = FittedKernelRidge;
fn fit_weighted(
&self,
x: &Array2<f64>,
y: &Array1<f64>,
sample_weight: Option<&Array1<f64>>,
) -> Result<Self::Fitted> {
if x.nrows() != y.len() {
return Err(RustMlError::ShapeMismatch(format!(
"X has {} rows but y has {} elements",
x.nrows(),
y.len()
)));
}
if x.is_empty() {
return Err(RustMlError::EmptyInput("training data is empty".into()));
}
if self.alpha < 0.0 {
return Err(RustMlError::InvalidParameter(
"alpha must be non-negative".into(),
));
}
if let Some(w) = sample_weight {
if w.len() != y.len() {
return Err(RustMlError::ShapeMismatch(format!(
"sample_weight len {} != y len {}",
w.len(),
y.len()
)));
}
for &wi in w.iter() {
if !wi.is_finite() || wi < 0.0 {
return Err(RustMlError::InvalidParameter(
"sample_weight must be non-negative finite".into(),
));
}
}
}
let n = x.nrows();
let mut k = build_gram(x, x, &self.kernel);
if let Some(w) = sample_weight {
let sqrtw: Vec<f64> = w.iter().map(|v| v.sqrt()).collect();
for i in 0..n {
for j in 0..n {
k[[i, j]] *= sqrtw[i] * sqrtw[j];
}
}
}
for i in 0..n {
k[[i, i]] += self.alpha;
}
let k_mat = Mat::from_fn(n, n, |i, j| k[[i, j]]);
let llt = faer::linalg::solvers::Llt::new(k_mat.as_ref(), Side::Lower)
.map_err(|e| RustMlError::InvalidParameter(format!("Cholesky failed: {e:?}")))?;
let rhs: Vec<f64> = match sample_weight {
Some(w) => (0..n).map(|i| w[i].sqrt() * y[i]).collect(),
None => y.iter().copied().collect(),
};
let y_mat = Mat::from_fn(n, 1, |i, _| rhs[i]);
let sol = llt.solve(&y_mat);
let dual = match sample_weight {
Some(w) => Array1::from_vec((0..n).map(|i| w[i].sqrt() * sol[(i, 0)]).collect()),
None => Array1::from_vec((0..n).map(|i| sol[(i, 0)]).collect()),
};
Ok(FittedKernelRidge {
x_train: x.clone(),
dual_coef: dual,
kernel: self.kernel.clone(),
})
}
}
impl Predict<f64> for FittedKernelRidge {
fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
if x.ncols() != self.x_train.ncols() {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} features, got {}",
self.x_train.ncols(),
x.ncols()
)));
}
let k_test = build_gram(x, &self.x_train, &self.kernel);
Ok(k_test.dot(&self.dual_coef))
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn test_linear_kernel_ridge_recovers_ridge_solution() {
let x = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [2.0, 0.0]];
let y = array![1.0, 2.0, 3.0, 2.0];
let alpha = 0.5;
let kr = KernelRidge::new()
.with_alpha(alpha)
.with_kernel(SvmKernel::Linear);
let fitted = kr.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 4);
for &p in preds.iter() {
assert!(p.is_finite());
}
}
#[test]
fn test_rbf_perfect_fit_zero_alpha() {
let x = array![[0.0], [1.0], [2.0], [3.0], [4.0]];
let y = array![1.0, -1.0, 0.5, 2.0, -0.5];
let fitted = KernelRidge::new()
.with_alpha(1e-10)
.with_kernel(SvmKernel::Rbf { gamma: 1.0 })
.fit(&x, &y)
.unwrap();
let pred = fitted.predict(&x).unwrap();
for i in 0..5 {
assert_abs_diff_eq!(pred[i], y[i], epsilon = 1e-5);
}
}
#[test]
fn test_negative_alpha_errors() {
let x = array![[1.0]];
let y = array![1.0];
assert!(KernelRidge::new().with_alpha(-1.0).fit(&x, &y).is_err());
}
use anofox_ml_core::FitWeighted;
#[test]
fn test_kernel_ridge_uniform_weights_match_unweighted() {
let x = array![[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [2.0, 3.0]];
let y = array![1.0, 2.0, 3.0, 4.0];
let kr = KernelRidge::new()
.with_alpha(0.5)
.with_kernel(SvmKernel::Rbf { gamma: 0.5 });
let unw = kr.fit(&x, &y).unwrap();
let ones = Array1::<f64>::ones(4);
let w = kr.fit_weighted(&x, &y, Some(&ones)).unwrap();
for i in 0..4 {
assert_abs_diff_eq!(unw.dual_coef[i], w.dual_coef[i], epsilon = 1e-10);
}
}
#[test]
fn test_kernel_ridge_high_weight_dominates() {
let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [10.0]];
let y = array![0.0, 0.5, 0.5, 0.0, 0.0, 100.0];
let kr = KernelRidge::new()
.with_alpha(1e-3)
.with_kernel(SvmKernel::Rbf { gamma: 0.5 });
let w = array![1.0, 1.0, 1.0, 1.0, 1.0, 1e6];
let fitted = kr.fit_weighted(&x, &y, Some(&w)).unwrap();
let p = fitted.predict(&array![[10.0]]).unwrap();
assert!(
(p[0] - 100.0).abs() < 1.0,
"high-weight anchor pred={}",
p[0]
);
}
}
impl anofox_ml_core::RegressorScore<f64> for FittedKernelRidge {}