use std::fmt;
use crate::learner::StreamingLearner;
pub trait Kernel: Send + Sync {
fn eval(&self, a: &[f64], b: &[f64]) -> f64;
fn name(&self) -> &str {
"Kernel"
}
}
#[derive(Clone, Debug)]
pub struct RBFKernel {
pub gamma: f64,
}
impl RBFKernel {
pub fn new(gamma: f64) -> Self {
assert!(gamma > 0.0, "RBF gamma must be > 0, got {gamma}");
Self { gamma }
}
}
impl Kernel for RBFKernel {
#[inline]
fn eval(&self, a: &[f64], b: &[f64]) -> f64 {
debug_assert_eq!(a.len(), b.len());
let sq_dist: f64 = a
.iter()
.zip(b.iter())
.map(|(ai, bi)| (ai - bi).powi(2))
.sum();
(-self.gamma * sq_dist).exp()
}
fn name(&self) -> &str {
"RBF"
}
}
#[derive(Clone, Debug)]
pub struct PolynomialKernel {
pub degree: usize,
pub coef0: f64,
}
impl PolynomialKernel {
pub fn new(degree: usize, coef0: f64) -> Self {
assert!(degree >= 1, "polynomial degree must be >= 1, got {degree}");
Self { degree, coef0 }
}
}
impl Kernel for PolynomialKernel {
#[inline]
fn eval(&self, a: &[f64], b: &[f64]) -> f64 {
debug_assert_eq!(a.len(), b.len());
let dot: f64 = a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum();
(dot + self.coef0).powi(self.degree as i32)
}
fn name(&self) -> &str {
"Polynomial"
}
}
#[derive(Clone, Debug, Default)]
pub struct LinearKernel;
impl Kernel for LinearKernel {
#[inline]
fn eval(&self, a: &[f64], b: &[f64]) -> f64 {
debug_assert_eq!(a.len(), b.len());
a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
}
fn name(&self) -> &str {
"Linear"
}
}
pub struct KRLS {
kernel: Box<dyn Kernel>,
dictionary: Vec<Vec<f64>>,
weights: Vec<f64>,
p_matrix: Vec<f64>,
budget: usize,
ald_threshold: f64,
forgetting_factor: f64,
samples_seen: u64,
}
impl KRLS {
pub fn new(kernel: Box<dyn Kernel>, budget: usize, ald_threshold: f64) -> Self {
Self::with_forgetting(kernel, budget, ald_threshold, 1.0)
}
pub fn with_forgetting(
kernel: Box<dyn Kernel>,
budget: usize,
ald_threshold: f64,
forgetting_factor: f64,
) -> Self {
assert!(budget > 0, "KRLS budget must be > 0, got {budget}");
assert!(
ald_threshold > 0.0,
"ALD threshold must be > 0, got {ald_threshold}"
);
assert!(
(0.0..=1.0).contains(&forgetting_factor),
"forgetting_factor must be in (0, 1], got {forgetting_factor}"
);
Self {
kernel,
dictionary: Vec::new(),
weights: Vec::new(),
p_matrix: Vec::new(),
budget,
ald_threshold,
forgetting_factor,
samples_seen: 0,
}
}
#[inline]
pub fn dict_size(&self) -> usize {
self.dictionary.len()
}
#[inline]
pub fn budget(&self) -> usize {
self.budget
}
#[inline]
pub fn ald_threshold(&self) -> f64 {
self.ald_threshold
}
#[inline]
pub fn forgetting_factor(&self) -> f64 {
self.forgetting_factor
}
pub fn dictionary(&self) -> &[Vec<f64>] {
&self.dictionary
}
pub fn weights(&self) -> &[f64] {
&self.weights
}
fn kernel_vector(&self, x: &[f64]) -> Vec<f64> {
self.dictionary
.iter()
.map(|di| self.kernel.eval(x, di))
.collect()
}
fn p_times_vec(&self, v: &[f64]) -> Vec<f64> {
let n = self.dictionary.len();
let mut result = vec![0.0; n];
for (i, ri) in result.iter_mut().enumerate() {
let row_start = i * n;
for (j, &vj) in v.iter().enumerate() {
*ri += self.p_matrix[row_start + j] * vj;
}
}
result
}
#[inline]
fn dot(a: &[f64], b: &[f64]) -> f64 {
irithyll_core::simd::simd_dot(a, b)
}
#[inline]
fn at_budget(&self) -> bool {
self.dictionary.len() >= self.budget
}
fn add_to_dictionary(&mut self, x: Vec<f64>, k_t: &[f64], delta: f64, target: f64) {
let n = self.dictionary.len();
let pred = Self::dot(&self.weights, k_t);
let error = target - pred;
let a_t = if n > 0 {
self.p_times_vec(k_t)
} else {
Vec::new()
};
let new_n = n + 1;
let mut new_p = vec![0.0; new_n * new_n];
let inv_delta = 1.0 / delta;
for i in 0..n {
for j in 0..n {
new_p[i * new_n + j] = self.p_matrix[i * n + j] + inv_delta * a_t[i] * a_t[j];
}
}
for i in 0..n {
new_p[i * new_n + n] = -inv_delta * a_t[i];
new_p[n * new_n + i] = -inv_delta * a_t[i];
}
new_p[n * new_n + n] = inv_delta;
self.p_matrix = new_p;
for (wi, &ai) in self.weights.iter_mut().zip(a_t.iter()) {
*wi -= inv_delta * ai * error;
}
self.weights.push(inv_delta * error);
self.dictionary.push(x);
}
fn update_weights_only(&mut self, k_t: &[f64], target: f64) {
let n = self.dictionary.len();
if n == 0 {
return;
}
let a_t = self.p_times_vec(k_t);
let pred = Self::dot(&self.weights, k_t);
let error = target - pred;
let denom = self.forgetting_factor + Self::dot(k_t, &a_t);
let inv_denom = 1.0 / denom;
for (wi, &ai) in self.weights.iter_mut().zip(a_t.iter()) {
*wi += ai * inv_denom * error;
}
let inv_lambda = 1.0 / self.forgetting_factor;
for (i, &a_i) in a_t.iter().enumerate() {
let qi = a_i * inv_denom;
let row_start = i * n;
for (j, &a_j) in a_t.iter().enumerate() {
self.p_matrix[row_start + j] =
(self.p_matrix[row_start + j] - qi * a_j) * inv_lambda;
}
}
}
}
impl StreamingLearner for KRLS {
fn train_one(&mut self, features: &[f64], target: f64, _weight: f64) {
self.samples_seen += 1;
let k_tt = self.kernel.eval(features, features);
let n = self.dictionary.len();
if n == 0 {
self.dictionary.push(features.to_vec());
self.weights.push(target / k_tt.max(1e-15));
self.p_matrix.push(1.0 / k_tt.max(1e-15));
return;
}
let k_t = self.kernel_vector(features);
let p_k = self.p_times_vec(&k_t);
let delta = k_tt - Self::dot(&k_t, &p_k);
if delta > self.ald_threshold && !self.at_budget() {
self.add_to_dictionary(features.to_vec(), &k_t, delta, target);
} else {
self.update_weights_only(&k_t, target);
}
}
fn predict(&self, features: &[f64]) -> f64 {
if self.dictionary.is_empty() {
return 0.0;
}
let k_t = self.kernel_vector(features);
Self::dot(&self.weights, &k_t)
}
#[inline]
fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
fn reset(&mut self) {
self.dictionary.clear();
self.weights.clear();
self.p_matrix.clear();
self.samples_seen = 0;
}
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
<Self as crate::learner::Tunable>::diagnostics_array(self)
}
#[allow(deprecated)]
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
<Self as crate::learner::Tunable>::adjust_config(self, lr_multiplier, lambda_delta);
}
}
impl crate::learner::Tunable for KRLS {
fn diagnostics_array(&self) -> [f64; 5] {
let budget = self.budget.max(1) as f64;
[
0.0,
1.0 - self.forgetting_factor,
0.0,
self.dictionary.len() as f64,
self.dictionary.len() as f64 / budget,
]
}
fn adjust_config(&mut self, lr_multiplier: f64, _lambda_delta: f64) {
self.forgetting_factor = (self.forgetting_factor * lr_multiplier).clamp(1e-6, 1.0);
}
}
impl fmt::Debug for KRLS {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("KRLS")
.field("kernel", &self.kernel.name())
.field("dict_size", &self.dictionary.len())
.field("budget", &self.budget)
.field("ald_threshold", &self.ald_threshold)
.field("forgetting_factor", &self.forgetting_factor)
.field("samples_seen", &self.samples_seen)
.finish()
}
}
impl crate::automl::DiagnosticSource for KRLS {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
let budget = self.budget().max(1) as f64;
Some(crate::automl::ConfigDiagnostics {
effective_dof: self.dict_size() as f64,
regularization_sensitivity: 1.0 - self.forgetting_factor(),
uncertainty: self.dict_size() as f64 / budget,
..Default::default()
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learner::StreamingLearner;
#[test]
fn construction_and_initial_state() {
let krls = KRLS::new(Box::new(RBFKernel::new(1.0)), 50, 1e-4);
assert_eq!(krls.dict_size(), 0);
assert_eq!(krls.budget(), 50);
assert_eq!(krls.n_samples_seen(), 0);
assert!((krls.predict(&[1.0]) - 0.0).abs() < 1e-15);
}
#[test]
fn learn_sine_rbf() {
let mut krls = KRLS::new(Box::new(RBFKernel::new(0.5)), 100, 1e-4);
for i in 0..500 {
let x = i as f64 * 0.02;
krls.train(&[x], x.sin());
}
let test_points = [1.0, 2.0, 3.0, 5.0];
let mut max_err = 0.0_f64;
for &x in &test_points {
let pred = krls.predict(&[x]);
let err = (pred - x.sin()).abs();
max_err = max_err.max(err);
}
assert!(
max_err < 1.0,
"KRLS should learn sin(x), max error = {}",
max_err
);
}
#[test]
fn dictionary_respects_budget() {
let budget = 20;
let mut krls = KRLS::new(Box::new(RBFKernel::new(0.5)), budget, 1e-6);
for i in 0..200 {
let x = i as f64 * 0.1;
krls.train(&[x], x.sin());
}
assert!(
krls.dict_size() <= budget,
"dict_size={} exceeds budget={}",
krls.dict_size(),
budget
);
}
#[test]
fn ald_sparsifies_dictionary() {
let mut krls = KRLS::new(Box::new(RBFKernel::new(1.0)), 500, 0.01);
for i in 0..200 {
let x = i as f64 * 0.05;
krls.train(&[x], x.sin());
}
assert!(
krls.dict_size() < 200,
"ALD should sparsify: dict_size={}, samples=200",
krls.dict_size()
);
}
#[test]
fn forgetting_adapts_to_shift() {
let mut krls = KRLS::with_forgetting(Box::new(RBFKernel::new(1.0)), 50, 1e-4, 0.98);
for i in 0..200 {
let x = i as f64 * 0.01;
krls.train(&[x], x);
}
for i in 0..200 {
let x = i as f64 * 0.01;
krls.train(&[x], -x);
}
let pred = krls.predict(&[1.0]);
assert!(
pred < 0.5,
"forgetting KRLS should adapt to shift, pred at 1.0 = {}",
pred
);
}
#[test]
fn reset_clears_all_state() {
let mut krls = KRLS::new(Box::new(RBFKernel::new(1.0)), 50, 1e-4);
krls.train(&[1.0], 1.0);
krls.train(&[2.0], 4.0);
assert!(krls.dict_size() > 0);
assert_eq!(krls.n_samples_seen(), 2);
krls.reset();
assert_eq!(krls.dict_size(), 0);
assert_eq!(krls.n_samples_seen(), 0);
assert!(krls.weights().is_empty());
}
#[test]
fn trait_object_works() {
let krls = KRLS::new(Box::new(RBFKernel::new(1.0)), 50, 1e-4);
let mut boxed: Box<dyn StreamingLearner> = Box::new(krls);
boxed.train(&[1.0], 2.0);
boxed.train(&[2.0], 4.0);
assert_eq!(boxed.n_samples_seen(), 2);
let pred = boxed.predict(&[1.5]);
assert!(pred.is_finite());
boxed.reset();
assert_eq!(boxed.n_samples_seen(), 0);
}
#[test]
fn polynomial_kernel_works() {
let mut krls = KRLS::new(Box::new(PolynomialKernel::new(2, 1.0)), 100, 1e-4);
for i in 0..300 {
let x = (i as f64 - 150.0) * 0.02;
krls.train(&[x], x * x);
}
let pred = krls.predict(&[2.0]);
assert!(
(pred - 4.0).abs() < 2.0,
"poly kernel should approximate x^2, pred(2.0) = {}",
pred
);
}
#[test]
fn linear_kernel_matches_linear() {
let mut krls = KRLS::new(Box::new(LinearKernel), 100, 1e-4);
for i in 0..200 {
let x = i as f64 * 0.05;
krls.train(&[x], 3.0 * x);
}
let pred = krls.predict(&[5.0]);
assert!(
(pred - 15.0).abs() < 2.0,
"linear kernel KRLS should learn y=3x, pred(5.0) = {}",
pred
);
}
#[test]
fn debug_format_works() {
let krls = KRLS::new(Box::new(RBFKernel::new(1.0)), 50, 1e-4);
let debug = format!("{:?}", krls);
assert!(debug.contains("KRLS"));
assert!(debug.contains("RBF"));
}
}