use scirs2_core::ndarray::{Array1, Array2};
use crate::NeuralError;
use super::{
rational::{RationalActivation, RationalConfig},
spline::{BSplineActivation, SplineConfig},
KanResult,
};
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum ActivationType {
BSpline(SplineConfig),
Rational(RationalConfig),
}
impl Default for ActivationType {
fn default() -> Self {
ActivationType::BSpline(SplineConfig::default())
}
}
#[derive(Debug, Clone)]
pub struct KanLayerConfig {
pub n_in: usize,
pub n_out: usize,
pub activation_type: ActivationType,
}
impl Default for KanLayerConfig {
fn default() -> Self {
Self {
n_in: 4,
n_out: 4,
activation_type: ActivationType::default(),
}
}
}
pub struct KanLayer {
config: KanLayerConfig,
spline_activations: Option<Vec<BSplineActivation>>,
rational_activations: Option<Vec<RationalActivation>>,
}
impl KanLayer {
pub fn new(config: KanLayerConfig) -> KanResult<Self> {
if config.n_in == 0 {
return Err(NeuralError::InvalidArgument(
"KanLayer: n_in must be > 0".to_string(),
));
}
if config.n_out == 0 {
return Err(NeuralError::InvalidArgument(
"KanLayer: n_out must be > 0".to_string(),
));
}
let n_edges = config.n_in * config.n_out;
match &config.activation_type {
ActivationType::BSpline(sc) => {
let mut activations = Vec::with_capacity(n_edges);
for _ in 0..n_edges {
activations.push(BSplineActivation::new(sc)?);
}
Ok(Self {
config,
spline_activations: Some(activations),
rational_activations: None,
})
}
ActivationType::Rational(rc) => {
let mut activations = Vec::with_capacity(n_edges);
for _ in 0..n_edges {
activations.push(RationalActivation::new(rc)?);
}
Ok(Self {
config,
spline_activations: None,
rational_activations: Some(activations),
})
}
}
}
fn eval_activation(&self, i: usize, j: usize, x: f64) -> f64 {
let idx = i * self.config.n_out + j;
match (&self.spline_activations, &self.rational_activations) {
(Some(splines), _) => splines[idx].evaluate(x),
(_, Some(rationals)) => rationals[idx].evaluate(x),
_ => 0.0, }
}
pub fn forward(&self, input: &Array1<f64>) -> KanResult<Array1<f64>> {
if input.len() != self.config.n_in {
return Err(NeuralError::DimensionMismatch(format!(
"KanLayer::forward expected n_in={} but got {}",
self.config.n_in,
input.len()
)));
}
let mut output = Array1::zeros(self.config.n_out);
for i in 0..self.config.n_in {
let x_i = input[i];
for j in 0..self.config.n_out {
output[j] += self.eval_activation(i, j, x_i);
}
}
Ok(output)
}
pub fn forward_batch(&self, input: &Array2<f64>) -> KanResult<Array2<f64>> {
let (batch, n_in) = input.dim();
if n_in != self.config.n_in {
return Err(NeuralError::DimensionMismatch(format!(
"KanLayer::forward_batch expected n_in={} but got {n_in}",
self.config.n_in
)));
}
let mut output = Array2::zeros((batch, self.config.n_out));
for b in 0..batch {
for i in 0..self.config.n_in {
let x_i = input[(b, i)];
for j in 0..self.config.n_out {
output[(b, j)] += self.eval_activation(i, j, x_i);
}
}
}
Ok(output)
}
pub fn n_params(&self) -> usize {
match (&self.spline_activations, &self.rational_activations) {
(Some(splines), _) => splines.iter().map(|s| s.n_params()).sum(),
(_, Some(rationals)) => rationals.iter().map(|r| r.n_params()).sum(),
_ => 0,
}
}
pub fn n_in(&self) -> usize {
self.config.n_in
}
pub fn n_out(&self) -> usize {
self.config.n_out
}
pub fn prune_edges(&mut self, threshold: f64) -> usize {
let mut pruned = 0;
if let Some(ref mut splines) = self.spline_activations {
for sp in splines.iter_mut() {
let l1: f64 = sp.coefficients.iter().map(|c| c.abs()).sum();
if l1 < threshold {
sp.coefficients.fill(0.0);
pruned += 1;
}
}
}
if let Some(ref mut rationals) = self.rational_activations {
for ra in rationals.iter_mut() {
let l1: f64 = ra
.p_coeffs
.iter()
.chain(ra.q_coeffs.iter())
.map(|c| c.abs())
.sum();
if l1 < threshold {
ra.p_coeffs.fill(0.0);
ra.q_coeffs.fill(0.0);
pruned += 1;
}
}
}
pruned
}
pub fn spline_activations_mut(&mut self) -> Option<&mut Vec<BSplineActivation>> {
self.spline_activations.as_mut()
}
pub fn rational_activations_mut(&mut self) -> Option<&mut Vec<RationalActivation>> {
self.rational_activations.as_mut()
}
pub fn config(&self) -> &KanLayerConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct KanConfig {
pub layer_widths: Vec<usize>,
pub activation_type: ActivationType,
}
impl Default for KanConfig {
fn default() -> Self {
Self {
layer_widths: vec![2, 8, 1],
activation_type: ActivationType::default(),
}
}
}
pub struct KanNetwork {
layers: Vec<KanLayer>,
config: KanConfig,
}
impl KanNetwork {
pub fn new(config: KanConfig) -> KanResult<Self> {
if config.layer_widths.len() < 2 {
return Err(NeuralError::InvalidArgument(
"KanNetwork: layer_widths must have at least 2 entries (input and output)"
.to_string(),
));
}
for (i, &w) in config.layer_widths.iter().enumerate() {
if w == 0 {
return Err(NeuralError::InvalidArgument(format!(
"KanNetwork: layer_widths[{i}] must be > 0"
)));
}
}
let mut layers = Vec::with_capacity(config.layer_widths.len() - 1);
for pair in config.layer_widths.windows(2) {
let layer_cfg = KanLayerConfig {
n_in: pair[0],
n_out: pair[1],
activation_type: config.activation_type.clone(),
};
layers.push(KanLayer::new(layer_cfg)?);
}
Ok(Self { layers, config })
}
pub fn forward(&self, input: &Array1<f64>) -> KanResult<Array1<f64>> {
let mut x = input.clone();
for layer in &self.layers {
x = layer.forward(&x)?;
}
Ok(x)
}
pub fn forward_batch(&self, input: &Array2<f64>) -> KanResult<Array2<f64>> {
let mut x = input.clone();
for layer in &self.layers {
x = layer.forward_batch(&x)?;
}
Ok(x)
}
pub fn n_params(&self) -> usize {
self.layers.iter().map(|l| l.n_params()).sum()
}
pub fn n_layers(&self) -> usize {
self.layers.len()
}
pub fn prune_edges(&mut self, threshold: f64) -> usize {
self.layers
.iter_mut()
.map(|l| l.prune_edges(threshold))
.sum()
}
pub fn layers(&self) -> &[KanLayer] {
&self.layers
}
pub fn layers_mut(&mut self) -> &mut Vec<KanLayer> {
&mut self.layers
}
pub fn config(&self) -> &KanConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layers::kan::rational::RationalConfig;
use crate::layers::kan::spline::SplineConfig;
fn default_spline_layer(n_in: usize, n_out: usize) -> KanLayer {
KanLayer::new(KanLayerConfig {
n_in,
n_out,
activation_type: ActivationType::BSpline(SplineConfig::default()),
})
.expect("valid layer config")
}
#[test]
fn kan_layer_forward_shape() {
let layer = default_spline_layer(3, 5);
let input = Array1::zeros(3);
let output = layer.forward(&input).expect("forward ok");
assert_eq!(output.len(), 5, "Output length mismatch");
}
#[test]
fn kan_layer_batch_forward_shape() {
let layer = default_spline_layer(4, 6);
let input = Array2::zeros((8, 4));
let output = layer.forward_batch(&input).expect("batch forward ok");
assert_eq!(output.dim(), (8, 6), "Batch output shape mismatch");
}
#[test]
fn kan_layer_zero_coeffs_outputs_zero() {
let layer = default_spline_layer(3, 4);
let input = Array1::from_vec(vec![0.5, -0.3, 0.9]);
let output = layer.forward(&input).expect("forward ok");
for (j, &v) in output.iter().enumerate() {
assert!(
v.abs() < 1e-14,
"Expected 0 at output[{j}] but got {v}"
);
}
}
#[test]
fn kan_layer_dimension_mismatch() {
let layer = default_spline_layer(3, 4);
let bad_input = Array1::zeros(5);
assert!(
layer.forward(&bad_input).is_err(),
"Should return error on dimension mismatch"
);
}
#[test]
fn kan_layer_single_input_single_output() {
let layer = default_spline_layer(1, 1);
let input = Array1::from_vec(vec![0.1]);
let output = layer.forward(&input).expect("forward ok");
assert_eq!(output.len(), 1);
}
#[test]
fn kan_rational_layer() {
let config = KanLayerConfig {
n_in: 3,
n_out: 2,
activation_type: ActivationType::Rational(RationalConfig::default()),
};
let layer = KanLayer::new(config).expect("valid rational layer");
let input = Array1::from_vec(vec![0.1, -0.5, 0.8]);
let output = layer.forward(&input).expect("forward ok");
assert_eq!(output.len(), 2);
for (j, &v) in output.iter().enumerate() {
assert!(v.abs() < 1e-14, "Expected 0 at output[{j}] got {v}");
}
}
#[test]
fn kan_prune_edges() {
let mut layer = default_spline_layer(2, 3);
let pruned = layer.prune_edges(1.0);
assert_eq!(pruned, 6, "All 2×3=6 edges should be pruned");
if let Some(splines) = layer.spline_activations_mut() {
splines[0].coefficients[0] = 10.0;
}
let pruned2 = layer.prune_edges(1.0);
assert_eq!(pruned2, 5);
}
fn default_network() -> KanNetwork {
KanNetwork::new(KanConfig {
layer_widths: vec![2, 4, 3, 1],
activation_type: ActivationType::BSpline(SplineConfig::default()),
})
.expect("valid network config")
}
#[test]
fn kan_network_forward() {
let net = default_network();
let input = Array1::from_vec(vec![0.3, -0.7]);
let output = net.forward(&input).expect("network forward ok");
assert_eq!(output.len(), 1, "Output should be scalar (width=1)");
}
#[test]
fn kan_network_batch_forward() {
let net = default_network();
let input = Array2::zeros((5, 2));
let output = net.forward_batch(&input).expect("batch forward ok");
assert_eq!(output.dim(), (5, 1), "Batch output shape mismatch");
}
#[test]
fn kan_network_n_params() {
let net = default_network();
let n = net.n_params();
assert_eq!(n, 184, "n_params mismatch: got {n}");
}
#[test]
fn kan_invalid_config_single_width() {
let result = KanNetwork::new(KanConfig {
layer_widths: vec![4],
activation_type: ActivationType::default(),
});
assert!(result.is_err(), "Should fail with single-element layer_widths");
}
#[test]
fn kan_network_n_layers() {
let net = default_network();
assert_eq!(net.n_layers(), 3);
}
#[test]
fn kan_network_prune_all_zero() {
let mut net = default_network();
let pruned = net.prune_edges(1.0);
assert_eq!(pruned, 23, "All 23 edges should be pruned; got {pruned}");
}
}