1use alloc::vec;
46use alloc::vec::Vec;
47
48use crate::math;
49use crate::rng::standard_normal;
50use crate::ssm::init::s4d_inv_real;
51use crate::ssm::projection::{dot, mat_vec, softplus, Xorshift64};
52use crate::ssm::SSMLayer;
53
54pub struct SelectiveSSM {
87 log_a: Vec<f64>,
89 w_delta: Vec<f64>,
91 b_delta: f64,
93 w_b: Vec<f64>,
95 w_c: Vec<f64>,
97 d_skip: Vec<f64>,
99 h: Vec<f64>,
101 n_state: usize,
103 d_in: usize,
105}
106
107impl SelectiveSSM {
108 pub fn new(d_in: usize, n_state: usize, seed: u64) -> Self {
131 let log_a = s4d_inv_real(n_state);
132 let mut rng = Xorshift64(seed);
133 let scale = 0.1;
134
135 let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
137 let b_delta = 0.0;
138 let w_b: Vec<f64> = (0..n_state * d_in)
139 .map(|_| rng.next_normal() * scale)
140 .collect();
141 let w_c: Vec<f64> = (0..n_state * d_in)
142 .map(|_| rng.next_normal() * scale)
143 .collect();
144 let d_skip = vec![1.0; d_in];
145 let h = vec![0.0; d_in * n_state];
146
147 Self {
148 log_a,
149 w_delta,
150 b_delta,
151 w_b,
152 w_c,
153 d_skip,
154 h,
155 n_state,
156 d_in,
157 }
158 }
159
160 #[inline]
162 pub fn d_in(&self) -> usize {
163 self.d_in
164 }
165
166 #[inline]
168 pub fn n_state(&self) -> usize {
169 self.n_state
170 }
171
172 pub fn reinitialize_channel(&mut self, d: usize, rng: &mut u64) {
188 assert!(
189 d < self.d_in,
190 "channel index {} out of range (d_in={})",
191 d,
192 self.d_in
193 );
194
195 let scale = 0.1;
196
197 for n in 0..self.n_state {
199 self.h[n * self.d_in + d] = 0.0;
200 }
201
202 self.w_delta[d] = standard_normal(rng) * scale;
204
205 for n in 0..self.n_state {
207 self.w_b[n * self.d_in + d] = standard_normal(rng) * scale;
208 }
209
210 for n in 0..self.n_state {
212 self.w_c[n * self.d_in + d] = standard_normal(rng) * scale;
213 }
214
215 self.d_skip[d] = 1.0;
217 }
218
219 fn selective_forward(&mut self, input: &[f64]) -> Vec<f64> {
224 let d_in = self.d_in;
225 let n_state = self.n_state;
226
227 let delta_raw = dot(&self.w_delta, input) + self.b_delta;
229 let delta = softplus(delta_raw);
230
231 let mut b_t = vec![0.0; n_state];
233 mat_vec(&self.w_b, input, n_state, d_in, &mut b_t);
234
235 let mut c_t = vec![0.0; n_state];
237 mat_vec(&self.w_c, input, n_state, d_in, &mut c_t);
238
239 let mut a_bar_vec = vec![0.0; n_state];
243 let mut b_bar_vec = vec![0.0; n_state];
244 for n in 0..n_state {
245 let a_n = -math::exp(self.log_a[n]); let ab = math::exp(delta * a_n); a_bar_vec[n] = ab;
248 b_bar_vec[n] = if math::abs(a_n) < 1e-12 {
249 delta * b_t[n]
250 } else {
251 (ab - 1.0) / a_n * b_t[n]
252 };
253 }
254
255 for n in 0..n_state {
260 let h_offset = n * d_in;
261 let a = a_bar_vec[n];
262 let b = b_bar_vec[n];
263 for (d, x_d) in input.iter().enumerate().take(d_in) {
264 self.h[h_offset + d] = a * self.h[h_offset + d] + b * x_d;
265 }
266 }
267
268 let mut output = vec![0.0; d_in];
270 for (n, &c_n) in c_t.iter().enumerate().take(n_state) {
271 let h_offset = n * d_in;
272 for (d, out_d) in output.iter_mut().enumerate().take(d_in) {
273 *out_d += c_n * self.h[h_offset + d];
274 }
275 }
276
277 for (out_d, (&skip, &x_d)) in output.iter_mut().zip(self.d_skip.iter().zip(input.iter())) {
279 *out_d += skip * x_d;
280 }
281
282 output
283 }
284}
285
286impl SSMLayer for SelectiveSSM {
287 fn forward(&mut self, input: &[f64]) -> Vec<f64> {
288 debug_assert_eq!(
289 input.len(),
290 self.d_in,
291 "input length {} must match d_in {}",
292 input.len(),
293 self.d_in
294 );
295 self.selective_forward(input)
296 }
297
298 fn state(&self) -> &[f64] {
299 &self.h
300 }
301
302 fn output_dim(&self) -> usize {
303 self.d_in
304 }
305
306 fn reset(&mut self) {
307 for h in self.h.iter_mut() {
308 *h = 0.0;
309 }
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn new_creates_correct_dimensions() {
319 let ssm = SelectiveSSM::new(4, 8, 42);
320 assert_eq!(ssm.d_in(), 4);
321 assert_eq!(ssm.n_state(), 8);
322 assert_eq!(ssm.state().len(), 4 * 8);
323 assert_eq!(ssm.output_dim(), 4);
324 }
325
326 #[test]
327 fn initial_state_is_zero() {
328 let ssm = SelectiveSSM::new(3, 16, 42);
329 for &h in ssm.state() {
330 assert!(math::abs(h) < 1e-15, "initial state should be zero");
331 }
332 }
333
334 #[test]
335 fn forward_produces_correct_output_dim() {
336 let mut ssm = SelectiveSSM::new(5, 8, 42);
337 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
338 let output = ssm.forward(&input);
339 assert_eq!(output.len(), 5, "output dim should match d_in");
340 }
341
342 #[test]
343 fn forward_produces_finite_output() {
344 let mut ssm = SelectiveSSM::new(3, 8, 42);
345 let input = vec![1.0, -1.0, 0.5];
346 let output = ssm.forward(&input);
347 for (i, &y) in output.iter().enumerate() {
348 assert!(y.is_finite(), "output[{}] should be finite, got {}", i, y);
349 }
350 }
351
352 #[test]
353 fn forward_updates_state() {
354 let mut ssm = SelectiveSSM::new(3, 8, 42);
355 let input = vec![1.0, 2.0, 3.0];
356 let _ = ssm.forward(&input);
357 let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
358 assert!(
359 state_norm > 0.0,
360 "state should be non-zero after processing non-zero input"
361 );
362 }
363
364 #[test]
365 fn reset_clears_state() {
366 let mut ssm = SelectiveSSM::new(3, 8, 42);
367 let _ = ssm.forward(&[1.0, 2.0, 3.0]);
368 ssm.reset();
369 for &h in ssm.state() {
370 assert!(math::abs(h) < 1e-15, "state should be zero after reset");
371 }
372 }
373
374 #[test]
375 fn state_decays_without_input() {
376 let mut ssm = SelectiveSSM::new(2, 4, 42);
377 let _ = ssm.forward(&[10.0, 10.0]);
379 let energy_after: f64 = ssm.state().iter().map(|h| h * h).sum();
380
381 for _ in 0..200 {
383 let _ = ssm.forward(&[0.0, 0.0]);
384 }
385 let energy_decayed: f64 = ssm.state().iter().map(|h| h * h).sum();
386 assert!(
387 energy_decayed < energy_after * 0.01,
388 "state should decay with zero input: initial={}, after={}",
389 energy_after,
390 energy_decayed
391 );
392 }
393
394 #[test]
395 fn deterministic_with_same_seed() {
396 let mut ssm1 = SelectiveSSM::new(3, 8, 42);
397 let mut ssm2 = SelectiveSSM::new(3, 8, 42);
398 let input = vec![1.0, 2.0, 3.0];
399 let out1 = ssm1.forward(&input);
400 let out2 = ssm2.forward(&input);
401 for (i, (&a, &b)) in out1.iter().zip(out2.iter()).enumerate() {
402 assert!(
403 math::abs(a - b) < 1e-15,
404 "output[{}] should be identical for same seed: {} vs {}",
405 i,
406 a,
407 b
408 );
409 }
410 }
411
412 #[test]
413 fn different_seeds_produce_different_outputs() {
414 let mut ssm1 = SelectiveSSM::new(3, 8, 42);
415 let mut ssm2 = SelectiveSSM::new(3, 8, 99);
416 let input = vec![1.0, 2.0, 3.0];
417 let out1 = ssm1.forward(&input);
418 let out2 = ssm2.forward(&input);
419 let diff: f64 = out1
420 .iter()
421 .zip(out2.iter())
422 .map(|(a, b)| (a - b) * (a - b))
423 .sum();
424 assert!(
425 diff > 1e-20,
426 "different seeds should generally produce different outputs"
427 );
428 }
429
430 #[test]
431 fn single_channel_works() {
432 let mut ssm = SelectiveSSM::new(1, 4, 42);
433 let output = ssm.forward(&[3.0]);
434 assert_eq!(output.len(), 1);
435 assert!(output[0].is_finite());
436 }
437
438 #[test]
439 fn single_state_dim_works() {
440 let mut ssm = SelectiveSSM::new(3, 1, 42);
441 let output = ssm.forward(&[1.0, 2.0, 3.0]);
442 assert_eq!(output.len(), 3);
443 for &y in &output {
444 assert!(y.is_finite());
445 }
446 }
447
448 #[test]
449 fn sequential_outputs_differ() {
450 let mut ssm = SelectiveSSM::new(2, 4, 42);
451 let out1 = ssm.forward(&[1.0, 0.0]);
452 let out2 = ssm.forward(&[1.0, 0.0]);
453 let diff: f64 = out1
455 .iter()
456 .zip(out2.iter())
457 .map(|(a, b)| (a - b) * (a - b))
458 .sum();
459 assert!(
460 diff > 1e-20,
461 "sequential calls with same input should differ due to state: out1={:?}, out2={:?}",
462 out1,
463 out2
464 );
465 }
466
467 #[test]
468 fn large_input_no_overflow() {
469 let mut ssm = SelectiveSSM::new(2, 4, 42);
470 let input = vec![1000.0, -1000.0];
471 let output = ssm.forward(&input);
472 for (i, &y) in output.iter().enumerate() {
473 assert!(
474 y.is_finite(),
475 "output[{}] should be finite for large inputs, got {}",
476 i,
477 y
478 );
479 }
480 }
481
482 #[test]
483 fn zero_input_zero_state_gives_zero_output() {
484 let mut ssm = SelectiveSSM::new(3, 8, 42);
485 let output = ssm.forward(&[0.0, 0.0, 0.0]);
486 for (i, &y) in output.iter().enumerate() {
487 assert!(
488 math::abs(y) < 1e-15,
489 "zero input with zero state should give zero output[{}], got {}",
490 i,
491 y
492 );
493 }
494 }
495
496 #[test]
497 fn reinitialize_channel_preserves_others() {
498 let mut ssm = SelectiveSSM::new(3, 8, 42);
499
500 for step in 0..10 {
502 let x = vec![
503 (step as f64) * 0.3,
504 (step as f64) * -0.2,
505 (step as f64) * 0.1,
506 ];
507 let _ = ssm.forward(&x);
508 }
509
510 let state_before: Vec<f64> = ssm.state().to_vec();
512 let w_delta_0 = ssm.w_delta[0];
513 let w_delta_2 = ssm.w_delta[2];
514
515 let wb_col0: Vec<f64> = (0..ssm.n_state).map(|n| ssm.w_b[n * ssm.d_in]).collect();
516 let wb_col2: Vec<f64> = (0..ssm.n_state)
517 .map(|n| ssm.w_b[n * ssm.d_in + 2])
518 .collect();
519 let wc_col0: Vec<f64> = (0..ssm.n_state).map(|n| ssm.w_c[n * ssm.d_in]).collect();
520 let wc_col2: Vec<f64> = (0..ssm.n_state)
521 .map(|n| ssm.w_c[n * ssm.d_in + 2])
522 .collect();
523
524 let mut rng = 0xBEEF_u64;
526 ssm.reinitialize_channel(1, &mut rng);
527
528 for n in 0..ssm.n_state {
530 let idx = n * ssm.d_in;
531 assert!(
532 math::abs(ssm.h[idx] - state_before[idx]) < 1e-15,
533 "channel 0 state[{}] should be preserved after reinit of channel 1",
534 n
535 );
536 }
537
538 for n in 0..ssm.n_state {
540 let idx = n * ssm.d_in + 2;
541 assert!(
542 math::abs(ssm.h[idx] - state_before[idx]) < 1e-15,
543 "channel 2 state[{}] should be preserved after reinit of channel 1",
544 n
545 );
546 }
547
548 for n in 0..ssm.n_state {
550 let idx = n * ssm.d_in + 1;
551 assert!(
552 math::abs(ssm.h[idx]) < 1e-15,
553 "channel 1 state[{}] should be zeroed after reinit, got {}",
554 n,
555 ssm.h[idx]
556 );
557 }
558
559 assert!(
561 math::abs(ssm.w_delta[0] - w_delta_0) < 1e-15,
562 "w_delta[0] should be preserved"
563 );
564 assert!(
565 math::abs(ssm.w_delta[2] - w_delta_2) < 1e-15,
566 "w_delta[2] should be preserved"
567 );
568 for n in 0..ssm.n_state {
569 assert!(
570 math::abs(ssm.w_b[n * ssm.d_in] - wb_col0[n]) < 1e-15,
571 "w_b col 0 row {} should be preserved",
572 n
573 );
574 assert!(
575 math::abs(ssm.w_b[n * ssm.d_in + 2] - wb_col2[n]) < 1e-15,
576 "w_b col 2 row {} should be preserved",
577 n
578 );
579 assert!(
580 math::abs(ssm.w_c[n * ssm.d_in] - wc_col0[n]) < 1e-15,
581 "w_c col 0 row {} should be preserved",
582 n
583 );
584 assert!(
585 math::abs(ssm.w_c[n * ssm.d_in + 2] - wc_col2[n]) < 1e-15,
586 "w_c col 2 row {} should be preserved",
587 n
588 );
589 }
590
591 let mut any_wb_diff = false;
593 for n in 0..ssm.n_state {
594 if math::abs(ssm.w_b[n * ssm.d_in + 1]) > 1e-15 {
595 any_wb_diff = true;
596 }
597 }
598 assert!(
599 any_wb_diff,
600 "reinitialised channel 1 w_b should have non-zero weights"
601 );
602
603 assert!(
605 math::abs(ssm.d_skip[1] - 1.0) < 1e-15,
606 "d_skip[1] should be reset to 1.0 after reinit, got {}",
607 ssm.d_skip[1]
608 );
609 }
610}