use crate::common::PlasticityConfig;
use crate::error::ConfigError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum GateMode {
#[default]
None,
ResidualMix,
LstmFull,
}
#[derive(Debug, Clone)]
pub struct KANConfig {
pub layer_sizes: Vec<usize>,
pub spline_order: usize,
pub grid_size: usize,
pub learning_rate: f64,
pub momentum: f64,
pub coefficient_decay: f64,
pub gate_mode: GateMode,
pub seed: u64,
pub plasticity: Option<PlasticityConfig>,
}
impl Default for KANConfig {
fn default() -> Self {
Self {
layer_sizes: vec![1, 5, 1],
spline_order: 3,
grid_size: 8,
learning_rate: 0.1,
momentum: 0.0,
coefficient_decay: 0.0,
gate_mode: GateMode::None,
seed: 42,
plasticity: None,
}
}
}
impl std::fmt::Display for KANConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"KANConfig(layers={:?}, k={}, g={}, lr={}, momentum={}, decay={}, gate_mode={:?}, seed={})",
self.layer_sizes,
self.spline_order,
self.grid_size,
self.learning_rate,
self.momentum,
self.coefficient_decay,
self.gate_mode,
self.seed
)
}
}
pub struct KANConfigBuilder {
config: KANConfig,
}
impl KANConfig {
pub fn builder() -> KANConfigBuilder {
KANConfigBuilder {
config: KANConfig::default(),
}
}
}
impl KANConfigBuilder {
pub fn layer_sizes(mut self, sizes: Vec<usize>) -> Self {
self.config.layer_sizes = sizes;
self
}
pub fn spline_order(mut self, k: usize) -> Self {
self.config.spline_order = k;
self
}
pub fn grid_size(mut self, g: usize) -> Self {
self.config.grid_size = g;
self
}
pub fn learning_rate(mut self, lr: f64) -> Self {
self.config.learning_rate = lr;
self
}
pub fn momentum(mut self, m: f64) -> Self {
self.config.momentum = m;
self
}
pub fn coefficient_decay(mut self, d: f64) -> Self {
self.config.coefficient_decay = d;
self
}
pub fn gate_mode(mut self, mode: GateMode) -> Self {
self.config.gate_mode = mode;
self
}
#[deprecated(
since = "10.0.0",
note = "Use `.gate_mode(GateMode::ResidualMix)` or `.gate_mode(GateMode::None)` instead"
)]
pub fn temporal(mut self, t: bool) -> Self {
self.config.gate_mode = if t {
GateMode::ResidualMix
} else {
GateMode::None
};
self
}
pub fn seed(mut self, s: u64) -> Self {
self.config.seed = s;
self
}
pub fn plasticity(mut self, p: Option<PlasticityConfig>) -> Self {
self.config.plasticity = p;
self
}
pub fn build(self) -> Result<KANConfig, ConfigError> {
let c = &self.config;
if c.layer_sizes.len() < 2 {
return Err(ConfigError::invalid(
"layer_sizes",
format!(
"need at least 2 layers (input + output), got {}",
c.layer_sizes.len()
),
));
}
for (i, &size) in c.layer_sizes.iter().enumerate() {
if size == 0 {
return Err(ConfigError::out_of_range(
"layer_sizes",
"all layer sizes must be > 0",
format!("layer_sizes[{}] = 0", i),
));
}
}
if c.layer_sizes[c.layer_sizes.len() - 1] != 1 {
return Err(ConfigError::invalid(
"layer_sizes",
format!(
"last layer must be 1 (regression output), got {}",
c.layer_sizes[c.layer_sizes.len() - 1]
),
));
}
if c.spline_order == 0 {
return Err(ConfigError::out_of_range(
"spline_order",
"must be > 0",
c.spline_order,
));
}
if c.grid_size == 0 {
return Err(ConfigError::out_of_range(
"grid_size",
"must be > 0",
c.grid_size,
));
}
if c.learning_rate <= 0.0 {
return Err(ConfigError::out_of_range(
"learning_rate",
"must be > 0",
c.learning_rate,
));
}
if c.momentum < 0.0 || c.momentum >= 1.0 {
return Err(ConfigError::out_of_range(
"momentum",
"must be in [0, 1)",
c.momentum,
));
}
if c.coefficient_decay < 0.0 || c.coefficient_decay >= 1.0 {
return Err(ConfigError::out_of_range(
"coefficient_decay",
"must be in [0, 1)",
c.coefficient_decay,
));
}
Ok(self.config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_builder_default() {
let config = KANConfig::builder().build().unwrap();
assert_eq!(config.layer_sizes, vec![1, 5, 1]);
assert_eq!(config.spline_order, 3);
assert_eq!(config.grid_size, 8);
assert!(
(config.learning_rate - 0.1).abs() < 1e-12,
"default learning_rate should be 0.1, got {}",
config.learning_rate
);
assert!(
config.momentum.abs() < 1e-12,
"default momentum should be 0.0, got {}",
config.momentum
);
assert!(
config.coefficient_decay.abs() < 1e-12,
"default coefficient_decay should be 0.0, got {}",
config.coefficient_decay
);
}
#[test]
fn config_builder_custom() {
let config = KANConfig::builder()
.layer_sizes(vec![3, 10, 1])
.spline_order(4)
.grid_size(8)
.learning_rate(0.005)
.seed(123)
.build()
.unwrap();
assert_eq!(config.layer_sizes, vec![3, 10, 1]);
assert_eq!(config.spline_order, 4);
assert_eq!(config.grid_size, 8);
assert!((config.learning_rate - 0.005).abs() < 1e-12);
assert_eq!(config.seed, 123);
}
#[test]
fn config_rejects_single_layer() {
let result = KANConfig::builder().layer_sizes(vec![5]).build();
assert!(result.is_err(), "single layer should be rejected");
}
#[test]
fn config_rejects_zero_size() {
let result = KANConfig::builder().layer_sizes(vec![0, 1]).build();
assert!(result.is_err(), "zero-size layer should be rejected");
}
#[test]
fn config_rejects_non_unit_output() {
let result = KANConfig::builder().layer_sizes(vec![3, 5, 2]).build();
assert!(result.is_err(), "non-unit output layer should be rejected");
}
#[test]
fn config_rejects_zero_spline_order() {
let result = KANConfig::builder()
.layer_sizes(vec![3, 1])
.spline_order(0)
.build();
assert!(result.is_err(), "zero spline order should be rejected");
}
#[test]
fn config_rejects_zero_grid_size() {
let result = KANConfig::builder()
.layer_sizes(vec![3, 1])
.grid_size(0)
.build();
assert!(result.is_err(), "zero grid size should be rejected");
}
#[test]
fn config_display() {
let config = KANConfig::builder()
.layer_sizes(vec![3, 10, 1])
.build()
.unwrap();
let s = format!("{config}");
assert!(s.contains("layers="), "display should contain layers");
assert!(s.contains("k=3"), "display should contain spline order");
assert!(s.contains("momentum="), "display should contain momentum");
assert!(s.contains("decay="), "display should contain decay");
}
#[test]
fn config_clone() {
let config = KANConfig::builder()
.layer_sizes(vec![3, 10, 1])
.seed(99)
.build()
.unwrap();
let cloned = config.clone();
assert_eq!(cloned.layer_sizes, config.layer_sizes);
assert_eq!(cloned.seed, config.seed);
}
#[test]
fn config_rejects_zero_learning_rate() {
let result = KANConfig::builder()
.layer_sizes(vec![3, 5, 1])
.learning_rate(0.0)
.build();
assert!(result.is_err(), "learning_rate=0 must be rejected");
let result = KANConfig::builder()
.layer_sizes(vec![3, 5, 1])
.learning_rate(-0.1)
.build();
assert!(result.is_err(), "negative learning_rate must be rejected");
}
#[test]
fn config_rejects_invalid_momentum() {
let result = KANConfig::builder()
.layer_sizes(vec![3, 5, 1])
.momentum(1.0)
.build();
assert!(result.is_err(), "momentum=1 must be rejected");
let result = KANConfig::builder()
.layer_sizes(vec![3, 5, 1])
.momentum(-0.1)
.build();
assert!(result.is_err(), "negative momentum must be rejected");
}
#[test]
fn config_rejects_invalid_coefficient_decay() {
let result = KANConfig::builder()
.layer_sizes(vec![3, 5, 1])
.coefficient_decay(1.0)
.build();
assert!(result.is_err(), "coefficient_decay=1 must be rejected");
let result = KANConfig::builder()
.layer_sizes(vec![3, 5, 1])
.coefficient_decay(-0.1)
.build();
assert!(
result.is_err(),
"negative coefficient_decay must be rejected"
);
}
#[test]
fn config_accepts_zero_coefficient_decay() {
let result = KANConfig::builder()
.layer_sizes(vec![3, 5, 1])
.coefficient_decay(0.0)
.build();
assert!(
result.is_ok(),
"coefficient_decay=0 (disabled) should be valid"
);
}
#[test]
fn config_accepts_zero_momentum() {
let result = KANConfig::builder()
.layer_sizes(vec![3, 5, 1])
.momentum(0.0)
.build();
assert!(result.is_ok(), "momentum=0 (disabled) should be valid");
}
#[test]
fn plasticity_disabled_by_default() {
let config = KANConfig::builder()
.layer_sizes(vec![3, 5, 1])
.build()
.unwrap();
assert!(
config.plasticity.is_none(),
"plasticity should default to None"
);
}
#[test]
fn plasticity_enabled_via_config() {
use crate::common::PlasticityConfig;
let config = KANConfig::builder()
.layer_sizes(vec![3, 5, 1])
.plasticity(Some(PlasticityConfig::default()))
.build()
.unwrap();
assert!(config.plasticity.is_some());
}
}