Skip to main content

irithyll_core/ssm/
selective.rs

1//! Mamba-style selective state space model with input-dependent projections.
2//!
3//! [`SelectiveSSM`] implements the core selective mechanism from Mamba
4//! (Gu & Dao, 2023), where the discretization step size Delta, input matrix B,
5//! and output matrix C are all functions of the current input. This enables
6//! content-aware filtering: the model can learn to selectively remember or
7//! forget information based on what it sees.
8//!
9//! # Architecture
10//!
11//! For each input timestep `x_t` (a d_in-dimensional vector):
12//!
13//! ```text
14//! Delta_t = softplus(W_delta * x_t + b_delta)    // scalar step size
15//! B_t     = W_B * x_t                             // N-dim input projection
16//! C_t     = W_C * x_t                             // N-dim output projection
17//! A       = -exp(log_A)                           // fixed, always negative diagonal
18//!
19//! For each input channel d in 0..d_in:
20//!   For each state dim n in 0..N:
21//!     A_bar = exp(Delta_t * A_n)
22//!     B_bar = (A_bar - 1) / A_n * B_t[n]
23//!     h[d,n] = A_bar * h[d,n] + B_bar * x_t[d]
24//!   output[d] = C_t^T * h[d,:] + D[d] * x_t[d]
25//! ```
26//!
27//! Each input channel maintains its own N-dimensional state vector, allowing
28//! the model to track per-channel temporal patterns independently.
29
30use alloc::vec;
31use alloc::vec::Vec;
32
33use crate::math;
34use crate::ssm::init::mamba_init;
35use crate::ssm::projection::{dot, mat_vec, softplus, Xorshift64};
36use crate::ssm::SSMLayer;
37
38/// Mamba-style selective state space model.
39///
40/// The selective mechanism computes input-dependent B, C, and Delta at each
41/// timestep, enabling the model to dynamically control what information is
42/// stored in and retrieved from the hidden state.
43///
44/// # Dimensions
45///
46/// - `d_in` -- input/output dimension (number of channels)
47/// - `n_state` -- hidden state dimension per channel (N)
48/// - Total hidden state size: `d_in * n_state`
49///
50/// # Weight Shapes
51///
52/// | Weight | Shape | Purpose |
53/// |--------|-------|---------|
54/// | `w_delta` | d_in | Projects input to scalar step size |
55/// | `w_b` | N x d_in | Projects input to state-input coupling |
56/// | `w_c` | N x d_in | Projects input to state-output coupling |
57/// | `d_skip` | d_in | Skip connection weights |
58/// | `log_a` | N | Fixed state transition (always negative after exp) |
59///
60/// # Example
61///
62/// ```
63/// use irithyll_core::ssm::selective::SelectiveSSM;
64/// use irithyll_core::ssm::SSMLayer;
65///
66/// let mut ssm = SelectiveSSM::new(4, 8, 42);
67/// let output = ssm.forward(&[1.0, 2.0, 3.0, 4.0]);
68/// assert_eq!(output.len(), 4);
69/// ```
70pub struct SelectiveSSM {
71    /// Log-space A parameters (N). Actual A_n = -exp(log_a[n]).
72    log_a: Vec<f64>,
73    /// Delta projection weights (d_in). Maps input to scalar step size.
74    w_delta: Vec<f64>,
75    /// Delta projection bias.
76    b_delta: f64,
77    /// B projection weights (N x d_in, row-major). Maps input to B_t.
78    w_b: Vec<f64>,
79    /// C projection weights (N x d_in, row-major). Maps input to C_t.
80    w_c: Vec<f64>,
81    /// Skip connection weights (d_in).
82    d_skip: Vec<f64>,
83    /// Hidden state (d_in * N, laid out as [channel_0_state, channel_1_state, ...]).
84    h: Vec<f64>,
85    /// Number of state dimensions per channel.
86    n_state: usize,
87    /// Input/output dimension.
88    d_in: usize,
89}
90
91impl SelectiveSSM {
92    /// Create a new selective SSM with random weight initialization.
93    ///
94    /// Weights are initialized from a small normal distribution (scale 0.01)
95    /// using the provided seed for reproducibility. A is initialized via the
96    /// Mamba strategy (A_n = -(n+1)).
97    ///
98    /// # Arguments
99    ///
100    /// * `d_in` -- input/output dimension (number of channels)
101    /// * `n_state` -- hidden state dimension per channel (N)
102    /// * `seed` -- random seed for weight initialization
103    ///
104    /// # Example
105    ///
106    /// ```
107    /// use irithyll_core::ssm::selective::SelectiveSSM;
108    ///
109    /// let ssm = SelectiveSSM::new(4, 16, 42);
110    /// ```
111    pub fn new(d_in: usize, n_state: usize, seed: u64) -> Self {
112        let log_a = mamba_init(n_state);
113        let mut rng = Xorshift64(seed);
114        let scale = 0.01;
115
116        // Initialize projection weights from small normal distribution
117        let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
118        let b_delta = 0.0;
119        let w_b: Vec<f64> = (0..n_state * d_in)
120            .map(|_| rng.next_normal() * scale)
121            .collect();
122        let w_c: Vec<f64> = (0..n_state * d_in)
123            .map(|_| rng.next_normal() * scale)
124            .collect();
125        let d_skip = vec![0.0; d_in];
126        let h = vec![0.0; d_in * n_state];
127
128        Self {
129            log_a,
130            w_delta,
131            b_delta,
132            w_b,
133            w_c,
134            d_skip,
135            h,
136            n_state,
137            d_in,
138        }
139    }
140
141    /// Get the input/output dimension.
142    #[inline]
143    pub fn d_in(&self) -> usize {
144        self.d_in
145    }
146
147    /// Get the number of state dimensions per channel.
148    #[inline]
149    pub fn n_state(&self) -> usize {
150        self.n_state
151    }
152
153    /// Compute the selective SSM forward pass for one timestep.
154    ///
155    /// This is the core Mamba recurrence: compute input-dependent Delta, B, C,
156    /// then update each channel's state and produce the output.
157    fn selective_forward(&mut self, input: &[f64]) -> Vec<f64> {
158        let d_in = self.d_in;
159        let n_state = self.n_state;
160
161        // 1. Compute delta = softplus(dot(w_delta, input) + b_delta)
162        let delta_raw = dot(&self.w_delta, input) + self.b_delta;
163        let delta = softplus(delta_raw);
164
165        // 2. Compute B_t = W_B * input (shape: N)
166        let mut b_t = vec![0.0; n_state];
167        mat_vec(&self.w_b, input, n_state, d_in, &mut b_t);
168
169        // 3. Compute C_t = W_C * input (shape: N)
170        let mut c_t = vec![0.0; n_state];
171        mat_vec(&self.w_c, input, n_state, d_in, &mut c_t);
172
173        // 4. For each input channel, update state and compute output
174        let mut output = vec![0.0; d_in];
175
176        for d in 0..d_in {
177            let h_offset = d * n_state;
178            let mut y = 0.0;
179
180            for n in 0..n_state {
181                let a_n = -math::exp(self.log_a[n]); // negative real diagonal
182
183                // ZOH discretization: A_bar = exp(delta * A_n)
184                let a_bar = math::exp(delta * a_n);
185
186                // B_bar = (A_bar - 1) / A_n * B_t[n]
187                // Handle A_n ~ 0 (shouldn't happen with mamba_init, but be safe)
188                let b_bar = if math::abs(a_n) < 1e-12 {
189                    delta * b_t[n]
190                } else {
191                    (a_bar - 1.0) / a_n * b_t[n]
192                };
193
194                // State update: h[d,n] = A_bar * h[d,n] + B_bar * input[d]
195                self.h[h_offset + n] = a_bar * self.h[h_offset + n] + b_bar * input[d];
196
197                // Accumulate output: y += C_t[n] * h[d,n]
198                y += c_t[n] * self.h[h_offset + n];
199            }
200
201            // Add skip connection
202            output[d] = y + self.d_skip[d] * input[d];
203        }
204
205        output
206    }
207}
208
209impl SSMLayer for SelectiveSSM {
210    fn forward(&mut self, input: &[f64]) -> Vec<f64> {
211        debug_assert_eq!(
212            input.len(),
213            self.d_in,
214            "input length {} must match d_in {}",
215            input.len(),
216            self.d_in
217        );
218        self.selective_forward(input)
219    }
220
221    fn state(&self) -> &[f64] {
222        &self.h
223    }
224
225    fn output_dim(&self) -> usize {
226        self.d_in
227    }
228
229    fn reset(&mut self) {
230        for h in self.h.iter_mut() {
231            *h = 0.0;
232        }
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn new_creates_correct_dimensions() {
242        let ssm = SelectiveSSM::new(4, 8, 42);
243        assert_eq!(ssm.d_in(), 4);
244        assert_eq!(ssm.n_state(), 8);
245        assert_eq!(ssm.state().len(), 4 * 8);
246        assert_eq!(ssm.output_dim(), 4);
247    }
248
249    #[test]
250    fn initial_state_is_zero() {
251        let ssm = SelectiveSSM::new(3, 16, 42);
252        for &h in ssm.state() {
253            assert!(math::abs(h) < 1e-15, "initial state should be zero");
254        }
255    }
256
257    #[test]
258    fn forward_produces_correct_output_dim() {
259        let mut ssm = SelectiveSSM::new(5, 8, 42);
260        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
261        let output = ssm.forward(&input);
262        assert_eq!(output.len(), 5, "output dim should match d_in");
263    }
264
265    #[test]
266    fn forward_produces_finite_output() {
267        let mut ssm = SelectiveSSM::new(3, 8, 42);
268        let input = vec![1.0, -1.0, 0.5];
269        let output = ssm.forward(&input);
270        for (i, &y) in output.iter().enumerate() {
271            assert!(y.is_finite(), "output[{}] should be finite, got {}", i, y);
272        }
273    }
274
275    #[test]
276    fn forward_updates_state() {
277        let mut ssm = SelectiveSSM::new(3, 8, 42);
278        let input = vec![1.0, 2.0, 3.0];
279        let _ = ssm.forward(&input);
280        let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
281        assert!(
282            state_norm > 0.0,
283            "state should be non-zero after processing non-zero input"
284        );
285    }
286
287    #[test]
288    fn reset_clears_state() {
289        let mut ssm = SelectiveSSM::new(3, 8, 42);
290        let _ = ssm.forward(&[1.0, 2.0, 3.0]);
291        ssm.reset();
292        for &h in ssm.state() {
293            assert!(math::abs(h) < 1e-15, "state should be zero after reset");
294        }
295    }
296
297    #[test]
298    fn state_decays_without_input() {
299        let mut ssm = SelectiveSSM::new(2, 4, 42);
300        // Inject state
301        let _ = ssm.forward(&[10.0, 10.0]);
302        let energy_after: f64 = ssm.state().iter().map(|h| h * h).sum();
303
304        // Feed zeros for many steps
305        for _ in 0..200 {
306            let _ = ssm.forward(&[0.0, 0.0]);
307        }
308        let energy_decayed: f64 = ssm.state().iter().map(|h| h * h).sum();
309        assert!(
310            energy_decayed < energy_after * 0.01,
311            "state should decay with zero input: initial={}, after={}",
312            energy_after,
313            energy_decayed
314        );
315    }
316
317    #[test]
318    fn deterministic_with_same_seed() {
319        let mut ssm1 = SelectiveSSM::new(3, 8, 42);
320        let mut ssm2 = SelectiveSSM::new(3, 8, 42);
321        let input = vec![1.0, 2.0, 3.0];
322        let out1 = ssm1.forward(&input);
323        let out2 = ssm2.forward(&input);
324        for (i, (&a, &b)) in out1.iter().zip(out2.iter()).enumerate() {
325            assert!(
326                math::abs(a - b) < 1e-15,
327                "output[{}] should be identical for same seed: {} vs {}",
328                i,
329                a,
330                b
331            );
332        }
333    }
334
335    #[test]
336    fn different_seeds_produce_different_outputs() {
337        let mut ssm1 = SelectiveSSM::new(3, 8, 42);
338        let mut ssm2 = SelectiveSSM::new(3, 8, 99);
339        let input = vec![1.0, 2.0, 3.0];
340        let out1 = ssm1.forward(&input);
341        let out2 = ssm2.forward(&input);
342        let diff: f64 = out1
343            .iter()
344            .zip(out2.iter())
345            .map(|(a, b)| (a - b) * (a - b))
346            .sum();
347        assert!(
348            diff > 1e-20,
349            "different seeds should generally produce different outputs"
350        );
351    }
352
353    #[test]
354    fn single_channel_works() {
355        let mut ssm = SelectiveSSM::new(1, 4, 42);
356        let output = ssm.forward(&[3.0]);
357        assert_eq!(output.len(), 1);
358        assert!(output[0].is_finite());
359    }
360
361    #[test]
362    fn single_state_dim_works() {
363        let mut ssm = SelectiveSSM::new(3, 1, 42);
364        let output = ssm.forward(&[1.0, 2.0, 3.0]);
365        assert_eq!(output.len(), 3);
366        for &y in &output {
367            assert!(y.is_finite());
368        }
369    }
370
371    #[test]
372    fn sequential_outputs_differ() {
373        let mut ssm = SelectiveSSM::new(2, 4, 42);
374        let out1 = ssm.forward(&[1.0, 0.0]);
375        let out2 = ssm.forward(&[1.0, 0.0]);
376        // Second call has non-zero state from first call, so outputs should differ
377        let diff: f64 = out1
378            .iter()
379            .zip(out2.iter())
380            .map(|(a, b)| (a - b) * (a - b))
381            .sum();
382        assert!(
383            diff > 1e-20,
384            "sequential calls with same input should differ due to state: out1={:?}, out2={:?}",
385            out1,
386            out2
387        );
388    }
389
390    #[test]
391    fn large_input_no_overflow() {
392        let mut ssm = SelectiveSSM::new(2, 4, 42);
393        let input = vec![1000.0, -1000.0];
394        let output = ssm.forward(&input);
395        for (i, &y) in output.iter().enumerate() {
396            assert!(
397                y.is_finite(),
398                "output[{}] should be finite for large inputs, got {}",
399                i,
400                y
401            );
402        }
403    }
404
405    #[test]
406    fn zero_input_zero_state_gives_zero_output() {
407        let mut ssm = SelectiveSSM::new(3, 8, 42);
408        let output = ssm.forward(&[0.0, 0.0, 0.0]);
409        for (i, &y) in output.iter().enumerate() {
410            assert!(
411                math::abs(y) < 1e-15,
412                "zero input with zero state should give zero output[{}], got {}",
413                i,
414                y
415            );
416        }
417    }
418}