zilla-muf 0.1.1

Shared structured-matrix and numerical primitives for sparse attention and state space models (SSMs).
Documentation
// complex_ops.rs — complex-arithmetic helpers for S4-style models, which
// diagonalize the state matrix into complex eigenvalues. Kept as a thin,
// stable layer over `num-complex` so the SSM code has one canonical
// spelling for each operation even if the backing implementation changes.
use num_complex::Complex;
use num_traits::Float;

/// Complex exponential `e^z` for a complex argument.
///
/// Thin wrapper over `num_complex`'s own `exp`, exposed here so callers
/// have a single, stable name for "exponentiate a complex eigenvalue"
/// (used when discretizing or generating kernels for diagonalized state
/// matrices). For `z = x + iy` this returns `e^x · (cos y + i·sin y)`.
///
/// # Example
///
/// ```
/// use num_complex::Complex;
/// use zilla_muf::complex_ops::complex_exp;
/// use std::f64::consts::PI;
/// // Euler's identity: e^(iπ) = -1
/// let result = complex_exp(Complex::new(0.0_f64, PI));
/// assert!((result.re + 1.0).abs() < 1e-10);
/// assert!(result.im.abs() < 1e-10);
/// ```
pub fn complex_exp<T: Float>(z: Complex<T>) -> Complex<T> {
	z.exp()
}

/// Element-wise exponential of a diagonal complex matrix's eigenvalues.
///
/// For a diagonal matrix `A = diag(λ_0, …, λ_{n-1})`, returns
/// `[exp(λ_0), …, exp(λ_{n-1})]` — the diagonal of `exp(A)`.
/// This is the S4D ZOH discretization step: given continuous-time
/// eigenvalues `Λ`, the discrete-time transition is `diag(exp(Λ · Δt))`.
///
/// # Example
///
/// ```
/// use num_complex::Complex;
/// use zilla_muf::complex_ops::{complex_exp, diag_complex_matrix_exp};
/// let lambdas = vec![Complex::new(0.0_f64, 1.0), Complex::new(-1.0_f64, 0.0)];
/// let result = diag_complex_matrix_exp(&lambdas);
/// for (r, &l) in result.iter().zip(lambdas.iter()) {
///     let expected = complex_exp(l);
///     assert!((r.re - expected.re).abs() < 1e-12);
///     assert!((r.im - expected.im).abs() < 1e-12);
/// }
/// ```
pub fn diag_complex_matrix_exp<T: Float>(lambdas: &[Complex<T>]) -> Vec<Complex<T>> {
	lambdas.iter().map(|&z| z.exp()).collect()
}

/// Real output from a conjugate-pair-folded complex state.
///
/// In S4D and related models the state vector is stored as conjugate pairs
/// `(h_i, conj(h_i))`. The real output projection then reduces to
/// `2 · Re(Σ c_i · h_i)`, which avoids keeping the redundant conjugate
/// half explicitly.
///
/// `states` and `c_coeffs` must have the same length; both hold one
/// element per conjugate pair (the upper half).
///
/// # Example
///
/// ```
/// use num_complex::Complex;
/// use zilla_muf::complex_ops::conjugate_pair_output;
/// let states = vec![Complex::new(1.0_f64, 0.5), Complex::new(0.0_f64, -1.0)];
/// let c = vec![Complex::new(1.0_f64, 0.0), Complex::new(0.0_f64, 1.0)];
/// let y = conjugate_pair_output(&states, &c);
/// // brute-force: 2 * Re(c[0]*h[0] + c[1]*h[1])
/// let dot = c[0] * states[0] + c[1] * states[1];
/// assert!((y - 2.0 * dot.re).abs() < 1e-12);
/// ```
pub fn conjugate_pair_output<T: Float>(states: &[Complex<T>], c_coeffs: &[Complex<T>]) -> T {
	assert_eq!(states.len(), c_coeffs.len(), "states and c_coeffs must be the same length");
	let two = T::one() + T::one();
	let re_sum = states
		.iter()
		.zip(c_coeffs.iter())
		.fold(T::zero(), |acc, (&h, &c)| acc + (c * h).re);
	two * re_sum
}

#[cfg(test)]
mod tests {
	use super::*;
	use std::f64::consts::PI;

	#[test]
	fn euler_identity() {
		// e^(iπ) = -1 + 0i
		let result = complex_exp(Complex::new(0.0_f64, PI));
		assert!((result.re + 1.0).abs() < 1e-10, "re={}", result.re);
		assert!(result.im.abs() < 1e-10, "im={}", result.im);
	}

	#[test]
	fn diag_matrix_exp_matches_elementwise() {
		let lambdas = vec![
			Complex::new(0.0_f64, PI),
			Complex::new(-1.0, 0.5),
			Complex::new(0.0, 0.0),
		];
		let result = diag_complex_matrix_exp(&lambdas);
		for (&l, r) in lambdas.iter().zip(result.iter()) {
			let expected = complex_exp(l);
			assert!((r.re - expected.re).abs() < 1e-12);
			assert!((r.im - expected.im).abs() < 1e-12);
		}
	}

	#[test]
	fn diag_matrix_exp_empty() {
		assert!(diag_complex_matrix_exp::<f64>(&[]).is_empty());
	}

	#[test]
	fn conjugate_pair_output_matches_brute_force() {
		let states = vec![
			Complex::new(1.0_f64, 0.5),
			Complex::new(-0.3, 0.7),
		];
		let c = vec![
			Complex::new(2.0_f64, -1.0),
			Complex::new(0.5, 0.5),
		];
		let y = conjugate_pair_output(&states, &c);
		let dot = c[0] * states[0] + c[1] * states[1];
		assert!((y - 2.0 * dot.re).abs() < 1e-12, "y={y}, expected={}", 2.0 * dot.re);
	}
}