Skip to main content

irithyll_core/ssm/
diagonal.rs

1//! Non-selective diagonal state space model with fixed parameters.
2//!
3//! [`DiagonalSSM`] implements a simple diagonal SSM where the A, B, C matrices
4//! and step size delta are fixed at construction time. This provides a baseline
5//! temporal feature extractor without the input-dependent selectivity of Mamba.
6//!
7//! The model processes scalar inputs and produces scalar outputs:
8//!
9//! ```text
10//! h_t = A_bar .* h_{t-1} + B_bar * x_t    (state update, element-wise)
11//! y_t = C^T * h_t + D * x_t                 (output)
12//! ```
13//!
14//! where `A_bar` and `B_bar` are computed once via ZOH discretization of the
15//! continuous-time parameters.
16
17use alloc::vec;
18use alloc::vec::Vec;
19
20use crate::math;
21use crate::ssm::discretize::zoh_discretize;
22use crate::ssm::init::mamba_init;
23use crate::ssm::projection::dot;
24use crate::ssm::SSMLayer;
25
26/// Non-selective diagonal SSM with fixed A, B, C, and step size.
27///
28/// This is the simplest SSM variant: all parameters are determined at
29/// construction time and do not adapt to input content. The hidden state
30/// is an N-dimensional vector that evolves through element-wise recurrence.
31///
32/// # When to Use
33///
34/// Use `DiagonalSSM` when:
35/// - You want a simple temporal smoothing/memory layer
36/// - Input-dependent selectivity is not needed
37/// - You need a fast baseline to compare against [`SelectiveSSM`](super::SelectiveSSM)
38///
39/// # Example
40///
41/// ```
42/// use irithyll_core::ssm::diagonal::DiagonalSSM;
43/// use irithyll_core::ssm::SSMLayer;
44///
45/// let mut ssm = DiagonalSSM::new(8, 0.1);
46/// let output = ssm.forward(&[1.0]);
47/// assert_eq!(output.len(), 1);
48/// ```
49pub struct DiagonalSSM {
50    /// Log-space A parameters (N). Actual A_n = -exp(log_a_n).
51    log_a: Vec<f64>,
52    /// B vector (N) -- input-to-state projection.
53    b: Vec<f64>,
54    /// C vector (N) -- state-to-output projection.
55    c: Vec<f64>,
56    /// Fixed discretization step size.
57    delta: f64,
58    /// Skip connection weight (D in the SSM equations).
59    d_skip: f64,
60    /// Hidden state vector (N).
61    h: Vec<f64>,
62    /// Pre-computed discretized A_bar values (N). Cached for efficiency.
63    a_bar: Vec<f64>,
64    /// Pre-computed discretized B_bar factor values (N). Cached for efficiency.
65    b_bar_factor: Vec<f64>,
66}
67
68impl DiagonalSSM {
69    /// Create a new non-selective diagonal SSM.
70    ///
71    /// Initializes A via the Mamba strategy (`A_n = -(n+1)`), sets B and C
72    /// to ones, D (skip connection) to zero, and pre-computes the discretized
73    /// matrices using ZOH.
74    ///
75    /// # Arguments
76    ///
77    /// * `n_state` -- number of hidden state dimensions (N)
78    /// * `delta` -- fixed discretization step size (positive)
79    ///
80    /// # Example
81    ///
82    /// ```
83    /// use irithyll_core::ssm::diagonal::DiagonalSSM;
84    ///
85    /// let ssm = DiagonalSSM::new(16, 0.01);
86    /// ```
87    pub fn new(n_state: usize, delta: f64) -> Self {
88        let log_a = mamba_init(n_state);
89        let b = vec![1.0; n_state];
90        let c = vec![1.0; n_state];
91        let h = vec![0.0; n_state];
92
93        // Pre-compute discretized A_bar and B_bar_factor
94        let mut a_bar = Vec::with_capacity(n_state);
95        let mut b_bar_factor = Vec::with_capacity(n_state);
96        for la in &log_a {
97            let a_n = -math::exp(*la);
98            let (ab, bbf) = zoh_discretize(a_n, delta);
99            a_bar.push(ab);
100            b_bar_factor.push(bbf);
101        }
102
103        Self {
104            log_a,
105            b,
106            c,
107            delta,
108            d_skip: 0.0,
109            h,
110            a_bar,
111            b_bar_factor,
112        }
113    }
114
115    /// Create a diagonal SSM with custom B, C vectors and skip connection.
116    ///
117    /// # Arguments
118    ///
119    /// * `n_state` -- number of hidden state dimensions
120    /// * `delta` -- fixed discretization step size
121    /// * `b` -- input projection vector (length n_state)
122    /// * `c` -- output projection vector (length n_state)
123    /// * `d_skip` -- skip connection weight
124    ///
125    /// # Panics
126    ///
127    /// Debug-asserts that `b.len() == n_state` and `c.len() == n_state`.
128    pub fn with_params(n_state: usize, delta: f64, b: Vec<f64>, c: Vec<f64>, d_skip: f64) -> Self {
129        debug_assert_eq!(b.len(), n_state);
130        debug_assert_eq!(c.len(), n_state);
131        let log_a = mamba_init(n_state);
132        let h = vec![0.0; n_state];
133
134        let mut a_bar = Vec::with_capacity(n_state);
135        let mut b_bar_factor = Vec::with_capacity(n_state);
136        for la in &log_a {
137            let a_n = -math::exp(*la);
138            let (ab, bbf) = zoh_discretize(a_n, delta);
139            a_bar.push(ab);
140            b_bar_factor.push(bbf);
141        }
142
143        Self {
144            log_a,
145            b,
146            c,
147            delta,
148            d_skip,
149            h,
150            a_bar,
151            b_bar_factor,
152        }
153    }
154
155    /// Process a single scalar input and return the scalar output.
156    ///
157    /// Updates the hidden state via:
158    /// ```text
159    /// h_n = a_bar_n * h_n + b_bar_factor_n * b_n * x
160    /// y = C^T * h + D * x
161    /// ```
162    #[inline]
163    pub fn forward_scalar(&mut self, x: f64) -> f64 {
164        let n_state = self.h.len();
165        for n in 0..n_state {
166            self.h[n] = self.a_bar[n] * self.h[n] + self.b_bar_factor[n] * self.b[n] * x;
167        }
168        dot(&self.c, &self.h) + self.d_skip * x
169    }
170
171    /// Get the number of state dimensions.
172    #[inline]
173    pub fn n_state(&self) -> usize {
174        self.h.len()
175    }
176
177    /// Get the log-A parameters (for inspection/serialization).
178    #[inline]
179    pub fn log_a(&self) -> &[f64] {
180        &self.log_a
181    }
182
183    /// Get the fixed step size.
184    #[inline]
185    pub fn delta(&self) -> f64 {
186        self.delta
187    }
188}
189
190impl SSMLayer for DiagonalSSM {
191    fn forward(&mut self, input: &[f64]) -> Vec<f64> {
192        // DiagonalSSM processes scalar input -- take first element or 0
193        let x = if input.is_empty() { 0.0 } else { input[0] };
194        vec![self.forward_scalar(x)]
195    }
196
197    fn state(&self) -> &[f64] {
198        &self.h
199    }
200
201    fn output_dim(&self) -> usize {
202        1
203    }
204
205    fn reset(&mut self) {
206        for h in self.h.iter_mut() {
207            *h = 0.0;
208        }
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn new_creates_zero_state() {
218        let ssm = DiagonalSSM::new(8, 0.1);
219        assert_eq!(ssm.n_state(), 8);
220        for &h in ssm.state() {
221            assert!(math::abs(h) < 1e-15, "initial state should be zero");
222        }
223    }
224
225    #[test]
226    fn forward_scalar_produces_finite_output() {
227        let mut ssm = DiagonalSSM::new(4, 0.1);
228        let y = ssm.forward_scalar(1.0);
229        assert!(y.is_finite(), "output should be finite, got {}", y);
230    }
231
232    #[test]
233    fn forward_updates_state() {
234        let mut ssm = DiagonalSSM::new(4, 0.1);
235        let _ = ssm.forward_scalar(1.0);
236        let state_norm: f64 = ssm.state().iter().map(|h| h * h).sum();
237        assert!(
238            state_norm > 0.0,
239            "state should be non-zero after processing input"
240        );
241    }
242
243    #[test]
244    fn reset_clears_state() {
245        let mut ssm = DiagonalSSM::new(4, 0.1);
246        let _ = ssm.forward_scalar(1.0);
247        ssm.reset();
248        for &h in ssm.state() {
249            assert!(math::abs(h) < 1e-15, "state should be zero after reset");
250        }
251    }
252
253    #[test]
254    fn state_decays_without_input() {
255        let mut ssm = DiagonalSSM::new(4, 0.1);
256        // Inject some state
257        let _ = ssm.forward_scalar(10.0);
258        let energy_after_input: f64 = ssm.state().iter().map(|h| h * h).sum();
259
260        // Process many zero inputs -- state should decay
261        for _ in 0..100 {
262            let _ = ssm.forward_scalar(0.0);
263        }
264        let energy_after_decay: f64 = ssm.state().iter().map(|h| h * h).sum();
265        assert!(
266            energy_after_decay < energy_after_input * 0.01,
267            "state energy should decay: initial={}, after={}",
268            energy_after_input,
269            energy_after_decay
270        );
271    }
272
273    #[test]
274    fn ssm_layer_trait_works() {
275        let mut ssm = DiagonalSSM::new(4, 0.1);
276        let out = ssm.forward(&[1.0]);
277        assert_eq!(out.len(), 1, "output_dim should be 1");
278        assert_eq!(ssm.output_dim(), 1);
279    }
280
281    #[test]
282    fn constant_input_converges() {
283        // With constant input, output should converge to a steady state
284        let mut ssm = DiagonalSSM::new(4, 0.1);
285        let mut prev_y = 0.0;
286        let mut settled = false;
287        for i in 0..500 {
288            let y = ssm.forward_scalar(1.0);
289            if i > 10 && math::abs(y - prev_y) < 1e-10 {
290                settled = true;
291                break;
292            }
293            prev_y = y;
294        }
295        assert!(settled, "output should converge for constant input");
296    }
297
298    #[test]
299    fn skip_connection_passes_through() {
300        let b = vec![0.0; 4]; // Zero B means no state update from input
301        let c = vec![0.0; 4]; // Zero C means no state contribution to output
302        let mut ssm = DiagonalSSM::with_params(4, 0.1, b, c, 1.0);
303        let y = ssm.forward_scalar(5.0);
304        assert!(
305            math::abs(y - 5.0) < 1e-12,
306            "with zero B/C and d_skip=1, output should equal input: got {}",
307            y
308        );
309    }
310
311    #[test]
312    fn empty_input_treated_as_zero() {
313        let mut ssm = DiagonalSSM::new(4, 0.1);
314        let out = ssm.forward(&[]);
315        assert_eq!(out.len(), 1);
316        assert!(
317            math::abs(out[0]) < 1e-15,
318            "empty input should be treated as zero"
319        );
320    }
321
322    #[test]
323    fn different_delta_changes_dynamics() {
324        let mut ssm_fast = DiagonalSSM::new(4, 1.0);
325        let mut ssm_slow = DiagonalSSM::new(4, 0.001);
326
327        // Same input sequence
328        let _ = ssm_fast.forward_scalar(1.0);
329        let y_fast = ssm_fast.forward_scalar(0.0);
330
331        let _ = ssm_slow.forward_scalar(1.0);
332        let y_slow = ssm_slow.forward_scalar(0.0);
333
334        // The faster step size should produce different decay behavior
335        assert!(
336            math::abs(y_fast - y_slow) > 1e-6,
337            "different delta should produce different dynamics: fast={}, slow={}",
338            y_fast,
339            y_slow
340        );
341    }
342}