irithyll_core/ssm/
diagonal.rs1use alloc::vec;
18use alloc::vec::Vec;
19
20use crate::math;
21use crate::ssm::discretize::zoh_discretize;
22use crate::ssm::init::mamba_init;
23use crate::ssm::projection::dot;
24use crate::ssm::SSMLayer;
25
26pub struct DiagonalSSM {
50 log_a: Vec<f64>,
52 b: Vec<f64>,
54 c: Vec<f64>,
56 delta: f64,
58 d_skip: f64,
60 h: Vec<f64>,
62 a_bar: Vec<f64>,
64 b_bar_factor: Vec<f64>,
66}
67
68impl DiagonalSSM {
69 pub fn new(n_state: usize, delta: f64) -> Self {
88 let log_a = mamba_init(n_state);
89 let b = vec![1.0; n_state];
90 let c = vec![1.0; n_state];
91 let h = vec![0.0; n_state];
92
93 let mut a_bar = Vec::with_capacity(n_state);
95 let mut b_bar_factor = Vec::with_capacity(n_state);
96 for la in &log_a {
97 let a_n = -math::exp(*la);
98 let (ab, bbf) = zoh_discretize(a_n, delta);
99 a_bar.push(ab);
100 b_bar_factor.push(bbf);
101 }
102
103 Self {
104 log_a,
105 b,
106 c,
107 delta,
108 d_skip: 0.0,
109 h,
110 a_bar,
111 b_bar_factor,
112 }
113 }
114
115 pub fn with_params(n_state: usize, delta: f64, b: Vec<f64>, c: Vec<f64>, d_skip: f64) -> Self {
129 debug_assert_eq!(b.len(), n_state);
130 debug_assert_eq!(c.len(), n_state);
131 let log_a = mamba_init(n_state);
132 let h = vec![0.0; n_state];
133
134 let mut a_bar = Vec::with_capacity(n_state);
135 let mut b_bar_factor = Vec::with_capacity(n_state);
136 for la in &log_a {
137 let a_n = -math::exp(*la);
138 let (ab, bbf) = zoh_discretize(a_n, delta);
139 a_bar.push(ab);
140 b_bar_factor.push(bbf);
141 }
142
143 Self {
144 log_a,
145 b,
146 c,
147 delta,
148 d_skip,
149 h,
150 a_bar,
151 b_bar_factor,
152 }
153 }
154
155 #[inline]
163 pub fn forward_scalar(&mut self, x: f64) -> f64 {
164 let n_state = self.h.len();
165 for n in 0..n_state {
166 self.h[n] = self.a_bar[n] * self.h[n] + self.b_bar_factor[n] * self.b[n] * x;
167 }
168 dot(&self.c, &self.h) + self.d_skip * x
169 }
170
171 #[inline]
173 pub fn n_state(&self) -> usize {
174 self.h.len()
175 }
176
177 #[inline]
179 pub fn log_a(&self) -> &[f64] {
180 &self.log_a
181 }
182
183 #[inline]
185 pub fn delta(&self) -> f64 {
186 self.delta
187 }
188}
189
190impl SSMLayer for DiagonalSSM {
191 fn forward(&mut self, input: &[f64]) -> Vec<f64> {
192 let x = if input.is_empty() { 0.0 } else { input[0] };
194 vec![self.forward_scalar(x)]
195 }
196
197 fn state(&self) -> &[f64] {
198 &self.h
199 }
200
201 fn output_dim(&self) -> usize {
202 1
203 }
204
205 fn reset(&mut self) {
206 for h in self.h.iter_mut() {
207 *h = 0.0;
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn new_creates_zero_state() {
218 let ssm = DiagonalSSM::new(8, 0.1);
219 assert_eq!(ssm.n_state(), 8);
220 for &h in ssm.state() {
221 assert!(math::abs(h) < 1e-15, "initial state should be zero");
222 }
223 }
224
225 #[test]
226 fn forward_scalar_produces_finite_output() {
227 let mut ssm = DiagonalSSM::new(4, 0.1);
228 let y = ssm.forward_scalar(1.0);
229 assert!(y.is_finite(), "output should be finite, got {}", y);
230 }
231
232 #[test]
233 fn forward_updates_state() {
234 let mut ssm = DiagonalSSM::new(4, 0.1);
235 let _ = ssm.forward_scalar(1.0);
236 let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
237 assert!(
238 state_norm > 0.0,
239 "state should be non-zero after processing input"
240 );
241 }
242
243 #[test]
244 fn reset_clears_state() {
245 let mut ssm = DiagonalSSM::new(4, 0.1);
246 let _ = ssm.forward_scalar(1.0);
247 ssm.reset();
248 for &h in ssm.state() {
249 assert!(math::abs(h) < 1e-15, "state should be zero after reset");
250 }
251 }
252
253 #[test]
254 fn state_decays_without_input() {
255 let mut ssm = DiagonalSSM::new(4, 0.1);
256 let _ = ssm.forward_scalar(10.0);
258 let energy_after_input: f64 = ssm.state().iter().map(|h| h * h).sum();
259
260 for _ in 0..100 {
262 let _ = ssm.forward_scalar(0.0);
263 }
264 let energy_after_decay: f64 = ssm.state().iter().map(|h| h * h).sum();
265 assert!(
266 energy_after_decay < energy_after_input * 0.01,
267 "state energy should decay: initial={}, after={}",
268 energy_after_input,
269 energy_after_decay
270 );
271 }
272
273 #[test]
274 fn ssm_layer_trait_works() {
275 let mut ssm = DiagonalSSM::new(4, 0.1);
276 let out = ssm.forward(&[1.0]);
277 assert_eq!(out.len(), 1, "output_dim should be 1");
278 assert_eq!(ssm.output_dim(), 1);
279 }
280
281 #[test]
282 fn constant_input_converges() {
283 let mut ssm = DiagonalSSM::new(4, 0.1);
285 let mut prev_y = 0.0;
286 let mut settled = false;
287 for i in 0..500 {
288 let y = ssm.forward_scalar(1.0);
289 if i > 10 && math::abs(y - prev_y) < 1e-10 {
290 settled = true;
291 break;
292 }
293 prev_y = y;
294 }
295 assert!(settled, "output should converge for constant input");
296 }
297
298 #[test]
299 fn skip_connection_passes_through() {
300 let b = vec![0.0; 4]; let c = vec![0.0; 4]; let mut ssm = DiagonalSSM::with_params(4, 0.1, b, c, 1.0);
303 let y = ssm.forward_scalar(5.0);
304 assert!(
305 math::abs(y - 5.0) < 1e-12,
306 "with zero B/C and d_skip=1, output should equal input: got {}",
307 y
308 );
309 }
310
311 #[test]
312 fn empty_input_treated_as_zero() {
313 let mut ssm = DiagonalSSM::new(4, 0.1);
314 let out = ssm.forward(&[]);
315 assert_eq!(out.len(), 1);
316 assert!(
317 math::abs(out[0]) < 1e-15,
318 "empty input should be treated as zero"
319 );
320 }
321
322 #[test]
323 fn different_delta_changes_dynamics() {
324 let mut ssm_fast = DiagonalSSM::new(4, 1.0);
325 let mut ssm_slow = DiagonalSSM::new(4, 0.001);
326
327 let _ = ssm_fast.forward_scalar(1.0);
329 let y_fast = ssm_fast.forward_scalar(0.0);
330
331 let _ = ssm_slow.forward_scalar(1.0);
332 let y_slow = ssm_slow.forward_scalar(0.0);
333
334 assert!(
336 math::abs(y_fast - y_slow) > 1e-6,
337 "different delta should produce different dynamics: fast={}, slow={}",
338 y_fast,
339 y_slow
340 );
341 }
342}