zilla_muf/complex_ops.rs
1// complex_ops.rs — complex-arithmetic helpers for S4-style models, which
2// diagonalize the state matrix into complex eigenvalues. Kept as a thin,
3// stable layer over `num-complex` so the SSM code has one canonical
4// spelling for each operation even if the backing implementation changes.
5use num_complex::Complex;
6use num_traits::Float;
7
8/// Complex exponential `e^z` for a complex argument.
9///
10/// Thin wrapper over `num_complex`'s own `exp`, exposed here so callers
11/// have a single, stable name for "exponentiate a complex eigenvalue"
12/// (used when discretizing or generating kernels for diagonalized state
13/// matrices). For `z = x + iy` this returns `e^x · (cos y + i·sin y)`.
14///
15/// # Example
16///
17/// ```
18/// use num_complex::Complex;
19/// use zilla_muf::complex_ops::complex_exp;
20/// use std::f64::consts::PI;
21/// // Euler's identity: e^(iπ) = -1
22/// let result = complex_exp(Complex::new(0.0_f64, PI));
23/// assert!((result.re + 1.0).abs() < 1e-10);
24/// assert!(result.im.abs() < 1e-10);
25/// ```
26pub fn complex_exp<T: Float>(z: Complex<T>) -> Complex<T> {
27 z.exp()
28}
29
30/// Element-wise exponential of a diagonal complex matrix's eigenvalues.
31///
32/// For a diagonal matrix `A = diag(λ_0, …, λ_{n-1})`, returns
33/// `[exp(λ_0), …, exp(λ_{n-1})]` — the diagonal of `exp(A)`.
34/// This is the S4D ZOH discretization step: given continuous-time
35/// eigenvalues `Λ`, the discrete-time transition is `diag(exp(Λ · Δt))`.
36///
37/// # Example
38///
39/// ```
40/// use num_complex::Complex;
41/// use zilla_muf::complex_ops::{complex_exp, diag_complex_matrix_exp};
42/// let lambdas = vec![Complex::new(0.0_f64, 1.0), Complex::new(-1.0_f64, 0.0)];
43/// let result = diag_complex_matrix_exp(&lambdas);
44/// for (r, &l) in result.iter().zip(lambdas.iter()) {
45/// let expected = complex_exp(l);
46/// assert!((r.re - expected.re).abs() < 1e-12);
47/// assert!((r.im - expected.im).abs() < 1e-12);
48/// }
49/// ```
50pub fn diag_complex_matrix_exp<T: Float>(lambdas: &[Complex<T>]) -> Vec<Complex<T>> {
51 lambdas.iter().map(|&z| z.exp()).collect()
52}
53
54/// Real output from a conjugate-pair-folded complex state.
55///
56/// In S4D and related models the state vector is stored as conjugate pairs
57/// `(h_i, conj(h_i))`. The real output projection then reduces to
58/// `2 · Re(Σ c_i · h_i)`, which avoids keeping the redundant conjugate
59/// half explicitly.
60///
61/// `states` and `c_coeffs` must have the same length; both hold one
62/// element per conjugate pair (the upper half).
63///
64/// # Example
65///
66/// ```
67/// use num_complex::Complex;
68/// use zilla_muf::complex_ops::conjugate_pair_output;
69/// let states = vec![Complex::new(1.0_f64, 0.5), Complex::new(0.0_f64, -1.0)];
70/// let c = vec![Complex::new(1.0_f64, 0.0), Complex::new(0.0_f64, 1.0)];
71/// let y = conjugate_pair_output(&states, &c);
72/// // brute-force: 2 * Re(c[0]*h[0] + c[1]*h[1])
73/// let dot = c[0] * states[0] + c[1] * states[1];
74/// assert!((y - 2.0 * dot.re).abs() < 1e-12);
75/// ```
76pub fn conjugate_pair_output<T: Float>(states: &[Complex<T>], c_coeffs: &[Complex<T>]) -> T {
77 assert_eq!(states.len(), c_coeffs.len(), "states and c_coeffs must be the same length");
78 let two = T::one() + T::one();
79 let re_sum = states
80 .iter()
81 .zip(c_coeffs.iter())
82 .fold(T::zero(), |acc, (&h, &c)| acc + (c * h).re);
83 two * re_sum
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use std::f64::consts::PI;
90
91 #[test]
92 fn euler_identity() {
93 // e^(iπ) = -1 + 0i
94 let result = complex_exp(Complex::new(0.0_f64, PI));
95 assert!((result.re + 1.0).abs() < 1e-10, "re={}", result.re);
96 assert!(result.im.abs() < 1e-10, "im={}", result.im);
97 }
98
99 #[test]
100 fn diag_matrix_exp_matches_elementwise() {
101 let lambdas = vec![
102 Complex::new(0.0_f64, PI),
103 Complex::new(-1.0, 0.5),
104 Complex::new(0.0, 0.0),
105 ];
106 let result = diag_complex_matrix_exp(&lambdas);
107 for (&l, r) in lambdas.iter().zip(result.iter()) {
108 let expected = complex_exp(l);
109 assert!((r.re - expected.re).abs() < 1e-12);
110 assert!((r.im - expected.im).abs() < 1e-12);
111 }
112 }
113
114 #[test]
115 fn diag_matrix_exp_empty() {
116 assert!(diag_complex_matrix_exp::<f64>(&[]).is_empty());
117 }
118
119 #[test]
120 fn conjugate_pair_output_matches_brute_force() {
121 let states = vec![
122 Complex::new(1.0_f64, 0.5),
123 Complex::new(-0.3, 0.7),
124 ];
125 let c = vec![
126 Complex::new(2.0_f64, -1.0),
127 Complex::new(0.5, 0.5),
128 ];
129 let y = conjugate_pair_output(&states, &c);
130 let dot = c[0] * states[0] + c[1] * states[1];
131 assert!((y - 2.0 * dot.re).abs() < 1e-12, "y={y}, expected={}", 2.0 * dot.re);
132 }
133}