Skip to main content

irithyll_core/ssm/
selective_bd.rs

1//! Block-Diagonal Linear Recurrent Unit (BD-LRU) selective state space model.
2//!
3//! [`SelectiveSSMBD`] implements a block-diagonal SSM variant inspired by
4//! Dubinin et al. (2026), where input channels are partitioned into blocks and
5//! each block has a dense A matrix enabling cross-channel state mixing within
6//! the block. This sits between the fully diagonal Mamba-1 (no cross-channel
7//! mixing) and a full dense SSM (quadratic cost).
8//!
9//! # Architecture
10//!
11//! For each input timestep `x_t` (a d_in-dimensional vector):
12//!
13//! ```text
14//! Delta_t = softplus(W_delta * x_t + b_delta)    // scalar step size
15//! B_t     = W_B * x_t                             // N-dim input projection
16//! C_t     = W_C * x_t                             // N-dim output projection
17//!
18//! For each block k (channels k*m .. (k+1)*m, m = block_size):
19//!   x_block = x_t[k*m .. (k+1)*m]
20//!   For each state dim n in 0..N:
21//!     // Euler discretization with row-L1-normalized dense A:
22//!     // A_disc[i,j] = delta * A[i,j]  for i != j
23//!     // A_disc[i,i] = 1 + delta * A[i,i]
24//!     h_block[n] = A_disc * h_block[n] + delta * B_t[n] * x_block
25//!
26//!   // Output: weighted sum over state dims
27//!   y_block = sum_n C_t[n] * h_block[n]
28//!
29//! output[d] = y_block[d_within_block] + D[d] * x_t[d]
30//! ```
31//!
32//! # Block-Diagonal vs Diagonal
33//!
34//! The key differentiator from [`SelectiveSSM`](crate::ssm::SelectiveSSM) is
35//! that channels within a block can influence each other's state evolution
36//! through the off-diagonal entries of the block's A matrix. With `block_size=1`,
37//! this reduces to a per-channel diagonal recurrence (equivalent to Mamba-1).
38//! Larger block sizes enable richer cross-channel dynamics at O(m^2) cost per
39//! block instead of O(d_in^2) for a full dense A.
40//!
41//! # Stability
42//!
43//! Each block's A matrix is row-wise L1-normalized so that the sum of absolute
44//! values in each row is at most 1.0. Combined with Euler discretization
45//! (`I + Delta * A`), this ensures the discretized transition matrix has
46//! bounded spectral radius for small Delta, preventing state explosion.
47
48use alloc::vec;
49use alloc::vec::Vec;
50
51use crate::math;
52use crate::rng::standard_normal;
53use crate::ssm::init::s4d_inv_real;
54use crate::ssm::projection::{dot, mat_vec, softplus, Xorshift64};
55use crate::ssm::SSMLayer;
56
57/// Block-Diagonal Linear Recurrent Unit selective state space model.
58///
59/// Partitions `d_in` channels into `n_blocks = d_in / block_size` blocks, each
60/// with a dense `block_size x block_size` A matrix for within-block
61/// cross-channel state mixing. B, C, and Delta projections are shared across
62/// blocks (same structure as Mamba-1).
63///
64/// # Dimensions
65///
66/// - `d_in` -- input/output dimension (number of channels)
67/// - `n_state` -- hidden state dimension per block-channel (N)
68/// - `block_size` -- number of channels per block (m)
69/// - `n_blocks` -- number of blocks (d_in / block_size)
70/// - Total hidden state size: `n_blocks * n_state * block_size`
71///
72/// # Weight Shapes
73///
74/// | Weight | Shape | Purpose |
75/// |--------|-------|---------|
76/// | `a_matrices` | n_blocks * m * m | Dense A per block (row-major, L1-normalized) |
77/// | `w_b` | N x d_in | Projects input to state-input coupling |
78/// | `w_c` | N x d_in | Projects input to state-output coupling |
79/// | `w_delta` | d_in | Projects input to scalar step size |
80/// | `d_skip` | d_in | Skip connection weights |
81///
82/// # Example
83///
84/// ```
85/// use irithyll_core::ssm::selective_bd::SelectiveSSMBD;
86/// use irithyll_core::ssm::SSMLayer;
87///
88/// let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
89/// let output = ssm.forward(&[1.0, 2.0, 3.0, 4.0]);
90/// assert_eq!(output.len(), 4);
91/// ```
92pub struct SelectiveSSMBD {
93    /// Per-block A matrices: n_blocks * block_size * block_size, row-major per block.
94    /// Each block's m x m matrix is contiguous, L1-row-normalized for stability.
95    a_matrices: Vec<f64>,
96    /// B projection weights (n_state x d_in, row-major). Maps input to B_t.
97    w_b: Vec<f64>,
98    /// C projection weights (n_state x d_in, row-major). Maps input to C_t.
99    w_c: Vec<f64>,
100    /// Delta projection weights (d_in). Maps input to scalar step size.
101    w_delta: Vec<f64>,
102    /// Delta projection bias.
103    b_delta: f64,
104    /// Skip connection weights (d_in).
105    d_skip: Vec<f64>,
106    /// Hidden state: n_blocks * n_state * block_size.
107    /// Layout: h[block * n_state * block_size + state_dim * block_size + channel_within_block]
108    h: Vec<f64>,
109    /// Input/output dimension.
110    d_in: usize,
111    /// Number of state dimensions per block-channel.
112    n_state: usize,
113    /// Number of channels per block.
114    block_size: usize,
115    /// Number of blocks (d_in / block_size).
116    n_blocks: usize,
117}
118
119/// Normalize each row of an m x m matrix in-place so that the L1 norm
120/// (sum of absolute values) of each row is at most 1.0.
121///
122/// Rows with L1 norm already <= 1.0 are left unchanged.
123fn normalize_row_l1(a: &mut [f64], m: usize) {
124    for row in 0..m {
125        let start = row * m;
126        let row_sum: f64 = a[start..start + m].iter().map(|x| math::abs(*x)).sum();
127        if row_sum > 1.0 {
128            for j in 0..m {
129                a[start + j] /= row_sum;
130            }
131        }
132    }
133}
134
135impl SelectiveSSMBD {
136    /// Create a new block-diagonal selective SSM with random weight initialization.
137    ///
138    /// A matrices are initialized with S4D-Inv diagonal values and small random
139    /// off-diagonal entries (scale 0.02), then row-wise L1-normalized. Projection
140    /// weights are initialized from a small normal distribution (scale 0.1).
141    /// Skip connections (D) are initialized to 1.0 for input passthrough.
142    ///
143    /// # Arguments
144    ///
145    /// * `d_in` -- input/output dimension (must be divisible by `block_size`)
146    /// * `n_state` -- hidden state dimension per block-channel (N)
147    /// * `block_size` -- number of channels per block (m)
148    /// * `seed` -- random seed for weight initialization
149    ///
150    /// # Panics
151    ///
152    /// Panics if `d_in` is not evenly divisible by `block_size`.
153    ///
154    /// # Example
155    ///
156    /// ```
157    /// use irithyll_core::ssm::selective_bd::SelectiveSSMBD;
158    ///
159    /// let ssm = SelectiveSSMBD::new(6, 8, 2, 42);
160    /// ```
161    pub fn new(d_in: usize, n_state: usize, block_size: usize, seed: u64) -> Self {
162        assert!(
163            d_in % block_size == 0,
164            "d_in ({}) must be evenly divisible by block_size ({})",
165            d_in,
166            block_size
167        );
168
169        let n_blocks = d_in / block_size;
170        let m = block_size;
171        let mut rng = Xorshift64(seed);
172        let scale = 0.1;
173        let off_diag_scale = 0.02;
174
175        // Initialize A matrices: S4D-Inv diagonal + small random off-diagonal
176        let log_a = s4d_inv_real(m);
177        let mut a_matrices = vec![0.0; n_blocks * m * m];
178
179        for blk in 0..n_blocks {
180            let base = blk * m * m;
181            // Fill with small random off-diagonal values
182            for i in 0..m {
183                for j in 0..m {
184                    if i == j {
185                        // Diagonal: negative S4D-Inv values
186                        // A_i = -(0.5 + i/m), use directly (not log-space here)
187                        a_matrices[base + i * m + j] = -math::exp(log_a[i]);
188                    } else {
189                        // Off-diagonal: small random normal
190                        a_matrices[base + i * m + j] = rng.next_normal() * off_diag_scale;
191                    }
192                }
193            }
194            // Apply row-wise L1 normalization for stability
195            normalize_row_l1(&mut a_matrices[base..base + m * m], m);
196        }
197
198        // Initialize projection weights from small normal distribution
199        let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
200        let b_delta = 0.0;
201        let w_b: Vec<f64> = (0..n_state * d_in)
202            .map(|_| rng.next_normal() * scale)
203            .collect();
204        let w_c: Vec<f64> = (0..n_state * d_in)
205            .map(|_| rng.next_normal() * scale)
206            .collect();
207        let d_skip = vec![1.0; d_in];
208        let h = vec![0.0; n_blocks * n_state * block_size];
209
210        Self {
211            a_matrices,
212            w_b,
213            w_c,
214            w_delta,
215            b_delta,
216            d_skip,
217            h,
218            d_in,
219            n_state,
220            block_size,
221            n_blocks,
222        }
223    }
224
225    /// Get the input/output dimension.
226    #[inline]
227    pub fn d_in(&self) -> usize {
228        self.d_in
229    }
230
231    /// Get the number of state dimensions per block-channel.
232    #[inline]
233    pub fn n_state(&self) -> usize {
234        self.n_state
235    }
236
237    /// Get the number of channels per block.
238    #[inline]
239    pub fn block_size(&self) -> usize {
240        self.block_size
241    }
242
243    /// Get the number of blocks.
244    #[inline]
245    pub fn n_blocks(&self) -> usize {
246        self.n_blocks
247    }
248
249    /// Surgically reinitialize a single block, preserving all other blocks.
250    ///
251    /// Resets block `b`'s hidden state to zero, reinitializes its A matrix
252    /// with S4D diagonal + small random off-diagonal values (then L1 row-
253    /// normalizes), and resets the skip connections for the block's channels
254    /// to 1.0. All other blocks are left untouched.
255    ///
256    /// # Arguments
257    ///
258    /// * `b` — block index to reinitialize (must be < `n_blocks`)
259    /// * `rng` — mutable RNG state for generating fresh weights
260    ///
261    /// # Panics
262    ///
263    /// Panics if `b >= n_blocks`.
264    pub fn reinitialize_block(&mut self, b: usize, rng: &mut u64) {
265        assert!(
266            b < self.n_blocks,
267            "block index {} out of range (n_blocks={})",
268            b,
269            self.n_blocks
270        );
271
272        let m = self.block_size;
273        let off_diag_scale = 0.02;
274
275        // Zero state: h[b * n_state * block_size .. (b+1) * n_state * block_size]
276        let h_start = b * self.n_state * m;
277        let h_end = h_start + self.n_state * m;
278        for h in self.h[h_start..h_end].iter_mut() {
279            *h = 0.0;
280        }
281
282        // Reinit A matrix for block b: S4D diagonal + small random off-diagonal
283        let log_a = s4d_inv_real(m);
284        let a_base = b * m * m;
285        for (i, &la_i) in log_a.iter().enumerate().take(m) {
286            for j in 0..m {
287                if i == j {
288                    self.a_matrices[a_base + i * m + j] = -math::exp(la_i);
289                } else {
290                    self.a_matrices[a_base + i * m + j] = standard_normal(rng) * off_diag_scale;
291                }
292            }
293        }
294        // Apply row-wise L1 normalization for stability
295        normalize_row_l1(&mut self.a_matrices[a_base..a_base + m * m], m);
296
297        // Reset d_skip for channels in this block to default passthrough
298        let ch_start = b * m;
299        for d in ch_start..ch_start + m {
300            self.d_skip[d] = 1.0;
301        }
302    }
303
304    /// Compute the block-diagonal SSM forward pass for one timestep.
305    ///
306    /// This is the core BD-LRU recurrence: compute input-dependent Delta, B, C,
307    /// then for each block apply the dense A state update with Euler
308    /// discretization and accumulate the output.
309    fn bd_forward(&mut self, input: &[f64]) -> Vec<f64> {
310        let d_in = self.d_in;
311        let n_state = self.n_state;
312        let m = self.block_size;
313        let n_blocks = self.n_blocks;
314
315        // 1. Compute delta = softplus(dot(w_delta, input) + b_delta).
316        //    Clamp to 1.0: the Euler discretization (I + delta*A) is only
317        //    stable for small delta because A diagonal entries are negative
318        //    (S4D-Inv). For large delta the term (1 + delta*A[i,i]) goes
319        //    strongly negative, causing exponential state divergence on
320        //    datasets with large-magnitude features (e.g. Power Plant).
321        //    ZOH (exp(delta*A)) is unconditionally stable but more expensive;
322        //    clamping delta is the minimal fix that preserves the architecture.
323        let delta_raw = dot(&self.w_delta, input) + self.b_delta;
324        let delta = softplus(delta_raw).min(1.0);
325
326        // 2. Compute B_t = W_B * input (shape: n_state)
327        let mut b_t = vec![0.0; n_state];
328        mat_vec(&self.w_b, input, n_state, d_in, &mut b_t);
329
330        // 3. Compute C_t = W_C * input (shape: n_state)
331        let mut c_t = vec![0.0; n_state];
332        mat_vec(&self.w_c, input, n_state, d_in, &mut c_t);
333
334        // 4. For each block: apply dense A state update
335        let mut output = vec![0.0; d_in];
336
337        for blk in 0..n_blocks {
338            let a_base = blk * m * m;
339            let x_start = blk * m;
340            let h_block_base = blk * n_state * m;
341
342            for (n, &b_n) in b_t.iter().enumerate().take(n_state) {
343                let h_offset = h_block_base + n * m;
344
345                // Apply block state update with Euler discretization:
346                // h_new[i] = sum_j(A_disc[i,j] * h_old[j]) + delta * B_t[n] * x_block[i]
347                // where A_disc[i,j] = delta * A[i,j] for i != j
348                //       A_disc[i,i] = 1 + delta * A[i,i]
349                //
350                // We compute h_new into a temp buffer to avoid reading stale values.
351                let db = delta * b_n;
352
353                // Temporary buffer for new state (avoid allocation for small blocks
354                // by using a stack array would be nice, but we need Vec for generality)
355                let mut h_new = vec![0.0; m];
356
357                for i in 0..m {
358                    let a_row = a_base + i * m;
359                    let mut sum = 0.0;
360                    for j in 0..m {
361                        let a_disc = if i == j {
362                            1.0 + delta * self.a_matrices[a_row + j]
363                        } else {
364                            delta * self.a_matrices[a_row + j]
365                        };
366                        sum += a_disc * self.h[h_offset + j];
367                    }
368                    // Input injection: delta * B_t[n] * x_block[i]
369                    h_new[i] = sum + db * input[x_start + i];
370                }
371
372                // Write back new state
373                self.h[h_offset..h_offset + m].copy_from_slice(&h_new);
374            }
375
376            // 5. Output accumulation: y_block[i] = sum_n C_t[n] * h[block, n, i]
377            for (n, &c_n) in c_t.iter().enumerate().take(n_state) {
378                let h_offset = h_block_base + n * m;
379                for i in 0..m {
380                    output[x_start + i] += c_n * self.h[h_offset + i];
381                }
382            }
383        }
384
385        // 6. Add skip connection: output[d] += D[d] * input[d]
386        for (out_d, (&skip, &x_d)) in output.iter_mut().zip(self.d_skip.iter().zip(input.iter())) {
387            *out_d += skip * x_d;
388        }
389
390        output
391    }
392}
393
394impl SSMLayer for SelectiveSSMBD {
395    fn forward(&mut self, input: &[f64]) -> Vec<f64> {
396        debug_assert_eq!(
397            input.len(),
398            self.d_in,
399            "input length {} must match d_in {}",
400            input.len(),
401            self.d_in
402        );
403        self.bd_forward(input)
404    }
405
406    fn state(&self) -> &[f64] {
407        &self.h
408    }
409
410    fn output_dim(&self) -> usize {
411        self.d_in
412    }
413
414    fn reset(&mut self) {
415        for h in self.h.iter_mut() {
416            *h = 0.0;
417        }
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn bd_new_correct_dimensions() {
427        let ssm = SelectiveSSMBD::new(6, 8, 2, 42);
428        assert_eq!(ssm.d_in(), 6);
429        assert_eq!(ssm.n_state(), 8);
430        assert_eq!(ssm.block_size(), 2);
431        assert_eq!(ssm.n_blocks(), 3);
432        assert_eq!(
433            ssm.state().len(),
434            3 * 8 * 2,
435            "state size = n_blocks * n_state * block_size"
436        );
437        assert_eq!(ssm.output_dim(), 6);
438    }
439
440    #[test]
441    fn bd_initial_state_zero() {
442        let ssm = SelectiveSSMBD::new(4, 8, 2, 42);
443        for &h in ssm.state() {
444            assert!(math::abs(h) < 1e-15, "initial state should be zero");
445        }
446    }
447
448    #[test]
449    fn bd_forward_correct_output_dim() {
450        let mut ssm = SelectiveSSMBD::new(6, 8, 3, 42);
451        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
452        let output = ssm.forward(&input);
453        assert_eq!(output.len(), 6, "output dim should match d_in");
454    }
455
456    #[test]
457    fn bd_forward_finite_output() {
458        let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
459        let input = vec![1.0, -1.0, 0.5, -0.5];
460        let output = ssm.forward(&input);
461        for (i, &y) in output.iter().enumerate() {
462            assert!(y.is_finite(), "output[{}] should be finite, got {}", i, y);
463        }
464    }
465
466    #[test]
467    fn bd_forward_updates_state() {
468        let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
469        let input = vec![1.0, 2.0, 3.0, 4.0];
470        let _ = ssm.forward(&input);
471        let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
472        assert!(
473            state_norm > 0.0,
474            "state should be non-zero after processing non-zero input"
475        );
476    }
477
478    #[test]
479    fn bd_reset_clears_state() {
480        let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
481        let _ = ssm.forward(&[1.0, 2.0, 3.0, 4.0]);
482        ssm.reset();
483        for &h in ssm.state() {
484            assert!(math::abs(h) < 1e-15, "state should be zero after reset");
485        }
486    }
487
488    #[test]
489    fn bd_deterministic_same_seed() {
490        let mut ssm1 = SelectiveSSMBD::new(4, 8, 2, 42);
491        let mut ssm2 = SelectiveSSMBD::new(4, 8, 2, 42);
492        let input = vec![1.0, -1.0, 0.5, -0.5];
493        let out1 = ssm1.forward(&input);
494        let out2 = ssm2.forward(&input);
495        for (i, (&a, &b)) in out1.iter().zip(out2.iter()).enumerate() {
496            assert!(
497                math::abs(a - b) < 1e-15,
498                "output[{}] should be identical for same seed: {} vs {}",
499                i,
500                a,
501                b
502            );
503        }
504    }
505
506    #[test]
507    fn bd_different_seeds_differ() {
508        let mut ssm1 = SelectiveSSMBD::new(4, 8, 2, 42);
509        let mut ssm2 = SelectiveSSMBD::new(4, 8, 2, 99);
510        let input = vec![1.0, 2.0, 3.0, 4.0];
511        let out1 = ssm1.forward(&input);
512        let out2 = ssm2.forward(&input);
513        let diff: f64 = out1
514            .iter()
515            .zip(out2.iter())
516            .map(|(a, b)| (a - b) * (a - b))
517            .sum();
518        assert!(
519            diff > 1e-20,
520            "different seeds should generally produce different outputs"
521        );
522    }
523
524    #[test]
525    fn bd_zero_input_zero_state_zero_output() {
526        let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
527        let output = ssm.forward(&[0.0, 0.0, 0.0, 0.0]);
528        for (i, &y) in output.iter().enumerate() {
529            assert!(
530                math::abs(y) < 1e-15,
531                "zero input with zero state should give zero output[{}], got {}",
532                i,
533                y
534            );
535        }
536    }
537
538    #[test]
539    fn bd_cross_channel_mixing() {
540        // With block_size > 1, off-diagonal A entries cause cross-channel mixing
541        // within each block. With block_size=1, there are no off-diagonal entries,
542        // so channels evolve independently. Verify the two produce different results.
543        let d_in = 4;
544        let n_state = 4;
545        let seed = 42;
546
547        let mut ssm_blk1 = SelectiveSSMBD::new(d_in, n_state, 1, seed);
548        let mut ssm_blk2 = SelectiveSSMBD::new(d_in, n_state, 2, seed);
549
550        let input = vec![1.0, 2.0, 3.0, 4.0];
551
552        // Run a few steps to accumulate state differences from cross-channel mixing
553        for _ in 0..5 {
554            let _ = ssm_blk1.forward(&input);
555            let _ = ssm_blk2.forward(&input);
556        }
557
558        let out1 = ssm_blk1.forward(&input);
559        let out2 = ssm_blk2.forward(&input);
560
561        // Both should be valid
562        for &y in &out1 {
563            assert!(y.is_finite(), "block_size=1 output should be finite");
564        }
565        for &y in &out2 {
566            assert!(y.is_finite(), "block_size=2 output should be finite");
567        }
568
569        // They should differ because block_size=2 has cross-channel mixing
570        let diff: f64 = out1
571            .iter()
572            .zip(out2.iter())
573            .map(|(a, b)| (a - b) * (a - b))
574            .sum();
575        assert!(
576            diff > 1e-20,
577            "block_size=1 vs block_size=2 should produce different outputs due to cross-channel mixing: diff={}",
578            diff
579        );
580    }
581
582    #[test]
583    fn bd_state_bounded_under_constant_input() {
584        let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
585        let input = vec![1.0, -0.5, 0.3, -0.8];
586        for step in 0..1000 {
587            let output = ssm.forward(&input);
588            for (i, &y) in output.iter().enumerate() {
589                assert!(
590                    y.is_finite(),
591                    "output[{}] is not finite at step {}: {}",
592                    i,
593                    step,
594                    y
595                );
596            }
597        }
598        // Verify state has no NaN/Inf
599        for (i, &h) in ssm.state().iter().enumerate() {
600            assert!(
601                h.is_finite(),
602                "state[{}] is not finite after 1000 steps: {}",
603                i,
604                h
605            );
606        }
607        // Verify state norm is bounded (not exploding)
608        let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
609        assert!(
610            state_norm < 1e12,
611            "state norm should be bounded after 1000 constant-input steps, got {}",
612            state_norm
613        );
614    }
615
616    #[test]
617    fn reinitialize_block_preserves_others() {
618        // 6 channels, 4 state dims, block_size=2 → 3 blocks
619        let mut ssm = SelectiveSSMBD::new(6, 4, 2, 42);
620
621        // Forward 10 steps to build up state
622        for step in 0..10 {
623            let s = step as f64;
624            let x = vec![s * 0.1, s * -0.2, s * 0.3, s * -0.1, s * 0.2, s * -0.3];
625            let _ = ssm.forward(&x);
626        }
627
628        // Snapshot state and A matrices for blocks 0 and 2
629        let state_before: Vec<f64> = ssm.state().to_vec();
630        let a_before: Vec<f64> = ssm.a_matrices.clone();
631        let n_state = ssm.n_state();
632        let m = ssm.block_size();
633
634        // Reinitialize block 1
635        let mut rng = 0xBEEF_u64;
636        ssm.reinitialize_block(1, &mut rng);
637
638        // Block 0 state unchanged
639        let b0_start = 0;
640        let b0_end = n_state * m;
641        for (i, &sb) in state_before.iter().enumerate().take(b0_end).skip(b0_start) {
642            assert!(
643                math::abs(ssm.h[i] - sb) < 1e-15,
644                "block 0 state[{}] should be preserved after reinit of block 1",
645                i
646            );
647        }
648
649        // Block 2 state unchanged
650        let b2_start = 2 * n_state * m;
651        let b2_end = 3 * n_state * m;
652        for (i, &sb) in state_before.iter().enumerate().take(b2_end).skip(b2_start) {
653            assert!(
654                math::abs(ssm.h[i] - sb) < 1e-15,
655                "block 2 state[{}] should be preserved after reinit of block 1",
656                i
657            );
658        }
659
660        // Block 1 state zeroed
661        let b1_start = n_state * m;
662        let b1_end = 2 * n_state * m;
663        for i in b1_start..b1_end {
664            assert!(
665                math::abs(ssm.h[i]) < 1e-15,
666                "block 1 state[{}] should be zero after reinit, got {}",
667                i,
668                ssm.h[i]
669            );
670        }
671
672        // Block 0 A matrix unchanged
673        let a0_start = 0;
674        let a0_end = m * m;
675        for (i, &ab) in a_before.iter().enumerate().take(a0_end).skip(a0_start) {
676            assert!(
677                math::abs(ssm.a_matrices[i] - ab) < 1e-15,
678                "block 0 A[{}] should be preserved",
679                i
680            );
681        }
682
683        // Block 2 A matrix unchanged
684        let a2_start = 2 * m * m;
685        let a2_end = 3 * m * m;
686        for (i, &ab) in a_before.iter().enumerate().take(a2_end).skip(a2_start) {
687            assert!(
688                math::abs(ssm.a_matrices[i] - ab) < 1e-15,
689                "block 2 A[{}] should be preserved",
690                i
691            );
692        }
693
694        // Block 1 A matrix should have changed (reinitialised)
695        let a1_start = m * m;
696        let a1_end = 2 * m * m;
697        let mut any_a_diff = false;
698        for (i, &ab) in a_before.iter().enumerate().take(a1_end).skip(a1_start) {
699            if math::abs(ssm.a_matrices[i] - ab) > 1e-15 {
700                any_a_diff = true;
701                break;
702            }
703        }
704        assert!(any_a_diff, "block 1 A matrix should differ after reinit");
705
706        // d_skip for block 1 channels (indices 2, 3) should be 1.0
707        assert!(
708            math::abs(ssm.d_skip[2] - 1.0) < 1e-15,
709            "d_skip[2] should be 1.0 after block 1 reinit"
710        );
711        assert!(
712            math::abs(ssm.d_skip[3] - 1.0) < 1e-15,
713            "d_skip[3] should be 1.0 after block 1 reinit"
714        );
715    }
716
717    #[test]
718    fn bd_block_sizes_produce_different_outputs() {
719        // block_size=2 vs block_size=4 should produce different outputs
720        // because the A block structure differs
721        let d_in = 8;
722        let n_state = 4;
723        let seed = 42;
724
725        let mut ssm_bs2 = SelectiveSSMBD::new(d_in, n_state, 2, seed);
726        let mut ssm_bs4 = SelectiveSSMBD::new(d_in, n_state, 4, seed);
727
728        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
729
730        // Run a few steps
731        for _ in 0..5 {
732            let _ = ssm_bs2.forward(&input);
733            let _ = ssm_bs4.forward(&input);
734        }
735
736        let out_bs2 = ssm_bs2.forward(&input);
737        let out_bs4 = ssm_bs4.forward(&input);
738
739        assert_eq!(out_bs2.len(), d_in);
740        assert_eq!(out_bs4.len(), d_in);
741
742        for &y in &out_bs2 {
743            assert!(y.is_finite(), "block_size=2 output should be finite");
744        }
745        for &y in &out_bs4 {
746            assert!(y.is_finite(), "block_size=4 output should be finite");
747        }
748
749        let diff: f64 = out_bs2
750            .iter()
751            .zip(out_bs4.iter())
752            .map(|(a, b)| (a - b) * (a - b))
753            .sum();
754        assert!(
755            diff > 1e-20,
756            "block_size=2 vs block_size=4 should produce different outputs: diff={}",
757            diff
758        );
759    }
760}