1use alloc::vec;
31use alloc::vec::Vec;
32
33use crate::math;
34use crate::ssm::init::mamba_init;
35use crate::ssm::projection::{dot, mat_vec, softplus, Xorshift64};
36use crate::ssm::SSMLayer;
37
38pub struct SelectiveSSM {
71 log_a: Vec<f64>,
73 w_delta: Vec<f64>,
75 b_delta: f64,
77 w_b: Vec<f64>,
79 w_c: Vec<f64>,
81 d_skip: Vec<f64>,
83 h: Vec<f64>,
85 n_state: usize,
87 d_in: usize,
89}
90
91impl SelectiveSSM {
92 pub fn new(d_in: usize, n_state: usize, seed: u64) -> Self {
112 let log_a = mamba_init(n_state);
113 let mut rng = Xorshift64(seed);
114 let scale = 0.01;
115
116 let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
118 let b_delta = 0.0;
119 let w_b: Vec<f64> = (0..n_state * d_in)
120 .map(|_| rng.next_normal() * scale)
121 .collect();
122 let w_c: Vec<f64> = (0..n_state * d_in)
123 .map(|_| rng.next_normal() * scale)
124 .collect();
125 let d_skip = vec![0.0; d_in];
126 let h = vec![0.0; d_in * n_state];
127
128 Self {
129 log_a,
130 w_delta,
131 b_delta,
132 w_b,
133 w_c,
134 d_skip,
135 h,
136 n_state,
137 d_in,
138 }
139 }
140
141 #[inline]
143 pub fn d_in(&self) -> usize {
144 self.d_in
145 }
146
147 #[inline]
149 pub fn n_state(&self) -> usize {
150 self.n_state
151 }
152
153 fn selective_forward(&mut self, input: &[f64]) -> Vec<f64> {
158 let d_in = self.d_in;
159 let n_state = self.n_state;
160
161 let delta_raw = dot(&self.w_delta, input) + self.b_delta;
163 let delta = softplus(delta_raw);
164
165 let mut b_t = vec![0.0; n_state];
167 mat_vec(&self.w_b, input, n_state, d_in, &mut b_t);
168
169 let mut c_t = vec![0.0; n_state];
171 mat_vec(&self.w_c, input, n_state, d_in, &mut c_t);
172
173 let mut output = vec![0.0; d_in];
175
176 for d in 0..d_in {
177 let h_offset = d * n_state;
178 let mut y = 0.0;
179
180 for n in 0..n_state {
181 let a_n = -math::exp(self.log_a[n]); let a_bar = math::exp(delta * a_n);
185
186 let b_bar = if math::abs(a_n) < 1e-12 {
189 delta * b_t[n]
190 } else {
191 (a_bar - 1.0) / a_n * b_t[n]
192 };
193
194 self.h[h_offset + n] = a_bar * self.h[h_offset + n] + b_bar * input[d];
196
197 y += c_t[n] * self.h[h_offset + n];
199 }
200
201 output[d] = y + self.d_skip[d] * input[d];
203 }
204
205 output
206 }
207}
208
209impl SSMLayer for SelectiveSSM {
210 fn forward(&mut self, input: &[f64]) -> Vec<f64> {
211 debug_assert_eq!(
212 input.len(),
213 self.d_in,
214 "input length {} must match d_in {}",
215 input.len(),
216 self.d_in
217 );
218 self.selective_forward(input)
219 }
220
221 fn state(&self) -> &[f64] {
222 &self.h
223 }
224
225 fn output_dim(&self) -> usize {
226 self.d_in
227 }
228
229 fn reset(&mut self) {
230 for h in self.h.iter_mut() {
231 *h = 0.0;
232 }
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn new_creates_correct_dimensions() {
242 let ssm = SelectiveSSM::new(4, 8, 42);
243 assert_eq!(ssm.d_in(), 4);
244 assert_eq!(ssm.n_state(), 8);
245 assert_eq!(ssm.state().len(), 4 * 8);
246 assert_eq!(ssm.output_dim(), 4);
247 }
248
249 #[test]
250 fn initial_state_is_zero() {
251 let ssm = SelectiveSSM::new(3, 16, 42);
252 for &h in ssm.state() {
253 assert!(math::abs(h) < 1e-15, "initial state should be zero");
254 }
255 }
256
257 #[test]
258 fn forward_produces_correct_output_dim() {
259 let mut ssm = SelectiveSSM::new(5, 8, 42);
260 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
261 let output = ssm.forward(&input);
262 assert_eq!(output.len(), 5, "output dim should match d_in");
263 }
264
265 #[test]
266 fn forward_produces_finite_output() {
267 let mut ssm = SelectiveSSM::new(3, 8, 42);
268 let input = vec![1.0, -1.0, 0.5];
269 let output = ssm.forward(&input);
270 for (i, &y) in output.iter().enumerate() {
271 assert!(y.is_finite(), "output[{}] should be finite, got {}", i, y);
272 }
273 }
274
275 #[test]
276 fn forward_updates_state() {
277 let mut ssm = SelectiveSSM::new(3, 8, 42);
278 let input = vec![1.0, 2.0, 3.0];
279 let _ = ssm.forward(&input);
280 let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
281 assert!(
282 state_norm > 0.0,
283 "state should be non-zero after processing non-zero input"
284 );
285 }
286
287 #[test]
288 fn reset_clears_state() {
289 let mut ssm = SelectiveSSM::new(3, 8, 42);
290 let _ = ssm.forward(&[1.0, 2.0, 3.0]);
291 ssm.reset();
292 for &h in ssm.state() {
293 assert!(math::abs(h) < 1e-15, "state should be zero after reset");
294 }
295 }
296
297 #[test]
298 fn state_decays_without_input() {
299 let mut ssm = SelectiveSSM::new(2, 4, 42);
300 let _ = ssm.forward(&[10.0, 10.0]);
302 let energy_after: f64 = ssm.state().iter().map(|h| h * h).sum();
303
304 for _ in 0..200 {
306 let _ = ssm.forward(&[0.0, 0.0]);
307 }
308 let energy_decayed: f64 = ssm.state().iter().map(|h| h * h).sum();
309 assert!(
310 energy_decayed < energy_after * 0.01,
311 "state should decay with zero input: initial={}, after={}",
312 energy_after,
313 energy_decayed
314 );
315 }
316
317 #[test]
318 fn deterministic_with_same_seed() {
319 let mut ssm1 = SelectiveSSM::new(3, 8, 42);
320 let mut ssm2 = SelectiveSSM::new(3, 8, 42);
321 let input = vec![1.0, 2.0, 3.0];
322 let out1 = ssm1.forward(&input);
323 let out2 = ssm2.forward(&input);
324 for (i, (&a, &b)) in out1.iter().zip(out2.iter()).enumerate() {
325 assert!(
326 math::abs(a - b) < 1e-15,
327 "output[{}] should be identical for same seed: {} vs {}",
328 i,
329 a,
330 b
331 );
332 }
333 }
334
335 #[test]
336 fn different_seeds_produce_different_outputs() {
337 let mut ssm1 = SelectiveSSM::new(3, 8, 42);
338 let mut ssm2 = SelectiveSSM::new(3, 8, 99);
339 let input = vec![1.0, 2.0, 3.0];
340 let out1 = ssm1.forward(&input);
341 let out2 = ssm2.forward(&input);
342 let diff: f64 = out1
343 .iter()
344 .zip(out2.iter())
345 .map(|(a, b)| (a - b) * (a - b))
346 .sum();
347 assert!(
348 diff > 1e-20,
349 "different seeds should generally produce different outputs"
350 );
351 }
352
353 #[test]
354 fn single_channel_works() {
355 let mut ssm = SelectiveSSM::new(1, 4, 42);
356 let output = ssm.forward(&[3.0]);
357 assert_eq!(output.len(), 1);
358 assert!(output[0].is_finite());
359 }
360
361 #[test]
362 fn single_state_dim_works() {
363 let mut ssm = SelectiveSSM::new(3, 1, 42);
364 let output = ssm.forward(&[1.0, 2.0, 3.0]);
365 assert_eq!(output.len(), 3);
366 for &y in &output {
367 assert!(y.is_finite());
368 }
369 }
370
371 #[test]
372 fn sequential_outputs_differ() {
373 let mut ssm = SelectiveSSM::new(2, 4, 42);
374 let out1 = ssm.forward(&[1.0, 0.0]);
375 let out2 = ssm.forward(&[1.0, 0.0]);
376 let diff: f64 = out1
378 .iter()
379 .zip(out2.iter())
380 .map(|(a, b)| (a - b) * (a - b))
381 .sum();
382 assert!(
383 diff > 1e-20,
384 "sequential calls with same input should differ due to state: out1={:?}, out2={:?}",
385 out1,
386 out2
387 );
388 }
389
390 #[test]
391 fn large_input_no_overflow() {
392 let mut ssm = SelectiveSSM::new(2, 4, 42);
393 let input = vec![1000.0, -1000.0];
394 let output = ssm.forward(&input);
395 for (i, &y) in output.iter().enumerate() {
396 assert!(
397 y.is_finite(),
398 "output[{}] should be finite for large inputs, got {}",
399 i,
400 y
401 );
402 }
403 }
404
405 #[test]
406 fn zero_input_zero_state_gives_zero_output() {
407 let mut ssm = SelectiveSSM::new(3, 8, 42);
408 let output = ssm.forward(&[0.0, 0.0, 0.0]);
409 for (i, &y) in output.iter().enumerate() {
410 assert!(
411 math::abs(y) < 1e-15,
412 "zero input with zero state should give zero output[{}], got {}",
413 i,
414 y
415 );
416 }
417 }
418}