Skip to main content

irithyll/lstm/
mod.rs

1//! Streaming sLSTM (stabilized LSTM) with exponential gating.
2//!
3//! sLSTM (Beck et al., 2024 -- xLSTM) replaces sigmoid gates with exponential
4//! gates and adds log-domain stabilization for numerically stable long-range
5//! memory. The output gate remains sigmoid. A normalizer state tracks
6//! cumulative gate products to prevent unbounded cell growth.
7//!
8//! # Architecture
9//!
10//! ```text
11//! x_t -> [sLSTM Cell: exp gates -> log stabilizer -> cell update] -> h_t -> [RLS Readout] -> y_hat_t
12//! ```
13//!
14//! # References
15//!
16//! - Beck et al. (2024) "xLSTM: Extended Long Short-Term Memory" NeurIPS
17
18use crate::common::PlasticityConfig;
19use crate::error::ConfigError;
20use crate::learner::StreamingLearner;
21use crate::learners::RecursiveLeastSquares;
22use irithyll_core::continual::{ContinualStrategy, NeuronRegeneration};
23
24// ---------------------------------------------------------------------------
25// SLSTMConfig
26// ---------------------------------------------------------------------------
27
28/// Configuration for [`StreamingLSTM`].
29///
30/// Create via the builder pattern:
31///
32/// ```
33/// use irithyll::lstm::SLSTMConfig;
34///
35/// let config = SLSTMConfig::builder()
36///     .d_model(32)
37///     .build()
38///     .unwrap();
39/// ```
40#[derive(Debug, Clone)]
41pub struct SLSTMConfig {
42    /// Hidden state dimension (default: 32).
43    pub d_model: usize,
44    /// RLS forgetting factor for readout (default: 0.998).
45    pub forgetting_factor: f64,
46    /// Initial P matrix diagonal for RLS (default: 100.0).
47    pub delta_rls: f64,
48    /// Warmup samples before RLS training starts (default: 10).
49    pub warmup: usize,
50    /// RNG seed (default: 42).
51    pub seed: u64,
52    /// Number of heads for block-diagonal recurrent weights (default: 1 = dense).
53    ///
54    /// When `> 1`, the SLOTS mechanism from Beck et al. (2024) xLSTM §2.2 is used:
55    /// each head only mixes within its `d_model / n_heads` units while the input
56    /// projection remains dense. Must divide `d_model`.
57    pub n_heads: usize,
58    /// Per-unit forget gate bias initializer (default: all 1.0).
59    ///
60    /// Beck et al. (2024) §3.2 recommend `linspace(3, 6)` across `d_model` units.
61    /// Use [`irithyll_core::lstm::SLSTMCell::forget_bias_linspace`] to construct
62    /// this vector, or pass `None` to use the default (all 1.0).
63    ///
64    /// When `None`, the default (all 1.0 per hidden unit) is applied.
65    pub forget_bias_init: Option<Vec<f64>>,
66    /// Optional plasticity configuration for neuron regeneration (default: None).
67    ///
68    /// When `Some`, tracks per-hidden-unit state energy and periodically
69    /// reinitializes dead units to maintain learning capacity over long
70    /// streams (Dohare et al., Nature 2024). Use [`PlasticityConfig::default()`]
71    /// for paper-recommended defaults.
72    pub plasticity: Option<PlasticityConfig>,
73}
74
75impl Default for SLSTMConfig {
76    fn default() -> Self {
77        Self {
78            d_model: 32,
79            forgetting_factor: 0.998,
80            delta_rls: 100.0,
81            warmup: 10,
82            seed: 42,
83            n_heads: 1,
84            forget_bias_init: None,
85            plasticity: None,
86        }
87    }
88}
89
90impl std::fmt::Display for SLSTMConfig {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        write!(
93            f,
94            "SLSTMConfig(d_model={}, n_heads={}, ff={}, delta_rls={}, warmup={}, seed={}, plasticity={})",
95            self.d_model,
96            self.n_heads,
97            self.forgetting_factor,
98            self.delta_rls,
99            self.warmup,
100            self.seed,
101            self.plasticity.is_some()
102        )
103    }
104}
105
106// ---------------------------------------------------------------------------
107// SLSTMConfigBuilder
108// ---------------------------------------------------------------------------
109
110/// Builder for [`SLSTMConfig`] with validation.
111///
112/// # Example
113///
114/// ```
115/// use irithyll::lstm::SLSTMConfig;
116///
117/// let config = SLSTMConfig::builder()
118///     .d_model(16)
119///     .forgetting_factor(0.995)
120///     .build()
121///     .unwrap();
122///
123/// assert_eq!(config.d_model, 16);
124/// ```
125pub struct SLSTMConfigBuilder {
126    config: SLSTMConfig,
127}
128
129impl SLSTMConfig {
130    /// Create a new builder with default values.
131    pub fn builder() -> SLSTMConfigBuilder {
132        SLSTMConfigBuilder {
133            config: SLSTMConfig::default(),
134        }
135    }
136}
137
138impl SLSTMConfigBuilder {
139    /// Set the hidden state dimension (default: 32).
140    pub fn d_model(mut self, d: usize) -> Self {
141        self.config.d_model = d;
142        self
143    }
144
145    /// Set the RLS forgetting factor for the readout (default: 0.998).
146    pub fn forgetting_factor(mut self, f: f64) -> Self {
147        self.config.forgetting_factor = f;
148        self
149    }
150
151    /// Set the initial P matrix diagonal for RLS (default: 100.0).
152    pub fn delta_rls(mut self, d: f64) -> Self {
153        self.config.delta_rls = d;
154        self
155    }
156
157    /// Set the warmup period in samples (default: 10).
158    pub fn warmup(mut self, w: usize) -> Self {
159        self.config.warmup = w;
160        self
161    }
162
163    /// Set the RNG seed (default: 42).
164    pub fn seed(mut self, s: u64) -> Self {
165        self.config.seed = s;
166        self
167    }
168
169    /// Set the number of heads for block-diagonal recurrent weights (default: 1 = dense).
170    ///
171    /// When `> 1`, the SLOTS mechanism from Beck et al. (2024) xLSTM §2.2 is used.
172    /// Must divide `d_model` — the builder validates this at [`build`](Self::build).
173    pub fn n_heads(mut self, n: usize) -> Self {
174        self.config.n_heads = n;
175        self
176    }
177
178    /// Set the per-unit forget gate bias initializer.
179    ///
180    /// Beck et al. (2024) §3.2 recommend `linspace(3, 6)`. Use
181    /// [`irithyll_core::lstm::SLSTMCell::forget_bias_linspace(3.0, 6.0, d_model)`]
182    /// to construct the vector, then pass it here.
183    ///
184    /// When `None` (default), all units are initialized to 1.0.
185    pub fn forget_bias_init(mut self, bias: Option<Vec<f64>>) -> Self {
186        self.config.forget_bias_init = bias;
187        self
188    }
189
190    /// Set the plasticity configuration (default: None = disabled).
191    ///
192    /// When `Some`, tracks per-hidden-unit state energy and periodically
193    /// reinitializes dead units to maintain learning capacity over long
194    /// streams (Dohare et al., Nature 2024). Use [`PlasticityConfig::default()`]
195    /// for paper-recommended defaults.
196    pub fn plasticity(mut self, p: Option<PlasticityConfig>) -> Self {
197        self.config.plasticity = p;
198        self
199    }
200
201    /// Build the config, validating all parameters.
202    ///
203    /// # Errors
204    ///
205    /// Returns [`ConfigError`] if:
206    /// - `d_model` is 0
207    /// - `forgetting_factor` is not in (0, 1]
208    /// - `delta_rls` is not > 0
209    /// - `n_heads` is 0 or does not divide `d_model`
210    /// - `forget_bias_init` length (when `Some`) does not equal `d_model`
211    pub fn build(self) -> Result<SLSTMConfig, ConfigError> {
212        let c = &self.config;
213        if c.d_model == 0 {
214            return Err(ConfigError::out_of_range(
215                "d_model",
216                "must be > 0",
217                c.d_model,
218            ));
219        }
220        if c.forgetting_factor <= 0.0 || c.forgetting_factor > 1.0 {
221            return Err(ConfigError::out_of_range(
222                "forgetting_factor",
223                "must be in (0, 1]",
224                c.forgetting_factor,
225            ));
226        }
227        if c.delta_rls <= 0.0 {
228            return Err(ConfigError::out_of_range(
229                "delta_rls",
230                "must be > 0",
231                c.delta_rls,
232            ));
233        }
234        if c.n_heads == 0 {
235            return Err(ConfigError::out_of_range(
236                "n_heads",
237                "must be > 0",
238                c.n_heads,
239            ));
240        }
241        if c.d_model % c.n_heads != 0 {
242            return Err(ConfigError::invalid(
243                "n_heads",
244                format!("must divide d_model ({}), got {}", c.d_model, c.n_heads),
245            ));
246        }
247        if let Some(ref bias) = c.forget_bias_init {
248            if bias.len() != c.d_model {
249                return Err(ConfigError::invalid(
250                    "forget_bias_init",
251                    format!(
252                        "length must equal d_model ({}), got {}",
253                        c.d_model,
254                        bias.len()
255                    ),
256                ));
257            }
258        }
259        Ok(self.config)
260    }
261}
262
263// ---------------------------------------------------------------------------
264// StreamingLSTM
265// ---------------------------------------------------------------------------
266
267/// Streaming sLSTM model with RLS readout.
268///
269/// Processes one sample at a time. The sLSTM cell uses exponential gating
270/// with log-domain stabilization for numerically stable long-range memory.
271/// An RLS readout maps the cell hidden state to predictions.
272///
273/// # Example
274///
275/// ```no_run
276/// use irithyll::lstm::{StreamingLSTM, SLSTMConfig};
277/// use irithyll::StreamingLearner;
278///
279/// let config = SLSTMConfig::builder().d_model(16).build().unwrap();
280/// let mut model = StreamingLSTM::new(config);
281/// model.train(&[1.0, 2.0, 3.0], 4.0);
282/// let pred = model.predict(&[1.0, 2.0, 3.0]);
283/// ```
284pub struct StreamingLSTM {
285    config: SLSTMConfig,
286    cell: irithyll_core::lstm::SLSTMCell,
287    readout: RecursiveLeastSquares,
288    last_features: Vec<f64>,
289    total_seen: u64,
290    samples_trained: u64,
291    /// EWMA of prediction uncertainty for forgetting factor modulation.
292    rolling_uncertainty: f64,
293    /// Fast-reacting EWMA of squared error for drift detection (alpha=0.1).
294    short_term_error: f64,
295    /// Previous prediction for residual alignment tracking.
296    prev_prediction: f64,
297    /// EWMA of maximum Frobenius squared norm of cell output for utilization ratio.
298    max_frob_sq_ewma: f64,
299    /// EWMA of residual alignment signal.
300    alignment_ewma: f64,
301    /// Previous prediction change for residual alignment tracking.
302    prev_change: f64,
303    /// Change from two steps ago, for acceleration-based alignment.
304    prev_prev_change: f64,
305    /// Optional plasticity guard for maintaining learning capacity.
306    ///
307    /// Tracks per-hidden-unit state energy and reinitializes dead units
308    /// to prevent loss of plasticity over long streams.
309    plasticity_guard: Option<NeuronRegeneration>,
310    /// Snapshot of hidden state from previous step, used for plasticity
311    /// utility computation (delta |h|).
312    prev_h_energy: Vec<f64>,
313    /// Welford online mean for input normalization (per feature).
314    input_mean: Vec<f64>,
315    /// Welford online variance accumulator for input normalization (per feature).
316    input_var: Vec<f64>,
317    /// Count of samples seen for Welford normalization.
318    input_count: u64,
319}
320
321impl StreamingLSTM {
322    /// Create a new StreamingLSTM from config.
323    pub fn new(config: SLSTMConfig) -> Self {
324        // Construct the sLSTM cell, wiring n_heads and forget_bias_init from config.
325        // When n_heads == 1 and forget_bias_init is None, this is equivalent to
326        // SLSTMCell::new(d_model, seed) — the single-head dense default.
327        let cell = if config.n_heads > 1 || config.forget_bias_init.is_some() {
328            let bias = config
329                .forget_bias_init
330                .clone()
331                .unwrap_or_else(|| vec![1.0; config.d_model]);
332            irithyll_core::lstm::SLSTMCell::with_config(
333                config.d_model,
334                config.n_heads,
335                bias,
336                config.seed,
337            )
338        } else {
339            irithyll_core::lstm::SLSTMCell::new(config.d_model, config.seed)
340        };
341        let readout = RecursiveLeastSquares::with_delta(config.forgetting_factor, config.delta_rls);
342        let last_features = vec![0.0; config.d_model];
343
344        // Create plasticity guard if a PlasticityConfig was provided.
345        // Tracks d_model hidden units (group_size=1 = per-unit tracking).
346        let plasticity_guard = config.plasticity.as_ref().map(|p| {
347            NeuronRegeneration::new(
348                config.d_model,
349                1, // group_size = 1 (per-unit tracking)
350                p.regen_fraction,
351                p.regen_interval,
352                p.utility_alpha,
353                config.seed.wrapping_add(0x_DEAD_CAFE),
354            )
355        });
356        let prev_h_energy = vec![0.0; config.d_model];
357
358        Self {
359            config,
360            cell,
361            readout,
362            last_features,
363            total_seen: 0,
364            samples_trained: 0,
365            rolling_uncertainty: 0.0,
366            short_term_error: 0.0,
367            prev_prediction: 0.0,
368            max_frob_sq_ewma: 0.0,
369            alignment_ewma: 0.0,
370            prev_change: 0.0,
371            prev_prev_change: 0.0,
372            plasticity_guard,
373            prev_h_energy,
374            input_mean: Vec::new(),
375            input_var: Vec::new(),
376            input_count: 0,
377        }
378    }
379
380    /// Normalize a feature vector via Welford online mean/std, updating stats.
381    ///
382    /// Returns the normalized features clamped to [-5, 5].
383    /// On the first call the dimension is inferred from `features.len()`.
384    fn normalize_input(&mut self, features: &[f64]) -> Vec<f64> {
385        let d = features.len();
386        if self.input_mean.len() != d {
387            self.input_mean = vec![0.0; d];
388            self.input_var = vec![0.0; d];
389        }
390        self.input_count += 1;
391        let n = self.input_count as f64;
392        let mut out = vec![0.0; d];
393        for i in 0..d {
394            let x = features[i];
395            let delta = x - self.input_mean[i];
396            self.input_mean[i] += delta / n;
397            let delta2 = x - self.input_mean[i];
398            self.input_var[i] += delta * delta2;
399            let std = if n > 1.0 {
400                (self.input_var[i] / (n - 1.0)).sqrt()
401            } else {
402                1.0
403            };
404            let std = if std < 1e-8 { 1.0 } else { std };
405            out[i] = ((x - self.input_mean[i]) / std).clamp(-5.0, 5.0);
406        }
407        out
408    }
409
410    /// Whether the model has seen enough samples for meaningful predictions.
411    #[inline]
412    pub fn past_warmup(&self) -> bool {
413        self.total_seen > self.config.warmup as u64
414    }
415
416    /// Access the config.
417    pub fn config(&self) -> &SLSTMConfig {
418        &self.config
419    }
420
421    /// Forward-looking prediction uncertainty from the RLS readout.
422    ///
423    /// Returns the estimated prediction standard deviation, computed as the
424    /// square root of the RLS noise variance (EWMA of squared residuals).
425    ///
426    /// Returns 0.0 before any training has occurred.
427    #[inline]
428    pub fn prediction_uncertainty(&self) -> f64 {
429        self.readout.noise_variance().sqrt()
430    }
431}
432
433impl StreamingLearner for StreamingLSTM {
434    fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
435        // 1. Uncertainty-modulated RLS forgetting factor
436        let current_uncertainty = self.readout.noise_variance().sqrt();
437        const UNCERTAINTY_ALPHA: f64 = 0.001;
438        if self.total_seen == 0 {
439            self.rolling_uncertainty = current_uncertainty;
440        } else {
441            self.rolling_uncertainty = (1.0 - UNCERTAINTY_ALPHA) * self.rolling_uncertainty
442                + UNCERTAINTY_ALPHA * current_uncertainty;
443        }
444
445        if self.rolling_uncertainty > 1e-10 {
446            let ratio = (current_uncertainty / self.rolling_uncertainty).clamp(0.5, 3.0);
447            let base_ff = self.config.forgetting_factor;
448            let adaptive_ff = (base_ff - 0.02 * (ratio - 1.0)).clamp(0.95, base_ff);
449            self.readout.set_forgetting_factor(adaptive_ff);
450        }
451
452        // 2. Residual alignment tracking (only after warmup)
453        if self.past_warmup() {
454            let current_pred = self.readout.predict(&self.last_features);
455            let pred_error = target - current_pred;
456
457            // Short-term error tracking for drift
458            let sq_err = pred_error * pred_error;
459            if self.samples_trained == 0 {
460                self.short_term_error = sq_err;
461            } else {
462                self.short_term_error = 0.9 * self.short_term_error + 0.1 * sq_err;
463            }
464            // Note: cell.reset() on drift detection is intentionally omitted.
465            // Resetting the cell mid-stream destroys the feature distribution that
466            // the RLS readout was trained on, causing catastrophic prediction errors
467            // that cascade into further resets. Input normalization stabilizes the
468            // feature scale at the source instead.
469            let _short_rmse = self.short_term_error.sqrt();
470
471            // Alignment tracking
472            let current_change = current_pred - self.prev_prediction;
473            if self.samples_trained > 0 {
474                let acceleration = current_change - self.prev_change;
475                let prev_acceleration = self.prev_change - self.prev_prev_change;
476                let agreement = if acceleration.abs() > 1e-15 && prev_acceleration.abs() > 1e-15 {
477                    if (acceleration > 0.0) == (prev_acceleration > 0.0) {
478                        1.0
479                    } else {
480                        -1.0
481                    }
482                } else {
483                    0.0
484                };
485                if self.samples_trained == 1 {
486                    self.alignment_ewma = agreement;
487                } else {
488                    self.alignment_ewma = 0.95 * self.alignment_ewma + 0.05 * agreement;
489                }
490            }
491            self.prev_prev_change = self.prev_change;
492            self.prev_change = current_change;
493            self.prev_prediction = current_pred;
494        }
495
496        // 3. Guard: skip non-finite inputs to prevent NaN from entering the cell state.
497        if !features.iter().all(|f| f.is_finite()) {
498            return;
499        }
500
501        // 4. Normalize input via Welford online mean/std before feeding to cell.
502        //    Raw feature scale can vary wildly (e.g. Friedman: 0-100, Lorenz: -20 to +20).
503        //    Without normalization, the cell receives large pre-activations that make
504        //    the hidden state erratic and cause RLS weight explosion.
505        let normalized = self.normalize_input(features);
506
507        // Option D step 1: compute readout features from PRE-update cell state.
508        // forward_predict uses the current cell state (h_{t-1}) without mutating it.
509        // This produces the same hidden state that predict() will query, making
510        // train and predict use the exact same feature distribution.
511        // Only possible once the cell is initialized (after first forward call = total_seen > 0).
512        let pre_cell_features: Option<Vec<f64>> = if self.total_seen > 0 {
513            let mut out = self.cell.forward_predict(&normalized);
514            for v in &mut out {
515                *v = v.clamp(-3.0, 3.0);
516            }
517            Some(out)
518        } else {
519            None
520        };
521
522        // Advance total_seen before the warmup check so the warmup boundary is consistent
523        // with the original (warmup samples = samples where total_seen <= warmup, same count).
524        // Option D step 3: advance sLSTM cell state by processing x_t.
525        //    Clone immediately to release the borrow on self.cell.
526        let mut cell_output = self.cell.forward(&normalized).to_vec();
527        self.total_seen += 1;
528
529        // Option D step 2: train RLS on pre-update features (before caching new state).
530        // Train only after warmup and when pre-update features are available.
531        // Note: total_seen was just incremented, so past_warmup() sees the same boundary
532        // as the original implementation's post-forward check.
533        if self.past_warmup() {
534            if let Some(ref feats) = pre_cell_features {
535                if feats.iter().all(|f| f.is_finite()) {
536                    self.readout.train_one(feats, target, weight);
537                    self.samples_trained += 1;
538                }
539            }
540        }
541
542        // Clamp cell output to [-3, 3] as a safety net.
543        //    The sLSTM cell is theoretically bounded (o * c/n where o=sigmoid, c/n ~ EMA of tanh),
544        //    but under extreme inputs the normalizer state can briefly allow larger values.
545        //    Clamping ensures the RLS readout always sees a stable, bounded feature space.
546        for v in &mut cell_output {
547            *v = v.clamp(-3.0, 3.0);
548        }
549
550        // Track output utilization (post-update state for Frobenius ratio diagnostics).
551        let frob_sq: f64 = cell_output.iter().map(|s| s * s).sum();
552        const FROB_ALPHA: f64 = 0.001;
553        self.max_frob_sq_ewma = if frob_sq > self.max_frob_sq_ewma {
554            frob_sq
555        } else {
556            (1.0 - FROB_ALPHA) * self.max_frob_sq_ewma + FROB_ALPHA * frob_sq
557        };
558
559        // Plasticity maintenance (Dohare et al., Nature 2024):
560        //    Track per-unit hidden state activation as utility signal.
561        //    When a unit dies (utility drops to bottom fraction), surgically
562        //    reinitialize its weight columns while preserving all other units.
563        if let Some(ref mut guard) = self.plasticity_guard {
564            let mut h_energy: Vec<f64> = self.cell.hidden_state().iter().map(|x| x.abs()).collect();
565            guard.pre_update(&self.prev_h_energy, &mut h_energy);
566            guard.post_update(&self.prev_h_energy);
567            // Surgical per-unit reinit: only dead units get recycled.
568            // Uses reinitialize_unit() which reinits weight rows in W_f/W_i/W_o/W_z,
569            // resets biases (forget=1.0), and zeros h/c/n/m for just that unit.
570            let mut reinit_rng = self
571                .config
572                .seed
573                .wrapping_add(0xCAFE_BABE_u64.wrapping_mul(self.total_seen));
574            for j in 0..guard.n_groups() {
575                if guard.was_regenerated(j) {
576                    self.cell.reinitialize_unit(j, &mut reinit_rng);
577                }
578            }
579            self.prev_h_energy = self.cell.hidden_state().iter().map(|x| x.abs()).collect();
580        }
581
582        // Cache post-update cell output for diagnostics (Frobenius ratio, alignment tracking).
583        // predict() uses forward_predict() which recomputes from the current cell state —
584        // it does not read last_features.
585        self.last_features = cell_output;
586    }
587
588    fn predict(&self, features: &[f64]) -> f64 {
589        if self.total_seen == 0 {
590            return 0.0;
591        }
592        // Apply same Welford normalization as train_one (read-only: use frozen stats).
593        let d = features.len();
594        let mut normalized = vec![0.0; d];
595        if self.input_count > 0 && self.input_mean.len() == d {
596            let n = self.input_count as f64;
597            for i in 0..d {
598                let std = if n > 1.0 {
599                    (self.input_var[i] / (n - 1.0)).sqrt()
600                } else {
601                    1.0
602                };
603                let std = if std < 1e-8 { 1.0 } else { std };
604                normalized[i] = ((features[i] - self.input_mean[i]) / std).clamp(-5.0, 5.0);
605            }
606        } else {
607            normalized.copy_from_slice(features);
608        }
609        let mut cell_features = self.cell.forward_predict(&normalized);
610        for v in &mut cell_features {
611            *v = v.clamp(-3.0, 3.0);
612        }
613        self.readout.predict(&cell_features)
614    }
615
616    #[inline]
617    fn n_samples_seen(&self) -> u64 {
618        self.samples_trained
619    }
620
621    fn reset(&mut self) {
622        self.cell.reset();
623        self.readout.reset();
624        self.last_features.iter_mut().for_each(|f| *f = 0.0);
625        self.total_seen = 0;
626        self.samples_trained = 0;
627        self.rolling_uncertainty = 0.0;
628        self.short_term_error = 0.0;
629        self.prev_prediction = 0.0;
630        self.prev_change = 0.0;
631        self.prev_prev_change = 0.0;
632        self.alignment_ewma = 0.0;
633        self.max_frob_sq_ewma = 0.0;
634        if let Some(ref mut guard) = self.plasticity_guard {
635            guard.reset();
636        }
637        self.prev_h_energy.fill(0.0);
638        self.input_mean.clear();
639        self.input_var.clear();
640        self.input_count = 0;
641    }
642
643    #[allow(deprecated)]
644    fn diagnostics_array(&self) -> [f64; 5] {
645        <Self as crate::learner::Tunable>::diagnostics_array(self)
646    }
647
648    #[allow(deprecated)]
649    fn readout_weights(&self) -> Option<&[f64]> {
650        let w = <Self as crate::learner::HasReadout>::readout_weights(self);
651        if w.is_empty() {
652            None
653        } else {
654            Some(w)
655        }
656    }
657}
658
659impl crate::learner::Tunable for StreamingLSTM {
660    fn diagnostics_array(&self) -> [f64; 5] {
661        use crate::automl::DiagnosticSource;
662        match self.config_diagnostics() {
663            Some(d) => [
664                d.residual_alignment,
665                d.regularization_sensitivity,
666                d.depth_sufficiency,
667                d.effective_dof,
668                d.uncertainty,
669            ],
670            None => [0.0; 5],
671        }
672    }
673
674    fn adjust_config(&mut self, lr_multiplier: f64, _lambda_delta: f64) {
675        // Scale the RLS readout forgetting factor as the primary tuning knob.
676        <crate::learners::RecursiveLeastSquares as crate::learner::Tunable>::adjust_config(
677            &mut self.readout,
678            lr_multiplier,
679            0.0,
680        );
681    }
682}
683
684impl crate::learner::HasReadout for StreamingLSTM {
685    fn readout_weights(&self) -> &[f64] {
686        self.readout.weights()
687    }
688}
689
690// ---------------------------------------------------------------------------
691// Debug impl
692// ---------------------------------------------------------------------------
693
694impl std::fmt::Debug for StreamingLSTM {
695    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
696        f.debug_struct("StreamingLSTM")
697            .field("d_model", &self.config.d_model)
698            .field("warmup", &self.config.warmup)
699            .field("total_seen", &self.total_seen)
700            .field("samples_trained", &self.samples_trained)
701            .field("past_warmup", &self.past_warmup())
702            .finish()
703    }
704}
705
706// ---------------------------------------------------------------------------
707// DiagnosticSource impl
708// ---------------------------------------------------------------------------
709
710impl crate::automl::DiagnosticSource for StreamingLSTM {
711    fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
712        // RLS saturation: 1.0 - trace(P) / (delta * d).
713        let rls_saturation = {
714            let p = self.readout.p_matrix();
715            let d = self.readout.weights().len();
716            if d > 0 && self.readout.delta() > 0.0 {
717                let trace: f64 = (0..d).map(|i| p[i * d + i]).sum();
718                (1.0 - trace / (self.readout.delta() * d as f64)).clamp(0.0, 1.0)
719            } else {
720                0.0
721            }
722        };
723
724        // sLSTM output Frobenius ratio: current ||h||_2^2 / max(||h||_2^2).
725        let state_frob_ratio = {
726            let frob_sq: f64 = self.last_features.iter().map(|s| s * s).sum();
727            if self.max_frob_sq_ewma > 1e-15 {
728                (frob_sq / self.max_frob_sq_ewma).clamp(0.0, 1.0)
729            } else {
730                0.0
731            }
732        };
733
734        let depth_sufficiency = 0.5 * rls_saturation + 0.5 * state_frob_ratio;
735
736        // Weight magnitude: ||w||_2 / sqrt(d).
737        let w = self.readout.weights();
738        let effective_dof = if !w.is_empty() {
739            let sq_sum: f64 = w.iter().map(|wi| wi * wi).sum();
740            sq_sum.sqrt() / (w.len() as f64).sqrt()
741        } else {
742            0.0
743        };
744
745        Some(crate::automl::ConfigDiagnostics {
746            residual_alignment: self.alignment_ewma,
747            regularization_sensitivity: 0.0,
748            depth_sufficiency,
749            effective_dof,
750            uncertainty: self.readout.noise_variance().sqrt(),
751        })
752    }
753}
754
755// ---------------------------------------------------------------------------
756// Backward-compatible type alias
757// ---------------------------------------------------------------------------
758
759/// Deprecated name. Use [`StreamingLSTM`] instead.
760pub type StreamingsLSTM = StreamingLSTM;
761
762// ---------------------------------------------------------------------------
763// Tests
764// ---------------------------------------------------------------------------
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769
770    #[test]
771    fn slstm_config_builder_default() {
772        let config = SLSTMConfig::builder().build().unwrap();
773        assert_eq!(config.d_model, 32);
774        assert_eq!(config.warmup, 10);
775    }
776
777    #[test]
778    fn slstm_config_rejects_zero_d_model() {
779        assert!(SLSTMConfig::builder().d_model(0).build().is_err());
780    }
781
782    #[test]
783    fn slstm_new_creates_model() {
784        let config = SLSTMConfig::builder().d_model(16).build().unwrap();
785        let model = StreamingsLSTM::new(config);
786        assert_eq!(model.n_samples_seen(), 0);
787        assert!(!model.past_warmup());
788    }
789
790    #[test]
791    fn slstm_train_and_predict_finite() {
792        let config = SLSTMConfig::builder()
793            .d_model(16)
794            .warmup(5)
795            .build()
796            .unwrap();
797        let mut model = StreamingsLSTM::new(config);
798        for i in 0..50 {
799            let x = [i as f64 * 0.1, (i as f64).sin()];
800            let y = x[0] * 2.0 + 1.0;
801            model.train(&x, y);
802        }
803        let pred = model.predict(&[1.0, 0.5]);
804        assert!(pred.is_finite(), "prediction must be finite, got {pred}");
805        assert_eq!(model.n_samples_seen(), 45); // 50 - 5 warmup
806    }
807
808    #[test]
809    fn slstm_reset_clears_state() {
810        let config = SLSTMConfig::builder().d_model(8).warmup(3).build().unwrap();
811        let mut model = StreamingsLSTM::new(config);
812        for i in 0..20 {
813            model.train(&[i as f64], i as f64 * 2.0);
814        }
815        assert!(model.n_samples_seen() > 0);
816        model.reset();
817        assert_eq!(model.n_samples_seen(), 0);
818        assert!(!model.past_warmup());
819    }
820
821    #[test]
822    fn slstm_predict_before_train_returns_zero() {
823        let config = SLSTMConfig::builder().d_model(8).build().unwrap();
824        let model = StreamingsLSTM::new(config);
825        assert_eq!(model.predict(&[1.0, 2.0]), 0.0);
826    }
827
828    #[test]
829    #[allow(deprecated)]
830    fn slstm_diagnostics_array_finite() {
831        let config = SLSTMConfig::builder().d_model(8).warmup(3).build().unwrap();
832        let mut model = StreamingsLSTM::new(config);
833        for i in 0..30 {
834            model.train(&[i as f64 * 0.1], i as f64);
835        }
836        let diag = model.diagnostics_array();
837        for (idx, val) in diag.iter().enumerate() {
838            assert!(
839                val.is_finite(),
840                "diagnostics[{idx}] must be finite, got {val}"
841            );
842        }
843    }
844
845    #[test]
846    #[allow(deprecated)]
847    fn slstm_readout_weights_available_after_training() {
848        let config = SLSTMConfig::builder().d_model(8).warmup(3).build().unwrap();
849        let mut model = StreamingsLSTM::new(config);
850        assert!(model.readout_weights().is_none());
851        for i in 0..20 {
852            model.train(&[i as f64], i as f64);
853        }
854        assert!(model.readout_weights().is_some());
855    }
856
857    #[test]
858    fn slstm_streaming_learner_boxable() {
859        let config = SLSTMConfig::builder().d_model(8).build().unwrap();
860        let model = StreamingsLSTM::new(config);
861        let _boxed: Box<dyn StreamingLearner> = Box::new(model);
862    }
863
864    #[test]
865    fn slstm_plasticity_disabled_by_default() {
866        let config = SLSTMConfig::builder().d_model(8).build().unwrap();
867        assert!(
868            config.plasticity.is_none(),
869            "plasticity should default to None"
870        );
871        let model = StreamingsLSTM::new(config);
872        assert!(
873            model.plasticity_guard.is_none(),
874            "guard should be None when plasticity is disabled"
875        );
876    }
877
878    #[test]
879    fn slstm_plasticity_enabled_creates_guard() {
880        use crate::common::PlasticityConfig;
881        let config = SLSTMConfig::builder()
882            .d_model(16)
883            .plasticity(Some(PlasticityConfig::default()))
884            .build()
885            .unwrap();
886        assert!(
887            config.plasticity.is_some(),
888            "plasticity should be Some when set"
889        );
890        let model = StreamingsLSTM::new(config);
891        assert!(
892            model.plasticity_guard.is_some(),
893            "guard should be Some when plasticity is enabled"
894        );
895        let guard = model.plasticity_guard.as_ref().unwrap();
896        assert_eq!(
897            guard.n_groups(),
898            16,
899            "should have one group per hidden unit"
900        );
901    }
902
903    #[test]
904    fn slstm_plasticity_train_runs_without_panic() {
905        use crate::common::PlasticityConfig;
906        let config = SLSTMConfig::builder()
907            .d_model(8)
908            .warmup(3)
909            .plasticity(Some(PlasticityConfig::default()))
910            .build()
911            .unwrap();
912        let mut model = StreamingsLSTM::new(config);
913        for i in 0..600 {
914            let x = [i as f64 * 0.01, (i as f64 * 0.1).sin()];
915            let y = x[0] * 2.0 + 1.0;
916            model.train(&x, y);
917        }
918        let pred = model.predict(&[1.0, 0.5]);
919        assert!(
920            pred.is_finite(),
921            "plasticity-enabled model should produce finite predictions, got {pred}"
922        );
923    }
924
925    #[test]
926    fn slstm_plasticity_reset_clears_guard() {
927        use crate::common::PlasticityConfig;
928        let config = SLSTMConfig::builder()
929            .d_model(8)
930            .warmup(3)
931            .plasticity(Some(PlasticityConfig::default()))
932            .build()
933            .unwrap();
934        let mut model = StreamingsLSTM::new(config);
935        for i in 0..20 {
936            model.train(&[i as f64], i as f64);
937        }
938        model.reset();
939        let guard = model.plasticity_guard.as_ref().unwrap();
940        assert_eq!(
941            guard.n_updates(),
942            0,
943            "plasticity guard should be reset after model reset"
944        );
945        assert!(
946            model.prev_h_energy.iter().all(|&e| e == 0.0),
947            "prev_h_energy should be zeroed after reset"
948        );
949    }
950
951    #[test]
952    fn slstm_rejects_invalid_forgetting_factor() {
953        assert!(
954            SLSTMConfig::builder()
955                .d_model(8)
956                .forgetting_factor(0.0)
957                .build()
958                .is_err(),
959            "forgetting_factor=0 must be rejected"
960        );
961        assert!(
962            SLSTMConfig::builder()
963                .d_model(8)
964                .forgetting_factor(1.01)
965                .build()
966                .is_err(),
967            "forgetting_factor>1 must be rejected"
968        );
969    }
970
971    #[test]
972    fn slstm_rejects_invalid_delta_rls() {
973        assert!(
974            SLSTMConfig::builder()
975                .d_model(8)
976                .delta_rls(0.0)
977                .build()
978                .is_err(),
979            "delta_rls=0 must be rejected"
980        );
981        assert!(
982            SLSTMConfig::builder()
983                .d_model(8)
984                .delta_rls(-1.0)
985                .build()
986                .is_err(),
987            "delta_rls<0 must be rejected"
988        );
989    }
990
991    #[test]
992    fn test_lstm_nan_input_skipped() {
993        // Train with valid samples past warmup, then send a NaN sample.
994        // Model should not panic and should remain healthy (weights finite).
995        let config = SLSTMConfig::builder().d_model(8).warmup(3).build().unwrap();
996        let mut model = StreamingLSTM::new(config);
997        for i in 0..20 {
998            model.train(&[i as f64 * 0.1], i as f64);
999        }
1000        let samples_before = model.n_samples_seen();
1001        // Feed NaN — should be silently skipped
1002        model.train(&[f64::NAN], 1.0);
1003        // Sample count should not change (NaN skipped at readout step)
1004        // Note: total_seen increments before the finiteness check, but samples_trained does not.
1005        assert_eq!(
1006            model.n_samples_seen(),
1007            samples_before,
1008            "NaN sample should not increment samples_trained: before={}, after={}",
1009            samples_before,
1010            model.n_samples_seen()
1011        );
1012        // Prediction should still be finite after NaN was fed
1013        let pred = model.predict(&[1.0]);
1014        assert!(
1015            pred.is_finite(),
1016            "prediction should be finite after NaN input, got {pred}"
1017        );
1018    }
1019
1020    #[test]
1021    fn test_streaming_lstm_alias() {
1022        // StreamingLSTM (new name) and StreamingsLSTM (alias) refer to the same type.
1023        let config = SLSTMConfig::builder().d_model(8).build().unwrap();
1024        let model: StreamingLSTM = StreamingLSTM::new(config.clone());
1025        let _alias: StreamingsLSTM = StreamingsLSTM::new(config);
1026        assert_eq!(
1027            model.config().d_model,
1028            8,
1029            "StreamingLSTM should have correct d_model"
1030        );
1031    }
1032
1033    /// Regression test: sLSTM must achieve reasonable RMSE on a sine regression task.
1034    ///
1035    /// Before the input normalization + cell output clamp fix, the RLS readout
1036    /// would receive an inconsistent feature distribution (cell resets mid-stream)
1037    /// and produce RMSE ~180. After the fix, RMSE should be well under 5.0.
1038    #[test]
1039    fn test_slstm_sine_regression_reasonable() {
1040        let config = SLSTMConfig::builder()
1041            .d_model(16)
1042            .warmup(10)
1043            .forgetting_factor(0.998)
1044            .build()
1045            .unwrap();
1046        let mut model = StreamingLSTM::new(config);
1047
1048        // Train on sin(x) for 500 samples.
1049        let n = 500usize;
1050        for i in 0..n {
1051            let x = i as f64 * 0.05;
1052            model.train(&[x], x.sin());
1053        }
1054
1055        // Compute RMSE on training samples (streaming: we evaluate as we trained).
1056        // Re-run predictions over the same sequence to measure fit quality.
1057        let mut model2 = {
1058            let config2 = SLSTMConfig::builder()
1059                .d_model(16)
1060                .warmup(10)
1061                .forgetting_factor(0.998)
1062                .build()
1063                .unwrap();
1064            StreamingLSTM::new(config2)
1065        };
1066        let mut sq_err_sum = 0.0;
1067        let mut count = 0usize;
1068        for i in 0..n {
1069            let x = i as f64 * 0.05;
1070            let y = x.sin();
1071            if model2.past_warmup() {
1072                let pred = model2.predict(&[x]);
1073                let err = pred - y;
1074                sq_err_sum += err * err;
1075                count += 1;
1076            }
1077            model2.train(&[x], y);
1078        }
1079        let rmse = if count > 0 {
1080            (sq_err_sum / count as f64).sqrt()
1081        } else {
1082            f64::INFINITY
1083        };
1084        assert!(
1085            rmse < 5.0,
1086            "sLSTM sine regression RMSE should be < 5.0 after fix, got {rmse:.4} (count={count})"
1087        );
1088    }
1089
1090    /// Option D correctness: predict(x_t) must strongly correlate with x_t, not x_{t-1}.
1091    ///
1092    /// Train on y_t = x_t[0] * 2.0 for many steps. Then verify that predict(x_a) and
1093    /// predict(x_b) differ meaningfully when x_a and x_b differ in the current input.
1094    /// A model that predicts from stale (prior-step) features would fail to distinguish
1095    /// inputs that differ only in the current timestep.
1096    #[test]
1097    fn lstm_predict_reads_current_input() {
1098        let config = SLSTMConfig::builder()
1099            .d_model(16)
1100            .warmup(5)
1101            .forgetting_factor(0.999)
1102            .build()
1103            .unwrap();
1104        let mut model = StreamingLSTM::new(config);
1105
1106        // Train on y_t = x_t[0] * 2.0 for 200 samples.
1107        for i in 0..200 {
1108            let x0 = (i as f64) * 0.05;
1109            model.train(&[x0], x0 * 2.0);
1110        }
1111
1112        // predict(x_a) and predict(x_b) should differ for x_a != x_b.
1113        let pred_a = model.predict(&[1.0]);
1114        let pred_b = model.predict(&[5.0]);
1115
1116        assert!(
1117            pred_a.is_finite() && pred_b.is_finite(),
1118            "both predictions must be finite: pred_a={pred_a}, pred_b={pred_b}"
1119        );
1120        assert!(
1121            (pred_a - pred_b).abs() > 0.1,
1122            "predict must respond to current input: pred_a={pred_a} (x=1.0), pred_b={pred_b} (x=5.0), diff={}",
1123            (pred_a - pred_b).abs()
1124        );
1125    }
1126
1127    /// Verify that n_heads and forget_bias_init are wired from SLSTMConfig through
1128    /// to the underlying SLSTMCell (Beck et al. 2024 §2.2 + §3.2).
1129    ///
1130    /// A model with n_heads=2 must use the block-diagonal recurrent path.
1131    /// The cell's n_heads() accessor must return 2, confirming the wiring.
1132    /// The model must also produce finite predictions after training.
1133    #[test]
1134    fn slstm_model_uses_multi_head_block_diagonal() {
1135        let d_model = 8usize;
1136        let bias = irithyll_core::lstm::SLSTMCell::forget_bias_linspace(3.0, 6.0, d_model);
1137
1138        let config = SLSTMConfig::builder()
1139            .d_model(d_model)
1140            .n_heads(2)
1141            .forget_bias_init(Some(bias))
1142            .warmup(5)
1143            .build()
1144            .unwrap();
1145
1146        assert_eq!(config.n_heads, 2, "config must store n_heads=2");
1147        assert!(
1148            config.forget_bias_init.is_some(),
1149            "config must store forget_bias_init"
1150        );
1151
1152        let mut model = StreamingLSTM::new(config);
1153
1154        // Verify the cell's n_heads accessor sees the wired value.
1155        assert_eq!(
1156            model.cell.n_heads(),
1157            2,
1158            "StreamingLSTM cell must have n_heads=2 from config"
1159        );
1160
1161        // Train and predict must work without panic.
1162        for i in 0..50 {
1163            let x = [i as f64 * 0.1, (i as f64).sin()];
1164            model.train(&x, x[0] * 2.0 + 1.0);
1165        }
1166        let pred = model.predict(&[1.0, 0.5]);
1167        assert!(
1168            pred.is_finite(),
1169            "multi-head model prediction must be finite, got {pred}"
1170        );
1171    }
1172
1173    /// Verify builder rejects n_heads that does not divide d_model.
1174    #[test]
1175    fn slstm_config_rejects_invalid_n_heads() {
1176        // 3 does not divide 8.
1177        assert!(
1178            SLSTMConfig::builder()
1179                .d_model(8)
1180                .n_heads(3)
1181                .build()
1182                .is_err(),
1183            "n_heads=3 must be rejected when d_model=8"
1184        );
1185        // n_heads=0 is always invalid.
1186        assert!(
1187            SLSTMConfig::builder()
1188                .d_model(8)
1189                .n_heads(0)
1190                .build()
1191                .is_err(),
1192            "n_heads=0 must be rejected"
1193        );
1194    }
1195
1196    /// Verify builder rejects forget_bias_init with wrong length.
1197    #[test]
1198    fn slstm_config_rejects_wrong_bias_length() {
1199        let wrong_bias = vec![1.0f64; 5]; // d_model is 8
1200        assert!(
1201            SLSTMConfig::builder()
1202                .d_model(8)
1203                .forget_bias_init(Some(wrong_bias))
1204                .build()
1205                .is_err(),
1206            "forget_bias_init of wrong length must be rejected"
1207        );
1208    }
1209}