Skip to main content

irithyll_core/ssm/
selective.rs

1//! Mamba-style selective state space model with input-dependent projections.
2//!
3//! [`SelectiveSSM`] implements the core selective mechanism from Mamba
4//! (Gu & Dao, 2023), where the discretization step size Delta, input matrix B,
5//! and output matrix C are all functions of the current input. This enables
6//! content-aware filtering: the model can learn to selectively remember or
7//! forget information based on what it sees.
8//!
9//! # Architecture
10//!
11//! For each input timestep `x_t` (a d_in-dimensional vector):
12//!
13//! ```text
14//! Delta_t = softplus(W_delta * x_t + b_delta)    // scalar step size
15//! B_t     = W_B * x_t                             // N-dim input projection
16//! C_t     = W_C * x_t                             // N-dim output projection
17//! A       = -exp(log_A)                           // fixed, always negative diagonal
18//!
19//! // Pre-compute discretized coefficients (independent of channel d):
20//! A_bar[n] = exp(Delta_t * A_n)
21//! B_bar[n] = (A_bar[n] - 1) / A_n * B_t[n]
22//!
23//! // State update (state-dim-major layout: h[n, d]):
24//! For each state dim n in 0..N:
25//!   For each input channel d in 0..d_in:
26//!     h[n,d] = A_bar[n] * h[n,d] + B_bar[n] * x_t[d]
27//!
28//! // Output accumulation:
29//! For each state dim n in 0..N:
30//!   For each input channel d in 0..d_in:
31//!     output[d] += C_t[n] * h[n,d]
32//! output[d] += D[d] * x_t[d]
33//! ```
34//!
35//! The hidden state uses a transposed (state-dim-major) layout where each
36//! state dimension's channels are contiguous in memory. This enables the
37//! compiler to auto-vectorize the inner d-loop (scalar `a`/`b` broadcast
38//! over contiguous `h` and `input` slices) and maximizes cache line
39//! utilization. The discretized coefficients `A_bar` and `B_bar` are hoisted
40//! out of the channel loop since they depend only on the state index `n`.
41//!
42//! Each input channel maintains its own N-dimensional state vector, allowing
43//! the model to track per-channel temporal patterns independently.
44
45use alloc::vec;
46use alloc::vec::Vec;
47
48use crate::math;
49use crate::rng::standard_normal;
50use crate::ssm::init::s4d_inv_real;
51use crate::ssm::projection::{dot, mat_vec, softplus, Xorshift64};
52use crate::ssm::SSMLayer;
53
54/// Mamba-style selective state space model.
55///
56/// The selective mechanism computes input-dependent B, C, and Delta at each
57/// timestep, enabling the model to dynamically control what information is
58/// stored in and retrieved from the hidden state.
59///
60/// # Dimensions
61///
62/// - `d_in` -- input/output dimension (number of channels)
63/// - `n_state` -- hidden state dimension per channel (N)
64/// - Total hidden state size: `d_in * n_state`
65///
66/// # Weight Shapes
67///
68/// | Weight | Shape | Purpose |
69/// |--------|-------|---------|
70/// | `w_delta` | d_in | Projects input to scalar step size |
71/// | `w_b` | N x d_in | Projects input to state-input coupling |
72/// | `w_c` | N x d_in | Projects input to state-output coupling |
73/// | `d_skip` | d_in | Skip connection weights |
74/// | `log_a` | N | Fixed state transition (always negative after exp) |
75///
76/// # Example
77///
78/// ```
79/// use irithyll_core::ssm::selective::SelectiveSSM;
80/// use irithyll_core::ssm::SSMLayer;
81///
82/// let mut ssm = SelectiveSSM::new(4, 8, 42);
83/// let output = ssm.forward(&[1.0, 2.0, 3.0, 4.0]);
84/// assert_eq!(output.len(), 4);
85/// ```
86pub struct SelectiveSSM {
87    /// Log-space A parameters (N). Actual A_n = -exp(log_a[n]).
88    log_a: Vec<f64>,
89    /// Delta projection weights (d_in). Maps input to scalar step size.
90    w_delta: Vec<f64>,
91    /// Delta projection bias.
92    b_delta: f64,
93    /// B projection weights (N x d_in, row-major). Maps input to B_t.
94    w_b: Vec<f64>,
95    /// C projection weights (N x d_in, row-major). Maps input to C_t.
96    w_c: Vec<f64>,
97    /// Skip connection weights (d_in).
98    d_skip: Vec<f64>,
99    /// Hidden state (d_in * N, state-dim-major: [state_0_channels, state_1_channels, ...]).
100    h: Vec<f64>,
101    /// Number of state dimensions per channel.
102    n_state: usize,
103    /// Input/output dimension.
104    d_in: usize,
105}
106
107impl SelectiveSSM {
108    /// Create a new selective SSM with random weight initialization.
109    ///
110    /// Weights are initialized from a small normal distribution (scale 0.1)
111    /// using the provided seed for reproducibility. A is initialized via the
112    /// S4D-Inv (HiPPO-inspired) strategy: `A_n = -(0.5 + n/N)`, which gives
113    /// a bounded spectrum of decay rates that remain meaningful at all state
114    /// sizes. Skip connections (D) are initialized to 1.0 to enable input
115    /// passthrough by default.
116    ///
117    /// # Arguments
118    ///
119    /// * `d_in` -- input/output dimension (number of channels)
120    /// * `n_state` -- hidden state dimension per channel (N)
121    /// * `seed` -- random seed for weight initialization
122    ///
123    /// # Example
124    ///
125    /// ```
126    /// use irithyll_core::ssm::selective::SelectiveSSM;
127    ///
128    /// let ssm = SelectiveSSM::new(4, 16, 42);
129    /// ```
130    pub fn new(d_in: usize, n_state: usize, seed: u64) -> Self {
131        let log_a = s4d_inv_real(n_state);
132        let mut rng = Xorshift64(seed);
133        let scale = 0.1;
134
135        // Initialize projection weights from small normal distribution
136        let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
137        let b_delta = 0.0;
138        let w_b: Vec<f64> = (0..n_state * d_in)
139            .map(|_| rng.next_normal() * scale)
140            .collect();
141        let w_c: Vec<f64> = (0..n_state * d_in)
142            .map(|_| rng.next_normal() * scale)
143            .collect();
144        let d_skip = vec![1.0; d_in];
145        let h = vec![0.0; d_in * n_state];
146
147        Self {
148            log_a,
149            w_delta,
150            b_delta,
151            w_b,
152            w_c,
153            d_skip,
154            h,
155            n_state,
156            d_in,
157        }
158    }
159
160    /// Get the input/output dimension.
161    #[inline]
162    pub fn d_in(&self) -> usize {
163        self.d_in
164    }
165
166    /// Get the number of state dimensions per channel.
167    #[inline]
168    pub fn n_state(&self) -> usize {
169        self.n_state
170    }
171
172    /// Surgically reinitialize a single channel, preserving all other channels.
173    ///
174    /// Resets channel `d`'s hidden state to zero across all state dimensions,
175    /// reinitializes its weight column in `w_b` and `w_c`, its `w_delta` entry,
176    /// and its skip connection `d_skip` to the default (1.0). All other channels
177    /// are left untouched.
178    ///
179    /// # Arguments
180    ///
181    /// * `d` — channel index to reinitialize (must be < `d_in`)
182    /// * `rng` — mutable RNG state for generating fresh weights
183    ///
184    /// # Panics
185    ///
186    /// Panics if `d >= d_in`.
187    pub fn reinitialize_channel(&mut self, d: usize, rng: &mut u64) {
188        assert!(
189            d < self.d_in,
190            "channel index {} out of range (d_in={})",
191            d,
192            self.d_in
193        );
194
195        let scale = 0.1;
196
197        // Zero state: h[n * d_in + d] for each state dim n (state-dim-major layout)
198        for n in 0..self.n_state {
199            self.h[n * self.d_in + d] = 0.0;
200        }
201
202        // Reinit w_delta[d]
203        self.w_delta[d] = standard_normal(rng) * scale;
204
205        // Reinit column d of w_b (N x d_in row-major): w_b[n * d_in + d]
206        for n in 0..self.n_state {
207            self.w_b[n * self.d_in + d] = standard_normal(rng) * scale;
208        }
209
210        // Reinit column d of w_c (N x d_in row-major): w_c[n * d_in + d]
211        for n in 0..self.n_state {
212            self.w_c[n * self.d_in + d] = standard_normal(rng) * scale;
213        }
214
215        // Reset skip connection to default passthrough
216        self.d_skip[d] = 1.0;
217    }
218
219    /// Compute the selective SSM forward pass for one timestep.
220    ///
221    /// This is the core Mamba recurrence: compute input-dependent Delta, B, C,
222    /// then update each channel's state and produce the output.
223    fn selective_forward(&mut self, input: &[f64]) -> Vec<f64> {
224        let d_in = self.d_in;
225        let n_state = self.n_state;
226
227        // 1. Compute delta = softplus(dot(w_delta, input) + b_delta)
228        let delta_raw = dot(&self.w_delta, input) + self.b_delta;
229        let delta = softplus(delta_raw);
230
231        // 2. Compute B_t = W_B * input (shape: N)
232        let mut b_t = vec![0.0; n_state];
233        mat_vec(&self.w_b, input, n_state, d_in, &mut b_t);
234
235        // 3. Compute C_t = W_C * input (shape: N)
236        let mut c_t = vec![0.0; n_state];
237        mat_vec(&self.w_c, input, n_state, d_in, &mut c_t);
238
239        // 4. Pre-compute discretized coefficients (independent of channel d)
240        //    This hoists 2*n_state exp() calls out of the d_in loop, saving
241        //    (d_in - 1) * 2 * n_state redundant transcendental evaluations.
242        let mut a_bar_vec = vec![0.0; n_state];
243        let mut b_bar_vec = vec![0.0; n_state];
244        for n in 0..n_state {
245            let a_n = -math::exp(self.log_a[n]); // negative real diagonal
246            let ab = math::exp(delta * a_n); // ZOH discretization
247            a_bar_vec[n] = ab;
248            b_bar_vec[n] = if math::abs(a_n) < 1e-12 {
249                delta * b_t[n]
250            } else {
251                (ab - 1.0) / a_n * b_t[n]
252            };
253        }
254
255        // 5. State update: for each state dim, process all channels contiguously.
256        //    State layout is state-dim-major: h[n * d_in + d], so the inner
257        //    d-loop touches contiguous memory with scalar a/b broadcasts —
258        //    ideal for auto-vectorization and cache utilization.
259        for n in 0..n_state {
260            let h_offset = n * d_in;
261            let a = a_bar_vec[n];
262            let b = b_bar_vec[n];
263            for (d, x_d) in input.iter().enumerate().take(d_in) {
264                self.h[h_offset + d] = a * self.h[h_offset + d] + b * x_d;
265            }
266        }
267
268        // 6. Output accumulation: y[d] = sum_n C_t[n] * h[n, d]
269        let mut output = vec![0.0; d_in];
270        for (n, &c_n) in c_t.iter().enumerate().take(n_state) {
271            let h_offset = n * d_in;
272            for (d, out_d) in output.iter_mut().enumerate().take(d_in) {
273                *out_d += c_n * self.h[h_offset + d];
274            }
275        }
276
277        // 7. Add skip connection
278        for (out_d, (&skip, &x_d)) in output.iter_mut().zip(self.d_skip.iter().zip(input.iter())) {
279            *out_d += skip * x_d;
280        }
281
282        output
283    }
284}
285
286impl SSMLayer for SelectiveSSM {
287    fn forward(&mut self, input: &[f64]) -> Vec<f64> {
288        debug_assert_eq!(
289            input.len(),
290            self.d_in,
291            "input length {} must match d_in {}",
292            input.len(),
293            self.d_in
294        );
295        self.selective_forward(input)
296    }
297
298    fn state(&self) -> &[f64] {
299        &self.h
300    }
301
302    fn output_dim(&self) -> usize {
303        self.d_in
304    }
305
306    fn reset(&mut self) {
307        for h in self.h.iter_mut() {
308            *h = 0.0;
309        }
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn new_creates_correct_dimensions() {
319        let ssm = SelectiveSSM::new(4, 8, 42);
320        assert_eq!(ssm.d_in(), 4);
321        assert_eq!(ssm.n_state(), 8);
322        assert_eq!(ssm.state().len(), 4 * 8);
323        assert_eq!(ssm.output_dim(), 4);
324    }
325
326    #[test]
327    fn initial_state_is_zero() {
328        let ssm = SelectiveSSM::new(3, 16, 42);
329        for &h in ssm.state() {
330            assert!(math::abs(h) < 1e-15, "initial state should be zero");
331        }
332    }
333
334    #[test]
335    fn forward_produces_correct_output_dim() {
336        let mut ssm = SelectiveSSM::new(5, 8, 42);
337        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
338        let output = ssm.forward(&input);
339        assert_eq!(output.len(), 5, "output dim should match d_in");
340    }
341
342    #[test]
343    fn forward_produces_finite_output() {
344        let mut ssm = SelectiveSSM::new(3, 8, 42);
345        let input = vec![1.0, -1.0, 0.5];
346        let output = ssm.forward(&input);
347        for (i, &y) in output.iter().enumerate() {
348            assert!(y.is_finite(), "output[{}] should be finite, got {}", i, y);
349        }
350    }
351
352    #[test]
353    fn forward_updates_state() {
354        let mut ssm = SelectiveSSM::new(3, 8, 42);
355        let input = vec![1.0, 2.0, 3.0];
356        let _ = ssm.forward(&input);
357        let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
358        assert!(
359            state_norm > 0.0,
360            "state should be non-zero after processing non-zero input"
361        );
362    }
363
364    #[test]
365    fn reset_clears_state() {
366        let mut ssm = SelectiveSSM::new(3, 8, 42);
367        let _ = ssm.forward(&[1.0, 2.0, 3.0]);
368        ssm.reset();
369        for &h in ssm.state() {
370            assert!(math::abs(h) < 1e-15, "state should be zero after reset");
371        }
372    }
373
374    #[test]
375    fn state_decays_without_input() {
376        let mut ssm = SelectiveSSM::new(2, 4, 42);
377        // Inject state
378        let _ = ssm.forward(&[10.0, 10.0]);
379        let energy_after: f64 = ssm.state().iter().map(|h| h * h).sum();
380
381        // Feed zeros for many steps
382        for _ in 0..200 {
383            let _ = ssm.forward(&[0.0, 0.0]);
384        }
385        let energy_decayed: f64 = ssm.state().iter().map(|h| h * h).sum();
386        assert!(
387            energy_decayed < energy_after * 0.01,
388            "state should decay with zero input: initial={}, after={}",
389            energy_after,
390            energy_decayed
391        );
392    }
393
394    #[test]
395    fn deterministic_with_same_seed() {
396        let mut ssm1 = SelectiveSSM::new(3, 8, 42);
397        let mut ssm2 = SelectiveSSM::new(3, 8, 42);
398        let input = vec![1.0, 2.0, 3.0];
399        let out1 = ssm1.forward(&input);
400        let out2 = ssm2.forward(&input);
401        for (i, (&a, &b)) in out1.iter().zip(out2.iter()).enumerate() {
402            assert!(
403                math::abs(a - b) < 1e-15,
404                "output[{}] should be identical for same seed: {} vs {}",
405                i,
406                a,
407                b
408            );
409        }
410    }
411
412    #[test]
413    fn different_seeds_produce_different_outputs() {
414        let mut ssm1 = SelectiveSSM::new(3, 8, 42);
415        let mut ssm2 = SelectiveSSM::new(3, 8, 99);
416        let input = vec![1.0, 2.0, 3.0];
417        let out1 = ssm1.forward(&input);
418        let out2 = ssm2.forward(&input);
419        let diff: f64 = out1
420            .iter()
421            .zip(out2.iter())
422            .map(|(a, b)| (a - b) * (a - b))
423            .sum();
424        assert!(
425            diff > 1e-20,
426            "different seeds should generally produce different outputs"
427        );
428    }
429
430    #[test]
431    fn single_channel_works() {
432        let mut ssm = SelectiveSSM::new(1, 4, 42);
433        let output = ssm.forward(&[3.0]);
434        assert_eq!(output.len(), 1);
435        assert!(output[0].is_finite());
436    }
437
438    #[test]
439    fn single_state_dim_works() {
440        let mut ssm = SelectiveSSM::new(3, 1, 42);
441        let output = ssm.forward(&[1.0, 2.0, 3.0]);
442        assert_eq!(output.len(), 3);
443        for &y in &output {
444            assert!(y.is_finite());
445        }
446    }
447
448    #[test]
449    fn sequential_outputs_differ() {
450        let mut ssm = SelectiveSSM::new(2, 4, 42);
451        let out1 = ssm.forward(&[1.0, 0.0]);
452        let out2 = ssm.forward(&[1.0, 0.0]);
453        // Second call has non-zero state from first call, so outputs should differ
454        let diff: f64 = out1
455            .iter()
456            .zip(out2.iter())
457            .map(|(a, b)| (a - b) * (a - b))
458            .sum();
459        assert!(
460            diff > 1e-20,
461            "sequential calls with same input should differ due to state: out1={:?}, out2={:?}",
462            out1,
463            out2
464        );
465    }
466
467    #[test]
468    fn large_input_no_overflow() {
469        let mut ssm = SelectiveSSM::new(2, 4, 42);
470        let input = vec![1000.0, -1000.0];
471        let output = ssm.forward(&input);
472        for (i, &y) in output.iter().enumerate() {
473            assert!(
474                y.is_finite(),
475                "output[{}] should be finite for large inputs, got {}",
476                i,
477                y
478            );
479        }
480    }
481
482    #[test]
483    fn zero_input_zero_state_gives_zero_output() {
484        let mut ssm = SelectiveSSM::new(3, 8, 42);
485        let output = ssm.forward(&[0.0, 0.0, 0.0]);
486        for (i, &y) in output.iter().enumerate() {
487            assert!(
488                math::abs(y) < 1e-15,
489                "zero input with zero state should give zero output[{}], got {}",
490                i,
491                y
492            );
493        }
494    }
495
496    #[test]
497    fn reinitialize_channel_preserves_others() {
498        let mut ssm = SelectiveSSM::new(3, 8, 42);
499
500        // Forward 10 steps to build up state
501        for step in 0..10 {
502            let x = vec![
503                (step as f64) * 0.3,
504                (step as f64) * -0.2,
505                (step as f64) * 0.1,
506            ];
507            let _ = ssm.forward(&x);
508        }
509
510        // Snapshot state and weights for channels 0 and 2 before reinit
511        let state_before: Vec<f64> = ssm.state().to_vec();
512        let w_delta_0 = ssm.w_delta[0];
513        let w_delta_2 = ssm.w_delta[2];
514
515        let wb_col0: Vec<f64> = (0..ssm.n_state).map(|n| ssm.w_b[n * ssm.d_in]).collect();
516        let wb_col2: Vec<f64> = (0..ssm.n_state)
517            .map(|n| ssm.w_b[n * ssm.d_in + 2])
518            .collect();
519        let wc_col0: Vec<f64> = (0..ssm.n_state).map(|n| ssm.w_c[n * ssm.d_in]).collect();
520        let wc_col2: Vec<f64> = (0..ssm.n_state)
521            .map(|n| ssm.w_c[n * ssm.d_in + 2])
522            .collect();
523
524        // Reinitialize channel 1
525        let mut rng = 0xBEEF_u64;
526        ssm.reinitialize_channel(1, &mut rng);
527
528        // Channel 0 state unchanged
529        for n in 0..ssm.n_state {
530            let idx = n * ssm.d_in;
531            assert!(
532                math::abs(ssm.h[idx] - state_before[idx]) < 1e-15,
533                "channel 0 state[{}] should be preserved after reinit of channel 1",
534                n
535            );
536        }
537
538        // Channel 2 state unchanged
539        for n in 0..ssm.n_state {
540            let idx = n * ssm.d_in + 2;
541            assert!(
542                math::abs(ssm.h[idx] - state_before[idx]) < 1e-15,
543                "channel 2 state[{}] should be preserved after reinit of channel 1",
544                n
545            );
546        }
547
548        // Channel 1 state zeroed
549        for n in 0..ssm.n_state {
550            let idx = n * ssm.d_in + 1;
551            assert!(
552                math::abs(ssm.h[idx]) < 1e-15,
553                "channel 1 state[{}] should be zeroed after reinit, got {}",
554                n,
555                ssm.h[idx]
556            );
557        }
558
559        // Channel 0 and 2 weights unchanged
560        assert!(
561            math::abs(ssm.w_delta[0] - w_delta_0) < 1e-15,
562            "w_delta[0] should be preserved"
563        );
564        assert!(
565            math::abs(ssm.w_delta[2] - w_delta_2) < 1e-15,
566            "w_delta[2] should be preserved"
567        );
568        for n in 0..ssm.n_state {
569            assert!(
570                math::abs(ssm.w_b[n * ssm.d_in] - wb_col0[n]) < 1e-15,
571                "w_b col 0 row {} should be preserved",
572                n
573            );
574            assert!(
575                math::abs(ssm.w_b[n * ssm.d_in + 2] - wb_col2[n]) < 1e-15,
576                "w_b col 2 row {} should be preserved",
577                n
578            );
579            assert!(
580                math::abs(ssm.w_c[n * ssm.d_in] - wc_col0[n]) < 1e-15,
581                "w_c col 0 row {} should be preserved",
582                n
583            );
584            assert!(
585                math::abs(ssm.w_c[n * ssm.d_in + 2] - wc_col2[n]) < 1e-15,
586                "w_c col 2 row {} should be preserved",
587                n
588            );
589        }
590
591        // Channel 1 weights should have changed (reinitialised to non-zero)
592        let mut any_wb_diff = false;
593        for n in 0..ssm.n_state {
594            if math::abs(ssm.w_b[n * ssm.d_in + 1]) > 1e-15 {
595                any_wb_diff = true;
596            }
597        }
598        assert!(
599            any_wb_diff,
600            "reinitialised channel 1 w_b should have non-zero weights"
601        );
602
603        // d_skip[1] should be reset to 1.0
604        assert!(
605            math::abs(ssm.d_skip[1] - 1.0) < 1e-15,
606            "d_skip[1] should be reset to 1.0 after reinit, got {}",
607            ssm.d_skip[1]
608        );
609    }
610}