use super::kernel::{Gamma, Kernel};
use crate::dataset::Dataset;
use crate::error::{Result, ScryLearnError};
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct KernelSVR {
kernel: Kernel,
c: f64,
epsilon: f64,
tol: f64,
max_iter: usize,
gamma_strategy: Option<Gamma>,
b: f64,
sv_x: Vec<Vec<f64>>,
sv_coeff: Vec<f64>,
fitted: bool,
#[cfg_attr(feature = "serde", serde(default))]
_schema_version: u32,
}
impl KernelSVR {
pub fn new() -> Self {
Self {
kernel: Kernel::default(),
c: 1.0,
epsilon: 0.1,
tol: 1e-3,
max_iter: 1000,
gamma_strategy: Some(Gamma::Scale),
b: 0.0,
sv_x: Vec::new(),
sv_coeff: Vec::new(),
fitted: false,
_schema_version: crate::version::SCHEMA_VERSION,
}
}
pub fn kernel(mut self, k: Kernel) -> Self {
if !matches!(k, Kernel::RBF { .. }) {
self.gamma_strategy = None;
}
self.kernel = k;
self
}
pub fn c(mut self, c: f64) -> Self {
self.c = c;
self
}
pub fn epsilon(mut self, e: f64) -> Self {
self.epsilon = e;
self
}
pub fn tol(mut self, t: f64) -> Self {
self.tol = t;
self
}
pub fn max_iter(mut self, n: usize) -> Self {
self.max_iter = n;
self
}
pub fn gamma(mut self, g: Gamma) -> Self {
self.gamma_strategy = Some(g);
self
}
pub fn fit(&mut self, data: &Dataset) -> Result<()> {
data.validate_finite()?;
let n = data.n_samples();
if n == 0 {
return Err(ScryLearnError::EmptyDataset);
}
if self.c <= 0.0 || !self.c.is_finite() {
return Err(ScryLearnError::InvalidParameter(
"C must be finite and positive".into(),
));
}
if let Some(ref gs) = self.gamma_strategy {
let m = data.n_features();
let rows = data.feature_matrix();
let variance = compute_feature_variance(&rows, m);
let g = gs.resolve(m, variance);
self.kernel = Kernel::RBF { gamma: g };
}
let rows = data.feature_matrix();
let y = &data.target;
let mut k_matrix = vec![vec![0.0; n]; n];
for i in 0..n {
for j in i..n {
let val = self.kernel.eval(&rows[i], &rows[j]);
k_matrix[i][j] = val;
k_matrix[j][i] = val;
}
}
let mut a = vec![0.0_f64; n]; let mut b = 0.0_f64;
let mut passes = 0_usize;
let mut total_iter = 0_usize;
let hard_cap = self.max_iter * n;
while passes < self.max_iter && total_iter < hard_cap {
let mut num_changed = 0_usize;
total_iter += 1;
for i in 0..n {
let f_i = svr_predict_raw(&a, &k_matrix[i], b);
let r_i = f_i - y[i];
let violates_pos = (r_i > self.epsilon + self.tol && a[i] < self.c)
|| (r_i < self.epsilon - self.tol && a[i] > 0.0);
let violates_neg = (-r_i > self.epsilon + self.tol && -a[i] < self.c)
|| (-r_i < self.epsilon - self.tol && -a[i] > 0.0);
if !violates_pos && !violates_neg {
continue;
}
let j = (i + 1 + (passes % n.saturating_sub(1).max(1))) % n;
if j == i {
continue;
}
let f_j = svr_predict_raw(&a, &k_matrix[j], b);
let r_j = f_j - y[j];
let eta = k_matrix[i][i] + k_matrix[j][j] - 2.0 * k_matrix[i][j];
if eta < 1e-12 {
continue;
}
let a_i_old = a[i];
let a_j_old = a[j];
let delta_i = if r_i > self.epsilon {
r_i - self.epsilon
} else if r_i < -self.epsilon {
r_i + self.epsilon
} else {
0.0
};
let delta_j = if r_j > self.epsilon {
r_j - self.epsilon
} else if r_j < -self.epsilon {
r_j + self.epsilon
} else {
0.0
};
let new_a_i = a[i] - (delta_i - delta_j) / eta;
let new_a_i = new_a_i.clamp(-self.c, self.c);
if (new_a_i - a_i_old).abs() < 1e-8 {
continue;
}
a[i] = new_a_i;
a[j] = a_j_old + (a_i_old - new_a_i);
a[j] = a[j].clamp(-self.c, self.c);
let b1 =
y[i] - self.epsilon * a[i].signum() - svr_predict_raw_no_b(&a, &k_matrix[i]);
let b2 =
y[j] - self.epsilon * a[j].signum() - svr_predict_raw_no_b(&a, &k_matrix[j]);
b = (b1 + b2) / 2.0;
num_changed += 1;
}
if num_changed == 0 {
passes += 1;
} else {
passes = 0;
}
}
self.sv_x = Vec::new();
self.sv_coeff = Vec::new();
for i in 0..n {
if a[i].abs() > 1e-10 {
self.sv_x.push(rows[i].clone());
self.sv_coeff.push(a[i]);
}
}
self.b = b;
self.fitted = true;
Ok(())
}
pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
crate::version::check_schema_version(self._schema_version)?;
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
Ok(features
.iter()
.map(|x| {
let mut sum = self.b;
for (sv, &coeff) in self.sv_x.iter().zip(self.sv_coeff.iter()) {
sum += coeff * self.kernel.eval(sv, x);
}
sum
})
.collect())
}
}
impl Default for KernelSVR {
fn default() -> Self {
Self::new()
}
}
#[inline]
fn svr_predict_raw(a: &[f64], k_row: &[f64], b: f64) -> f64 {
let mut sum = b;
for (&ai, &ki) in a.iter().zip(k_row.iter()) {
sum += ai * ki;
}
sum
}
#[inline]
fn svr_predict_raw_no_b(a: &[f64], k_row: &[f64]) -> f64 {
let mut sum = 0.0;
for (&ai, &ki) in a.iter().zip(k_row.iter()) {
sum += ai * ki;
}
sum
}
fn compute_feature_variance(rows: &[Vec<f64>], n_features: usize) -> f64 {
let n = rows.len() as f64;
if n <= 1.0 || n_features == 0 {
return 1.0;
}
let mut total_var = 0.0;
for j in 0..n_features {
let mean = rows.iter().map(|r| r[j]).sum::<f64>() / n;
let var = rows.iter().map(|r| (r[j] - mean).powi(2)).sum::<f64>() / n;
total_var += var;
}
total_var / n_features as f64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kernel_svr_linear() {
let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]];
let target = vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0];
let data = Dataset::new(features, target, vec!["x".into()], "y");
let mut svr = KernelSVR::new()
.kernel(Kernel::Linear)
.c(100.0)
.epsilon(0.1)
.max_iter(2000);
svr.fit(&data).unwrap();
let preds = svr.predict(&[vec![3.0], vec![5.0]]).unwrap();
assert!(
(preds[0] - 6.0).abs() < 3.0,
"Expected ~6.0, got {}",
preds[0]
);
assert!(
(preds[1] - 10.0).abs() < 3.0,
"Expected ~10.0, got {}",
preds[1]
);
}
#[test]
fn test_kernel_svr_rbf() {
let n = 30;
let x: Vec<f64> = (0..n)
.map(|i| -3.0 + 6.0 * i as f64 / (n - 1) as f64)
.collect();
let y: Vec<f64> = x.iter().map(|&xi| xi * xi).collect();
let data = Dataset::new(vec![x.clone()], y, vec!["x".into()], "y");
let mut svr = KernelSVR::new()
.kernel(Kernel::RBF { gamma: 0.5 })
.c(100.0)
.epsilon(0.1)
.max_iter(2000);
svr.fit(&data).unwrap();
let test_x = vec![vec![0.0], vec![1.0], vec![-1.0]];
let preds = svr.predict(&test_x).unwrap();
assert!(preds[0].abs() < 2.0, "Expected ~0, got {}", preds[0]);
assert!(
(preds[1] - 1.0).abs() < 2.0,
"Expected ~1, got {}",
preds[1]
);
assert!(
(preds[2] - 1.0).abs() < 2.0,
"Expected ~1, got {}",
preds[2]
);
}
#[test]
fn test_kernel_svr_not_fitted() {
let svr = KernelSVR::new();
assert!(svr.predict(&[vec![1.0]]).is_err());
}
#[test]
fn test_kernel_svr_invalid_c() {
let features = vec![vec![1.0]];
let target = vec![0.0];
let data = Dataset::new(features, target, vec!["x".into()], "y");
let mut svr = KernelSVR::new().c(-1.0);
assert!(svr.fit(&data).is_err());
}
#[test]
fn test_kernel_svr_gamma_auto() {
let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0], vec![2.0, 3.0, 4.0, 5.0, 6.0]];
let target = vec![2.0, 4.0, 6.0, 8.0, 10.0];
let data = Dataset::new(features, target, vec!["a".into(), "b".into()], "y");
let mut svr = KernelSVR::new().gamma(Gamma::Auto).c(10.0);
svr.fit(&data).unwrap();
match &svr.kernel {
Kernel::RBF { gamma } => {
assert!(
(*gamma - 0.5).abs() < 1e-10,
"Gamma::Auto should give 1/n_features=0.5, got {gamma}",
);
}
other => panic!("expected RBF kernel, got {:?}", other),
}
}
}