Skip to main content

irithyll_core/ssm/
complex_diag.rs

1//! Complex Diagonal SSM — standalone reusable streaming primitive.
2//!
3//! [`ComplexDiagonalSSM`] implements the complex-valued diagonal recurrence:
4//!
5//! ```text
6//! A = Diag(-exp(log|re|) + j · im)    (stable complex eigenvalues)
7//! h_t = discretize(A, Δ_t) · h_{t-1} + input_contribution(B, x_t, Δ_t)
8//! y_t = Re(C^T · h_t)                  (real output)
9//! ```
10//!
11//! This is the mathematical core shared by all Mamba-3 variants. Extracting it
12//! as a standalone primitive enables:
13//!
14//! 1. **Reservoir computing**: complex echo state networks with oscillatory dynamics.
15//! 2. **Signal processing**: market-data phase tracking via complex rotation.
16//! 3. **Composability**: `StreamingMamba V3Exp` and `V3Mimo` both use this cell.
17//! 4. **Testability**: unit-test the recurrence math independently of the
18//!    full Mamba plumbing.
19//!
20//! ## Parameterization
21//!
22//! Complex A is stored as `log_a_complex: Vec<f64>` with interleaved layout
23//! `[log|re_0|, im_0, log|re_1|, im_1, ...]`. Actual eigenvalues:
24//!
25//! ```text
26//! A_n = -exp(log_a_complex[2n]) + j · log_a_complex[2n+1]
27//! ```
28//!
29//! Stability: Re(A_n) < 0 is structurally enforced by the negated exp, so
30//! `|α_n| = exp(Δ · Re(A_n)) < 1` for any Δ > 0.
31//!
32//! ## Discretization methods
33//!
34//! Two methods are supported via [`DiscretizeMethod`]:
35//!
36//! - **Tustin** (default, `trapezoidal_complex`): S4-style bilinear transform.
37//! - **ExpTrapezoidal** (Mamba-3 spec, `exp_trapezoidal_complex`): 3-term
38//!   recurrence with data-dependent λ_t.
39//!
40//! ## References
41//!
42//! - Lahoti et al. "Mamba-3: Improved Sequence Modeling using State Space
43//!   Principles." arXiv:2603.15569, ICLR 2026. §2-3 (complex SSM, exp-trap).
44//! - Gu et al. "On the Parameterization and Initialization of Diagonal State
45//!   Space Models." NeurIPS 2022. (S4D, s4d_inv_complex init).
46//! - Proposition 2 (Mamba-3 paper): complex SSM of dim N/2 ≡ real SSM of dim N
47//!   with block-diagonal 2×2 rotation matrices.
48
49use alloc::vec;
50use alloc::vec::Vec;
51
52use crate::math;
53use crate::ssm::discretize::{exp_trapezoidal_complex, trapezoidal_complex};
54use crate::ssm::init::s4d_inv_complex;
55
56/// Discretization method for [`ComplexDiagonalSSM`].
57///
58/// Selects how the continuous-time complex eigenvalue A is mapped to a
59/// discrete-time transition coefficient α.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum DiscretizeMethod {
62    /// Bilinear (Tustin) transform: S4-style, 2-term recurrence.
63    ///
64    /// `α = (I + Δ/2·A)(I - Δ/2·A)⁻¹`. Maps left-half s-plane to unit disk
65    /// exactly. Good for S4-style oscillatory SSMs.
66    Tustin,
67    /// Exponential-trapezoidal: Mamba-3 spec, 3-term recurrence.
68    ///
69    /// `α = exp(Δ·A)`. Stronger stability (no Δ constraint). Requires
70    /// `lambda` per step — either fixed or data-dependent from the model.
71    ///
72    /// See `exp_trapezoidal_complex` for the full 3-term derivation.
73    /// (Lahoti et al., arXiv:2603.15569, ICLR 2026, Table 1.)
74    ExpTrapezoidal,
75}
76
77/// Complex Diagonal SSM — standalone streaming primitive.
78///
79/// Maintains a complex hidden state `h ∈ C^N` (stored as 2N real values,
80/// interleaved re/im) evolving via a stable complex diagonal recurrence.
81/// The real-valued scalar output `y = Re(C^T · h)` projects the complex
82/// state back to a real observation.
83///
84/// ## State layout
85///
86/// `h` is a flat `Vec<f64>` of length `2 * n_state`:
87/// `[re_0, im_0, re_1, im_1, ..., re_{N-1}, im_{N-1}]`.
88///
89/// ## Stability guarantee
90///
91/// The A parameterization `A_n = -exp(log_a_complex[2n]) + j·log_a_complex[2n+1]`
92/// ensures `Re(A_n) < 0` structurally. Combined with either Tustin or
93/// exp-trapezoidal discretization, the spectral radius of the state transition
94/// is guaranteed < 1 for any valid Δ > 0.
95///
96/// ## Initialization
97///
98/// Default: S4D-Inv complex (`s4d_inv_complex`), which gives harmonically-spaced
99/// eigenvalues with oscillatory imaginary parts. This is the Mamba-3 default.
100///
101/// # Example
102///
103/// ```
104/// use irithyll_core::ssm::complex_diag::{ComplexDiagonalSSM, DiscretizeMethod};
105///
106/// let mut cell = ComplexDiagonalSSM::new(8, DiscretizeMethod::Tustin);
107/// let b = vec![1.0; 8];
108/// let c = vec![1.0; 8];
109/// let y = cell.step(0.1, &b, &c, 1.0, 0.5);
110/// assert!(y.is_finite(), "output must be finite");
111/// assert_eq!(cell.state().len(), 16, "state is 2*n_state complex values");
112/// ```
113pub struct ComplexDiagonalSSM {
114    /// Log-magnitude of A's real part and direct imaginary values.
115    /// Layout: `[log|re_0|, im_0, log|re_1|, im_1, ...]` (length = 2*n_state).
116    /// Actual: `A_n = -exp(log_a[2n]) + j·log_a[2n+1]`.
117    log_a_complex: Vec<f64>,
118    /// Complex hidden state (length = 2 * n_state, interleaved re/im).
119    h: Vec<f64>,
120    /// Number of complex state dimensions.
121    n_state: usize,
122    /// Previous B·x contribution for 3-term recurrence (exp-trapezoidal only).
123    /// Layout: `[re_0, im_0, re_1, im_1, ...]` (length = 2 * n_state).
124    prev_bx_re: Vec<f64>,
125    prev_bx_im: Vec<f64>,
126    /// Discretization method.
127    method: DiscretizeMethod,
128}
129
130impl ComplexDiagonalSSM {
131    /// Create a new complex diagonal SSM with S4D-Inv complex initialization.
132    ///
133    /// Uses `s4d_inv_complex` for harmonically-spaced eigenvalues with
134    /// oscillatory imaginary parts (Gu et al., NeurIPS 2022).
135    ///
136    /// # Arguments
137    ///
138    /// * `n_state` -- number of complex state dimensions (total state dim = 2*N)
139    /// * `method` -- discretization method (Tustin or ExpTrapezoidal)
140    pub fn new(n_state: usize, method: DiscretizeMethod) -> Self {
141        let log_a_complex = s4d_inv_complex(n_state);
142        debug_assert!(
143            log_a_complex
144                .iter()
145                .enumerate()
146                .step_by(2)
147                .all(|(_i, &v)| v < 20.0),
148            "log|re| values from s4d_inv_complex must not overflow exp (< 20.0), \
149             but some exceed threshold. Max state dim where ln(0.5+N/1) > 20 is N > e^20 ≈ 5e8."
150        );
151        Self {
152            h: vec![0.0; 2 * n_state],
153            prev_bx_re: vec![0.0; n_state],
154            prev_bx_im: vec![0.0; n_state],
155            n_state,
156            log_a_complex,
157            method,
158        }
159    }
160
161    /// Create a ComplexDiagonalSSM with custom A-matrix log-parameters.
162    ///
163    /// # Arguments
164    ///
165    /// * `log_a_complex` -- 2*n_state values in `[log|re_0|, im_0, ...]` layout.
166    ///   Real parts: `A_re = -exp(log_a[2n])` (must satisfy `log_a[2n] > 0` for |A_re| > 1).
167    /// * `method` -- discretization method
168    ///
169    /// # Panics
170    ///
171    /// Panics if `log_a_complex.len()` is not even.
172    pub fn with_init(log_a_complex: Vec<f64>, method: DiscretizeMethod) -> Self {
173        assert!(
174            log_a_complex.len() % 2 == 0,
175            "log_a_complex must have even length (interleaved re/im), got {}",
176            log_a_complex.len()
177        );
178        let n_state = log_a_complex.len() / 2;
179        Self {
180            h: vec![0.0; 2 * n_state],
181            prev_bx_re: vec![0.0; n_state],
182            prev_bx_im: vec![0.0; n_state],
183            n_state,
184            log_a_complex,
185            method,
186        }
187    }
188
189    /// Advance state by one timestep and return the real-valued scalar output.
190    ///
191    /// Implements the SISO (single-input single-output) forward pass:
192    ///
193    /// ```text
194    /// A_n = -exp(log_a[2n]) + j·log_a[2n+1]
195    /// (α, β, γ) = discretize(A_n, delta, lambda)
196    /// h_n ← α·h_n + β·prev_bx_n + γ·(B[n]·x)
197    /// y += Re(C[n]·h_n)  = C[n] · Re(h_n)  (real C case)
198    /// ```
199    ///
200    /// For [`DiscretizeMethod::Tustin`], the 3-term β·prev_bx term is zero
201    /// (β=0 by construction — `prev_bx` is not used).
202    ///
203    /// # Arguments
204    ///
205    /// * `delta` -- step size (positive, data-dependent in Mamba-3)
206    /// * `b` -- real input projection vector (length = n_state)
207    /// * `c` -- real output projection vector (length = n_state)
208    /// * `x` -- scalar input at this timestep
209    /// * `lambda` -- exp-trapezoidal mixing parameter ∈ [0,1] (ignored for Tustin)
210    ///
211    /// # Returns
212    ///
213    /// Real scalar output `y = Re(C^T · h)`.
214    pub fn step(&mut self, delta: f64, b: &[f64], c: &[f64], x: f64, lambda: f64) -> f64 {
215        debug_assert_eq!(b.len(), self.n_state, "b must have n_state elements");
216        debug_assert_eq!(c.len(), self.n_state, "c must have n_state elements");
217
218        let mut y = 0.0;
219
220        for n in 0..self.n_state {
221            let a_re = -math::exp(self.log_a_complex[2 * n]);
222            let a_im = self.log_a_complex[2 * n + 1];
223
224            // Current B·x contribution (real B, scalar x → real scalar)
225            let bx = b[n] * x;
226
227            // Compute state update based on discretization method
228            let (h_re_new, h_im_new) = match self.method {
229                DiscretizeMethod::Tustin => {
230                    let (a_bar_re, a_bar_im, b_fac_re, b_fac_im) =
231                        trapezoidal_complex(a_re, a_im, delta);
232                    // 2-term: h = α·h + b_fac·bx
233                    let h_re_old = self.h[2 * n];
234                    let h_im_old = self.h[2 * n + 1];
235                    let h_re = a_bar_re * h_re_old - a_bar_im * h_im_old + b_fac_re * bx;
236                    let h_im = a_bar_re * h_im_old + a_bar_im * h_re_old + b_fac_im * bx;
237                    (h_re, h_im)
238                }
239                DiscretizeMethod::ExpTrapezoidal => {
240                    let (alpha_re, alpha_im, beta_re, beta_im, gamma_re, gamma_im) =
241                        exp_trapezoidal_complex(a_re, a_im, delta, lambda);
242
243                    let h_re_old = self.h[2 * n];
244                    let h_im_old = self.h[2 * n + 1];
245
246                    // 3-term: h = α·h + β·prev_bx + γ·bx
247                    // α·h (complex × complex):
248                    let ah_re = alpha_re * h_re_old - alpha_im * h_im_old;
249                    let ah_im = alpha_re * h_im_old + alpha_im * h_re_old;
250
251                    // β·prev_bx (complex β, real prev_bx stored as [re, im] of complex state):
252                    let pbx_re = self.prev_bx_re[n];
253                    let pbx_im = self.prev_bx_im[n];
254                    let b_prev_re = beta_re * pbx_re - beta_im * pbx_im;
255                    let b_prev_im = beta_re * pbx_im + beta_im * pbx_re;
256
257                    // γ·bx (real γ_re from paper: γ = λ·Δ is real, γ_im=0):
258                    let b_curr_re = gamma_re * bx;
259                    let b_curr_im = gamma_im * bx;
260
261                    let h_re = ah_re + b_prev_re + b_curr_re;
262                    let h_im = ah_im + b_prev_im + b_curr_im;
263                    (h_re, h_im)
264                }
265            };
266
267            self.h[2 * n] = h_re_new;
268            self.h[2 * n + 1] = h_im_new;
269
270            // Cache B·x as complex for next step's β term (only meaningful for ExpTrapezoidal)
271            // For Tustin this is a no-op store that adds no cost.
272            self.prev_bx_re[n] = bx;
273            self.prev_bx_im[n] = 0.0; // real B·x has no imaginary part
274
275            // Output: y += Re(C[n] · h_n) = C[n] · Re(h_n) (real C)
276            y += c[n] * h_re_new;
277        }
278
279        y
280    }
281
282    /// Advance state with complex B and C projections.
283    ///
284    /// Returns `Re(C^* · h)` where `C^*` is the complex conjugate of C.
285    /// This is the full complex SSM output per Mamba-3 Proposition 2:
286    /// `y = Re((C_re + j·C_im)^* · h) = C_re·Re(h) + C_im·Im(h)`.
287    ///
288    /// # Arguments
289    ///
290    /// * `delta` -- step size
291    /// * `b_re` -- real part of B vector (length = n_state)
292    /// * `b_im` -- imaginary part of B vector (length = n_state)
293    /// * `c_re` -- real part of C vector (length = n_state)
294    /// * `c_im` -- imaginary part of C vector (length = n_state)
295    /// * `x` -- scalar input
296    /// * `lambda` -- exp-trapezoidal λ_t (ignored for Tustin)
297    ///
298    /// # Returns
299    ///
300    /// Real scalar output `Re(C^* · h)`.
301    #[allow(clippy::too_many_arguments)]
302    pub fn step_complex(
303        &mut self,
304        delta: f64,
305        b_re: &[f64],
306        b_im: &[f64],
307        c_re: &[f64],
308        c_im: &[f64],
309        x: f64,
310        lambda: f64,
311    ) -> f64 {
312        debug_assert_eq!(b_re.len(), self.n_state);
313        debug_assert_eq!(b_im.len(), self.n_state);
314        debug_assert_eq!(c_re.len(), self.n_state);
315        debug_assert_eq!(c_im.len(), self.n_state);
316
317        let mut y = 0.0;
318
319        for n in 0..self.n_state {
320            let a_re = -math::exp(self.log_a_complex[2 * n]);
321            let a_im = self.log_a_complex[2 * n + 1];
322
323            // Complex B·x: bx = (B_re[n] + j·B_im[n]) * x
324            let bx_re = b_re[n] * x;
325            let bx_im = b_im[n] * x;
326
327            let (h_re_new, h_im_new) = match self.method {
328                DiscretizeMethod::Tustin => {
329                    let (a_bar_re, a_bar_im, b_fac_re, b_fac_im) =
330                        trapezoidal_complex(a_re, a_im, delta);
331                    let h_re_old = self.h[2 * n];
332                    let h_im_old = self.h[2 * n + 1];
333                    // α·h
334                    let ah_re = a_bar_re * h_re_old - a_bar_im * h_im_old;
335                    let ah_im = a_bar_re * h_im_old + a_bar_im * h_re_old;
336                    // b_fac · bx (complex × complex):
337                    let b_contrib_re = b_fac_re * bx_re - b_fac_im * bx_im;
338                    let b_contrib_im = b_fac_re * bx_im + b_fac_im * bx_re;
339                    (ah_re + b_contrib_re, ah_im + b_contrib_im)
340                }
341                DiscretizeMethod::ExpTrapezoidal => {
342                    let (alpha_re, alpha_im, beta_re, beta_im, gamma_re, gamma_im) =
343                        exp_trapezoidal_complex(a_re, a_im, delta, lambda);
344
345                    let h_re_old = self.h[2 * n];
346                    let h_im_old = self.h[2 * n + 1];
347
348                    // α·h
349                    let ah_re = alpha_re * h_re_old - alpha_im * h_im_old;
350                    let ah_im = alpha_re * h_im_old + alpha_im * h_re_old;
351
352                    // β·prev_bx (both complex)
353                    let pbx_re = self.prev_bx_re[n];
354                    let pbx_im = self.prev_bx_im[n];
355                    let b_prev_re = beta_re * pbx_re - beta_im * pbx_im;
356                    let b_prev_im = beta_re * pbx_im + beta_im * pbx_re;
357
358                    // γ·bx (γ is real: gamma_im=0)
359                    let b_curr_re = gamma_re * bx_re - gamma_im * bx_im;
360                    let b_curr_im = gamma_re * bx_im + gamma_im * bx_re;
361
362                    (ah_re + b_prev_re + b_curr_re, ah_im + b_prev_im + b_curr_im)
363                }
364            };
365
366            self.h[2 * n] = h_re_new;
367            self.h[2 * n + 1] = h_im_new;
368
369            // Cache complex B·x for 3-term recurrence
370            self.prev_bx_re[n] = bx_re;
371            self.prev_bx_im[n] = bx_im;
372
373            // Re(C^* · h) = C_re · Re(h) + C_im · Im(h)
374            y += c_re[n] * h_re_new + c_im[n] * h_im_new;
375        }
376
377        y
378    }
379
380    /// Per-state-dimension L2 energy for plasticity and diagnostics.
381    ///
382    /// Returns a vec of length `n_state` where each element is
383    /// `sqrt(Re(h_n)² + Im(h_n)²)` — the magnitude of the n-th complex state.
384    ///
385    /// Bounded by stability: `|h_n| ≤ Σ_{t} |α|^{T-t} · |input_t|`. Under a
386    /// stable recurrence, this converges to a finite value.
387    pub fn state_energies(&self) -> Vec<f64> {
388        (0..self.n_state)
389            .map(|n| {
390                let re = self.h[2 * n];
391                let im = self.h[2 * n + 1];
392                math::sqrt(re * re + im * im)
393            })
394            .collect()
395    }
396
397    /// Get the complex hidden state as interleaved re/im f64 slice.
398    ///
399    /// Layout: `[re_0, im_0, re_1, im_1, ...]`, length = `2 * n_state`.
400    #[inline]
401    pub fn state(&self) -> &[f64] {
402        &self.h
403    }
404
405    /// Number of complex state dimensions.
406    #[inline]
407    pub fn n_state(&self) -> usize {
408        self.n_state
409    }
410
411    /// Current discretization method.
412    #[inline]
413    pub fn method(&self) -> DiscretizeMethod {
414        self.method
415    }
416
417    /// Reset the hidden state and previous B·x cache to zero.
418    ///
419    /// Clears all temporal memory without changing A-matrix parameters.
420    pub fn reset(&mut self) {
421        self.h.fill(0.0);
422        self.prev_bx_re.fill(0.0);
423        self.prev_bx_im.fill(0.0);
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    /// Stability over 10^6 steps: complex state magnitude must stay bounded.
432    /// Validates the structural stability guarantee from the A parameterization.
433    #[test]
434    fn complex_diag_million_step_finite() {
435        let mut cell = ComplexDiagonalSSM::new(8, DiscretizeMethod::ExpTrapezoidal);
436        let b: Vec<f64> = (0..8).map(|n| 0.1 * (n as f64 + 1.0)).collect();
437        let c: Vec<f64> = (0..8).map(|n| 0.1 * (n as f64 + 1.0)).collect();
438
439        let mut max_abs_output = 0.0_f64;
440        for step in 0..1_000_000u64 {
441            let x = if step % 2 == 0 { 1.0 } else { -1.0 };
442            let lambda = 0.5;
443            let delta = 0.1;
444            let y = cell.step(delta, &b, &c, x, lambda);
445            assert!(
446                y.is_finite(),
447                "output must be finite at step {}: got {}",
448                step,
449                y
450            );
451            max_abs_output = max_abs_output.max(y.abs());
452        }
453
454        // State must remain finite
455        for (n, &s) in cell.state().iter().enumerate() {
456            assert!(
457                s.is_finite(),
458                "state[{}] must be finite after 10^6 steps: got {}",
459                n,
460                s
461            );
462        }
463
464        // State must be bounded (not growing without bound)
465        let state_norm: f64 = cell.state().iter().map(|s| s * s).sum::<f64>().sqrt();
466        assert!(
467            state_norm < 1e6,
468            "state Frobenius norm must be bounded after 10^6 steps: got {}",
469            state_norm
470        );
471    }
472
473    #[test]
474    fn complex_diag_tustin_stable() {
475        let mut cell = ComplexDiagonalSSM::new(4, DiscretizeMethod::Tustin);
476        let b = vec![0.1; 4];
477        let c = vec![0.1; 4];
478        for step in 0..1000 {
479            let y = cell.step(0.1, &b, &c, 1.0, 0.5);
480            assert!(
481                y.is_finite(),
482                "Tustin output must be finite at step {}",
483                step
484            );
485        }
486        for &s in cell.state() {
487            assert!(s.is_finite(), "Tustin state must remain finite");
488        }
489    }
490
491    #[test]
492    fn complex_diag_reset_clears_state() {
493        let mut cell = ComplexDiagonalSSM::new(4, DiscretizeMethod::Tustin);
494        let b = vec![1.0; 4];
495        let c = vec![1.0; 4];
496        let _ = cell.step(0.1, &b, &c, 1.0, 0.5);
497
498        let energy_before: f64 = cell.state().iter().map(|s| s * s).sum();
499        assert!(energy_before > 0.0, "state must be non-zero after step");
500
501        cell.reset();
502        for &s in cell.state() {
503            assert!(s.abs() < 1e-15, "state must be zero after reset, got {}", s);
504        }
505        for &s in &cell.prev_bx_re {
506            assert!(s.abs() < 1e-15, "prev_bx_re must be zero after reset");
507        }
508    }
509
510    #[test]
511    fn complex_diag_zero_input_zero_output_from_zero_state() {
512        let mut cell = ComplexDiagonalSSM::new(4, DiscretizeMethod::ExpTrapezoidal);
513        let b = vec![1.0; 4];
514        let c = vec![1.0; 4];
515        let y = cell.step(0.1, &b, &c, 0.0, 0.5);
516        assert!(
517            y.abs() < 1e-15,
518            "zero input from zero state must give zero output, got {}",
519            y
520        );
521    }
522
523    #[test]
524    fn complex_diag_state_energies_bounded() {
525        let mut cell = ComplexDiagonalSSM::new(8, DiscretizeMethod::ExpTrapezoidal);
526        let b = vec![0.5; 8];
527        let c = vec![0.5; 8];
528        for _ in 0..1000 {
529            let _ = cell.step(0.1, &b, &c, 1.0, 0.5);
530        }
531        let energies = cell.state_energies();
532        assert_eq!(
533            energies.len(),
534            8,
535            "state_energies must have n_state entries"
536        );
537        for (n, &e) in energies.iter().enumerate() {
538            assert!(
539                e.is_finite() && e >= 0.0,
540                "energy[{}] must be finite non-negative, got {}",
541                n,
542                e
543            );
544        }
545    }
546
547    #[test]
548    fn complex_diag_with_init_custom_params() {
549        // Custom init with 2 complex state dims
550        let log_a = vec![
551            0.5, 1.0, // n=0: A_re=-exp(0.5)≈-1.65, A_im=1.0
552            1.0, 2.0, // n=1: A_re=-exp(1.0)≈-2.72, A_im=2.0
553        ];
554        let mut cell = ComplexDiagonalSSM::with_init(log_a, DiscretizeMethod::Tustin);
555        assert_eq!(cell.n_state(), 2);
556        let b = vec![0.5, 0.5];
557        let c = vec![0.5, 0.5];
558        let y = cell.step(0.1, &b, &c, 1.0, 0.5);
559        assert!(y.is_finite());
560    }
561
562    #[test]
563    fn complex_diag_step_complex_produces_finite_output() {
564        let mut cell = ComplexDiagonalSSM::new(4, DiscretizeMethod::ExpTrapezoidal);
565        let b_re = vec![0.3; 4];
566        let b_im = vec![0.1; 4];
567        let c_re = vec![0.3; 4];
568        let c_im = vec![0.1; 4];
569        for _ in 0..100 {
570            let y = cell.step_complex(0.1, &b_re, &b_im, &c_re, &c_im, 1.0, 0.5);
571            assert!(
572                y.is_finite(),
573                "complex step output must be finite: got {}",
574                y
575            );
576        }
577    }
578
579    /// The 3-term recurrence (exp-trap) and 2-term (Tustin) must agree at Δ→0.
580    /// This validates that `prev_bx` cache does not corrupt when beta→0.
581    #[test]
582    fn complex_diag_exp_trap_and_tustin_agree_small_delta() {
583        let n = 4;
584        let log_a = vec![0.5, 0.3, 1.0, 0.5, 1.5, 0.8, 2.0, 1.0];
585        let mut cell_et =
586            ComplexDiagonalSSM::with_init(log_a.clone(), DiscretizeMethod::ExpTrapezoidal);
587        let mut cell_tu = ComplexDiagonalSSM::with_init(log_a, DiscretizeMethod::Tustin);
588        let b = vec![0.2_f64; n];
589        let c = vec![0.2_f64; n];
590        let delta = 0.0001; // very small delta
591        let lambda = 1.0; // λ=1 → β=0 → exp-trap collapses to 2-term
592
593        let mut y_et = 0.0;
594        let mut y_tu = 0.0;
595        for _ in 0..10 {
596            y_et = cell_et.step(delta, &b, &c, 1.0, lambda);
597            y_tu = cell_tu.step(delta, &b, &c, 1.0, 0.5);
598        }
599        // At λ=1 and tiny Δ, exp-trap ≈ Tustin
600        assert!(
601            (y_et - y_tu).abs() < 1e-4,
602            "at small delta and lambda=1, exp-trap should approximate Tustin: et={}, tu={}",
603            y_et,
604            y_tu
605        );
606    }
607
608    #[test]
609    fn complex_diag_n_state_accessor() {
610        let cell = ComplexDiagonalSSM::new(16, DiscretizeMethod::Tustin);
611        assert_eq!(cell.n_state(), 16);
612        assert_eq!(cell.state().len(), 32);
613    }
614}