1use alloc::vec;
31use alloc::vec::Vec;
32
33use crate::math;
34use crate::ssm::init::s4d_inv_real;
35use crate::ssm::projection::{dot, mat_vec, softplus, Xorshift64};
36use crate::ssm::SSMLayer;
37
38pub struct SelectiveSSM {
71 log_a: Vec<f64>,
73 w_delta: Vec<f64>,
75 b_delta: f64,
77 w_b: Vec<f64>,
79 w_c: Vec<f64>,
81 d_skip: Vec<f64>,
83 h: Vec<f64>,
85 n_state: usize,
87 d_in: usize,
89}
90
91impl SelectiveSSM {
92 pub fn new(d_in: usize, n_state: usize, seed: u64) -> Self {
115 let log_a = s4d_inv_real(n_state);
116 let mut rng = Xorshift64(seed);
117 let scale = 0.1;
118
119 let w_delta: Vec<f64> = (0..d_in).map(|_| rng.next_normal() * scale).collect();
121 let b_delta = 0.0;
122 let w_b: Vec<f64> = (0..n_state * d_in)
123 .map(|_| rng.next_normal() * scale)
124 .collect();
125 let w_c: Vec<f64> = (0..n_state * d_in)
126 .map(|_| rng.next_normal() * scale)
127 .collect();
128 let d_skip = vec![1.0; d_in];
129 let h = vec![0.0; d_in * n_state];
130
131 Self {
132 log_a,
133 w_delta,
134 b_delta,
135 w_b,
136 w_c,
137 d_skip,
138 h,
139 n_state,
140 d_in,
141 }
142 }
143
144 #[inline]
146 pub fn d_in(&self) -> usize {
147 self.d_in
148 }
149
150 #[inline]
152 pub fn n_state(&self) -> usize {
153 self.n_state
154 }
155
156 fn selective_forward(&mut self, input: &[f64]) -> Vec<f64> {
161 let d_in = self.d_in;
162 let n_state = self.n_state;
163
164 let delta_raw = dot(&self.w_delta, input) + self.b_delta;
166 let delta = softplus(delta_raw);
167
168 let mut b_t = vec![0.0; n_state];
170 mat_vec(&self.w_b, input, n_state, d_in, &mut b_t);
171
172 let mut c_t = vec![0.0; n_state];
174 mat_vec(&self.w_c, input, n_state, d_in, &mut c_t);
175
176 let mut output = vec![0.0; d_in];
178
179 for d in 0..d_in {
180 let h_offset = d * n_state;
181 let mut y = 0.0;
182
183 for n in 0..n_state {
184 let a_n = -math::exp(self.log_a[n]); let a_bar = math::exp(delta * a_n);
188
189 let b_bar = if math::abs(a_n) < 1e-12 {
192 delta * b_t[n]
193 } else {
194 (a_bar - 1.0) / a_n * b_t[n]
195 };
196
197 self.h[h_offset + n] = a_bar * self.h[h_offset + n] + b_bar * input[d];
199
200 y += c_t[n] * self.h[h_offset + n];
202 }
203
204 output[d] = y + self.d_skip[d] * input[d];
206 }
207
208 output
209 }
210}
211
212impl SSMLayer for SelectiveSSM {
213 fn forward(&mut self, input: &[f64]) -> Vec<f64> {
214 debug_assert_eq!(
215 input.len(),
216 self.d_in,
217 "input length {} must match d_in {}",
218 input.len(),
219 self.d_in
220 );
221 self.selective_forward(input)
222 }
223
224 fn state(&self) -> &[f64] {
225 &self.h
226 }
227
228 fn output_dim(&self) -> usize {
229 self.d_in
230 }
231
232 fn reset(&mut self) {
233 for h in self.h.iter_mut() {
234 *h = 0.0;
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn new_creates_correct_dimensions() {
245 let ssm = SelectiveSSM::new(4, 8, 42);
246 assert_eq!(ssm.d_in(), 4);
247 assert_eq!(ssm.n_state(), 8);
248 assert_eq!(ssm.state().len(), 4 * 8);
249 assert_eq!(ssm.output_dim(), 4);
250 }
251
252 #[test]
253 fn initial_state_is_zero() {
254 let ssm = SelectiveSSM::new(3, 16, 42);
255 for &h in ssm.state() {
256 assert!(math::abs(h) < 1e-15, "initial state should be zero");
257 }
258 }
259
260 #[test]
261 fn forward_produces_correct_output_dim() {
262 let mut ssm = SelectiveSSM::new(5, 8, 42);
263 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
264 let output = ssm.forward(&input);
265 assert_eq!(output.len(), 5, "output dim should match d_in");
266 }
267
268 #[test]
269 fn forward_produces_finite_output() {
270 let mut ssm = SelectiveSSM::new(3, 8, 42);
271 let input = vec![1.0, -1.0, 0.5];
272 let output = ssm.forward(&input);
273 for (i, &y) in output.iter().enumerate() {
274 assert!(y.is_finite(), "output[{}] should be finite, got {}", i, y);
275 }
276 }
277
278 #[test]
279 fn forward_updates_state() {
280 let mut ssm = SelectiveSSM::new(3, 8, 42);
281 let input = vec![1.0, 2.0, 3.0];
282 let _ = ssm.forward(&input);
283 let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
284 assert!(
285 state_norm > 0.0,
286 "state should be non-zero after processing non-zero input"
287 );
288 }
289
290 #[test]
291 fn reset_clears_state() {
292 let mut ssm = SelectiveSSM::new(3, 8, 42);
293 let _ = ssm.forward(&[1.0, 2.0, 3.0]);
294 ssm.reset();
295 for &h in ssm.state() {
296 assert!(math::abs(h) < 1e-15, "state should be zero after reset");
297 }
298 }
299
300 #[test]
301 fn state_decays_without_input() {
302 let mut ssm = SelectiveSSM::new(2, 4, 42);
303 let _ = ssm.forward(&[10.0, 10.0]);
305 let energy_after: f64 = ssm.state().iter().map(|h| h * h).sum();
306
307 for _ in 0..200 {
309 let _ = ssm.forward(&[0.0, 0.0]);
310 }
311 let energy_decayed: f64 = ssm.state().iter().map(|h| h * h).sum();
312 assert!(
313 energy_decayed < energy_after * 0.01,
314 "state should decay with zero input: initial={}, after={}",
315 energy_after,
316 energy_decayed
317 );
318 }
319
320 #[test]
321 fn deterministic_with_same_seed() {
322 let mut ssm1 = SelectiveSSM::new(3, 8, 42);
323 let mut ssm2 = SelectiveSSM::new(3, 8, 42);
324 let input = vec![1.0, 2.0, 3.0];
325 let out1 = ssm1.forward(&input);
326 let out2 = ssm2.forward(&input);
327 for (i, (&a, &b)) in out1.iter().zip(out2.iter()).enumerate() {
328 assert!(
329 math::abs(a - b) < 1e-15,
330 "output[{}] should be identical for same seed: {} vs {}",
331 i,
332 a,
333 b
334 );
335 }
336 }
337
338 #[test]
339 fn different_seeds_produce_different_outputs() {
340 let mut ssm1 = SelectiveSSM::new(3, 8, 42);
341 let mut ssm2 = SelectiveSSM::new(3, 8, 99);
342 let input = vec![1.0, 2.0, 3.0];
343 let out1 = ssm1.forward(&input);
344 let out2 = ssm2.forward(&input);
345 let diff: f64 = out1
346 .iter()
347 .zip(out2.iter())
348 .map(|(a, b)| (a - b) * (a - b))
349 .sum();
350 assert!(
351 diff > 1e-20,
352 "different seeds should generally produce different outputs"
353 );
354 }
355
356 #[test]
357 fn single_channel_works() {
358 let mut ssm = SelectiveSSM::new(1, 4, 42);
359 let output = ssm.forward(&[3.0]);
360 assert_eq!(output.len(), 1);
361 assert!(output[0].is_finite());
362 }
363
364 #[test]
365 fn single_state_dim_works() {
366 let mut ssm = SelectiveSSM::new(3, 1, 42);
367 let output = ssm.forward(&[1.0, 2.0, 3.0]);
368 assert_eq!(output.len(), 3);
369 for &y in &output {
370 assert!(y.is_finite());
371 }
372 }
373
374 #[test]
375 fn sequential_outputs_differ() {
376 let mut ssm = SelectiveSSM::new(2, 4, 42);
377 let out1 = ssm.forward(&[1.0, 0.0]);
378 let out2 = ssm.forward(&[1.0, 0.0]);
379 let diff: f64 = out1
381 .iter()
382 .zip(out2.iter())
383 .map(|(a, b)| (a - b) * (a - b))
384 .sum();
385 assert!(
386 diff > 1e-20,
387 "sequential calls with same input should differ due to state: out1={:?}, out2={:?}",
388 out1,
389 out2
390 );
391 }
392
393 #[test]
394 fn large_input_no_overflow() {
395 let mut ssm = SelectiveSSM::new(2, 4, 42);
396 let input = vec![1000.0, -1000.0];
397 let output = ssm.forward(&input);
398 for (i, &y) in output.iter().enumerate() {
399 assert!(
400 y.is_finite(),
401 "output[{}] should be finite for large inputs, got {}",
402 i,
403 y
404 );
405 }
406 }
407
408 #[test]
409 fn zero_input_zero_state_gives_zero_output() {
410 let mut ssm = SelectiveSSM::new(3, 8, 42);
411 let output = ssm.forward(&[0.0, 0.0, 0.0]);
412 for (i, &y) in output.iter().enumerate() {
413 assert!(
414 math::abs(y) < 1e-15,
415 "zero input with zero state should give zero output[{}], got {}",
416 i,
417 y
418 );
419 }
420 }
421}