pub mod features;
pub mod learner_impl;
use irithyll_core::continual::NeuronRegeneration;
use irithyll_core::ssm::{
SSMLayer, SelectiveSSM, SelectiveSSMBD, SelectiveSSMv3, SelectiveSSMv3Exp, SelectiveSSMv3Mimo,
};
use crate::learner::StreamingLearner;
use crate::learners::RecursiveLeastSquares;
use crate::ssm::mamba_config::{MambaConfig, MambaVersion};
use irithyll_core::rng::standard_normal;
pub(crate) enum SSMVariant {
V1(SelectiveSSM),
V3(SelectiveSSMv3),
V3Exp(SelectiveSSMv3Exp),
V3Mimo(SelectiveSSMv3Mimo),
BD(SelectiveSSMBD),
}
impl SSMVariant {
fn forward(&mut self, input: &[f64]) -> Vec<f64> {
match self {
SSMVariant::V1(ssm) => ssm.forward(input),
SSMVariant::V3(ssm) => ssm.forward(input),
SSMVariant::V3Exp(ssm) => ssm.forward(input),
SSMVariant::V3Mimo(ssm) => ssm.forward(input),
SSMVariant::BD(ssm) => ssm.forward(input),
}
}
fn state(&self) -> &[f64] {
match self {
SSMVariant::V1(ssm) => ssm.state(),
SSMVariant::V3(ssm) => ssm.state(),
SSMVariant::V3Exp(ssm) => ssm.state(),
SSMVariant::V3Mimo(ssm) => ssm.state(),
SSMVariant::BD(ssm) => ssm.state(),
}
}
fn reset(&mut self) {
match self {
SSMVariant::V1(ssm) => ssm.reset(),
SSMVariant::V3(ssm) => ssm.reset(),
SSMVariant::V3Exp(ssm) => ssm.reset(),
SSMVariant::V3Mimo(ssm) => ssm.reset(),
SSMVariant::BD(ssm) => ssm.reset(),
}
}
}
pub struct StreamingMamba {
pub(crate) config: MambaConfig,
pub(crate) ssm: SSMVariant,
pub(crate) readout: RecursiveLeastSquares,
pub(crate) gate_weights: Vec<f64>,
pub(crate) gate_bias: Vec<f64>,
pub(crate) last_features: Vec<f64>,
pub(crate) n_samples: u64,
pub(crate) prev_prediction: f64,
pub(crate) prev_change: f64,
pub(crate) prev_prev_change: f64,
pub(crate) alignment_ewma: f64,
pub(crate) max_frob_sq_ewma: f64,
pub(crate) plasticity_guard: Option<NeuronRegeneration>,
pub(crate) prev_state_energy: Vec<f64>,
pub(crate) last_ssm_output: Vec<f64>,
pub(crate) lift_weights: Option<Vec<f64>>,
pub(crate) lift_bias: Option<Vec<f64>>,
pub(crate) n_lift: usize,
pub(crate) lift_input_dim: usize,
}
impl StreamingMamba {
pub fn new(config: MambaConfig) -> Self {
let ssm = match config.version {
MambaVersion::V1 => {
SSMVariant::V1(SelectiveSSM::new(config.d_in, config.n_state, config.seed))
}
MambaVersion::V3 => SSMVariant::V3(SelectiveSSMv3::new(
config.d_in,
config.n_state,
config.n_groups,
config.seed,
)),
MambaVersion::V3Exp { use_bcnorm } => SSMVariant::V3Exp(SelectiveSSMv3Exp::new(
config.d_in,
config.n_state,
config.n_groups,
config.seed,
use_bcnorm,
)),
MambaVersion::V3Mimo { rank, use_bcnorm } => {
SSMVariant::V3Mimo(SelectiveSSMv3Mimo::new(
config.d_in,
config.n_state,
config.n_groups,
rank,
config.seed,
use_bcnorm,
))
}
MambaVersion::BlockDiagonal { block_size } => SSMVariant::BD(SelectiveSSMBD::new(
config.d_in,
config.n_state,
block_size,
config.seed,
)),
};
let readout = RecursiveLeastSquares::with_delta(config.forgetting_factor, config.delta_rls);
let readout_dim = Self::readout_dim_for_config(&config);
let last_features = vec![0.0; readout_dim];
let (gate_weights, gate_bias) = Self::init_gate_weights(config.d_in, config.seed);
let plasticity_n_units = match config.version {
MambaVersion::V1 => config.d_in,
MambaVersion::V3 | MambaVersion::V3Exp { .. } | MambaVersion::V3Mimo { .. } => {
config.n_groups
}
MambaVersion::BlockDiagonal { block_size } => config.d_in / block_size,
};
let plasticity_guard = config.plasticity.as_ref().map(|p| {
NeuronRegeneration::new(
plasticity_n_units,
1, p.regen_fraction,
p.regen_interval,
p.utility_alpha,
config.seed.wrapping_add(0x_DEAD_CAFE),
)
});
let prev_state_energy = vec![0.0; plasticity_n_units];
let last_ssm_output = vec![0.0; config.d_in];
let n_lift = Self::n_lift_for_config(&config);
let lift_input_dim = config.d_in;
let (lift_weights, lift_bias) = if n_lift > 0 {
Self::init_lift_weights(lift_input_dim, n_lift, config.seed)
} else {
(None, None)
};
Self {
config,
ssm,
readout,
gate_weights,
gate_bias,
last_features,
n_samples: 0,
prev_prediction: 0.0,
prev_change: 0.0,
prev_prev_change: 0.0,
alignment_ewma: 0.0,
max_frob_sq_ewma: 0.0,
plasticity_guard,
prev_state_energy,
last_ssm_output,
lift_weights,
lift_bias,
n_lift,
lift_input_dim,
}
}
pub(crate) fn init_lift_weights(
lift_input_dim: usize,
n_lift: usize,
seed: u64,
) -> (Option<Vec<f64>>, Option<Vec<f64>>) {
let mut rng_state = seed.wrapping_add(0xF1F7_F1F7_F1F7_F1F7);
if rng_state == 0 {
rng_state = 1;
}
let scale = 1.0 / (lift_input_dim as f64).sqrt();
let weights: Vec<f64> = (0..n_lift * lift_input_dim)
.map(|_| standard_normal(&mut rng_state) * scale)
.collect();
let biases: Vec<f64> = (0..n_lift)
.map(|_| {
standard_normal(&mut rng_state).tanh()
})
.collect();
(Some(weights), Some(biases))
}
pub(crate) fn init_gate_weights(d_in: usize, seed: u64) -> (Vec<f64>, Vec<f64>) {
let mut rng_state = seed.wrapping_add(0x9E37_79B9_7F4A_7C15);
if rng_state == 0 {
rng_state = 1;
}
let scale = 1.0 / (d_in as f64).sqrt();
let gate_weights: Vec<f64> = (0..d_in * d_in)
.map(|_| standard_normal(&mut rng_state) * scale)
.collect();
let gate_bias = vec![0.0; d_in];
(gate_weights, gate_bias)
}
pub(crate) fn base_readout_dim_for_config(config: &MambaConfig) -> usize {
match config.version {
MambaVersion::V1 => config.d_in * 2,
MambaVersion::V3 | MambaVersion::V3Mimo { .. } => config.d_in + config.n_groups,
MambaVersion::V3Exp { .. } => {
config.d_in + config.n_groups + 4 * config.n_groups * config.n_state
}
MambaVersion::BlockDiagonal { block_size } => config.d_in + config.d_in / block_size,
}
}
pub(crate) fn n_lift_for_config(config: &MambaConfig) -> usize {
match config.version {
MambaVersion::V3Exp { .. } => (32 * config.d_in).max(64),
_ => 0,
}
}
pub(crate) fn readout_dim_for_config(config: &MambaConfig) -> usize {
Self::base_readout_dim_for_config(config) + Self::n_lift_for_config(config)
}
pub(crate) fn build_readout_features(
&self,
gated_output: &[f64],
state: &[f64],
raw_input: &[f64],
) -> Vec<f64> {
features::build_readout_features(self, gated_output, state, raw_input)
}
pub fn config(&self) -> &MambaConfig {
&self.config
}
pub fn ssm_state(&self) -> &[f64] {
self.ssm.state()
}
#[inline]
pub fn prediction_uncertainty(&self) -> f64 {
self.readout.noise_variance().sqrt()
}
pub fn last_features(&self) -> &[f64] {
&self.last_features
}
}
impl StreamingLearner for StreamingMamba {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
learner_impl::train_one(self, features, target, weight);
}
fn predict(&self, features: &[f64]) -> f64 {
learner_impl::predict(self, features)
}
fn n_samples_seen(&self) -> u64 {
self.n_samples
}
fn reset(&mut self) {
learner_impl::reset(self);
}
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
<Self as crate::learner::Tunable>::diagnostics_array(self)
}
#[allow(deprecated)]
fn readout_weights(&self) -> Option<&[f64]> {
let w = <Self as crate::learner::HasReadout>::readout_weights(self);
if w.is_empty() {
None
} else {
Some(w)
}
}
}
impl crate::learner::Tunable for StreamingMamba {
fn diagnostics_array(&self) -> [f64; 5] {
use crate::automl::DiagnosticSource;
match self.config_diagnostics() {
Some(d) => [
d.residual_alignment,
d.regularization_sensitivity,
d.depth_sufficiency,
d.effective_dof,
d.uncertainty,
],
None => [0.0; 5],
}
}
fn adjust_config(&mut self, lr_multiplier: f64, _lambda_delta: f64) {
<crate::learners::RecursiveLeastSquares as crate::learner::Tunable>::adjust_config(
&mut self.readout,
lr_multiplier,
0.0,
);
}
}
impl crate::learner::HasReadout for StreamingMamba {
fn readout_weights(&self) -> &[f64] {
self.readout.weights()
}
}
impl crate::automl::DiagnosticSource for StreamingMamba {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
let rls_saturation = {
let p = self.readout.p_matrix();
let d = self.readout.weights().len();
if d > 0 && self.readout.delta() > 0.0 {
let trace: f64 = (0..d).map(|i| p[i * d + i]).sum();
(1.0 - trace / (self.readout.delta() * d as f64)).clamp(0.0, 1.0)
} else {
0.0
}
};
let state_frob_ratio = {
let state = self.ssm.state();
let frob_sq: f64 = state.iter().map(|s| s * s).sum();
if self.max_frob_sq_ewma > 1e-15 {
(frob_sq / self.max_frob_sq_ewma).clamp(0.0, 1.0)
} else {
0.0
}
};
let depth_sufficiency = 0.5 * rls_saturation + 0.5 * state_frob_ratio;
let w = self.readout.weights();
let effective_dof = if !w.is_empty() {
let sq_sum: f64 = w.iter().map(|wi| wi * wi).sum();
sq_sum.sqrt() / (w.len() as f64).sqrt()
} else {
0.0
};
Some(crate::automl::ConfigDiagnostics {
residual_alignment: self.alignment_ewma,
regularization_sensitivity: 1.0 - self.config.forgetting_factor,
depth_sufficiency,
effective_dof,
uncertainty: self.prediction_uncertainty(),
})
}
}
#[cfg(test)]
mod tests;