use std::fmt;
use crate::learner::StreamingLearner;
use crate::learners::RecursiveLeastSquares;
use irithyll_core::continual::{ContinualStrategy, NeuronRegeneration};
use irithyll_core::reservoir::{CycleReservoir, Xorshift64Rng};
use super::esn_config::ESNConfig;
pub struct EchoStateNetwork {
config: ESNConfig,
reservoir: CycleReservoir,
rls: RecursiveLeastSquares,
total_seen: u64,
samples_trained: u64,
n_inputs: Option<usize>,
prev_prediction: f64,
prev_change: f64,
prev_prev_change: f64,
alignment_ewma: f64,
state_activity_ewma: Vec<f64>,
readout_projection: Option<Vec<f64>>,
plasticity_guard: Option<NeuronRegeneration>,
prev_state_energy: Vec<f64>,
}
impl EchoStateNetwork {
pub fn new(config: ESNConfig) -> Self {
let rls = RecursiveLeastSquares::with_delta(config.forgetting_factor, config.delta);
let plasticity_guard = config.plasticity.as_ref().map(|p| {
NeuronRegeneration::new(
config.n_reservoir,
1, p.regen_fraction,
p.regen_interval,
p.utility_alpha,
config.seed.wrapping_add(0x_DEAD_CAFE),
)
});
let prev_state_energy = vec![0.0; config.n_reservoir];
Self {
reservoir: CycleReservoir::new(
config.n_reservoir,
1, config.spectral_radius,
config.input_scaling,
config.leak_rate,
config.bias_scaling,
config.seed,
),
rls,
total_seen: 0,
samples_trained: 0,
n_inputs: None,
config,
prev_prediction: 0.0,
prev_change: 0.0,
prev_prev_change: 0.0,
alignment_ewma: 0.0,
state_activity_ewma: Vec::new(),
readout_projection: None,
plasticity_guard,
prev_state_energy,
}
}
#[inline]
pub fn past_warmup(&self) -> bool {
self.total_seen > self.config.warmup as u64
}
fn build_readout_features(&self, input: &[f64]) -> Vec<f64> {
let state = self.reservoir.state();
let full_features = if self.config.passthrough_input {
let mut features = Vec::with_capacity(state.len() + input.len());
features.extend_from_slice(state);
features.extend_from_slice(input);
features
} else {
state.to_vec()
};
if let Some(ref proj) = self.readout_projection {
let k = self.config.readout_dim.unwrap();
let n = full_features.len();
let mut projected = vec![0.0; k];
for (i, p_i) in projected.iter_mut().enumerate() {
let row_start = i * n;
let mut sum = 0.0;
for (j, &f) in full_features.iter().enumerate() {
sum += proj[row_start + j] * f;
}
*p_i = sum;
}
projected
} else {
full_features
}
}
pub fn config(&self) -> &ESNConfig {
&self.config
}
pub fn total_seen(&self) -> u64 {
self.total_seen
}
#[inline]
pub fn prediction_uncertainty(&self) -> f64 {
self.rls.noise_variance().sqrt()
}
pub fn reservoir_state(&self) -> &[f64] {
self.reservoir.state()
}
fn ensure_reservoir(&mut self, n_inputs: usize) {
if self.n_inputs.is_none() || self.n_inputs != Some(n_inputs) {
self.n_inputs = Some(n_inputs);
self.reservoir = CycleReservoir::new(
self.config.n_reservoir,
n_inputs,
self.config.spectral_radius,
self.config.input_scaling,
self.config.leak_rate,
self.config.bias_scaling,
self.config.seed,
);
self.state_activity_ewma = vec![0.0; self.config.n_reservoir];
self.init_readout_projection(n_inputs);
}
}
fn init_readout_projection(&mut self, n_inputs: usize) {
if let Some(k) = self.config.readout_dim {
let n_full = if self.config.passthrough_input {
self.config.n_reservoir + n_inputs
} else {
self.config.n_reservoir
};
if k >= n_full {
self.readout_projection = None;
return;
}
let scale = 1.0 / (k as f64).sqrt();
let mut rng = Xorshift64Rng::new(self.config.seed ^ 0xCAFE_BABE);
let proj: Vec<f64> = (0..k * n_full)
.map(|_| if rng.next_f64() < 0.5 { -scale } else { scale })
.collect();
self.readout_projection = Some(proj);
} else {
self.readout_projection = None;
}
}
}
impl StreamingLearner for EchoStateNetwork {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
self.ensure_reservoir(features.len());
if !features.iter().all(|f| f.is_finite()) {
return;
}
self.total_seen += 1;
if self.past_warmup() {
let readout_features = self.build_readout_features(features);
if !readout_features.iter().all(|f| f.is_finite()) {
self.reservoir.update(features);
return;
}
let current_pred = self.rls.predict(&readout_features);
let current_change = current_pred - self.prev_prediction;
if self.samples_trained > 0 {
let acceleration = current_change - self.prev_change;
let prev_acceleration = self.prev_change - self.prev_prev_change;
let agreement = if acceleration.abs() > 1e-15 && prev_acceleration.abs() > 1e-15 {
if (acceleration > 0.0) == (prev_acceleration > 0.0) {
1.0
} else {
-1.0
}
} else {
0.0
};
const ALIGN_ALPHA: f64 = 0.05;
if self.samples_trained == 1 {
self.alignment_ewma = agreement;
} else {
self.alignment_ewma =
(1.0 - ALIGN_ALPHA) * self.alignment_ewma + ALIGN_ALPHA * agreement;
}
}
self.prev_prev_change = self.prev_change;
self.prev_change = current_change;
self.prev_prediction = current_pred;
self.rls.train_one(&readout_features, target, weight);
self.samples_trained += 1;
}
self.reservoir.update(features);
const STATE_ALPHA: f64 = 0.01;
let state = self.reservoir.state();
for (ewma, &s) in self.state_activity_ewma.iter_mut().zip(state.iter()) {
*ewma = (1.0 - STATE_ALPHA) * *ewma + STATE_ALPHA * s.abs();
}
if let Some(ref mut guard) = self.plasticity_guard {
let state = self.reservoir.state();
let mut unit_energy: Vec<f64> = state.iter().map(|s| s.abs()).collect();
guard.pre_update(&self.prev_state_energy, &mut unit_energy);
guard.post_update(&self.prev_state_energy);
let mut reinit_rng = self.config.seed.wrapping_add(self.total_seen);
for j in 0..guard.n_groups() {
if guard.was_regenerated(j) {
self.reservoir.reinitialize_unit(j, &mut reinit_rng);
}
}
self.prev_state_energy = unit_energy;
}
}
fn predict(&self, features: &[f64]) -> f64 {
if !self.past_warmup() || self.n_inputs.is_none() {
return 0.0;
}
let readout_features = self.build_readout_features(features);
self.rls.predict(&readout_features)
}
#[inline]
fn n_samples_seen(&self) -> u64 {
self.samples_trained
}
fn reset(&mut self) {
self.reservoir.reset();
self.rls.reset();
self.total_seen = 0;
self.samples_trained = 0;
self.prev_prediction = 0.0;
self.prev_change = 0.0;
self.prev_prev_change = 0.0;
self.alignment_ewma = 0.0;
self.state_activity_ewma.fill(0.0);
if let Some(n_inputs) = self.n_inputs {
self.init_readout_projection(n_inputs);
}
if let Some(ref mut guard) = self.plasticity_guard {
guard.reset();
}
self.prev_state_energy.fill(0.0);
}
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
<Self as crate::learner::Tunable>::diagnostics_array(self)
}
#[allow(deprecated)]
fn readout_weights(&self) -> Option<&[f64]> {
let w = <Self as crate::learner::HasReadout>::readout_weights(self);
if w.is_empty() {
None
} else {
Some(w)
}
}
#[allow(deprecated)]
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
<Self as crate::learner::Tunable>::adjust_config(self, lr_multiplier, lambda_delta);
}
}
impl crate::learner::Tunable for EchoStateNetwork {
fn diagnostics_array(&self) -> [f64; 5] {
use crate::automl::DiagnosticSource;
match self.config_diagnostics() {
Some(d) => [
d.residual_alignment,
d.regularization_sensitivity,
d.depth_sufficiency,
d.effective_dof,
d.uncertainty,
],
None => [0.0; 5],
}
}
fn adjust_config(&mut self, lr_multiplier: f64, _lambda_delta: f64) {
self.config.spectral_radius = (self.config.spectral_radius * lr_multiplier).min(1.5);
}
}
impl crate::learner::HasReadout for EchoStateNetwork {
fn readout_weights(&self) -> &[f64] {
self.rls.weights()
}
}
pub type StreamingESN = EchoStateNetwork;
impl fmt::Debug for EchoStateNetwork {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EchoStateNetwork")
.field("n_reservoir", &self.config.n_reservoir)
.field("spectral_radius", &self.config.spectral_radius)
.field("leak_rate", &self.config.leak_rate)
.field("warmup", &self.config.warmup)
.field("total_seen", &self.total_seen)
.field("samples_trained", &self.samples_trained)
.field("past_warmup", &self.past_warmup())
.finish()
}
}
impl crate::automl::DiagnosticSource for EchoStateNetwork {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
let rls_saturation = {
let p = self.rls.p_matrix();
let d = self.rls.weights().len();
if d > 0 && self.rls.delta() > 0.0 {
let trace: f64 = (0..d).map(|i| p[i * d + i]).sum();
(1.0 - trace / (self.rls.delta() * d as f64)).clamp(0.0, 1.0)
} else {
0.0
}
};
let reservoir_entropy = {
let sum: f64 = self.state_activity_ewma.iter().sum();
if sum > 1e-15 && self.state_activity_ewma.len() > 1 {
let n = self.state_activity_ewma.len();
let ln_n = (n as f64).ln();
let mut h = 0.0;
for &a in &self.state_activity_ewma {
let p = a / sum;
if p > 1e-15 {
h -= p * p.ln();
}
}
(h / ln_n).clamp(0.0, 1.0)
} else {
0.0
}
};
let depth_sufficiency = 0.5 * rls_saturation + 0.5 * reservoir_entropy;
let w = self.rls.weights();
let effective_dof = if !w.is_empty() {
let sq_sum: f64 = w.iter().map(|wi| wi * wi).sum();
sq_sum.sqrt() / (w.len() as f64).sqrt()
} else {
0.0
};
Some(crate::automl::ConfigDiagnostics {
residual_alignment: self.alignment_ewma,
regularization_sensitivity: 1.0 - self.config.forgetting_factor,
depth_sufficiency,
effective_dof,
uncertainty: self.prediction_uncertainty(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_esn() -> EchoStateNetwork {
let config = ESNConfig::builder()
.n_reservoir(50)
.warmup(10)
.build()
.unwrap();
EchoStateNetwork::new(config)
}
#[test]
fn cold_start_returns_zero() {
let esn = default_esn();
assert_eq!(esn.predict(&[1.0]), 0.0);
assert_eq!(esn.n_samples_seen(), 0);
assert!(!esn.past_warmup());
}
#[test]
fn warmup_period_no_training() {
let mut esn = default_esn();
for i in 0..10 {
esn.train(&[i as f64 * 0.1], 0.0);
}
assert!(
!esn.past_warmup(),
"10th sample is last warmup sample, not yet past warmup"
);
assert_eq!(esn.n_samples_seen(), 0);
esn.train(&[1.0], 0.0);
assert!(esn.past_warmup(), "11th sample should be past warmup");
assert_eq!(esn.n_samples_seen(), 1);
}
#[test]
fn trains_after_warmup() {
let mut esn = default_esn();
for i in 0..15 {
esn.train(&[i as f64 * 0.1], 0.0);
}
assert_eq!(esn.n_samples_seen(), 5);
assert_eq!(esn.total_seen(), 15);
}
#[test]
fn predict_is_side_effect_free() {
let mut esn = default_esn();
for i in 0..20 {
esn.train(&[i as f64 * 0.1], i as f64);
}
let total_before = esn.total_seen();
let samples_before = esn.n_samples_seen();
let _ = esn.predict(&[99.0]);
assert_eq!(
esn.total_seen(),
total_before,
"predict should not increment total_seen",
);
assert_eq!(
esn.n_samples_seen(),
samples_before,
"predict should not increment samples_trained",
);
}
#[test]
fn reset_clears_learned_state() {
let mut esn = default_esn();
for i in 0..30 {
esn.train(&[i as f64 * 0.1], i as f64);
}
assert!(esn.n_samples_seen() > 0);
esn.reset();
assert_eq!(esn.n_samples_seen(), 0);
assert_eq!(esn.total_seen(), 0);
assert!(!esn.past_warmup());
for &s in esn.reservoir_state() {
assert_eq!(s, 0.0);
}
}
#[test]
fn deterministic_with_same_seed() {
let config1 = ESNConfig::builder()
.n_reservoir(30)
.warmup(5)
.seed(42)
.build()
.unwrap();
let config2 = config1.clone();
let mut esn1 = EchoStateNetwork::new(config1);
let mut esn2 = EchoStateNetwork::new(config2);
for i in 0..20 {
let x = (i as f64 * 0.3).sin();
let y = (i as f64 * 0.3 + 0.3).sin();
esn1.train(&[x], y);
esn2.train(&[x], y);
}
let pred1 = esn1.predict(&[0.5]);
let pred2 = esn2.predict(&[0.5]);
assert!(
(pred1 - pred2).abs() < 1e-12,
"same seed should produce identical predictions: {} vs {}",
pred1,
pred2,
);
}
#[test]
fn sine_wave_regression() {
let config = ESNConfig::builder()
.n_reservoir(100)
.spectral_radius(0.9)
.leak_rate(0.3)
.input_scaling(1.0)
.warmup(50)
.forgetting_factor(0.999)
.seed(42)
.build()
.unwrap();
let mut esn = EchoStateNetwork::new(config);
let dt = 0.1;
let n_train = 500;
for i in 0..n_train {
let t = i as f64 * dt;
let x = t.sin();
let target = (t + dt).sin();
esn.train(&[x], target);
}
let mut total_error = 0.0;
let n_test = 50;
for i in n_train..(n_train + n_test) {
let t = i as f64 * dt;
let x = t.sin();
let target = (t + dt).sin();
esn.train(&[x], target);
let pred = esn.predict(&[x]);
total_error += (pred - target).abs();
}
let mae = total_error / n_test as f64;
assert!(mae < 0.5, "ESN sine wave MAE should be < 0.5, got {}", mae,);
}
#[test]
fn trait_object_compatibility() {
let config = ESNConfig::builder()
.n_reservoir(20)
.warmup(5)
.build()
.unwrap();
let esn = EchoStateNetwork::new(config);
let mut boxed: Box<dyn StreamingLearner> = Box::new(esn);
for i in 0..20 {
boxed.train(&[i as f64 * 0.1], i as f64);
}
let pred = boxed.predict(&[1.0]);
assert!(pred.is_finite());
}
#[test]
fn no_passthrough_input() {
let config = ESNConfig::builder()
.n_reservoir(30)
.warmup(5)
.passthrough_input(false)
.build()
.unwrap();
let mut esn = EchoStateNetwork::new(config);
for i in 0..30 {
esn.train(&[i as f64 * 0.1], i as f64);
}
let pred = esn.predict(&[1.0]);
assert!(
pred.is_finite(),
"prediction should be finite without passthrough"
);
}
#[test]
fn reservoir_state_evolves() {
let mut esn = default_esn();
esn.train(&[1.0], 0.0);
esn.train(&[2.0], 0.0);
let nonzero = esn
.reservoir_state()
.iter()
.filter(|&&s| s.abs() > 1e-15)
.count();
assert!(nonzero > 0, "reservoir state should be nonzero after input",);
}
#[test]
fn esn_prediction_uncertainty() {
let config = ESNConfig::builder()
.n_reservoir(50)
.warmup(10)
.build()
.unwrap();
let mut esn = EchoStateNetwork::new(config);
assert!(
esn.prediction_uncertainty().abs() < 1e-15,
"uncertainty should be 0.0 before training, got {}",
esn.prediction_uncertainty()
);
for i in 0..100 {
let t = i as f64 * 0.1;
let x = t.sin();
let target = (t + 0.1).sin();
esn.train(&[x], target);
}
let unc = esn.prediction_uncertainty();
assert!(
unc > 0.0,
"prediction_uncertainty should be > 0 after training, got {}",
unc
);
assert!(
unc.is_finite(),
"prediction_uncertainty should be finite, got {}",
unc
);
}
#[test]
#[allow(deprecated)]
fn readout_projection_reduces_rls_dim() {
let config = ESNConfig::builder()
.n_reservoir(100)
.readout_dim(30)
.warmup(5)
.seed(42)
.build()
.unwrap();
assert_eq!(config.readout_dim, Some(30));
let mut esn = EchoStateNetwork::new(config);
for i in 0..20 {
esn.train(&[i as f64 * 0.1], i as f64);
}
let weights = esn
.readout_weights()
.expect("should have weights after training");
assert_eq!(
weights.len(),
30,
"RLS should have 30 weights (readout_dim), not {} (full reservoir + input)",
weights.len(),
);
}
#[test]
fn readout_projection_deterministic() {
let config = ESNConfig::builder()
.n_reservoir(100)
.warmup(5)
.seed(99)
.build()
.unwrap();
let mut esn1 = EchoStateNetwork::new(config.clone());
let mut esn2 = EchoStateNetwork::new(config);
for i in 0..30 {
let x = (i as f64 * 0.2).sin();
let y = (i as f64 * 0.2 + 0.2).sin();
esn1.train(&[x], y);
esn2.train(&[x], y);
}
let pred1 = esn1.predict(&[0.5]);
let pred2 = esn2.predict(&[0.5]);
assert!(
(pred1 - pred2).abs() < 1e-12,
"projected ESN predictions should be deterministic: {} vs {}",
pred1,
pred2,
);
}
#[test]
fn readout_projection_reset_preserves_determinism() {
let config = ESNConfig::builder()
.n_reservoir(100)
.warmup(5)
.seed(42)
.build()
.unwrap();
let mut esn = EchoStateNetwork::new(config);
let train_data: Vec<(f64, f64)> = (0..30)
.map(|i| {
let x = (i as f64 * 0.2).sin();
let y = (i as f64 * 0.2 + 0.2).sin();
(x, y)
})
.collect();
for &(x, y) in &train_data {
esn.train(&[x], y);
}
let pred_before = esn.predict(&[0.5]);
esn.reset();
for &(x, y) in &train_data {
esn.train(&[x], y);
}
let pred_after = esn.predict(&[0.5]);
assert!(
(pred_before - pred_after).abs() < 1e-12,
"predictions after reset should match: {} vs {}",
pred_before,
pred_after,
);
}
#[test]
#[allow(deprecated)]
fn small_reservoir_no_projection() {
let config = ESNConfig::builder()
.n_reservoir(50)
.warmup(5)
.build()
.unwrap();
assert_eq!(
config.readout_dim, None,
"small reservoir should have no readout_dim",
);
let mut esn = EchoStateNetwork::new(config);
for i in 0..20 {
esn.train(&[i as f64 * 0.1], i as f64);
}
let weights = esn.readout_weights().expect("should have weights");
assert_eq!(
weights.len(),
51,
"small reservoir RLS should see all 51 features, got {}",
weights.len(),
);
}
#[test]
fn large_reservoir_sine_wave_with_projection() {
let config = ESNConfig::builder()
.n_reservoir(300)
.spectral_radius(0.9)
.leak_rate(0.3)
.input_scaling(1.0)
.warmup(50)
.forgetting_factor(0.999)
.seed(42)
.build()
.unwrap();
assert_eq!(
config.readout_dim,
Some(64),
"n=300 should auto-default readout_dim to 64",
);
let mut esn = EchoStateNetwork::new(config);
let dt = 0.1;
let n_train = 500;
for i in 0..n_train {
let t = i as f64 * dt;
let x = t.sin();
let target = (t + dt).sin();
esn.train(&[x], target);
}
let mut total_error = 0.0;
let n_test = 50;
for i in n_train..(n_train + n_test) {
let t = i as f64 * dt;
let x = t.sin();
let target = (t + dt).sin();
esn.train(&[x], target);
let pred = esn.predict(&[x]);
total_error += (pred - target).abs();
}
let mae = total_error / n_test as f64;
assert!(
mae < 0.5,
"large projected ESN sine MAE should be < 0.5, got {}",
mae,
);
}
#[test]
fn esn_plasticity_disabled_by_default() {
let config = ESNConfig::builder().n_reservoir(50).build().unwrap();
assert!(
config.plasticity.is_none(),
"plasticity should default to None"
);
let esn = EchoStateNetwork::new(config);
assert!(
esn.plasticity_guard.is_none(),
"guard should be None when plasticity is disabled"
);
}
#[test]
fn esn_plasticity_enabled_creates_guard() {
use crate::common::PlasticityConfig;
let config = ESNConfig::builder()
.n_reservoir(50)
.plasticity(Some(PlasticityConfig::default()))
.build()
.unwrap();
let esn = EchoStateNetwork::new(config);
assert!(
esn.plasticity_guard.is_some(),
"guard should be Some when plasticity is enabled"
);
assert_eq!(
esn.plasticity_guard.as_ref().unwrap().n_groups(),
50,
"should have one group per reservoir unit"
);
}
#[test]
fn esn_plasticity_train_runs_without_panic() {
use crate::common::PlasticityConfig;
let config = ESNConfig::builder()
.n_reservoir(30)
.warmup(10)
.plasticity(Some(PlasticityConfig::default()))
.build()
.unwrap();
let mut esn = EchoStateNetwork::new(config);
for i in 0..600 {
let t = i as f64 * 0.1;
esn.train(&[t.sin()], (t + 0.1).sin());
}
let pred = esn.predict(&[0.5]);
assert!(
pred.is_finite(),
"plasticity-enabled ESN should produce finite predictions, got {pred}"
);
}
#[test]
fn test_esn_nan_skipped() {
let config = ESNConfig::builder()
.n_reservoir(30)
.warmup(10)
.build()
.unwrap();
let mut esn = EchoStateNetwork::new(config);
for i in 0..30 {
let t = i as f64 * 0.1;
esn.train(&[t.sin()], (t + 0.1).sin());
}
let samples_before = esn.n_samples_seen();
esn.train(&[f64::NAN], 1.0);
assert_eq!(
esn.n_samples_seen(),
samples_before,
"NaN input should not increment samples_trained: before={}, after={}",
samples_before,
esn.n_samples_seen()
);
let pred = esn.predict(&[0.5]);
assert!(
pred.is_finite(),
"prediction should be finite after NaN input, got {pred}"
);
}
#[test]
fn test_esn_streaming_alias() {
let config = ESNConfig::builder()
.n_reservoir(30)
.warmup(5)
.build()
.unwrap();
let mut esn: StreamingESN = StreamingESN::new(config);
for i in 0..20 {
esn.train(&[i as f64 * 0.1], i as f64);
}
let pred = esn.predict(&[1.0]);
assert!(
pred.is_finite(),
"StreamingESN alias should work, got {pred}"
);
}
#[test]
fn predict_reads_current_input() {
let config = ESNConfig::builder()
.n_reservoir(50)
.warmup(10)
.passthrough_input(true) .seed(42)
.build()
.unwrap();
let mut esn = EchoStateNetwork::new(config);
for i in 0..100 {
let t = i as f64 * 0.1;
esn.train(&[t.sin()], (t + 0.1).sin());
}
let pred_a = esn.predict(&[0.0]);
let pred_b = esn.predict(&[1.0]);
assert!(
pred_a.is_finite(),
"predict(0.0) should be finite, got {pred_a}"
);
assert!(
pred_b.is_finite(),
"predict(1.0) should be finite, got {pred_b}"
);
assert_ne!(
pred_a.to_bits(),
pred_b.to_bits(),
"ESN predict must use current input: predict(0.0)={pred_a} == predict(1.0)={pred_b}"
);
}
}