use crate::common::PlasticityConfig;
use crate::error::{ConfigError, IrithyllError, Result};
#[derive(Debug, Clone)]
pub struct ESNConfig {
pub n_reservoir: usize,
pub spectral_radius: f64,
pub leak_rate: f64,
pub input_scaling: f64,
pub bias_scaling: f64,
pub forgetting_factor: f64,
pub delta: f64,
pub seed: u64,
pub warmup: usize,
pub passthrough_input: bool,
pub readout_dim: Option<usize>,
pub plasticity: Option<PlasticityConfig>,
}
impl Default for ESNConfig {
fn default() -> Self {
Self {
n_reservoir: 100,
spectral_radius: 0.9,
leak_rate: 0.3,
input_scaling: 1.0,
bias_scaling: 0.0,
forgetting_factor: 0.998,
delta: 100.0,
seed: 42,
warmup: 50,
passthrough_input: true,
readout_dim: None,
plasticity: None,
}
}
}
impl ESNConfig {
pub fn builder() -> ESNConfigBuilder {
ESNConfigBuilder::new()
}
}
#[derive(Debug, Clone)]
pub struct ESNConfigBuilder {
config: ESNConfig,
}
impl ESNConfigBuilder {
pub fn new() -> Self {
Self {
config: ESNConfig::default(),
}
}
pub fn n_reservoir(mut self, n: usize) -> Self {
self.config.n_reservoir = n;
self
}
pub fn spectral_radius(mut self, sr: f64) -> Self {
self.config.spectral_radius = sr;
self
}
pub fn leak_rate(mut self, lr: f64) -> Self {
self.config.leak_rate = lr;
self
}
pub fn input_scaling(mut self, is: f64) -> Self {
self.config.input_scaling = is;
self
}
pub fn bias_scaling(mut self, bs: f64) -> Self {
self.config.bias_scaling = bs;
self
}
pub fn forgetting_factor(mut self, ff: f64) -> Self {
self.config.forgetting_factor = ff;
self
}
pub fn delta(mut self, delta: f64) -> Self {
self.config.delta = delta;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.config.seed = seed;
self
}
pub fn warmup(mut self, warmup: usize) -> Self {
self.config.warmup = warmup;
self
}
pub fn passthrough_input(mut self, pt: bool) -> Self {
self.config.passthrough_input = pt;
self
}
pub fn readout_dim(mut self, dim: usize) -> Self {
self.config.readout_dim = Some(dim);
self
}
pub fn plasticity(mut self, p: Option<PlasticityConfig>) -> Self {
self.config.plasticity = p;
self
}
pub fn build(mut self) -> Result<ESNConfig> {
let c = &self.config;
if c.n_reservoir < 1 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"n_reservoir",
"must be >= 1",
c.n_reservoir,
)));
}
if c.spectral_radius <= 0.0 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"spectral_radius",
"must be > 0",
c.spectral_radius,
)));
}
if c.leak_rate <= 0.0 || c.leak_rate > 1.0 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"leak_rate",
"must be in (0, 1]",
c.leak_rate,
)));
}
if c.input_scaling < 0.0 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"input_scaling",
"must be >= 0",
c.input_scaling,
)));
}
if c.bias_scaling < 0.0 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"bias_scaling",
"must be >= 0",
c.bias_scaling,
)));
}
if c.forgetting_factor <= 0.0 || c.forgetting_factor > 1.0 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"forgetting_factor",
"must be in (0, 1]",
c.forgetting_factor,
)));
}
if c.delta <= 0.0 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"delta",
"must be > 0",
c.delta,
)));
}
if self.config.readout_dim.is_none() && self.config.n_reservoir > 200 {
self.config.readout_dim = Some(self.config.n_reservoir.div_ceil(3).min(64));
}
Ok(self.config)
}
}
impl Default for ESNConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_builds() {
let config = ESNConfig::builder().build().unwrap();
assert_eq!(config.n_reservoir, 100);
assert!((config.spectral_radius - 0.9).abs() < 1e-12);
assert!((config.leak_rate - 0.3).abs() < 1e-12);
assert!((config.input_scaling - 1.0).abs() < 1e-12);
assert!((config.bias_scaling).abs() < 1e-12);
assert!((config.forgetting_factor - 0.998).abs() < 1e-12);
assert!((config.delta - 100.0).abs() < 1e-12);
assert_eq!(config.seed, 42);
assert_eq!(config.warmup, 50);
assert!(config.passthrough_input);
assert_eq!(
config.readout_dim, None,
"n=100 should not auto-default readout_dim",
);
}
#[test]
fn custom_config_builds() {
let config = ESNConfig::builder()
.n_reservoir(200)
.spectral_radius(0.95)
.leak_rate(0.5)
.input_scaling(0.5)
.bias_scaling(0.1)
.forgetting_factor(0.99)
.delta(50.0)
.seed(123)
.warmup(100)
.passthrough_input(false)
.build()
.unwrap();
assert_eq!(config.n_reservoir, 200);
assert!((config.spectral_radius - 0.95).abs() < 1e-12);
assert!((config.leak_rate - 0.5).abs() < 1e-12);
assert!((config.input_scaling - 0.5).abs() < 1e-12);
assert!((config.bias_scaling - 0.1).abs() < 1e-12);
assert!((config.forgetting_factor - 0.99).abs() < 1e-12);
assert!((config.delta - 50.0).abs() < 1e-12);
assert_eq!(config.seed, 123);
assert_eq!(config.warmup, 100);
assert!(!config.passthrough_input);
assert_eq!(
config.readout_dim, None,
"n=200 should not auto-default readout_dim",
);
}
#[test]
fn zero_reservoir_fails() {
let result = ESNConfig::builder().n_reservoir(0).build();
assert!(result.is_err());
}
#[test]
fn negative_spectral_radius_fails() {
let result = ESNConfig::builder().spectral_radius(-0.1).build();
assert!(result.is_err());
}
#[test]
fn zero_spectral_radius_fails() {
let result = ESNConfig::builder().spectral_radius(0.0).build();
assert!(result.is_err());
}
#[test]
fn leak_rate_zero_fails() {
let result = ESNConfig::builder().leak_rate(0.0).build();
assert!(result.is_err());
}
#[test]
fn leak_rate_above_one_fails() {
let result = ESNConfig::builder().leak_rate(1.01).build();
assert!(result.is_err());
}
#[test]
fn negative_input_scaling_fails() {
let result = ESNConfig::builder().input_scaling(-0.1).build();
assert!(result.is_err());
}
#[test]
fn forgetting_factor_zero_fails() {
let result = ESNConfig::builder().forgetting_factor(0.0).build();
assert!(result.is_err());
}
#[test]
fn delta_zero_fails() {
let result = ESNConfig::builder().delta(0.0).build();
assert!(result.is_err());
}
#[test]
fn readout_dim_auto_defaults_for_large_reservoir() {
let config = ESNConfig::builder().n_reservoir(300).build().unwrap();
assert_eq!(config.readout_dim, Some(64));
let config = ESNConfig::builder().n_reservoir(500).build().unwrap();
assert_eq!(config.readout_dim, Some(64));
let config = ESNConfig::builder().n_reservoir(250).build().unwrap();
assert_eq!(config.readout_dim, Some(64));
}
#[test]
fn readout_dim_no_auto_default_for_small_reservoir() {
let config = ESNConfig::builder().n_reservoir(50).build().unwrap();
assert_eq!(
config.readout_dim, None,
"small reservoirs should not auto-default readout_dim",
);
let config = ESNConfig::builder().n_reservoir(200).build().unwrap();
assert_eq!(config.readout_dim, None);
let config = ESNConfig::builder().n_reservoir(100).build().unwrap();
assert_eq!(config.readout_dim, None);
}
#[test]
fn readout_dim_explicit_overrides_auto() {
let config = ESNConfig::builder()
.n_reservoir(300)
.readout_dim(128)
.build()
.unwrap();
assert_eq!(
config.readout_dim,
Some(128),
"explicit readout_dim should not be overridden by auto-default",
);
}
}