use num_complex::Complex;
use num_traits::Float;
pub fn complex_exp<T: Float>(z: Complex<T>) -> Complex<T> {
z.exp()
}
pub fn diag_complex_matrix_exp<T: Float>(lambdas: &[Complex<T>]) -> Vec<Complex<T>> {
lambdas.iter().map(|&z| z.exp()).collect()
}
pub fn conjugate_pair_output<T: Float>(states: &[Complex<T>], c_coeffs: &[Complex<T>]) -> T {
assert_eq!(states.len(), c_coeffs.len(), "states and c_coeffs must be the same length");
let two = T::one() + T::one();
let re_sum = states
.iter()
.zip(c_coeffs.iter())
.fold(T::zero(), |acc, (&h, &c)| acc + (c * h).re);
two * re_sum
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn euler_identity() {
let result = complex_exp(Complex::new(0.0_f64, PI));
assert!((result.re + 1.0).abs() < 1e-10, "re={}", result.re);
assert!(result.im.abs() < 1e-10, "im={}", result.im);
}
#[test]
fn diag_matrix_exp_matches_elementwise() {
let lambdas = vec![
Complex::new(0.0_f64, PI),
Complex::new(-1.0, 0.5),
Complex::new(0.0, 0.0),
];
let result = diag_complex_matrix_exp(&lambdas);
for (&l, r) in lambdas.iter().zip(result.iter()) {
let expected = complex_exp(l);
assert!((r.re - expected.re).abs() < 1e-12);
assert!((r.im - expected.im).abs() < 1e-12);
}
}
#[test]
fn diag_matrix_exp_empty() {
assert!(diag_complex_matrix_exp::<f64>(&[]).is_empty());
}
#[test]
fn conjugate_pair_output_matches_brute_force() {
let states = vec![
Complex::new(1.0_f64, 0.5),
Complex::new(-0.3, 0.7),
];
let c = vec![
Complex::new(2.0_f64, -1.0),
Complex::new(0.5, 0.5),
];
let y = conjugate_pair_output(&states, &c);
let dot = c[0] * states[0] + c[1] * states[1];
assert!((y - 2.0 * dot.re).abs() < 1e-12, "y={y}, expected={}", 2.0 * dot.re);
}
}