use super::bspline;
use irithyll_core::math::{silu, silu_derivative};
use irithyll_core::rng::standard_normal;
#[allow(dead_code)] pub(crate) struct KANLayer {
coefficients: Vec<f64>,
velocity: Vec<f64>,
grad_sq_accum: Vec<f64>,
w_b: Vec<f64>,
w_s: Vec<f64>,
grid: Vec<f64>,
n_in: usize,
n_out: usize,
k: usize,
g: usize,
n_coeffs: usize,
momentum: f64,
}
impl KANLayer {
pub fn new(
n_in: usize,
n_out: usize,
k: usize,
g: usize,
momentum: f64,
seed: &mut u64,
) -> Self {
let n_coeffs = g + k;
let n_edges = n_out * n_in;
let grid = bspline::make_grid(-1.0, 1.0, g, k);
let total_coeffs = n_edges * n_coeffs;
let mut coefficients = Vec::with_capacity(total_coeffs);
for _ in 0..total_coeffs {
coefficients.push(standard_normal(seed) * 0.1);
}
let velocity = vec![0.0; total_coeffs];
let grad_sq_accum = vec![1.0; total_coeffs];
let w_b_val = 1.0 / n_in as f64;
let w_b = vec![w_b_val; n_edges];
let w_s = vec![1.0; n_edges];
Self {
coefficients,
velocity,
grad_sq_accum,
w_b,
w_s,
grid,
n_in,
n_out,
k,
g,
n_coeffs,
momentum,
}
}
pub fn forward(&self, input: &[f64]) -> Vec<f64> {
debug_assert_eq!(input.len(), self.n_in, "input length mismatch");
let mut output = vec![0.0; self.n_out];
for (j, out_j) in output.iter_mut().enumerate() {
for (i, &x_raw) in input.iter().enumerate() {
let x = x_raw.clamp(-1.0, 1.0);
let edge = j * self.n_in + i;
let coeff_base = edge * self.n_coeffs;
let (span, bases) = bspline::evaluate_basis(x, &self.grid, self.g, self.k);
let basis_start = span - self.k; let mut spline_val = 0.0;
for (b, &basis_val) in bases.iter().enumerate() {
let coeff_idx = basis_start + b;
if coeff_idx < self.n_coeffs {
spline_val += self.coefficients[coeff_base + coeff_idx] * basis_val;
}
}
let phi = self.w_b[edge] * silu(x) + self.w_s[edge] * spline_val;
*out_j += phi;
}
}
for val in &mut output {
*val = val.clamp(-1e6, 1e6);
}
output
}
pub fn backward(&mut self, input: &[f64], output_grad: &[f64], lr: f64) -> Vec<f64> {
debug_assert_eq!(input.len(), self.n_in, "input length mismatch");
debug_assert_eq!(output_grad.len(), self.n_out, "output_grad length mismatch");
let mut input_grad = vec![0.0; self.n_in];
const GRAD_CLIP: f64 = 10.0;
const UPDATE_CLIP: f64 = 0.5;
const RMSPROP_BETA: f64 = 0.9;
const RMSPROP_EPS: f64 = 1e-8;
for (j, &delta_j_raw) in output_grad.iter().enumerate() {
let delta_j = delta_j_raw.clamp(-GRAD_CLIP, GRAD_CLIP);
if !delta_j.is_finite() {
continue;
}
for (i, &x_raw) in input.iter().enumerate() {
let x = x_raw.clamp(-1.0, 1.0);
let edge = j * self.n_in + i;
let coeff_base = edge * self.n_coeffs;
let (span, bases) = bspline::evaluate_basis(x, &self.grid, self.g, self.k);
let (_, derivs) =
bspline::evaluate_basis_derivatives(x, &self.grid, self.g, self.k);
let basis_start = span - self.k;
let mut spline_val = 0.0;
let mut spline_deriv = 0.0;
for (b, (&basis_val, &deriv_val)) in bases.iter().zip(derivs.iter()).enumerate() {
let coeff_idx = basis_start + b;
if coeff_idx < self.n_coeffs {
let c = self.coefficients[coeff_base + coeff_idx];
spline_val += c * basis_val;
spline_deriv += c * deriv_val;
}
}
let coeff_grad_base = delta_j * self.w_s[edge];
if coeff_grad_base.is_finite() {
for (b, &basis_val) in bases.iter().enumerate() {
let coeff_idx = basis_start + b;
if coeff_idx < self.n_coeffs {
let grad = coeff_grad_base * basis_val;
if grad.is_finite() {
let vi = coeff_base + coeff_idx;
self.grad_sq_accum[vi] = RMSPROP_BETA * self.grad_sq_accum[vi]
+ (1.0 - RMSPROP_BETA) * grad * grad;
let adaptive_lr =
lr / (self.grad_sq_accum[vi].sqrt() + RMSPROP_EPS);
let update = self.momentum * self.velocity[vi] + adaptive_lr * grad;
self.velocity[vi] = update;
self.coefficients[vi] -= update;
}
}
}
}
let wb_grad = delta_j * silu(x);
if wb_grad.is_finite() {
let wb_update = (lr * wb_grad).clamp(-UPDATE_CLIP, UPDATE_CLIP);
self.w_b[edge] -= wb_update;
}
let ws_grad = delta_j * spline_val;
if ws_grad.is_finite() {
let ws_update = (lr * ws_grad).clamp(-UPDATE_CLIP, UPDATE_CLIP);
self.w_s[edge] -= ws_update;
}
let dphi_dx = self.w_b[edge] * silu_derivative(x) + self.w_s[edge] * spline_deriv;
input_grad[i] += delta_j * dphi_dx;
}
}
for g in &mut input_grad {
*g = g.clamp(-GRAD_CLIP, GRAD_CLIP);
}
input_grad
}
pub fn reset(&mut self, seed: &mut u64) {
for c in &mut self.coefficients {
*c = standard_normal(seed) * 0.1;
}
for v in &mut self.velocity {
*v = 0.0;
}
for a in &mut self.grad_sq_accum {
*a = 1.0;
}
let w_b_val = 1.0 / self.n_in as f64;
for w in &mut self.w_b {
*w = w_b_val;
}
for w in &mut self.w_s {
*w = 1.0;
}
}
pub fn reinitialize_output_node(&mut self, j: usize, rng: &mut u64) {
assert!(
j < self.n_out,
"output node index {} out of range (n_out={})",
j,
self.n_out
);
let scale = (2.0 / (self.n_in + self.n_out) as f64).sqrt();
for i in 0..self.n_in {
let edge = j * self.n_in + i;
let coeff_base = edge * self.n_coeffs;
for c in 0..self.n_coeffs {
self.coefficients[coeff_base + c] = standard_normal(rng) * scale;
}
for c in 0..self.n_coeffs {
self.velocity[coeff_base + c] = 0.0;
self.grad_sq_accum[coeff_base + c] = 1.0;
}
self.w_b[edge] = 1.0 / self.n_in as f64;
self.w_s[edge] = 1.0;
}
}
pub fn reinitialize_input_node(&mut self, i: usize, rng: &mut u64) {
assert!(
i < self.n_in,
"input node index {} out of range (n_in={})",
i,
self.n_in
);
let scale = (2.0 / (self.n_in + self.n_out) as f64).sqrt();
for j in 0..self.n_out {
let edge = j * self.n_in + i;
let coeff_base = edge * self.n_coeffs;
for c in 0..self.n_coeffs {
self.coefficients[coeff_base + c] = standard_normal(rng) * scale;
}
for c in 0..self.n_coeffs {
self.velocity[coeff_base + c] = 0.0;
self.grad_sq_accum[coeff_base + c] = 1.0;
}
self.w_b[edge] = 1.0 / self.n_in as f64;
self.w_s[edge] = 1.0;
}
}
#[allow(dead_code)]
pub fn n_in(&self) -> usize {
self.n_in
}
#[allow(dead_code)]
pub fn n_out(&self) -> usize {
self.n_out
}
pub fn n_params(&self) -> usize {
self.n_out * self.n_in * (self.n_coeffs + 2)
}
#[allow(dead_code)]
#[inline]
pub fn coefficients(&self) -> &[f64] {
&self.coefficients
}
#[inline]
pub fn coefficients_mut(&mut self) -> &mut [f64] {
&mut self.coefficients
}
pub fn count_dead_edges(&self, threshold: f64) -> (usize, usize) {
let n_edges = self.n_out * self.n_in;
let mut dead = 0;
for edge in 0..n_edges {
let base = edge * self.n_coeffs;
let all_dead = self.velocity[base..base + self.n_coeffs]
.iter()
.all(|v| v.abs() < threshold);
if all_dead {
dead += 1;
}
}
(dead, n_edges)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_seed() -> u64 {
0xDEAD_BEEF_CAFE_BABEu64
}
#[test]
fn forward_dimensions() {
let mut seed = make_seed();
let layer = KANLayer::new(4, 3, 3, 5, 0.9, &mut seed);
let input = vec![0.1, -0.5, 0.3, 0.8];
let output = layer.forward(&input);
assert_eq!(
output.len(),
3,
"output should have n_out=3 elements, got {}",
output.len()
);
}
#[test]
fn forward_finite() {
let mut seed = make_seed();
let layer = KANLayer::new(4, 3, 3, 5, 0.9, &mut seed);
let input = vec![0.1, -0.5, 0.3, 0.8];
let output = layer.forward(&input);
for (idx, &val) in output.iter().enumerate() {
assert!(val.is_finite(), "output[{}] is not finite: {}", idx, val);
}
}
#[test]
fn backward_updates_coefficients() {
let mut seed = make_seed();
let mut layer = KANLayer::new(2, 2, 3, 5, 0.9, &mut seed);
let input = vec![0.3, -0.7];
let output_grad = vec![1.0, -0.5];
let coeffs_before: Vec<f64> = layer.coefficients.clone();
let _input_grad = layer.backward(&input, &output_grad, 0.01);
let changed = coeffs_before
.iter()
.zip(layer.coefficients.iter())
.filter(|(&a, &b)| (a - b).abs() > 1e-15)
.count();
assert!(
changed > 0,
"no coefficients were updated during backward pass"
);
}
#[test]
fn backward_sparse() {
let n_in = 3;
let n_out = 2;
let k = 3;
let g = 5;
let mut seed = make_seed();
let mut layer = KANLayer::new(n_in, n_out, k, g, 0.9, &mut seed);
let input = vec![0.2, -0.4, 0.6];
let output_grad = vec![1.0, 1.0];
let coeffs_before: Vec<f64> = layer.coefficients.clone();
let _input_grad = layer.backward(&input, &output_grad, 0.01);
let changed = coeffs_before
.iter()
.zip(layer.coefficients.iter())
.filter(|(&a, &b)| (a - b).abs() > 1e-15)
.count();
let n_edges = n_out * n_in;
let max_changed = n_edges * (k + 1);
assert!(
changed <= max_changed,
"too many coefficients changed: {} > max {} (n_edges={} * (k+1)={})",
changed,
max_changed,
n_edges,
k + 1
);
assert!(changed > 0, "no coefficients changed — sparsity check moot");
}
#[test]
fn n_params_formula() {
let n_in = 4;
let n_out = 3;
let k = 3;
let g = 5;
let mut seed = make_seed();
let layer = KANLayer::new(n_in, n_out, k, g, 0.9, &mut seed);
let n_coeffs = g + k;
let expected = n_out * n_in * (n_coeffs + 2);
assert_eq!(
layer.n_params(),
expected,
"n_params should be {} but got {}",
expected,
layer.n_params()
);
}
#[test]
fn learning_y_equals_x_squared() {
let mut seed = make_seed();
let mut layer = KANLayer::new(1, 1, 3, 5, 0.9, &mut seed);
let lr = 0.01;
let n_samples = 20;
let mut xs: Vec<f64> = Vec::with_capacity(n_samples);
for i in 0..n_samples {
xs.push(-1.0 + 2.0 * i as f64 / (n_samples - 1) as f64);
}
let compute_mse = |layer: &KANLayer| -> f64 {
let mut mse = 0.0;
for &x in &xs {
let y_pred = layer.forward(&[x])[0];
let y_true = x * x;
mse += (y_pred - y_true).powi(2);
}
mse / xs.len() as f64
};
let initial_error = compute_mse(&layer);
for _ in 0..200 {
for &x in &xs {
let y_true = x * x;
let y_pred = layer.forward(&[x])[0];
let grad = 2.0 * (y_pred - y_true); let _input_grad = layer.backward(&[x], &[grad], lr);
}
}
let final_error = compute_mse(&layer);
assert!(
final_error < initial_error,
"error should decrease: initial={:.6}, final={:.6}",
initial_error,
final_error
);
assert!(
final_error < initial_error * 0.5,
"error should decrease substantially: initial={:.6}, final={:.6} (ratio={:.4})",
initial_error,
final_error,
final_error / initial_error
);
}
}