Skip to main content

irithyll_core/ssm/
selective.rs

1//! Mamba-style selective state space model with input-dependent projections.
2//!
3//! [`SelectiveSSM`] implements the core selective mechanism from Mamba
4//! (Gu & Dao, 2023), where the discretization step size Delta, input matrix B,
5//! and output matrix C are all functions of the current input. This enables
6//! content-aware filtering: the model can learn to selectively remember or
7//! forget information based on what it sees.
8//!
9//! # Architecture
10//!
11//! For each input timestep `x_t` (a d_in-dimensional vector):
12//!
13//! ```text
14//! Delta_t = softplus(W_delta * x_t + b_delta)    // scalar step size
15//! B_t     = W_B * x_t                             // N-dim input projection
16//! C_t     = W_C * x_t                             // N-dim output projection
17//! A       = -exp(log_A)                           // fixed, always negative diagonal
18//!
19//! For each input channel d in 0..d_in:
20//!   For each state dim n in 0..N:
21//!     A_bar = exp(Delta_t * A_n)
22//!     B_bar = (A_bar - 1) / A_n * B_t[n]
23//!     h[d,n] = A_bar * h[d,n] + B_bar * x_t[d]
24//!   output[d] = C_t^T * h[d,:] + D[d] * x_t[d]
25//! ```
26//!
27//! Each input channel maintains its own N-dimensional state vector, allowing
28//! the model to track per-channel temporal patterns independently.
29
30use alloc::vec;
31use alloc::vec::Vec;
32
33use crate::math;
34use crate::ssm::init::s4d_inv_real;
35use crate::ssm::projection::{dot, mat_vec, softplus, Xorshift64};
36use crate::ssm::SSMLayer;
37
38/// Mamba-style selective state space model.
39///
40/// The selective mechanism computes input-dependent B, C, and Delta at each
41/// timestep, enabling the model to dynamically control what information is
42/// stored in and retrieved from the hidden state.
43///
44/// # Dimensions
45///
46/// - `d_in` -- input/output dimension (number of channels)
47/// - `n_state` -- hidden state dimension per channel (N)
48/// - Total hidden state size: `d_in * n_state`
49///
50/// # Weight Shapes
51///
52/// | Weight | Shape | Purpose |
53/// |--------|-------|---------|
54/// | `w_delta` | d_in | Projects input to scalar step size |
55/// | `w_b` | N x d_in | Projects input to state-input coupling |
56/// | `w_c` | N x d_in | Projects input to state-output coupling |
57/// | `d_skip` | d_in | Skip connection weights |
58/// | `log_a` | N | Fixed state transition (always negative after exp) |
59///
60/// # Example
61///
62/// ```
63/// use irithyll_core::ssm::selective::SelectiveSSM;
64/// use irithyll_core::ssm::SSMLayer;
65///
66/// let mut ssm = SelectiveSSM::new(4, 8, 42);
67/// let output = ssm.forward(&[1.0, 2.0, 3.0, 4.0]);
68/// assert_eq!(output.len(), 4);
69/// ```
70pub struct SelectiveSSM {
71    /// Log-space A parameters (N). Actual A_n = -exp(log_a[n]).
72    log_a: Vec<f64>,
73    /// Delta projection weights (d_in). Maps input to scalar step size.
74    w_delta: Vec<f64>,
75    /// Delta projection bias.
76    b_delta: f64,
77    /// B projection weights (N x d_in, row-major). Maps input to B_t.
78    w_b: Vec<f64>,
79    /// C projection weights (N x d_in, row-major). Maps input to C_t.
80    w_c: Vec<f64>,
81    /// Skip connection weights (d_in).
82    d_skip: Vec<f64>,
83    /// Hidden state (d_in * N, laid out as [channel_0_state, channel_1_state, ...]).
84    h: Vec<f64>,
85    /// Number of state dimensions per channel.
86    n_state: usize,
87    /// Input/output dimension.
88    d_in: usize,
89}
90
91impl SelectiveSSM {
92    /// Create a new selective SSM with random weight initialization.
93    ///
94    /// Weights are initialized from a small normal distribution (scale 0.1)
95    /// using the provided seed for reproducibility. A is initialized via the
96    /// S4D-Inv (HiPPO-inspired) strategy: `A_n = -(0.5 + n/N)`, which gives
97    /// a bounded spectrum of decay rates that remain meaningful at all state
98    /// sizes. Skip connections (D) are initialized to 1.0 to enable input
99    /// passthrough by default.
100    ///
101    /// # Arguments
102    ///
103    /// * `d_in` -- input/output dimension (number of channels)
104    /// * `n_state` -- hidden state dimension per channel (N)
105    /// * `seed` -- random seed for weight initialization
106    ///
107    /// # Example
108    ///
109    /// ```
110    /// use irithyll_core::ssm::selective::SelectiveSSM;
111    ///
112    /// let ssm = SelectiveSSM::new(4, 16, 42);
113    /// ```
114    pub fn new(d_in: usize, n_state: usize, seed: u64) -> Self {
115        let log_a = s4d_inv_real(n_state);
116        let mut rng = Xorshift64(seed);
117        let scale = 0.1;
118
119        // Initialize projection weights from small normal distribution
120        let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
121        let b_delta = 0.0;
122        let w_b: Vec<f64> = (0..n_state * d_in)
123            .map(|_| rng.next_normal() * scale)
124            .collect();
125        let w_c: Vec<f64> = (0..n_state * d_in)
126            .map(|_| rng.next_normal() * scale)
127            .collect();
128        let d_skip = vec![1.0; d_in];
129        let h = vec![0.0; d_in * n_state];
130
131        Self {
132            log_a,
133            w_delta,
134            b_delta,
135            w_b,
136            w_c,
137            d_skip,
138            h,
139            n_state,
140            d_in,
141        }
142    }
143
144    /// Get the input/output dimension.
145    #[inline]
146    pub fn d_in(&self) -> usize {
147        self.d_in
148    }
149
150    /// Get the number of state dimensions per channel.
151    #[inline]
152    pub fn n_state(&self) -> usize {
153        self.n_state
154    }
155
156    /// Compute the selective SSM forward pass for one timestep.
157    ///
158    /// This is the core Mamba recurrence: compute input-dependent Delta, B, C,
159    /// then update each channel's state and produce the output.
160    fn selective_forward(&mut self, input: &[f64]) -> Vec<f64> {
161        let d_in = self.d_in;
162        let n_state = self.n_state;
163
164        // 1. Compute delta = softplus(dot(w_delta, input) + b_delta)
165        let delta_raw = dot(&self.w_delta, input) + self.b_delta;
166        let delta = softplus(delta_raw);
167
168        // 2. Compute B_t = W_B * input (shape: N)
169        let mut b_t = vec![0.0; n_state];
170        mat_vec(&self.w_b, input, n_state, d_in, &mut b_t);
171
172        // 3. Compute C_t = W_C * input (shape: N)
173        let mut c_t = vec![0.0; n_state];
174        mat_vec(&self.w_c, input, n_state, d_in, &mut c_t);
175
176        // 4. For each input channel, update state and compute output
177        let mut output = vec![0.0; d_in];
178
179        for d in 0..d_in {
180            let h_offset = d * n_state;
181            let mut y = 0.0;
182
183            for n in 0..n_state {
184                let a_n = -math::exp(self.log_a[n]); // negative real diagonal
185
186                // ZOH discretization: A_bar = exp(delta * A_n)
187                let a_bar = math::exp(delta * a_n);
188
189                // B_bar = (A_bar - 1) / A_n * B_t[n]
190                // Handle A_n ~ 0 (shouldn't happen with s4d_inv, but be safe)
191                let b_bar = if math::abs(a_n) < 1e-12 {
192                    delta * b_t[n]
193                } else {
194                    (a_bar - 1.0) / a_n * b_t[n]
195                };
196
197                // State update: h[d,n] = A_bar * h[d,n] + B_bar * input[d]
198                self.h[h_offset + n] = a_bar * self.h[h_offset + n] + b_bar * input[d];
199
200                // Accumulate output: y += C_t[n] * h[d,n]
201                y += c_t[n] * self.h[h_offset + n];
202            }
203
204            // Add skip connection
205            output[d] = y + self.d_skip[d] * input[d];
206        }
207
208        output
209    }
210}
211
212impl SSMLayer for SelectiveSSM {
213    fn forward(&mut self, input: &[f64]) -> Vec<f64> {
214        debug_assert_eq!(
215            input.len(),
216            self.d_in,
217            "input length {} must match d_in {}",
218            input.len(),
219            self.d_in
220        );
221        self.selective_forward(input)
222    }
223
224    fn state(&self) -> &[f64] {
225        &self.h
226    }
227
228    fn output_dim(&self) -> usize {
229        self.d_in
230    }
231
232    fn reset(&mut self) {
233        for h in self.h.iter_mut() {
234            *h = 0.0;
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn new_creates_correct_dimensions() {
245        let ssm = SelectiveSSM::new(4, 8, 42);
246        assert_eq!(ssm.d_in(), 4);
247        assert_eq!(ssm.n_state(), 8);
248        assert_eq!(ssm.state().len(), 4 * 8);
249        assert_eq!(ssm.output_dim(), 4);
250    }
251
252    #[test]
253    fn initial_state_is_zero() {
254        let ssm = SelectiveSSM::new(3, 16, 42);
255        for &h in ssm.state() {
256            assert!(math::abs(h) < 1e-15, "initial state should be zero");
257        }
258    }
259
260    #[test]
261    fn forward_produces_correct_output_dim() {
262        let mut ssm = SelectiveSSM::new(5, 8, 42);
263        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
264        let output = ssm.forward(&input);
265        assert_eq!(output.len(), 5, "output dim should match d_in");
266    }
267
268    #[test]
269    fn forward_produces_finite_output() {
270        let mut ssm = SelectiveSSM::new(3, 8, 42);
271        let input = vec![1.0, -1.0, 0.5];
272        let output = ssm.forward(&input);
273        for (i, &y) in output.iter().enumerate() {
274            assert!(y.is_finite(), "output[{}] should be finite, got {}", i, y);
275        }
276    }
277
278    #[test]
279    fn forward_updates_state() {
280        let mut ssm = SelectiveSSM::new(3, 8, 42);
281        let input = vec![1.0, 2.0, 3.0];
282        let _ = ssm.forward(&input);
283        let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
284        assert!(
285            state_norm > 0.0,
286            "state should be non-zero after processing non-zero input"
287        );
288    }
289
290    #[test]
291    fn reset_clears_state() {
292        let mut ssm = SelectiveSSM::new(3, 8, 42);
293        let _ = ssm.forward(&[1.0, 2.0, 3.0]);
294        ssm.reset();
295        for &h in ssm.state() {
296            assert!(math::abs(h) < 1e-15, "state should be zero after reset");
297        }
298    }
299
300    #[test]
301    fn state_decays_without_input() {
302        let mut ssm = SelectiveSSM::new(2, 4, 42);
303        // Inject state
304        let _ = ssm.forward(&[10.0, 10.0]);
305        let energy_after: f64 = ssm.state().iter().map(|h| h * h).sum();
306
307        // Feed zeros for many steps
308        for _ in 0..200 {
309            let _ = ssm.forward(&[0.0, 0.0]);
310        }
311        let energy_decayed: f64 = ssm.state().iter().map(|h| h * h).sum();
312        assert!(
313            energy_decayed < energy_after * 0.01,
314            "state should decay with zero input: initial={}, after={}",
315            energy_after,
316            energy_decayed
317        );
318    }
319
320    #[test]
321    fn deterministic_with_same_seed() {
322        let mut ssm1 = SelectiveSSM::new(3, 8, 42);
323        let mut ssm2 = SelectiveSSM::new(3, 8, 42);
324        let input = vec![1.0, 2.0, 3.0];
325        let out1 = ssm1.forward(&input);
326        let out2 = ssm2.forward(&input);
327        for (i, (&a, &b)) in out1.iter().zip(out2.iter()).enumerate() {
328            assert!(
329                math::abs(a - b) < 1e-15,
330                "output[{}] should be identical for same seed: {} vs {}",
331                i,
332                a,
333                b
334            );
335        }
336    }
337
338    #[test]
339    fn different_seeds_produce_different_outputs() {
340        let mut ssm1 = SelectiveSSM::new(3, 8, 42);
341        let mut ssm2 = SelectiveSSM::new(3, 8, 99);
342        let input = vec![1.0, 2.0, 3.0];
343        let out1 = ssm1.forward(&input);
344        let out2 = ssm2.forward(&input);
345        let diff: f64 = out1
346            .iter()
347            .zip(out2.iter())
348            .map(|(a, b)| (a - b) * (a - b))
349            .sum();
350        assert!(
351            diff > 1e-20,
352            "different seeds should generally produce different outputs"
353        );
354    }
355
356    #[test]
357    fn single_channel_works() {
358        let mut ssm = SelectiveSSM::new(1, 4, 42);
359        let output = ssm.forward(&[3.0]);
360        assert_eq!(output.len(), 1);
361        assert!(output[0].is_finite());
362    }
363
364    #[test]
365    fn single_state_dim_works() {
366        let mut ssm = SelectiveSSM::new(3, 1, 42);
367        let output = ssm.forward(&[1.0, 2.0, 3.0]);
368        assert_eq!(output.len(), 3);
369        for &y in &output {
370            assert!(y.is_finite());
371        }
372    }
373
374    #[test]
375    fn sequential_outputs_differ() {
376        let mut ssm = SelectiveSSM::new(2, 4, 42);
377        let out1 = ssm.forward(&[1.0, 0.0]);
378        let out2 = ssm.forward(&[1.0, 0.0]);
379        // Second call has non-zero state from first call, so outputs should differ
380        let diff: f64 = out1
381            .iter()
382            .zip(out2.iter())
383            .map(|(a, b)| (a - b) * (a - b))
384            .sum();
385        assert!(
386            diff > 1e-20,
387            "sequential calls with same input should differ due to state: out1={:?}, out2={:?}",
388            out1,
389            out2
390        );
391    }
392
393    #[test]
394    fn large_input_no_overflow() {
395        let mut ssm = SelectiveSSM::new(2, 4, 42);
396        let input = vec![1000.0, -1000.0];
397        let output = ssm.forward(&input);
398        for (i, &y) in output.iter().enumerate() {
399            assert!(
400                y.is_finite(),
401                "output[{}] should be finite for large inputs, got {}",
402                i,
403                y
404            );
405        }
406    }
407
408    #[test]
409    fn zero_input_zero_state_gives_zero_output() {
410        let mut ssm = SelectiveSSM::new(3, 8, 42);
411        let output = ssm.forward(&[0.0, 0.0, 0.0]);
412        for (i, &y) in output.iter().enumerate() {
413            assert!(
414                math::abs(y) < 1e-15,
415                "zero input with zero state should give zero output[{}], got {}",
416                i,
417                y
418            );
419        }
420    }
421}