use std::fmt;
use crate::learner::StreamingLearner;
use crate::learners::RecursiveLeastSquares;
use irithyll_core::reservoir::{DelayBuffer, HighDegreePolynomial};
use super::ngrc_config::NGRCConfig;
pub struct NextGenRC {
config: NGRCConfig,
buffer: DelayBuffer,
poly: HighDegreePolynomial,
rls: RecursiveLeastSquares,
total_pushed: u64,
warmup_count: usize,
n_inputs: Option<usize>,
samples_seen: u64,
prev_prediction: f64,
prev_change: f64,
prev_prev_change: f64,
alignment_ewma: f64,
state_activity_ewma: Vec<f64>,
}
impl NextGenRC {
pub fn new(config: NGRCConfig) -> Self {
let buf_capacity = (config.k - 1) * config.s + 1;
let warmup_count = buf_capacity;
let poly = HighDegreePolynomial::new(config.degree);
let rls = RecursiveLeastSquares::with_delta(config.forgetting_factor, config.delta);
Self {
buffer: DelayBuffer::new(buf_capacity),
poly,
rls,
total_pushed: 0,
warmup_count,
n_inputs: None,
samples_seen: 0,
config,
prev_prediction: 0.0,
prev_change: 0.0,
prev_prev_change: 0.0,
alignment_ewma: 0.0,
state_activity_ewma: Vec::new(),
}
}
#[inline]
pub fn is_warm(&self) -> bool {
self.total_pushed >= self.warmup_count as u64
}
fn build_features(&self) -> Option<Vec<f64>> {
if !self.is_warm() {
return None;
}
let d = self.n_inputs?;
let k = self.config.k;
let s = self.config.s;
let lin_dim = k * d;
let mut o_lin = Vec::with_capacity(lin_dim);
for delay_k in 0..k {
let delay_idx = delay_k * s;
if let Some(obs) = self.buffer.get(delay_idx) {
o_lin.extend_from_slice(obs);
} else {
return None;
}
}
let o_nonlin = self.poly.generate(&o_lin);
let bias_dim = if self.config.include_bias { 1 } else { 0 };
let total_dim = bias_dim + o_lin.len() + o_nonlin.len();
let mut features = Vec::with_capacity(total_dim);
if self.config.include_bias {
features.push(1.0);
}
features.extend_from_slice(&o_lin);
features.extend_from_slice(&o_nonlin);
Some(features)
}
pub fn config(&self) -> &NGRCConfig {
&self.config
}
pub fn total_pushed(&self) -> u64 {
self.total_pushed
}
}
impl StreamingLearner for NextGenRC {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
if self.n_inputs.is_none() {
self.n_inputs = Some(features.len());
}
self.buffer.push(features);
self.total_pushed += 1;
if let Some(feat_vec) = self.build_features() {
const STATE_ALPHA: f64 = 0.01;
if self.state_activity_ewma.len() != feat_vec.len() {
self.state_activity_ewma = vec![0.0; feat_vec.len()];
}
for (ewma, &f) in self.state_activity_ewma.iter_mut().zip(feat_vec.iter()) {
*ewma = (1.0 - STATE_ALPHA) * *ewma + STATE_ALPHA * f.abs();
}
let current_pred = self.rls.predict(&feat_vec);
let current_change = current_pred - self.prev_prediction;
if self.samples_seen > 0 {
let acceleration = current_change - self.prev_change;
let prev_acceleration = self.prev_change - self.prev_prev_change;
let agreement = if acceleration.abs() > 1e-15 && prev_acceleration.abs() > 1e-15 {
if (acceleration > 0.0) == (prev_acceleration > 0.0) {
1.0
} else {
-1.0
}
} else {
0.0
};
const ALIGN_ALPHA: f64 = 0.05;
self.alignment_ewma =
(1.0 - ALIGN_ALPHA) * self.alignment_ewma + ALIGN_ALPHA * agreement;
}
self.prev_prev_change = self.prev_change;
self.prev_change = current_change;
self.prev_prediction = current_pred;
self.rls.train_one(&feat_vec, target, weight);
self.samples_seen += 1;
}
}
fn predict(&self, _features: &[f64]) -> f64 {
if !self.is_warm() || self.n_inputs.is_none() {
return 0.0;
}
if let Some(feat_vec) = self.build_features() {
self.rls.predict(&feat_vec)
} else {
0.0
}
}
#[inline]
fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
fn reset(&mut self) {
self.buffer.reset();
self.rls.reset();
self.total_pushed = 0;
self.n_inputs = None;
self.samples_seen = 0;
self.prev_prediction = 0.0;
self.prev_change = 0.0;
self.prev_prev_change = 0.0;
self.alignment_ewma = 0.0;
self.state_activity_ewma.clear();
}
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
<Self as crate::learner::Tunable>::diagnostics_array(self)
}
#[allow(deprecated)]
fn readout_weights(&self) -> Option<&[f64]> {
let w = <Self as crate::learner::HasReadout>::readout_weights(self);
if w.is_empty() {
None
} else {
Some(w)
}
}
}
impl crate::learner::Tunable for NextGenRC {
fn diagnostics_array(&self) -> [f64; 5] {
use crate::automl::DiagnosticSource;
match self.config_diagnostics() {
Some(d) => [
d.residual_alignment,
d.regularization_sensitivity,
d.depth_sufficiency,
d.effective_dof,
d.uncertainty,
],
None => [0.0; 5],
}
}
fn adjust_config(&mut self, lr_multiplier: f64, _lambda_delta: f64) {
<crate::learners::RecursiveLeastSquares as crate::learner::Tunable>::adjust_config(
&mut self.rls,
lr_multiplier,
0.0,
);
}
}
impl crate::learner::HasReadout for NextGenRC {
fn readout_weights(&self) -> &[f64] {
self.rls.weights()
}
}
impl fmt::Debug for NextGenRC {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NextGenRC")
.field("k", &self.config.k)
.field("s", &self.config.s)
.field("degree", &self.config.degree)
.field("warmup_count", &self.warmup_count)
.field("total_pushed", &self.total_pushed)
.field("samples_seen", &self.samples_seen)
.field("is_warm", &self.is_warm())
.finish()
}
}
impl crate::automl::DiagnosticSource for NextGenRC {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
let rls_saturation = {
let p = self.rls.p_matrix();
let d = self.rls.weights().len();
if d > 0 && self.rls.delta() > 0.0 {
let trace: f64 = (0..d).map(|i| p[i * d + i]).sum();
(1.0 - trace / (self.rls.delta() * d as f64)).clamp(0.0, 1.0)
} else {
0.0
}
};
let feature_entropy = {
let sum: f64 = self.state_activity_ewma.iter().sum();
if sum > 1e-15 && self.state_activity_ewma.len() > 1 {
let n = self.state_activity_ewma.len();
let ln_n = (n as f64).ln();
let mut h = 0.0;
for &a in &self.state_activity_ewma {
let p_i = a / sum;
if p_i > 1e-15 {
h -= p_i * p_i.ln();
}
}
(h / ln_n).clamp(0.0, 1.0)
} else {
0.0
}
};
let depth_sufficiency = 0.5 * rls_saturation + 0.5 * feature_entropy;
let w = self.rls.weights();
let effective_dof = if !w.is_empty() {
let sq_sum: f64 = w.iter().map(|wi| wi * wi).sum();
sq_sum.sqrt() / (w.len() as f64).sqrt()
} else {
0.0
};
Some(crate::automl::ConfigDiagnostics {
residual_alignment: self.alignment_ewma,
regularization_sensitivity: 1.0 - self.config.forgetting_factor,
depth_sufficiency,
effective_dof,
uncertainty: self.rls.noise_variance().sqrt(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_ngrc() -> NextGenRC {
let config = NGRCConfig::builder().build().unwrap();
NextGenRC::new(config)
}
#[test]
fn cold_start_returns_zero() {
let ngrc = default_ngrc();
assert_eq!(ngrc.predict(&[1.0]), 0.0);
assert_eq!(ngrc.n_samples_seen(), 0);
assert!(!ngrc.is_warm());
}
#[test]
fn warmup_then_trains() {
let mut ngrc = default_ngrc();
ngrc.train(&[1.0], 2.0);
assert!(!ngrc.is_warm());
assert_eq!(ngrc.n_samples_seen(), 0);
ngrc.train(&[2.0], 3.0);
assert!(ngrc.is_warm());
assert_eq!(ngrc.n_samples_seen(), 1);
}
#[test]
fn predict_is_side_effect_free() {
let mut ngrc = default_ngrc();
ngrc.train(&[1.0], 2.0);
ngrc.train(&[2.0], 3.0);
ngrc.train(&[3.0], 4.0);
let samples_before = ngrc.n_samples_seen();
let pushed_before = ngrc.total_pushed();
let _ = ngrc.predict(&[99.0]);
assert_eq!(
ngrc.n_samples_seen(),
samples_before,
"predict should not change samples_seen",
);
assert_eq!(
ngrc.total_pushed(),
pushed_before,
"predict should not push to buffer",
);
}
#[test]
fn reset_returns_to_cold_state() {
let mut ngrc = default_ngrc();
for i in 0..20 {
ngrc.train(&[i as f64], (i + 1) as f64);
}
assert!(ngrc.is_warm());
assert!(ngrc.n_samples_seen() > 0);
ngrc.reset();
assert!(!ngrc.is_warm());
assert_eq!(ngrc.n_samples_seen(), 0);
assert_eq!(ngrc.total_pushed(), 0);
assert_eq!(ngrc.predict(&[1.0]), 0.0);
}
#[test]
fn learns_linear_trend() {
let config = NGRCConfig::builder()
.k(2)
.s(1)
.degree(2)
.forgetting_factor(1.0)
.build()
.unwrap();
let mut ngrc = NextGenRC::new(config);
for i in 0..300 {
let x = i as f64 * 0.01;
let y = 3.0 * x + 1.0;
ngrc.train(&[x], y);
}
let x_test = 3.0;
ngrc.train(&[x_test], 3.0 * x_test + 1.0);
let pred = ngrc.predict(&[x_test]);
let expected = 3.0 * x_test + 1.0;
assert!(
(pred - expected).abs() < 1.0,
"expected ~{}, got {} (error = {})",
expected,
pred,
(pred - expected).abs(),
);
}
#[test]
fn sine_wave_regression() {
let config = NGRCConfig::builder()
.k(3)
.s(1)
.degree(2)
.forgetting_factor(0.999)
.delta(100.0)
.build()
.unwrap();
let mut ngrc = NextGenRC::new(config);
let dt = 0.1;
let n_train = 500;
for i in 0..n_train {
let t = i as f64 * dt;
let x = (t).sin();
let target = (t + dt).sin();
ngrc.train(&[x], target);
}
let mut total_error = 0.0;
let n_test = 50;
for i in n_train..(n_train + n_test) {
let t = i as f64 * dt;
let x = (t).sin();
let target = (t + dt).sin();
ngrc.train(&[x], target);
let pred = ngrc.predict(&[x]);
total_error += (pred - target).abs();
}
let mae = total_error / n_test as f64;
assert!(mae < 0.5, "sine wave MAE should be < 0.5, got {}", mae,);
}
#[test]
fn trait_object_compatibility() {
let config = NGRCConfig::builder().build().unwrap();
let ngrc = NextGenRC::new(config);
let mut boxed: Box<dyn StreamingLearner> = Box::new(ngrc);
boxed.train(&[1.0], 2.0);
boxed.train(&[2.0], 3.0);
boxed.train(&[3.0], 4.0);
let pred = boxed.predict(&[4.0]);
assert!(pred.is_finite());
}
#[test]
fn skip_greater_than_one() {
let config = NGRCConfig::builder().k(2).s(3).degree(2).build().unwrap();
let mut ngrc = NextGenRC::new(config);
for i in 0..3 {
ngrc.train(&[i as f64], 0.0);
assert!(!ngrc.is_warm(), "should not be warm after {} pushes", i + 1);
}
ngrc.train(&[3.0], 0.0);
assert!(
ngrc.is_warm(),
"should be warm after 4 pushes with k=2, s=3"
);
}
#[test]
fn no_bias_config() {
let config = NGRCConfig::builder().include_bias(false).build().unwrap();
let mut ngrc = NextGenRC::new(config);
for i in 0..10 {
ngrc.train(&[i as f64], (i * 2) as f64);
}
let pred = ngrc.predict(&[5.0]);
assert!(pred.is_finite());
}
#[test]
fn predict_reads_current_input() {
let config = NGRCConfig::builder()
.k(2)
.s(1)
.degree(2)
.forgetting_factor(0.999)
.build()
.unwrap();
let mut ngrc = NextGenRC::new(config);
for i in 0..100 {
let t = i as f64 * 0.1;
ngrc.train(&[t.sin()], (t + 0.1).sin());
}
let pred_initial = ngrc.predict(&[]);
ngrc.train(&[99.0_f64.sin()], 0.0);
let pred_after = ngrc.predict(&[]);
assert!(
pred_initial.is_finite(),
"initial predict should be finite, got {pred_initial}"
);
assert!(
pred_after.is_finite(),
"post-train predict should be finite, got {pred_after}"
);
assert_ne!(
pred_initial.to_bits(),
pred_after.to_bits(),
"NGRC predict must reflect current buffer state: {pred_initial} == {pred_after}"
);
}
}