use crate::accumulator::BinnedAccumulatorF64;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ComplexF64 {
pub re: f64,
pub im: f64,
}
impl ComplexF64 {
#[inline]
pub fn new(re: f64, im: f64) -> Self {
ComplexF64 { re, im }
}
#[inline]
pub fn real(re: f64) -> Self {
ComplexF64 { re, im: 0.0 }
}
#[inline]
pub fn imag(im: f64) -> Self {
ComplexF64 { re: 0.0, im }
}
pub const ZERO: ComplexF64 = ComplexF64 { re: 0.0, im: 0.0 };
pub const ONE: ComplexF64 = ComplexF64 { re: 1.0, im: 0.0 };
pub const I: ComplexF64 = ComplexF64 { re: 0.0, im: 1.0 };
#[inline]
pub fn norm_sq(self) -> f64 {
let r2 = self.re * self.re;
let i2 = self.im * self.im;
r2 + i2
}
#[inline]
pub fn abs(self) -> f64 {
self.norm_sq().sqrt()
}
#[inline]
pub fn conj(self) -> Self {
ComplexF64 { re: self.re, im: -self.im }
}
#[inline]
pub fn mul_fixed(self, rhs: Self) -> Self {
let t1 = self.re * rhs.re; let t2 = self.im * rhs.im; let t3 = self.re * rhs.im; let t4 = self.im * rhs.re;
let re = t1 - t2; let im = t3 + t4;
ComplexF64 { re, im }
}
#[inline]
pub fn add(self, rhs: Self) -> Self {
ComplexF64 {
re: self.re + rhs.re,
im: self.im + rhs.im,
}
}
#[inline]
pub fn sub(self, rhs: Self) -> Self {
ComplexF64 {
re: self.re - rhs.re,
im: self.im - rhs.im,
}
}
#[inline]
pub fn neg(self) -> Self {
ComplexF64 { re: -self.re, im: -self.im }
}
#[inline]
pub fn div_fixed(self, rhs: Self) -> Self {
let cc = rhs.re * rhs.re;
let dd = rhs.im * rhs.im;
let denom = cc + dd;
let ac = self.re * rhs.re;
let bd = self.im * rhs.im;
let re = (ac + bd) / denom;
let bc = self.im * rhs.re;
let ad = self.re * rhs.im;
let im = (bc - ad) / denom;
ComplexF64 { re, im }
}
#[inline]
pub fn scale(self, s: f64) -> Self {
ComplexF64 { re: s * self.re, im: s * self.im }
}
#[inline]
pub fn is_nan(self) -> bool {
self.re.is_nan() || self.im.is_nan()
}
#[inline]
pub fn is_finite(self) -> bool {
self.re.is_finite() && self.im.is_finite()
}
}
impl std::fmt::Display for ComplexF64 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.im >= 0.0 {
write!(f, "{}+{}i", self.re, self.im)
} else {
write!(f, "{}{}i", self.re, self.im)
}
}
}
pub fn complex_dot(a: &[ComplexF64], b: &[ComplexF64]) -> ComplexF64 {
debug_assert_eq!(a.len(), b.len());
let mut re_acc = BinnedAccumulatorF64::new();
let mut im_acc = BinnedAccumulatorF64::new();
for i in 0..a.len() {
let z = a[i].mul_fixed(b[i].conj());
re_acc.add(z.re);
im_acc.add(z.im);
}
ComplexF64 {
re: re_acc.finalize(),
im: im_acc.finalize(),
}
}
pub fn complex_sum(values: &[ComplexF64]) -> ComplexF64 {
let mut re_acc = BinnedAccumulatorF64::new();
let mut im_acc = BinnedAccumulatorF64::new();
for &z in values {
re_acc.add(z.re);
im_acc.add(z.im);
}
ComplexF64 {
re: re_acc.finalize(),
im: im_acc.finalize(),
}
}
pub fn complex_matmul(
a: &[ComplexF64], b: &[ComplexF64], out: &mut [ComplexF64],
m: usize, k: usize, n: usize,
) {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
debug_assert_eq!(out.len(), m * n);
for i in 0..m {
for j in 0..n {
let mut re_acc = BinnedAccumulatorF64::new();
let mut im_acc = BinnedAccumulatorF64::new();
for p in 0..k {
let prod = a[i * k + p].mul_fixed(b[p * n + j]);
re_acc.add(prod.re);
im_acc.add(prod.im);
}
out[i * n + j] = ComplexF64 {
re: re_acc.finalize(),
im: im_acc.finalize(),
};
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_complex_mul_basic() {
let a = ComplexF64::new(1.0, 2.0);
let b = ComplexF64::new(3.0, 4.0);
let c = a.mul_fixed(b);
assert_eq!(c.re, -5.0);
assert_eq!(c.im, 10.0);
}
#[test]
fn test_complex_mul_commutative() {
let a = ComplexF64::new(1.23456789, -9.87654321);
let b = ComplexF64::new(-3.14159265, 2.71828183);
let ab = a.mul_fixed(b);
let ba = b.mul_fixed(a);
assert_eq!(ab.re.to_bits(), ba.re.to_bits());
assert_eq!(ab.im.to_bits(), ba.im.to_bits());
}
#[test]
fn test_complex_mul_identity() {
let a = ComplexF64::new(7.0, -3.0);
let one = ComplexF64::ONE;
let result = a.mul_fixed(one);
assert_eq!(result.re, a.re);
assert_eq!(result.im, a.im);
}
#[test]
fn test_complex_mul_i_squared() {
let i = ComplexF64::I;
let result = i.mul_fixed(i);
assert_eq!(result.re, -1.0);
assert_eq!(result.im, 0.0);
}
#[test]
fn test_complex_conj() {
let z = ComplexF64::new(3.0, 4.0);
let c = z.conj();
assert_eq!(c.re, 3.0);
assert_eq!(c.im, -4.0);
}
#[test]
fn test_complex_abs() {
let z = ComplexF64::new(3.0, 4.0);
assert_eq!(z.abs(), 5.0);
}
#[test]
fn test_complex_dot_basic() {
let a = vec![ComplexF64::new(1.0, 0.0), ComplexF64::new(0.0, 1.0)];
let b = vec![ComplexF64::new(1.0, 0.0), ComplexF64::new(0.0, 1.0)];
let result = complex_dot(&a, &b);
assert_eq!(result.re, 2.0);
assert_eq!(result.im, 0.0);
}
#[test]
fn test_complex_dot_deterministic() {
let n = 500;
let a: Vec<ComplexF64> = (0..n)
.map(|i| ComplexF64::new(i as f64 * 0.001, -(i as f64 * 0.002)))
.collect();
let b: Vec<ComplexF64> = (0..n)
.map(|i| ComplexF64::new((n - i) as f64 * 0.003, i as f64 * 0.004))
.collect();
let r1 = complex_dot(&a, &b);
let r2 = complex_dot(&a, &b);
assert_eq!(r1.re.to_bits(), r2.re.to_bits());
assert_eq!(r1.im.to_bits(), r2.im.to_bits());
}
#[test]
fn test_complex_sum_deterministic() {
let values: Vec<ComplexF64> = (0..1000)
.map(|i| ComplexF64::new(i as f64 * 0.7 - 350.0, -(i as f64) * 0.3 + 150.0))
.collect();
let r1 = complex_sum(&values);
let r2 = complex_sum(&values);
assert_eq!(r1.re.to_bits(), r2.re.to_bits());
assert_eq!(r1.im.to_bits(), r2.im.to_bits());
}
#[test]
fn test_complex_sum_near_order_invariant() {
let values: Vec<ComplexF64> = (0..100)
.map(|i| ComplexF64::new(i as f64 * 1.1 - 50.0, -(i as f64) * 0.9 + 45.0))
.collect();
let mut reversed = values.clone();
reversed.reverse();
let r1 = complex_sum(&values);
let r2 = complex_sum(&reversed);
let re_ulps = (r1.re.to_bits() as i64 - r2.re.to_bits() as i64).unsigned_abs();
let im_ulps = (r1.im.to_bits() as i64 - r2.im.to_bits() as i64).unsigned_abs();
assert!(re_ulps < 10, "Real parts near-order-invariant: {re_ulps} ULPs");
assert!(im_ulps < 10, "Imaginary parts near-order-invariant: {im_ulps} ULPs");
}
#[test]
fn test_complex_sum_merge_order_invariant() {
let values: Vec<ComplexF64> = (0..100)
.map(|i| ComplexF64::new(i as f64 * 1.1 - 50.0, -(i as f64) * 0.9 + 45.0))
.collect();
let mut re_fwd = BinnedAccumulatorF64::new();
let mut im_fwd = BinnedAccumulatorF64::new();
for chunk in values.chunks(10) {
let mut re_c = BinnedAccumulatorF64::new();
let mut im_c = BinnedAccumulatorF64::new();
for z in chunk {
re_c.add(z.re);
im_c.add(z.im);
}
re_fwd.merge(&re_c);
im_fwd.merge(&im_c);
}
let chunks: Vec<Vec<ComplexF64>> = values.chunks(10).map(|c| c.to_vec()).collect();
let mut re_rev = BinnedAccumulatorF64::new();
let mut im_rev = BinnedAccumulatorF64::new();
for chunk in chunks.iter().rev() {
let mut re_c = BinnedAccumulatorF64::new();
let mut im_c = BinnedAccumulatorF64::new();
for z in chunk.iter() {
re_c.add(z.re);
im_c.add(z.im);
}
re_rev.merge(&re_c);
im_rev.merge(&im_c);
}
assert_eq!(re_fwd.finalize().to_bits(), re_rev.finalize().to_bits(),
"Complex real merge must be order-invariant");
assert_eq!(im_fwd.finalize().to_bits(), im_rev.finalize().to_bits(),
"Complex imaginary merge must be order-invariant");
}
#[test]
fn test_complex_matmul_identity() {
let identity = vec![
ComplexF64::ONE, ComplexF64::ZERO,
ComplexF64::ZERO, ComplexF64::ONE,
];
let b = vec![
ComplexF64::new(1.0, 2.0), ComplexF64::new(3.0, 4.0),
ComplexF64::new(5.0, 6.0), ComplexF64::new(7.0, 8.0),
];
let mut out = vec![ComplexF64::ZERO; 4];
complex_matmul(&identity, &b, &mut out, 2, 2, 2);
for (i, &v) in out.iter().enumerate() {
assert_eq!(v.re, b[i].re);
assert_eq!(v.im, b[i].im);
}
}
#[test]
fn test_complex_matmul_deterministic() {
let a: Vec<ComplexF64> = (0..9)
.map(|i| ComplexF64::new(i as f64 * 0.3, -(i as f64) * 0.2))
.collect();
let b: Vec<ComplexF64> = (0..9)
.map(|i| ComplexF64::new(-(i as f64) * 0.1, i as f64 * 0.4))
.collect();
let mut out1 = vec![ComplexF64::ZERO; 9];
let mut out2 = vec![ComplexF64::ZERO; 9];
complex_matmul(&a, &b, &mut out1, 3, 3, 3);
complex_matmul(&a, &b, &mut out2, 3, 3, 3);
for i in 0..9 {
assert_eq!(out1[i].re.to_bits(), out2[i].re.to_bits());
assert_eq!(out1[i].im.to_bits(), out2[i].im.to_bits());
}
}
#[test]
fn test_complex_div_basic() {
let a = ComplexF64::new(1.0, 2.0);
let one = ComplexF64::new(1.0, 0.0);
let c = a.div_fixed(one);
assert_eq!(c.re, 1.0);
assert_eq!(c.im, 2.0);
}
#[test]
fn test_complex_div_nontrivial() {
let a = ComplexF64::new(3.0, 4.0);
let b = ComplexF64::new(1.0, 2.0);
let c = a.div_fixed(b);
let tol = 1e-15;
assert!((c.re - 2.2).abs() < tol, "re: {} vs 2.2", c.re);
assert!((c.im - (-0.4)).abs() < tol, "im: {} vs -0.4", c.im);
}
#[test]
fn test_complex_div_by_zero() {
let a = ComplexF64::new(1.0, 2.0);
let zero = ComplexF64::ZERO;
let c = a.div_fixed(zero);
assert!(!c.re.is_finite() || c.re.is_nan());
assert!(!c.im.is_finite() || c.im.is_nan());
}
#[test]
fn test_complex_div_roundtrip() {
let z = ComplexF64::new(3.7, -2.1);
let w = ComplexF64::new(1.5, 0.8);
let product = z.mul_fixed(w);
let back = product.div_fixed(w);
let tol = 1e-12;
assert!((back.re - z.re).abs() < tol, "re roundtrip: {} vs {}", back.re, z.re);
assert!((back.im - z.im).abs() < tol, "im roundtrip: {} vs {}", back.im, z.im);
}
#[test]
fn test_complex_signed_zero_preserved() {
let z1 = ComplexF64::new(0.0, 0.0);
let z2 = ComplexF64::new(-0.0, -0.0);
let sum = z1.add(z2);
assert!(sum.re.is_sign_positive() || sum.re == 0.0);
}
#[test]
fn test_complex_nan_propagation() {
let nan_z = ComplexF64::new(f64::NAN, 1.0);
let normal = ComplexF64::new(1.0, 1.0);
let result = nan_z.mul_fixed(normal);
assert!(result.is_nan());
}
#[test]
fn test_complex_display() {
let z = ComplexF64::new(3.0, -4.0);
assert_eq!(format!("{z}"), "3-4i");
let z2 = ComplexF64::new(1.0, 2.0);
assert_eq!(format!("{z2}"), "1+2i");
}
}