1use crate::common::PlasticityConfig;
19use crate::error::ConfigError;
20use crate::learner::StreamingLearner;
21use crate::learners::RecursiveLeastSquares;
22use irithyll_core::continual::{ContinualStrategy, NeuronRegeneration};
23
24#[derive(Debug, Clone)]
41pub struct SLSTMConfig {
42 pub d_model: usize,
44 pub forgetting_factor: f64,
46 pub delta_rls: f64,
48 pub warmup: usize,
50 pub seed: u64,
52 pub n_heads: usize,
58 pub forget_bias_init: Option<Vec<f64>>,
66 pub plasticity: Option<PlasticityConfig>,
73}
74
75impl Default for SLSTMConfig {
76 fn default() -> Self {
77 Self {
78 d_model: 32,
79 forgetting_factor: 0.998,
80 delta_rls: 100.0,
81 warmup: 10,
82 seed: 42,
83 n_heads: 1,
84 forget_bias_init: None,
85 plasticity: None,
86 }
87 }
88}
89
90impl std::fmt::Display for SLSTMConfig {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(
93 f,
94 "SLSTMConfig(d_model={}, n_heads={}, ff={}, delta_rls={}, warmup={}, seed={}, plasticity={})",
95 self.d_model,
96 self.n_heads,
97 self.forgetting_factor,
98 self.delta_rls,
99 self.warmup,
100 self.seed,
101 self.plasticity.is_some()
102 )
103 }
104}
105
106pub struct SLSTMConfigBuilder {
126 config: SLSTMConfig,
127}
128
129impl SLSTMConfig {
130 pub fn builder() -> SLSTMConfigBuilder {
132 SLSTMConfigBuilder {
133 config: SLSTMConfig::default(),
134 }
135 }
136}
137
138impl SLSTMConfigBuilder {
139 pub fn d_model(mut self, d: usize) -> Self {
141 self.config.d_model = d;
142 self
143 }
144
145 pub fn forgetting_factor(mut self, f: f64) -> Self {
147 self.config.forgetting_factor = f;
148 self
149 }
150
151 pub fn delta_rls(mut self, d: f64) -> Self {
153 self.config.delta_rls = d;
154 self
155 }
156
157 pub fn warmup(mut self, w: usize) -> Self {
159 self.config.warmup = w;
160 self
161 }
162
163 pub fn seed(mut self, s: u64) -> Self {
165 self.config.seed = s;
166 self
167 }
168
169 pub fn n_heads(mut self, n: usize) -> Self {
174 self.config.n_heads = n;
175 self
176 }
177
178 pub fn forget_bias_init(mut self, bias: Option<Vec<f64>>) -> Self {
186 self.config.forget_bias_init = bias;
187 self
188 }
189
190 pub fn plasticity(mut self, p: Option<PlasticityConfig>) -> Self {
197 self.config.plasticity = p;
198 self
199 }
200
201 pub fn build(self) -> Result<SLSTMConfig, ConfigError> {
212 let c = &self.config;
213 if c.d_model == 0 {
214 return Err(ConfigError::out_of_range(
215 "d_model",
216 "must be > 0",
217 c.d_model,
218 ));
219 }
220 if c.forgetting_factor <= 0.0 || c.forgetting_factor > 1.0 {
221 return Err(ConfigError::out_of_range(
222 "forgetting_factor",
223 "must be in (0, 1]",
224 c.forgetting_factor,
225 ));
226 }
227 if c.delta_rls <= 0.0 {
228 return Err(ConfigError::out_of_range(
229 "delta_rls",
230 "must be > 0",
231 c.delta_rls,
232 ));
233 }
234 if c.n_heads == 0 {
235 return Err(ConfigError::out_of_range(
236 "n_heads",
237 "must be > 0",
238 c.n_heads,
239 ));
240 }
241 if c.d_model % c.n_heads != 0 {
242 return Err(ConfigError::invalid(
243 "n_heads",
244 format!("must divide d_model ({}), got {}", c.d_model, c.n_heads),
245 ));
246 }
247 if let Some(ref bias) = c.forget_bias_init {
248 if bias.len() != c.d_model {
249 return Err(ConfigError::invalid(
250 "forget_bias_init",
251 format!(
252 "length must equal d_model ({}), got {}",
253 c.d_model,
254 bias.len()
255 ),
256 ));
257 }
258 }
259 Ok(self.config)
260 }
261}
262
263pub struct StreamingLSTM {
285 config: SLSTMConfig,
286 cell: irithyll_core::lstm::SLSTMCell,
287 readout: RecursiveLeastSquares,
288 last_features: Vec<f64>,
289 total_seen: u64,
290 samples_trained: u64,
291 rolling_uncertainty: f64,
293 short_term_error: f64,
295 prev_prediction: f64,
297 max_frob_sq_ewma: f64,
299 alignment_ewma: f64,
301 prev_change: f64,
303 prev_prev_change: f64,
305 plasticity_guard: Option<NeuronRegeneration>,
310 prev_h_energy: Vec<f64>,
313 input_mean: Vec<f64>,
315 input_var: Vec<f64>,
317 input_count: u64,
319}
320
321impl StreamingLSTM {
322 pub fn new(config: SLSTMConfig) -> Self {
324 let cell = if config.n_heads > 1 || config.forget_bias_init.is_some() {
328 let bias = config
329 .forget_bias_init
330 .clone()
331 .unwrap_or_else(|| vec![1.0; config.d_model]);
332 irithyll_core::lstm::SLSTMCell::with_config(
333 config.d_model,
334 config.n_heads,
335 bias,
336 config.seed,
337 )
338 } else {
339 irithyll_core::lstm::SLSTMCell::new(config.d_model, config.seed)
340 };
341 let readout = RecursiveLeastSquares::with_delta(config.forgetting_factor, config.delta_rls);
342 let last_features = vec![0.0; config.d_model];
343
344 let plasticity_guard = config.plasticity.as_ref().map(|p| {
347 NeuronRegeneration::new(
348 config.d_model,
349 1, p.regen_fraction,
351 p.regen_interval,
352 p.utility_alpha,
353 config.seed.wrapping_add(0x_DEAD_CAFE),
354 )
355 });
356 let prev_h_energy = vec![0.0; config.d_model];
357
358 Self {
359 config,
360 cell,
361 readout,
362 last_features,
363 total_seen: 0,
364 samples_trained: 0,
365 rolling_uncertainty: 0.0,
366 short_term_error: 0.0,
367 prev_prediction: 0.0,
368 max_frob_sq_ewma: 0.0,
369 alignment_ewma: 0.0,
370 prev_change: 0.0,
371 prev_prev_change: 0.0,
372 plasticity_guard,
373 prev_h_energy,
374 input_mean: Vec::new(),
375 input_var: Vec::new(),
376 input_count: 0,
377 }
378 }
379
380 fn normalize_input(&mut self, features: &[f64]) -> Vec<f64> {
385 let d = features.len();
386 if self.input_mean.len() != d {
387 self.input_mean = vec![0.0; d];
388 self.input_var = vec![0.0; d];
389 }
390 self.input_count += 1;
391 let n = self.input_count as f64;
392 let mut out = vec![0.0; d];
393 for i in 0..d {
394 let x = features[i];
395 let delta = x - self.input_mean[i];
396 self.input_mean[i] += delta / n;
397 let delta2 = x - self.input_mean[i];
398 self.input_var[i] += delta * delta2;
399 let std = if n > 1.0 {
400 (self.input_var[i] / (n - 1.0)).sqrt()
401 } else {
402 1.0
403 };
404 let std = if std < 1e-8 { 1.0 } else { std };
405 out[i] = ((x - self.input_mean[i]) / std).clamp(-5.0, 5.0);
406 }
407 out
408 }
409
410 #[inline]
412 pub fn past_warmup(&self) -> bool {
413 self.total_seen > self.config.warmup as u64
414 }
415
416 pub fn config(&self) -> &SLSTMConfig {
418 &self.config
419 }
420
421 #[inline]
428 pub fn prediction_uncertainty(&self) -> f64 {
429 self.readout.noise_variance().sqrt()
430 }
431}
432
433impl StreamingLearner for StreamingLSTM {
434 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
435 let current_uncertainty = self.readout.noise_variance().sqrt();
437 const UNCERTAINTY_ALPHA: f64 = 0.001;
438 if self.total_seen == 0 {
439 self.rolling_uncertainty = current_uncertainty;
440 } else {
441 self.rolling_uncertainty = (1.0 - UNCERTAINTY_ALPHA) * self.rolling_uncertainty
442 + UNCERTAINTY_ALPHA * current_uncertainty;
443 }
444
445 if self.rolling_uncertainty > 1e-10 {
446 let ratio = (current_uncertainty / self.rolling_uncertainty).clamp(0.5, 3.0);
447 let base_ff = self.config.forgetting_factor;
448 let adaptive_ff = (base_ff - 0.02 * (ratio - 1.0)).clamp(0.95, base_ff);
449 self.readout.set_forgetting_factor(adaptive_ff);
450 }
451
452 if self.past_warmup() {
454 let current_pred = self.readout.predict(&self.last_features);
455 let pred_error = target - current_pred;
456
457 let sq_err = pred_error * pred_error;
459 if self.samples_trained == 0 {
460 self.short_term_error = sq_err;
461 } else {
462 self.short_term_error = 0.9 * self.short_term_error + 0.1 * sq_err;
463 }
464 let _short_rmse = self.short_term_error.sqrt();
470
471 let current_change = current_pred - self.prev_prediction;
473 if self.samples_trained > 0 {
474 let acceleration = current_change - self.prev_change;
475 let prev_acceleration = self.prev_change - self.prev_prev_change;
476 let agreement = if acceleration.abs() > 1e-15 && prev_acceleration.abs() > 1e-15 {
477 if (acceleration > 0.0) == (prev_acceleration > 0.0) {
478 1.0
479 } else {
480 -1.0
481 }
482 } else {
483 0.0
484 };
485 if self.samples_trained == 1 {
486 self.alignment_ewma = agreement;
487 } else {
488 self.alignment_ewma = 0.95 * self.alignment_ewma + 0.05 * agreement;
489 }
490 }
491 self.prev_prev_change = self.prev_change;
492 self.prev_change = current_change;
493 self.prev_prediction = current_pred;
494 }
495
496 if !features.iter().all(|f| f.is_finite()) {
498 return;
499 }
500
501 let normalized = self.normalize_input(features);
506
507 let pre_cell_features: Option<Vec<f64>> = if self.total_seen > 0 {
513 let mut out = self.cell.forward_predict(&normalized);
514 for v in &mut out {
515 *v = v.clamp(-3.0, 3.0);
516 }
517 Some(out)
518 } else {
519 None
520 };
521
522 let mut cell_output = self.cell.forward(&normalized).to_vec();
527 self.total_seen += 1;
528
529 if self.past_warmup() {
534 if let Some(ref feats) = pre_cell_features {
535 if feats.iter().all(|f| f.is_finite()) {
536 self.readout.train_one(feats, target, weight);
537 self.samples_trained += 1;
538 }
539 }
540 }
541
542 for v in &mut cell_output {
547 *v = v.clamp(-3.0, 3.0);
548 }
549
550 let frob_sq: f64 = cell_output.iter().map(|s| s * s).sum();
552 const FROB_ALPHA: f64 = 0.001;
553 self.max_frob_sq_ewma = if frob_sq > self.max_frob_sq_ewma {
554 frob_sq
555 } else {
556 (1.0 - FROB_ALPHA) * self.max_frob_sq_ewma + FROB_ALPHA * frob_sq
557 };
558
559 if let Some(ref mut guard) = self.plasticity_guard {
564 let mut h_energy: Vec<f64> = self.cell.hidden_state().iter().map(|x| x.abs()).collect();
565 guard.pre_update(&self.prev_h_energy, &mut h_energy);
566 guard.post_update(&self.prev_h_energy);
567 let mut reinit_rng = self
571 .config
572 .seed
573 .wrapping_add(0xCAFE_BABE_u64.wrapping_mul(self.total_seen));
574 for j in 0..guard.n_groups() {
575 if guard.was_regenerated(j) {
576 self.cell.reinitialize_unit(j, &mut reinit_rng);
577 }
578 }
579 self.prev_h_energy = self.cell.hidden_state().iter().map(|x| x.abs()).collect();
580 }
581
582 self.last_features = cell_output;
586 }
587
588 fn predict(&self, features: &[f64]) -> f64 {
589 if self.total_seen == 0 {
590 return 0.0;
591 }
592 let d = features.len();
594 let mut normalized = vec![0.0; d];
595 if self.input_count > 0 && self.input_mean.len() == d {
596 let n = self.input_count as f64;
597 for i in 0..d {
598 let std = if n > 1.0 {
599 (self.input_var[i] / (n - 1.0)).sqrt()
600 } else {
601 1.0
602 };
603 let std = if std < 1e-8 { 1.0 } else { std };
604 normalized[i] = ((features[i] - self.input_mean[i]) / std).clamp(-5.0, 5.0);
605 }
606 } else {
607 normalized.copy_from_slice(features);
608 }
609 let mut cell_features = self.cell.forward_predict(&normalized);
610 for v in &mut cell_features {
611 *v = v.clamp(-3.0, 3.0);
612 }
613 self.readout.predict(&cell_features)
614 }
615
616 #[inline]
617 fn n_samples_seen(&self) -> u64 {
618 self.samples_trained
619 }
620
621 fn reset(&mut self) {
622 self.cell.reset();
623 self.readout.reset();
624 self.last_features.iter_mut().for_each(|f| *f = 0.0);
625 self.total_seen = 0;
626 self.samples_trained = 0;
627 self.rolling_uncertainty = 0.0;
628 self.short_term_error = 0.0;
629 self.prev_prediction = 0.0;
630 self.prev_change = 0.0;
631 self.prev_prev_change = 0.0;
632 self.alignment_ewma = 0.0;
633 self.max_frob_sq_ewma = 0.0;
634 if let Some(ref mut guard) = self.plasticity_guard {
635 guard.reset();
636 }
637 self.prev_h_energy.fill(0.0);
638 self.input_mean.clear();
639 self.input_var.clear();
640 self.input_count = 0;
641 }
642
643 #[allow(deprecated)]
644 fn diagnostics_array(&self) -> [f64; 5] {
645 <Self as crate::learner::Tunable>::diagnostics_array(self)
646 }
647
648 #[allow(deprecated)]
649 fn readout_weights(&self) -> Option<&[f64]> {
650 let w = <Self as crate::learner::HasReadout>::readout_weights(self);
651 if w.is_empty() {
652 None
653 } else {
654 Some(w)
655 }
656 }
657}
658
659impl crate::learner::Tunable for StreamingLSTM {
660 fn diagnostics_array(&self) -> [f64; 5] {
661 use crate::automl::DiagnosticSource;
662 match self.config_diagnostics() {
663 Some(d) => [
664 d.residual_alignment,
665 d.regularization_sensitivity,
666 d.depth_sufficiency,
667 d.effective_dof,
668 d.uncertainty,
669 ],
670 None => [0.0; 5],
671 }
672 }
673
674 fn adjust_config(&mut self, lr_multiplier: f64, _lambda_delta: f64) {
675 <crate::learners::RecursiveLeastSquares as crate::learner::Tunable>::adjust_config(
677 &mut self.readout,
678 lr_multiplier,
679 0.0,
680 );
681 }
682}
683
684impl crate::learner::HasReadout for StreamingLSTM {
685 fn readout_weights(&self) -> &[f64] {
686 self.readout.weights()
687 }
688}
689
690impl std::fmt::Debug for StreamingLSTM {
695 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
696 f.debug_struct("StreamingLSTM")
697 .field("d_model", &self.config.d_model)
698 .field("warmup", &self.config.warmup)
699 .field("total_seen", &self.total_seen)
700 .field("samples_trained", &self.samples_trained)
701 .field("past_warmup", &self.past_warmup())
702 .finish()
703 }
704}
705
706impl crate::automl::DiagnosticSource for StreamingLSTM {
711 fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
712 let rls_saturation = {
714 let p = self.readout.p_matrix();
715 let d = self.readout.weights().len();
716 if d > 0 && self.readout.delta() > 0.0 {
717 let trace: f64 = (0..d).map(|i| p[i * d + i]).sum();
718 (1.0 - trace / (self.readout.delta() * d as f64)).clamp(0.0, 1.0)
719 } else {
720 0.0
721 }
722 };
723
724 let state_frob_ratio = {
726 let frob_sq: f64 = self.last_features.iter().map(|s| s * s).sum();
727 if self.max_frob_sq_ewma > 1e-15 {
728 (frob_sq / self.max_frob_sq_ewma).clamp(0.0, 1.0)
729 } else {
730 0.0
731 }
732 };
733
734 let depth_sufficiency = 0.5 * rls_saturation + 0.5 * state_frob_ratio;
735
736 let w = self.readout.weights();
738 let effective_dof = if !w.is_empty() {
739 let sq_sum: f64 = w.iter().map(|wi| wi * wi).sum();
740 sq_sum.sqrt() / (w.len() as f64).sqrt()
741 } else {
742 0.0
743 };
744
745 Some(crate::automl::ConfigDiagnostics {
746 residual_alignment: self.alignment_ewma,
747 regularization_sensitivity: 0.0,
748 depth_sufficiency,
749 effective_dof,
750 uncertainty: self.readout.noise_variance().sqrt(),
751 })
752 }
753}
754
755pub type StreamingsLSTM = StreamingLSTM;
761
762#[cfg(test)]
767mod tests {
768 use super::*;
769
770 #[test]
771 fn slstm_config_builder_default() {
772 let config = SLSTMConfig::builder().build().unwrap();
773 assert_eq!(config.d_model, 32);
774 assert_eq!(config.warmup, 10);
775 }
776
777 #[test]
778 fn slstm_config_rejects_zero_d_model() {
779 assert!(SLSTMConfig::builder().d_model(0).build().is_err());
780 }
781
782 #[test]
783 fn slstm_new_creates_model() {
784 let config = SLSTMConfig::builder().d_model(16).build().unwrap();
785 let model = StreamingsLSTM::new(config);
786 assert_eq!(model.n_samples_seen(), 0);
787 assert!(!model.past_warmup());
788 }
789
790 #[test]
791 fn slstm_train_and_predict_finite() {
792 let config = SLSTMConfig::builder()
793 .d_model(16)
794 .warmup(5)
795 .build()
796 .unwrap();
797 let mut model = StreamingsLSTM::new(config);
798 for i in 0..50 {
799 let x = [i as f64 * 0.1, (i as f64).sin()];
800 let y = x[0] * 2.0 + 1.0;
801 model.train(&x, y);
802 }
803 let pred = model.predict(&[1.0, 0.5]);
804 assert!(pred.is_finite(), "prediction must be finite, got {pred}");
805 assert_eq!(model.n_samples_seen(), 45); }
807
808 #[test]
809 fn slstm_reset_clears_state() {
810 let config = SLSTMConfig::builder().d_model(8).warmup(3).build().unwrap();
811 let mut model = StreamingsLSTM::new(config);
812 for i in 0..20 {
813 model.train(&[i as f64], i as f64 * 2.0);
814 }
815 assert!(model.n_samples_seen() > 0);
816 model.reset();
817 assert_eq!(model.n_samples_seen(), 0);
818 assert!(!model.past_warmup());
819 }
820
821 #[test]
822 fn slstm_predict_before_train_returns_zero() {
823 let config = SLSTMConfig::builder().d_model(8).build().unwrap();
824 let model = StreamingsLSTM::new(config);
825 assert_eq!(model.predict(&[1.0, 2.0]), 0.0);
826 }
827
828 #[test]
829 #[allow(deprecated)]
830 fn slstm_diagnostics_array_finite() {
831 let config = SLSTMConfig::builder().d_model(8).warmup(3).build().unwrap();
832 let mut model = StreamingsLSTM::new(config);
833 for i in 0..30 {
834 model.train(&[i as f64 * 0.1], i as f64);
835 }
836 let diag = model.diagnostics_array();
837 for (idx, val) in diag.iter().enumerate() {
838 assert!(
839 val.is_finite(),
840 "diagnostics[{idx}] must be finite, got {val}"
841 );
842 }
843 }
844
845 #[test]
846 #[allow(deprecated)]
847 fn slstm_readout_weights_available_after_training() {
848 let config = SLSTMConfig::builder().d_model(8).warmup(3).build().unwrap();
849 let mut model = StreamingsLSTM::new(config);
850 assert!(model.readout_weights().is_none());
851 for i in 0..20 {
852 model.train(&[i as f64], i as f64);
853 }
854 assert!(model.readout_weights().is_some());
855 }
856
857 #[test]
858 fn slstm_streaming_learner_boxable() {
859 let config = SLSTMConfig::builder().d_model(8).build().unwrap();
860 let model = StreamingsLSTM::new(config);
861 let _boxed: Box<dyn StreamingLearner> = Box::new(model);
862 }
863
864 #[test]
865 fn slstm_plasticity_disabled_by_default() {
866 let config = SLSTMConfig::builder().d_model(8).build().unwrap();
867 assert!(
868 config.plasticity.is_none(),
869 "plasticity should default to None"
870 );
871 let model = StreamingsLSTM::new(config);
872 assert!(
873 model.plasticity_guard.is_none(),
874 "guard should be None when plasticity is disabled"
875 );
876 }
877
878 #[test]
879 fn slstm_plasticity_enabled_creates_guard() {
880 use crate::common::PlasticityConfig;
881 let config = SLSTMConfig::builder()
882 .d_model(16)
883 .plasticity(Some(PlasticityConfig::default()))
884 .build()
885 .unwrap();
886 assert!(
887 config.plasticity.is_some(),
888 "plasticity should be Some when set"
889 );
890 let model = StreamingsLSTM::new(config);
891 assert!(
892 model.plasticity_guard.is_some(),
893 "guard should be Some when plasticity is enabled"
894 );
895 let guard = model.plasticity_guard.as_ref().unwrap();
896 assert_eq!(
897 guard.n_groups(),
898 16,
899 "should have one group per hidden unit"
900 );
901 }
902
903 #[test]
904 fn slstm_plasticity_train_runs_without_panic() {
905 use crate::common::PlasticityConfig;
906 let config = SLSTMConfig::builder()
907 .d_model(8)
908 .warmup(3)
909 .plasticity(Some(PlasticityConfig::default()))
910 .build()
911 .unwrap();
912 let mut model = StreamingsLSTM::new(config);
913 for i in 0..600 {
914 let x = [i as f64 * 0.01, (i as f64 * 0.1).sin()];
915 let y = x[0] * 2.0 + 1.0;
916 model.train(&x, y);
917 }
918 let pred = model.predict(&[1.0, 0.5]);
919 assert!(
920 pred.is_finite(),
921 "plasticity-enabled model should produce finite predictions, got {pred}"
922 );
923 }
924
925 #[test]
926 fn slstm_plasticity_reset_clears_guard() {
927 use crate::common::PlasticityConfig;
928 let config = SLSTMConfig::builder()
929 .d_model(8)
930 .warmup(3)
931 .plasticity(Some(PlasticityConfig::default()))
932 .build()
933 .unwrap();
934 let mut model = StreamingsLSTM::new(config);
935 for i in 0..20 {
936 model.train(&[i as f64], i as f64);
937 }
938 model.reset();
939 let guard = model.plasticity_guard.as_ref().unwrap();
940 assert_eq!(
941 guard.n_updates(),
942 0,
943 "plasticity guard should be reset after model reset"
944 );
945 assert!(
946 model.prev_h_energy.iter().all(|&e| e == 0.0),
947 "prev_h_energy should be zeroed after reset"
948 );
949 }
950
951 #[test]
952 fn slstm_rejects_invalid_forgetting_factor() {
953 assert!(
954 SLSTMConfig::builder()
955 .d_model(8)
956 .forgetting_factor(0.0)
957 .build()
958 .is_err(),
959 "forgetting_factor=0 must be rejected"
960 );
961 assert!(
962 SLSTMConfig::builder()
963 .d_model(8)
964 .forgetting_factor(1.01)
965 .build()
966 .is_err(),
967 "forgetting_factor>1 must be rejected"
968 );
969 }
970
971 #[test]
972 fn slstm_rejects_invalid_delta_rls() {
973 assert!(
974 SLSTMConfig::builder()
975 .d_model(8)
976 .delta_rls(0.0)
977 .build()
978 .is_err(),
979 "delta_rls=0 must be rejected"
980 );
981 assert!(
982 SLSTMConfig::builder()
983 .d_model(8)
984 .delta_rls(-1.0)
985 .build()
986 .is_err(),
987 "delta_rls<0 must be rejected"
988 );
989 }
990
991 #[test]
992 fn test_lstm_nan_input_skipped() {
993 let config = SLSTMConfig::builder().d_model(8).warmup(3).build().unwrap();
996 let mut model = StreamingLSTM::new(config);
997 for i in 0..20 {
998 model.train(&[i as f64 * 0.1], i as f64);
999 }
1000 let samples_before = model.n_samples_seen();
1001 model.train(&[f64::NAN], 1.0);
1003 assert_eq!(
1006 model.n_samples_seen(),
1007 samples_before,
1008 "NaN sample should not increment samples_trained: before={}, after={}",
1009 samples_before,
1010 model.n_samples_seen()
1011 );
1012 let pred = model.predict(&[1.0]);
1014 assert!(
1015 pred.is_finite(),
1016 "prediction should be finite after NaN input, got {pred}"
1017 );
1018 }
1019
1020 #[test]
1021 fn test_streaming_lstm_alias() {
1022 let config = SLSTMConfig::builder().d_model(8).build().unwrap();
1024 let model: StreamingLSTM = StreamingLSTM::new(config.clone());
1025 let _alias: StreamingsLSTM = StreamingsLSTM::new(config);
1026 assert_eq!(
1027 model.config().d_model,
1028 8,
1029 "StreamingLSTM should have correct d_model"
1030 );
1031 }
1032
1033 #[test]
1039 fn test_slstm_sine_regression_reasonable() {
1040 let config = SLSTMConfig::builder()
1041 .d_model(16)
1042 .warmup(10)
1043 .forgetting_factor(0.998)
1044 .build()
1045 .unwrap();
1046 let mut model = StreamingLSTM::new(config);
1047
1048 let n = 500usize;
1050 for i in 0..n {
1051 let x = i as f64 * 0.05;
1052 model.train(&[x], x.sin());
1053 }
1054
1055 let mut model2 = {
1058 let config2 = SLSTMConfig::builder()
1059 .d_model(16)
1060 .warmup(10)
1061 .forgetting_factor(0.998)
1062 .build()
1063 .unwrap();
1064 StreamingLSTM::new(config2)
1065 };
1066 let mut sq_err_sum = 0.0;
1067 let mut count = 0usize;
1068 for i in 0..n {
1069 let x = i as f64 * 0.05;
1070 let y = x.sin();
1071 if model2.past_warmup() {
1072 let pred = model2.predict(&[x]);
1073 let err = pred - y;
1074 sq_err_sum += err * err;
1075 count += 1;
1076 }
1077 model2.train(&[x], y);
1078 }
1079 let rmse = if count > 0 {
1080 (sq_err_sum / count as f64).sqrt()
1081 } else {
1082 f64::INFINITY
1083 };
1084 assert!(
1085 rmse < 5.0,
1086 "sLSTM sine regression RMSE should be < 5.0 after fix, got {rmse:.4} (count={count})"
1087 );
1088 }
1089
1090 #[test]
1097 fn lstm_predict_reads_current_input() {
1098 let config = SLSTMConfig::builder()
1099 .d_model(16)
1100 .warmup(5)
1101 .forgetting_factor(0.999)
1102 .build()
1103 .unwrap();
1104 let mut model = StreamingLSTM::new(config);
1105
1106 for i in 0..200 {
1108 let x0 = (i as f64) * 0.05;
1109 model.train(&[x0], x0 * 2.0);
1110 }
1111
1112 let pred_a = model.predict(&[1.0]);
1114 let pred_b = model.predict(&[5.0]);
1115
1116 assert!(
1117 pred_a.is_finite() && pred_b.is_finite(),
1118 "both predictions must be finite: pred_a={pred_a}, pred_b={pred_b}"
1119 );
1120 assert!(
1121 (pred_a - pred_b).abs() > 0.1,
1122 "predict must respond to current input: pred_a={pred_a} (x=1.0), pred_b={pred_b} (x=5.0), diff={}",
1123 (pred_a - pred_b).abs()
1124 );
1125 }
1126
1127 #[test]
1134 fn slstm_model_uses_multi_head_block_diagonal() {
1135 let d_model = 8usize;
1136 let bias = irithyll_core::lstm::SLSTMCell::forget_bias_linspace(3.0, 6.0, d_model);
1137
1138 let config = SLSTMConfig::builder()
1139 .d_model(d_model)
1140 .n_heads(2)
1141 .forget_bias_init(Some(bias))
1142 .warmup(5)
1143 .build()
1144 .unwrap();
1145
1146 assert_eq!(config.n_heads, 2, "config must store n_heads=2");
1147 assert!(
1148 config.forget_bias_init.is_some(),
1149 "config must store forget_bias_init"
1150 );
1151
1152 let mut model = StreamingLSTM::new(config);
1153
1154 assert_eq!(
1156 model.cell.n_heads(),
1157 2,
1158 "StreamingLSTM cell must have n_heads=2 from config"
1159 );
1160
1161 for i in 0..50 {
1163 let x = [i as f64 * 0.1, (i as f64).sin()];
1164 model.train(&x, x[0] * 2.0 + 1.0);
1165 }
1166 let pred = model.predict(&[1.0, 0.5]);
1167 assert!(
1168 pred.is_finite(),
1169 "multi-head model prediction must be finite, got {pred}"
1170 );
1171 }
1172
1173 #[test]
1175 fn slstm_config_rejects_invalid_n_heads() {
1176 assert!(
1178 SLSTMConfig::builder()
1179 .d_model(8)
1180 .n_heads(3)
1181 .build()
1182 .is_err(),
1183 "n_heads=3 must be rejected when d_model=8"
1184 );
1185 assert!(
1187 SLSTMConfig::builder()
1188 .d_model(8)
1189 .n_heads(0)
1190 .build()
1191 .is_err(),
1192 "n_heads=0 must be rejected"
1193 );
1194 }
1195
1196 #[test]
1198 fn slstm_config_rejects_wrong_bias_length() {
1199 let wrong_bias = vec![1.0f64; 5]; assert!(
1201 SLSTMConfig::builder()
1202 .d_model(8)
1203 .forget_bias_init(Some(wrong_bias))
1204 .build()
1205 .is_err(),
1206 "forget_bias_init of wrong length must be rejected"
1207 );
1208 }
1209}