irithyll_core/ssm/selective_bd.rs
1//! Block-Diagonal Linear Recurrent Unit (BD-LRU) selective state space model.
2//!
3//! [`SelectiveSSMBD`] implements a block-diagonal SSM variant inspired by
4//! Dubinin et al. (2026), where input channels are partitioned into blocks and
5//! each block has a dense A matrix enabling cross-channel state mixing within
6//! the block. This sits between the fully diagonal Mamba-1 (no cross-channel
7//! mixing) and a full dense SSM (quadratic cost).
8//!
9//! # Architecture
10//!
11//! For each input timestep `x_t` (a d_in-dimensional vector):
12//!
13//! ```text
14//! Delta_t = softplus(W_delta * x_t + b_delta) // scalar step size
15//! B_t = W_B * x_t // N-dim input projection
16//! C_t = W_C * x_t // N-dim output projection
17//!
18//! For each block k (channels k*m .. (k+1)*m, m = block_size):
19//! x_block = x_t[k*m .. (k+1)*m]
20//! For each state dim n in 0..N:
21//! // Euler discretization with row-L1-normalized dense A:
22//! // A_disc[i,j] = delta * A[i,j] for i != j
23//! // A_disc[i,i] = 1 + delta * A[i,i]
24//! h_block[n] = A_disc * h_block[n] + delta * B_t[n] * x_block
25//!
26//! // Output: weighted sum over state dims
27//! y_block = sum_n C_t[n] * h_block[n]
28//!
29//! output[d] = y_block[d_within_block] + D[d] * x_t[d]
30//! ```
31//!
32//! # Block-Diagonal vs Diagonal
33//!
34//! The key differentiator from [`SelectiveSSM`](crate::ssm::SelectiveSSM) is
35//! that channels within a block can influence each other's state evolution
36//! through the off-diagonal entries of the block's A matrix. With `block_size=1`,
37//! this reduces to a per-channel diagonal recurrence (equivalent to Mamba-1).
38//! Larger block sizes enable richer cross-channel dynamics at O(m^2) cost per
39//! block instead of O(d_in^2) for a full dense A.
40//!
41//! # Stability
42//!
43//! Each block's A matrix is row-wise L1-normalized so that the sum of absolute
44//! values in each row is at most 1.0. Combined with Euler discretization
45//! (`I + Delta * A`), this ensures the discretized transition matrix has
46//! bounded spectral radius for small Delta, preventing state explosion.
47
48use alloc::vec;
49use alloc::vec::Vec;
50
51use crate::math;
52use crate::rng::standard_normal;
53use crate::ssm::init::s4d_inv_real;
54use crate::ssm::projection::{dot, mat_vec, softplus, Xorshift64};
55use crate::ssm::SSMLayer;
56
57/// Block-Diagonal Linear Recurrent Unit selective state space model.
58///
59/// Partitions `d_in` channels into `n_blocks = d_in / block_size` blocks, each
60/// with a dense `block_size x block_size` A matrix for within-block
61/// cross-channel state mixing. B, C, and Delta projections are shared across
62/// blocks (same structure as Mamba-1).
63///
64/// # Dimensions
65///
66/// - `d_in` -- input/output dimension (number of channels)
67/// - `n_state` -- hidden state dimension per block-channel (N)
68/// - `block_size` -- number of channels per block (m)
69/// - `n_blocks` -- number of blocks (d_in / block_size)
70/// - Total hidden state size: `n_blocks * n_state * block_size`
71///
72/// # Weight Shapes
73///
74/// | Weight | Shape | Purpose |
75/// |--------|-------|---------|
76/// | `a_matrices` | n_blocks * m * m | Dense A per block (row-major, L1-normalized) |
77/// | `w_b` | N x d_in | Projects input to state-input coupling |
78/// | `w_c` | N x d_in | Projects input to state-output coupling |
79/// | `w_delta` | d_in | Projects input to scalar step size |
80/// | `d_skip` | d_in | Skip connection weights |
81///
82/// # Example
83///
84/// ```
85/// use irithyll_core::ssm::selective_bd::SelectiveSSMBD;
86/// use irithyll_core::ssm::SSMLayer;
87///
88/// let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
89/// let output = ssm.forward(&[1.0, 2.0, 3.0, 4.0]);
90/// assert_eq!(output.len(), 4);
91/// ```
92pub struct SelectiveSSMBD {
93 /// Per-block A matrices: n_blocks * block_size * block_size, row-major per block.
94 /// Each block's m x m matrix is contiguous, L1-row-normalized for stability.
95 a_matrices: Vec<f64>,
96 /// B projection weights (n_state x d_in, row-major). Maps input to B_t.
97 w_b: Vec<f64>,
98 /// C projection weights (n_state x d_in, row-major). Maps input to C_t.
99 w_c: Vec<f64>,
100 /// Delta projection weights (d_in). Maps input to scalar step size.
101 w_delta: Vec<f64>,
102 /// Delta projection bias.
103 b_delta: f64,
104 /// Skip connection weights (d_in).
105 d_skip: Vec<f64>,
106 /// Hidden state: n_blocks * n_state * block_size.
107 /// Layout: h[block * n_state * block_size + state_dim * block_size + channel_within_block]
108 h: Vec<f64>,
109 /// Input/output dimension.
110 d_in: usize,
111 /// Number of state dimensions per block-channel.
112 n_state: usize,
113 /// Number of channels per block.
114 block_size: usize,
115 /// Number of blocks (d_in / block_size).
116 n_blocks: usize,
117}
118
119/// Normalize each row of an m x m matrix in-place so that the L1 norm
120/// (sum of absolute values) of each row is at most 1.0.
121///
122/// Rows with L1 norm already <= 1.0 are left unchanged.
123fn normalize_row_l1(a: &mut [f64], m: usize) {
124 for row in 0..m {
125 let start = row * m;
126 let row_sum: f64 = a[start..start + m].iter().map(|x| math::abs(*x)).sum();
127 if row_sum > 1.0 {
128 for j in 0..m {
129 a[start + j] /= row_sum;
130 }
131 }
132 }
133}
134
135impl SelectiveSSMBD {
136 /// Create a new block-diagonal selective SSM with random weight initialization.
137 ///
138 /// A matrices are initialized with S4D-Inv diagonal values and small random
139 /// off-diagonal entries (scale 0.02), then row-wise L1-normalized. Projection
140 /// weights are initialized from a small normal distribution (scale 0.1).
141 /// Skip connections (D) are initialized to 1.0 for input passthrough.
142 ///
143 /// # Arguments
144 ///
145 /// * `d_in` -- input/output dimension (must be divisible by `block_size`)
146 /// * `n_state` -- hidden state dimension per block-channel (N)
147 /// * `block_size` -- number of channels per block (m)
148 /// * `seed` -- random seed for weight initialization
149 ///
150 /// # Panics
151 ///
152 /// Panics if `d_in` is not evenly divisible by `block_size`.
153 ///
154 /// # Example
155 ///
156 /// ```
157 /// use irithyll_core::ssm::selective_bd::SelectiveSSMBD;
158 ///
159 /// let ssm = SelectiveSSMBD::new(6, 8, 2, 42);
160 /// ```
161 pub fn new(d_in: usize, n_state: usize, block_size: usize, seed: u64) -> Self {
162 assert!(
163 d_in % block_size == 0,
164 "d_in ({}) must be evenly divisible by block_size ({})",
165 d_in,
166 block_size
167 );
168
169 let n_blocks = d_in / block_size;
170 let m = block_size;
171 let mut rng = Xorshift64(seed);
172 let scale = 0.1;
173 let off_diag_scale = 0.02;
174
175 // Initialize A matrices: S4D-Inv diagonal + small random off-diagonal
176 let log_a = s4d_inv_real(m);
177 let mut a_matrices = vec![0.0; n_blocks * m * m];
178
179 for blk in 0..n_blocks {
180 let base = blk * m * m;
181 // Fill with small random off-diagonal values
182 for i in 0..m {
183 for j in 0..m {
184 if i == j {
185 // Diagonal: negative S4D-Inv values
186 // A_i = -(0.5 + i/m), use directly (not log-space here)
187 a_matrices[base + i * m + j] = -math::exp(log_a[i]);
188 } else {
189 // Off-diagonal: small random normal
190 a_matrices[base + i * m + j] = rng.next_normal() * off_diag_scale;
191 }
192 }
193 }
194 // Apply row-wise L1 normalization for stability
195 normalize_row_l1(&mut a_matrices[base..base + m * m], m);
196 }
197
198 // Initialize projection weights from small normal distribution
199 let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
200 let b_delta = 0.0;
201 let w_b: Vec<f64> = (0..n_state * d_in)
202 .map(|_| rng.next_normal() * scale)
203 .collect();
204 let w_c: Vec<f64> = (0..n_state * d_in)
205 .map(|_| rng.next_normal() * scale)
206 .collect();
207 let d_skip = vec![1.0; d_in];
208 let h = vec![0.0; n_blocks * n_state * block_size];
209
210 Self {
211 a_matrices,
212 w_b,
213 w_c,
214 w_delta,
215 b_delta,
216 d_skip,
217 h,
218 d_in,
219 n_state,
220 block_size,
221 n_blocks,
222 }
223 }
224
225 /// Get the input/output dimension.
226 #[inline]
227 pub fn d_in(&self) -> usize {
228 self.d_in
229 }
230
231 /// Get the number of state dimensions per block-channel.
232 #[inline]
233 pub fn n_state(&self) -> usize {
234 self.n_state
235 }
236
237 /// Get the number of channels per block.
238 #[inline]
239 pub fn block_size(&self) -> usize {
240 self.block_size
241 }
242
243 /// Get the number of blocks.
244 #[inline]
245 pub fn n_blocks(&self) -> usize {
246 self.n_blocks
247 }
248
249 /// Surgically reinitialize a single block, preserving all other blocks.
250 ///
251 /// Resets block `b`'s hidden state to zero, reinitializes its A matrix
252 /// with S4D diagonal + small random off-diagonal values (then L1 row-
253 /// normalizes), and resets the skip connections for the block's channels
254 /// to 1.0. All other blocks are left untouched.
255 ///
256 /// # Arguments
257 ///
258 /// * `b` — block index to reinitialize (must be < `n_blocks`)
259 /// * `rng` — mutable RNG state for generating fresh weights
260 ///
261 /// # Panics
262 ///
263 /// Panics if `b >= n_blocks`.
264 pub fn reinitialize_block(&mut self, b: usize, rng: &mut u64) {
265 assert!(
266 b < self.n_blocks,
267 "block index {} out of range (n_blocks={})",
268 b,
269 self.n_blocks
270 );
271
272 let m = self.block_size;
273 let off_diag_scale = 0.02;
274
275 // Zero state: h[b * n_state * block_size .. (b+1) * n_state * block_size]
276 let h_start = b * self.n_state * m;
277 let h_end = h_start + self.n_state * m;
278 for h in self.h[h_start..h_end].iter_mut() {
279 *h = 0.0;
280 }
281
282 // Reinit A matrix for block b: S4D diagonal + small random off-diagonal
283 let log_a = s4d_inv_real(m);
284 let a_base = b * m * m;
285 for (i, &la_i) in log_a.iter().enumerate().take(m) {
286 for j in 0..m {
287 if i == j {
288 self.a_matrices[a_base + i * m + j] = -math::exp(la_i);
289 } else {
290 self.a_matrices[a_base + i * m + j] = standard_normal(rng) * off_diag_scale;
291 }
292 }
293 }
294 // Apply row-wise L1 normalization for stability
295 normalize_row_l1(&mut self.a_matrices[a_base..a_base + m * m], m);
296
297 // Reset d_skip for channels in this block to default passthrough
298 let ch_start = b * m;
299 for d in ch_start..ch_start + m {
300 self.d_skip[d] = 1.0;
301 }
302 }
303
304 /// Compute the block-diagonal SSM forward pass for one timestep.
305 ///
306 /// This is the core BD-LRU recurrence: compute input-dependent Delta, B, C,
307 /// then for each block apply the dense A state update with Euler
308 /// discretization and accumulate the output.
309 fn bd_forward(&mut self, input: &[f64]) -> Vec<f64> {
310 let d_in = self.d_in;
311 let n_state = self.n_state;
312 let m = self.block_size;
313 let n_blocks = self.n_blocks;
314
315 // 1. Compute delta = softplus(dot(w_delta, input) + b_delta).
316 // Clamp to 1.0: the Euler discretization (I + delta*A) is only
317 // stable for small delta because A diagonal entries are negative
318 // (S4D-Inv). For large delta the term (1 + delta*A[i,i]) goes
319 // strongly negative, causing exponential state divergence on
320 // datasets with large-magnitude features (e.g. Power Plant).
321 // ZOH (exp(delta*A)) is unconditionally stable but more expensive;
322 // clamping delta is the minimal fix that preserves the architecture.
323 let delta_raw = dot(&self.w_delta, input) + self.b_delta;
324 let delta = softplus(delta_raw).min(1.0);
325
326 // 2. Compute B_t = W_B * input (shape: n_state)
327 let mut b_t = vec![0.0; n_state];
328 mat_vec(&self.w_b, input, n_state, d_in, &mut b_t);
329
330 // 3. Compute C_t = W_C * input (shape: n_state)
331 let mut c_t = vec![0.0; n_state];
332 mat_vec(&self.w_c, input, n_state, d_in, &mut c_t);
333
334 // 4. For each block: apply dense A state update
335 let mut output = vec![0.0; d_in];
336
337 for blk in 0..n_blocks {
338 let a_base = blk * m * m;
339 let x_start = blk * m;
340 let h_block_base = blk * n_state * m;
341
342 for (n, &b_n) in b_t.iter().enumerate().take(n_state) {
343 let h_offset = h_block_base + n * m;
344
345 // Apply block state update with Euler discretization:
346 // h_new[i] = sum_j(A_disc[i,j] * h_old[j]) + delta * B_t[n] * x_block[i]
347 // where A_disc[i,j] = delta * A[i,j] for i != j
348 // A_disc[i,i] = 1 + delta * A[i,i]
349 //
350 // We compute h_new into a temp buffer to avoid reading stale values.
351 let db = delta * b_n;
352
353 // Temporary buffer for new state (avoid allocation for small blocks
354 // by using a stack array would be nice, but we need Vec for generality)
355 let mut h_new = vec![0.0; m];
356
357 for i in 0..m {
358 let a_row = a_base + i * m;
359 let mut sum = 0.0;
360 for j in 0..m {
361 let a_disc = if i == j {
362 1.0 + delta * self.a_matrices[a_row + j]
363 } else {
364 delta * self.a_matrices[a_row + j]
365 };
366 sum += a_disc * self.h[h_offset + j];
367 }
368 // Input injection: delta * B_t[n] * x_block[i]
369 h_new[i] = sum + db * input[x_start + i];
370 }
371
372 // Write back new state
373 self.h[h_offset..h_offset + m].copy_from_slice(&h_new);
374 }
375
376 // 5. Output accumulation: y_block[i] = sum_n C_t[n] * h[block, n, i]
377 for (n, &c_n) in c_t.iter().enumerate().take(n_state) {
378 let h_offset = h_block_base + n * m;
379 for i in 0..m {
380 output[x_start + i] += c_n * self.h[h_offset + i];
381 }
382 }
383 }
384
385 // 6. Add skip connection: output[d] += D[d] * input[d]
386 for (out_d, (&skip, &x_d)) in output.iter_mut().zip(self.d_skip.iter().zip(input.iter())) {
387 *out_d += skip * x_d;
388 }
389
390 output
391 }
392}
393
394impl SSMLayer for SelectiveSSMBD {
395 fn forward(&mut self, input: &[f64]) -> Vec<f64> {
396 debug_assert_eq!(
397 input.len(),
398 self.d_in,
399 "input length {} must match d_in {}",
400 input.len(),
401 self.d_in
402 );
403 self.bd_forward(input)
404 }
405
406 fn state(&self) -> &[f64] {
407 &self.h
408 }
409
410 fn output_dim(&self) -> usize {
411 self.d_in
412 }
413
414 fn reset(&mut self) {
415 for h in self.h.iter_mut() {
416 *h = 0.0;
417 }
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn bd_new_correct_dimensions() {
427 let ssm = SelectiveSSMBD::new(6, 8, 2, 42);
428 assert_eq!(ssm.d_in(), 6);
429 assert_eq!(ssm.n_state(), 8);
430 assert_eq!(ssm.block_size(), 2);
431 assert_eq!(ssm.n_blocks(), 3);
432 assert_eq!(
433 ssm.state().len(),
434 3 * 8 * 2,
435 "state size = n_blocks * n_state * block_size"
436 );
437 assert_eq!(ssm.output_dim(), 6);
438 }
439
440 #[test]
441 fn bd_initial_state_zero() {
442 let ssm = SelectiveSSMBD::new(4, 8, 2, 42);
443 for &h in ssm.state() {
444 assert!(math::abs(h) < 1e-15, "initial state should be zero");
445 }
446 }
447
448 #[test]
449 fn bd_forward_correct_output_dim() {
450 let mut ssm = SelectiveSSMBD::new(6, 8, 3, 42);
451 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
452 let output = ssm.forward(&input);
453 assert_eq!(output.len(), 6, "output dim should match d_in");
454 }
455
456 #[test]
457 fn bd_forward_finite_output() {
458 let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
459 let input = vec![1.0, -1.0, 0.5, -0.5];
460 let output = ssm.forward(&input);
461 for (i, &y) in output.iter().enumerate() {
462 assert!(y.is_finite(), "output[{}] should be finite, got {}", i, y);
463 }
464 }
465
466 #[test]
467 fn bd_forward_updates_state() {
468 let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
469 let input = vec![1.0, 2.0, 3.0, 4.0];
470 let _ = ssm.forward(&input);
471 let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
472 assert!(
473 state_norm > 0.0,
474 "state should be non-zero after processing non-zero input"
475 );
476 }
477
478 #[test]
479 fn bd_reset_clears_state() {
480 let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
481 let _ = ssm.forward(&[1.0, 2.0, 3.0, 4.0]);
482 ssm.reset();
483 for &h in ssm.state() {
484 assert!(math::abs(h) < 1e-15, "state should be zero after reset");
485 }
486 }
487
488 #[test]
489 fn bd_deterministic_same_seed() {
490 let mut ssm1 = SelectiveSSMBD::new(4, 8, 2, 42);
491 let mut ssm2 = SelectiveSSMBD::new(4, 8, 2, 42);
492 let input = vec![1.0, -1.0, 0.5, -0.5];
493 let out1 = ssm1.forward(&input);
494 let out2 = ssm2.forward(&input);
495 for (i, (&a, &b)) in out1.iter().zip(out2.iter()).enumerate() {
496 assert!(
497 math::abs(a - b) < 1e-15,
498 "output[{}] should be identical for same seed: {} vs {}",
499 i,
500 a,
501 b
502 );
503 }
504 }
505
506 #[test]
507 fn bd_different_seeds_differ() {
508 let mut ssm1 = SelectiveSSMBD::new(4, 8, 2, 42);
509 let mut ssm2 = SelectiveSSMBD::new(4, 8, 2, 99);
510 let input = vec![1.0, 2.0, 3.0, 4.0];
511 let out1 = ssm1.forward(&input);
512 let out2 = ssm2.forward(&input);
513 let diff: f64 = out1
514 .iter()
515 .zip(out2.iter())
516 .map(|(a, b)| (a - b) * (a - b))
517 .sum();
518 assert!(
519 diff > 1e-20,
520 "different seeds should generally produce different outputs"
521 );
522 }
523
524 #[test]
525 fn bd_zero_input_zero_state_zero_output() {
526 let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
527 let output = ssm.forward(&[0.0, 0.0, 0.0, 0.0]);
528 for (i, &y) in output.iter().enumerate() {
529 assert!(
530 math::abs(y) < 1e-15,
531 "zero input with zero state should give zero output[{}], got {}",
532 i,
533 y
534 );
535 }
536 }
537
538 #[test]
539 fn bd_cross_channel_mixing() {
540 // With block_size > 1, off-diagonal A entries cause cross-channel mixing
541 // within each block. With block_size=1, there are no off-diagonal entries,
542 // so channels evolve independently. Verify the two produce different results.
543 let d_in = 4;
544 let n_state = 4;
545 let seed = 42;
546
547 let mut ssm_blk1 = SelectiveSSMBD::new(d_in, n_state, 1, seed);
548 let mut ssm_blk2 = SelectiveSSMBD::new(d_in, n_state, 2, seed);
549
550 let input = vec![1.0, 2.0, 3.0, 4.0];
551
552 // Run a few steps to accumulate state differences from cross-channel mixing
553 for _ in 0..5 {
554 let _ = ssm_blk1.forward(&input);
555 let _ = ssm_blk2.forward(&input);
556 }
557
558 let out1 = ssm_blk1.forward(&input);
559 let out2 = ssm_blk2.forward(&input);
560
561 // Both should be valid
562 for &y in &out1 {
563 assert!(y.is_finite(), "block_size=1 output should be finite");
564 }
565 for &y in &out2 {
566 assert!(y.is_finite(), "block_size=2 output should be finite");
567 }
568
569 // They should differ because block_size=2 has cross-channel mixing
570 let diff: f64 = out1
571 .iter()
572 .zip(out2.iter())
573 .map(|(a, b)| (a - b) * (a - b))
574 .sum();
575 assert!(
576 diff > 1e-20,
577 "block_size=1 vs block_size=2 should produce different outputs due to cross-channel mixing: diff={}",
578 diff
579 );
580 }
581
582 #[test]
583 fn bd_state_bounded_under_constant_input() {
584 let mut ssm = SelectiveSSMBD::new(4, 8, 2, 42);
585 let input = vec![1.0, -0.5, 0.3, -0.8];
586 for step in 0..1000 {
587 let output = ssm.forward(&input);
588 for (i, &y) in output.iter().enumerate() {
589 assert!(
590 y.is_finite(),
591 "output[{}] is not finite at step {}: {}",
592 i,
593 step,
594 y
595 );
596 }
597 }
598 // Verify state has no NaN/Inf
599 for (i, &h) in ssm.state().iter().enumerate() {
600 assert!(
601 h.is_finite(),
602 "state[{}] is not finite after 1000 steps: {}",
603 i,
604 h
605 );
606 }
607 // Verify state norm is bounded (not exploding)
608 let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
609 assert!(
610 state_norm < 1e12,
611 "state norm should be bounded after 1000 constant-input steps, got {}",
612 state_norm
613 );
614 }
615
616 #[test]
617 fn reinitialize_block_preserves_others() {
618 // 6 channels, 4 state dims, block_size=2 → 3 blocks
619 let mut ssm = SelectiveSSMBD::new(6, 4, 2, 42);
620
621 // Forward 10 steps to build up state
622 for step in 0..10 {
623 let s = step as f64;
624 let x = vec![s * 0.1, s * -0.2, s * 0.3, s * -0.1, s * 0.2, s * -0.3];
625 let _ = ssm.forward(&x);
626 }
627
628 // Snapshot state and A matrices for blocks 0 and 2
629 let state_before: Vec<f64> = ssm.state().to_vec();
630 let a_before: Vec<f64> = ssm.a_matrices.clone();
631 let n_state = ssm.n_state();
632 let m = ssm.block_size();
633
634 // Reinitialize block 1
635 let mut rng = 0xBEEF_u64;
636 ssm.reinitialize_block(1, &mut rng);
637
638 // Block 0 state unchanged
639 let b0_start = 0;
640 let b0_end = n_state * m;
641 for (i, &sb) in state_before.iter().enumerate().take(b0_end).skip(b0_start) {
642 assert!(
643 math::abs(ssm.h[i] - sb) < 1e-15,
644 "block 0 state[{}] should be preserved after reinit of block 1",
645 i
646 );
647 }
648
649 // Block 2 state unchanged
650 let b2_start = 2 * n_state * m;
651 let b2_end = 3 * n_state * m;
652 for (i, &sb) in state_before.iter().enumerate().take(b2_end).skip(b2_start) {
653 assert!(
654 math::abs(ssm.h[i] - sb) < 1e-15,
655 "block 2 state[{}] should be preserved after reinit of block 1",
656 i
657 );
658 }
659
660 // Block 1 state zeroed
661 let b1_start = n_state * m;
662 let b1_end = 2 * n_state * m;
663 for i in b1_start..b1_end {
664 assert!(
665 math::abs(ssm.h[i]) < 1e-15,
666 "block 1 state[{}] should be zero after reinit, got {}",
667 i,
668 ssm.h[i]
669 );
670 }
671
672 // Block 0 A matrix unchanged
673 let a0_start = 0;
674 let a0_end = m * m;
675 for (i, &ab) in a_before.iter().enumerate().take(a0_end).skip(a0_start) {
676 assert!(
677 math::abs(ssm.a_matrices[i] - ab) < 1e-15,
678 "block 0 A[{}] should be preserved",
679 i
680 );
681 }
682
683 // Block 2 A matrix unchanged
684 let a2_start = 2 * m * m;
685 let a2_end = 3 * m * m;
686 for (i, &ab) in a_before.iter().enumerate().take(a2_end).skip(a2_start) {
687 assert!(
688 math::abs(ssm.a_matrices[i] - ab) < 1e-15,
689 "block 2 A[{}] should be preserved",
690 i
691 );
692 }
693
694 // Block 1 A matrix should have changed (reinitialised)
695 let a1_start = m * m;
696 let a1_end = 2 * m * m;
697 let mut any_a_diff = false;
698 for (i, &ab) in a_before.iter().enumerate().take(a1_end).skip(a1_start) {
699 if math::abs(ssm.a_matrices[i] - ab) > 1e-15 {
700 any_a_diff = true;
701 break;
702 }
703 }
704 assert!(any_a_diff, "block 1 A matrix should differ after reinit");
705
706 // d_skip for block 1 channels (indices 2, 3) should be 1.0
707 assert!(
708 math::abs(ssm.d_skip[2] - 1.0) < 1e-15,
709 "d_skip[2] should be 1.0 after block 1 reinit"
710 );
711 assert!(
712 math::abs(ssm.d_skip[3] - 1.0) < 1e-15,
713 "d_skip[3] should be 1.0 after block 1 reinit"
714 );
715 }
716
717 #[test]
718 fn bd_block_sizes_produce_different_outputs() {
719 // block_size=2 vs block_size=4 should produce different outputs
720 // because the A block structure differs
721 let d_in = 8;
722 let n_state = 4;
723 let seed = 42;
724
725 let mut ssm_bs2 = SelectiveSSMBD::new(d_in, n_state, 2, seed);
726 let mut ssm_bs4 = SelectiveSSMBD::new(d_in, n_state, 4, seed);
727
728 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
729
730 // Run a few steps
731 for _ in 0..5 {
732 let _ = ssm_bs2.forward(&input);
733 let _ = ssm_bs4.forward(&input);
734 }
735
736 let out_bs2 = ssm_bs2.forward(&input);
737 let out_bs4 = ssm_bs4.forward(&input);
738
739 assert_eq!(out_bs2.len(), d_in);
740 assert_eq!(out_bs4.len(), d_in);
741
742 for &y in &out_bs2 {
743 assert!(y.is_finite(), "block_size=2 output should be finite");
744 }
745 for &y in &out_bs4 {
746 assert!(y.is_finite(), "block_size=4 output should be finite");
747 }
748
749 let diff: f64 = out_bs2
750 .iter()
751 .zip(out_bs4.iter())
752 .map(|(a, b)| (a - b) * (a - b))
753 .sum();
754 assert!(
755 diff > 1e-20,
756 "block_size=2 vs block_size=4 should produce different outputs: diff={}",
757 diff
758 );
759 }
760}