use alloc::vec;
use alloc::vec::Vec;
use crate::math;
use crate::ssm::discretize::{exp_trapezoidal_complex, trapezoidal_complex};
use crate::ssm::init::s4d_inv_complex;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DiscretizeMethod {
Tustin,
ExpTrapezoidal,
}
pub struct ComplexDiagonalSSM {
log_a_complex: Vec<f64>,
h: Vec<f64>,
n_state: usize,
prev_bx_re: Vec<f64>,
prev_bx_im: Vec<f64>,
method: DiscretizeMethod,
}
impl ComplexDiagonalSSM {
pub fn new(n_state: usize, method: DiscretizeMethod) -> Self {
let log_a_complex = s4d_inv_complex(n_state);
debug_assert!(
log_a_complex
.iter()
.enumerate()
.step_by(2)
.all(|(_i, &v)| v < 20.0),
"log|re| values from s4d_inv_complex must not overflow exp (< 20.0), \
but some exceed threshold. Max state dim where ln(0.5+N/1) > 20 is N > e^20 ≈ 5e8."
);
Self {
h: vec![0.0; 2 * n_state],
prev_bx_re: vec![0.0; n_state],
prev_bx_im: vec![0.0; n_state],
n_state,
log_a_complex,
method,
}
}
pub fn with_init(log_a_complex: Vec<f64>, method: DiscretizeMethod) -> Self {
assert!(
log_a_complex.len() % 2 == 0,
"log_a_complex must have even length (interleaved re/im), got {}",
log_a_complex.len()
);
let n_state = log_a_complex.len() / 2;
Self {
h: vec![0.0; 2 * n_state],
prev_bx_re: vec![0.0; n_state],
prev_bx_im: vec![0.0; n_state],
n_state,
log_a_complex,
method,
}
}
pub fn step(&mut self, delta: f64, b: &[f64], c: &[f64], x: f64, lambda: f64) -> f64 {
debug_assert_eq!(b.len(), self.n_state, "b must have n_state elements");
debug_assert_eq!(c.len(), self.n_state, "c must have n_state elements");
let mut y = 0.0;
for n in 0..self.n_state {
let a_re = -math::exp(self.log_a_complex[2 * n]);
let a_im = self.log_a_complex[2 * n + 1];
let bx = b[n] * x;
let (h_re_new, h_im_new) = match self.method {
DiscretizeMethod::Tustin => {
let (a_bar_re, a_bar_im, b_fac_re, b_fac_im) =
trapezoidal_complex(a_re, a_im, delta);
let h_re_old = self.h[2 * n];
let h_im_old = self.h[2 * n + 1];
let h_re = a_bar_re * h_re_old - a_bar_im * h_im_old + b_fac_re * bx;
let h_im = a_bar_re * h_im_old + a_bar_im * h_re_old + b_fac_im * bx;
(h_re, h_im)
}
DiscretizeMethod::ExpTrapezoidal => {
let (alpha_re, alpha_im, beta_re, beta_im, gamma_re, gamma_im) =
exp_trapezoidal_complex(a_re, a_im, delta, lambda);
let h_re_old = self.h[2 * n];
let h_im_old = self.h[2 * n + 1];
let ah_re = alpha_re * h_re_old - alpha_im * h_im_old;
let ah_im = alpha_re * h_im_old + alpha_im * h_re_old;
let pbx_re = self.prev_bx_re[n];
let pbx_im = self.prev_bx_im[n];
let b_prev_re = beta_re * pbx_re - beta_im * pbx_im;
let b_prev_im = beta_re * pbx_im + beta_im * pbx_re;
let b_curr_re = gamma_re * bx;
let b_curr_im = gamma_im * bx;
let h_re = ah_re + b_prev_re + b_curr_re;
let h_im = ah_im + b_prev_im + b_curr_im;
(h_re, h_im)
}
};
self.h[2 * n] = h_re_new;
self.h[2 * n + 1] = h_im_new;
self.prev_bx_re[n] = bx;
self.prev_bx_im[n] = 0.0;
y += c[n] * h_re_new;
}
y
}
#[allow(clippy::too_many_arguments)]
pub fn step_complex(
&mut self,
delta: f64,
b_re: &[f64],
b_im: &[f64],
c_re: &[f64],
c_im: &[f64],
x: f64,
lambda: f64,
) -> f64 {
debug_assert_eq!(b_re.len(), self.n_state);
debug_assert_eq!(b_im.len(), self.n_state);
debug_assert_eq!(c_re.len(), self.n_state);
debug_assert_eq!(c_im.len(), self.n_state);
let mut y = 0.0;
for n in 0..self.n_state {
let a_re = -math::exp(self.log_a_complex[2 * n]);
let a_im = self.log_a_complex[2 * n + 1];
let bx_re = b_re[n] * x;
let bx_im = b_im[n] * x;
let (h_re_new, h_im_new) = match self.method {
DiscretizeMethod::Tustin => {
let (a_bar_re, a_bar_im, b_fac_re, b_fac_im) =
trapezoidal_complex(a_re, a_im, delta);
let h_re_old = self.h[2 * n];
let h_im_old = self.h[2 * n + 1];
let ah_re = a_bar_re * h_re_old - a_bar_im * h_im_old;
let ah_im = a_bar_re * h_im_old + a_bar_im * h_re_old;
let b_contrib_re = b_fac_re * bx_re - b_fac_im * bx_im;
let b_contrib_im = b_fac_re * bx_im + b_fac_im * bx_re;
(ah_re + b_contrib_re, ah_im + b_contrib_im)
}
DiscretizeMethod::ExpTrapezoidal => {
let (alpha_re, alpha_im, beta_re, beta_im, gamma_re, gamma_im) =
exp_trapezoidal_complex(a_re, a_im, delta, lambda);
let h_re_old = self.h[2 * n];
let h_im_old = self.h[2 * n + 1];
let ah_re = alpha_re * h_re_old - alpha_im * h_im_old;
let ah_im = alpha_re * h_im_old + alpha_im * h_re_old;
let pbx_re = self.prev_bx_re[n];
let pbx_im = self.prev_bx_im[n];
let b_prev_re = beta_re * pbx_re - beta_im * pbx_im;
let b_prev_im = beta_re * pbx_im + beta_im * pbx_re;
let b_curr_re = gamma_re * bx_re - gamma_im * bx_im;
let b_curr_im = gamma_re * bx_im + gamma_im * bx_re;
(ah_re + b_prev_re + b_curr_re, ah_im + b_prev_im + b_curr_im)
}
};
self.h[2 * n] = h_re_new;
self.h[2 * n + 1] = h_im_new;
self.prev_bx_re[n] = bx_re;
self.prev_bx_im[n] = bx_im;
y += c_re[n] * h_re_new + c_im[n] * h_im_new;
}
y
}
pub fn state_energies(&self) -> Vec<f64> {
(0..self.n_state)
.map(|n| {
let re = self.h[2 * n];
let im = self.h[2 * n + 1];
math::sqrt(re * re + im * im)
})
.collect()
}
#[inline]
pub fn state(&self) -> &[f64] {
&self.h
}
#[inline]
pub fn n_state(&self) -> usize {
self.n_state
}
#[inline]
pub fn method(&self) -> DiscretizeMethod {
self.method
}
pub fn reset(&mut self) {
self.h.fill(0.0);
self.prev_bx_re.fill(0.0);
self.prev_bx_im.fill(0.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn complex_diag_million_step_finite() {
let mut cell = ComplexDiagonalSSM::new(8, DiscretizeMethod::ExpTrapezoidal);
let b: Vec<f64> = (0..8).map(|n| 0.1 * (n as f64 + 1.0)).collect();
let c: Vec<f64> = (0..8).map(|n| 0.1 * (n as f64 + 1.0)).collect();
let mut max_abs_output = 0.0_f64;
for step in 0..1_000_000u64 {
let x = if step % 2 == 0 { 1.0 } else { -1.0 };
let lambda = 0.5;
let delta = 0.1;
let y = cell.step(delta, &b, &c, x, lambda);
assert!(
y.is_finite(),
"output must be finite at step {}: got {}",
step,
y
);
max_abs_output = max_abs_output.max(y.abs());
}
for (n, &s) in cell.state().iter().enumerate() {
assert!(
s.is_finite(),
"state[{}] must be finite after 10^6 steps: got {}",
n,
s
);
}
let state_norm: f64 = cell.state().iter().map(|s| s * s).sum::<f64>().sqrt();
assert!(
state_norm < 1e6,
"state Frobenius norm must be bounded after 10^6 steps: got {}",
state_norm
);
}
#[test]
fn complex_diag_tustin_stable() {
let mut cell = ComplexDiagonalSSM::new(4, DiscretizeMethod::Tustin);
let b = vec![0.1; 4];
let c = vec![0.1; 4];
for step in 0..1000 {
let y = cell.step(0.1, &b, &c, 1.0, 0.5);
assert!(
y.is_finite(),
"Tustin output must be finite at step {}",
step
);
}
for &s in cell.state() {
assert!(s.is_finite(), "Tustin state must remain finite");
}
}
#[test]
fn complex_diag_reset_clears_state() {
let mut cell = ComplexDiagonalSSM::new(4, DiscretizeMethod::Tustin);
let b = vec![1.0; 4];
let c = vec![1.0; 4];
let _ = cell.step(0.1, &b, &c, 1.0, 0.5);
let energy_before: f64 = cell.state().iter().map(|s| s * s).sum();
assert!(energy_before > 0.0, "state must be non-zero after step");
cell.reset();
for &s in cell.state() {
assert!(s.abs() < 1e-15, "state must be zero after reset, got {}", s);
}
for &s in &cell.prev_bx_re {
assert!(s.abs() < 1e-15, "prev_bx_re must be zero after reset");
}
}
#[test]
fn complex_diag_zero_input_zero_output_from_zero_state() {
let mut cell = ComplexDiagonalSSM::new(4, DiscretizeMethod::ExpTrapezoidal);
let b = vec![1.0; 4];
let c = vec![1.0; 4];
let y = cell.step(0.1, &b, &c, 0.0, 0.5);
assert!(
y.abs() < 1e-15,
"zero input from zero state must give zero output, got {}",
y
);
}
#[test]
fn complex_diag_state_energies_bounded() {
let mut cell = ComplexDiagonalSSM::new(8, DiscretizeMethod::ExpTrapezoidal);
let b = vec![0.5; 8];
let c = vec![0.5; 8];
for _ in 0..1000 {
let _ = cell.step(0.1, &b, &c, 1.0, 0.5);
}
let energies = cell.state_energies();
assert_eq!(
energies.len(),
8,
"state_energies must have n_state entries"
);
for (n, &e) in energies.iter().enumerate() {
assert!(
e.is_finite() && e >= 0.0,
"energy[{}] must be finite non-negative, got {}",
n,
e
);
}
}
#[test]
fn complex_diag_with_init_custom_params() {
let log_a = vec![
0.5, 1.0, 1.0, 2.0, ];
let mut cell = ComplexDiagonalSSM::with_init(log_a, DiscretizeMethod::Tustin);
assert_eq!(cell.n_state(), 2);
let b = vec![0.5, 0.5];
let c = vec![0.5, 0.5];
let y = cell.step(0.1, &b, &c, 1.0, 0.5);
assert!(y.is_finite());
}
#[test]
fn complex_diag_step_complex_produces_finite_output() {
let mut cell = ComplexDiagonalSSM::new(4, DiscretizeMethod::ExpTrapezoidal);
let b_re = vec![0.3; 4];
let b_im = vec![0.1; 4];
let c_re = vec![0.3; 4];
let c_im = vec![0.1; 4];
for _ in 0..100 {
let y = cell.step_complex(0.1, &b_re, &b_im, &c_re, &c_im, 1.0, 0.5);
assert!(
y.is_finite(),
"complex step output must be finite: got {}",
y
);
}
}
#[test]
fn complex_diag_exp_trap_and_tustin_agree_small_delta() {
let n = 4;
let log_a = vec![0.5, 0.3, 1.0, 0.5, 1.5, 0.8, 2.0, 1.0];
let mut cell_et =
ComplexDiagonalSSM::with_init(log_a.clone(), DiscretizeMethod::ExpTrapezoidal);
let mut cell_tu = ComplexDiagonalSSM::with_init(log_a, DiscretizeMethod::Tustin);
let b = vec![0.2_f64; n];
let c = vec![0.2_f64; n];
let delta = 0.0001; let lambda = 1.0;
let mut y_et = 0.0;
let mut y_tu = 0.0;
for _ in 0..10 {
y_et = cell_et.step(delta, &b, &c, 1.0, lambda);
y_tu = cell_tu.step(delta, &b, &c, 1.0, 0.5);
}
assert!(
(y_et - y_tu).abs() < 1e-4,
"at small delta and lambda=1, exp-trap should approximate Tustin: et={}, tu={}",
y_et,
y_tu
);
}
#[test]
fn complex_diag_n_state_accessor() {
let cell = ComplexDiagonalSSM::new(16, DiscretizeMethod::Tustin);
assert_eq!(cell.n_state(), 16);
assert_eq!(cell.state().len(), 32);
}
}