use num_complex::Complex;
use crate::machine::BesselFloat;
#[inline]
pub(crate) fn mul_add<T: BesselFloat>(s: Complex<T>, a: Complex<T>, b: Complex<T>) -> Complex<T> {
Complex::new(
BesselFloat::fma(s.re, a.re, b.re) - s.im * a.im,
BesselFloat::fma(s.re, a.im, BesselFloat::fma(s.im, a.re, b.im)),
)
}
#[inline]
pub(crate) fn mul_add_scalar<T: BesselFloat>(s: Complex<T>, a: T, b: Complex<T>) -> Complex<T> {
Complex::new(s.re.fma(a, b.re), s.im.fma(a, b.im))
}
#[inline]
pub(crate) fn mul_i<T: BesselFloat>(c: Complex<T>) -> Complex<T> {
Complex::new(-c.im, c.re)
}
#[inline]
pub(crate) fn mul_neg_i<T: BesselFloat>(c: Complex<T>) -> Complex<T> {
Complex::new(c.im, -c.re)
}
#[inline]
pub(crate) fn zabs<T: BesselFloat>(z: Complex<T>) -> T {
let u = z.re.abs();
let v = z.im.abs();
let s = u + v;
if s == T::zero() {
return T::zero();
}
if u > v {
let q = v / u;
u * (T::one() + q * q).sqrt()
} else {
let q = u / v;
v * (T::one() + q * q).sqrt()
}
}
#[inline]
pub(crate) fn zdiv<T: BesselFloat>(a: Complex<T>, b: Complex<T>) -> Complex<T> {
let bm = T::one() / zabs(b);
let cc = b.re * bm;
let cd = b.im * bm;
Complex::new((a.re * cc + a.im * cd) * bm, (a.im * cc - a.re * cd) * bm)
}
#[inline]
pub(crate) fn reciprocal_z<T: BesselFloat>(z: Complex<T>) -> Complex<T> {
let one = T::one();
let raz = one / zabs(z);
let str = z.re * raz;
let sti = -z.im * raz;
Complex::new((str + str) * raz, (sti + sti) * raz)
}
#[inline]
pub(crate) fn sinpi<T: BesselFloat>(x: T) -> T {
let zero = T::zero();
let one = T::one();
let two = T::from_f64(2.0);
let half = T::from_f64(0.5);
let one_half = T::from_f64(1.5);
let pi = T::from_f64(core::f64::consts::PI);
let (ax, sign) = if x < zero { (-x, -one) } else { (x, one) };
let r = ax % two;
if r == zero || r == one {
return zero;
}
if r == half {
return sign;
}
if r == one_half {
return -sign;
}
let s = if r < half {
(r * pi).sin()
} else if r < one {
((one - r) * pi).sin()
} else if r < one_half {
-((r - one) * pi).sin()
} else {
-((two - r) * pi).sin()
};
sign * s
}
#[inline]
pub(crate) fn cospi<T: BesselFloat>(x: T) -> T {
let zero = T::zero();
let one = T::one();
let two = T::from_f64(2.0);
let half = T::from_f64(0.5);
let one_half = T::from_f64(1.5);
let pi = T::from_f64(core::f64::consts::PI);
let ax = x.abs();
let r = ax % two;
if r == zero {
return one;
}
if r == half || r == one_half {
return zero;
}
if r == one {
return -one;
}
if r < half {
(r * pi).cos()
} else if r < one {
-((one - r) * pi).cos()
} else if r < one_half {
-((r - one) * pi).cos()
} else {
((two - r) * pi).cos()
}
}
#[cfg(test)]
mod tests {
use super::*;
use num_complex::Complex64;
#[test]
fn zabs_zero() {
assert_eq!(zabs(Complex64::new(0.0, 0.0)), 0.0);
}
#[test]
fn zabs_real_only() {
let z = Complex64::new(3.0, 0.0);
assert!((zabs(z) - 3.0).abs() < 1e-15);
let z = Complex64::new(-5.0, 0.0);
assert!((zabs(z) - 5.0).abs() < 1e-15);
}
#[test]
fn zabs_imag_only() {
let z = Complex64::new(0.0, 4.0);
assert!((zabs(z) - 4.0).abs() < 1e-15);
}
#[test]
fn zabs_3_4_triangle() {
let z = Complex64::new(3.0, 4.0);
assert!((zabs(z) - 5.0).abs() < 1e-15);
}
#[test]
fn zabs_large_values_no_overflow() {
let big = 1.0e154;
let z = Complex64::new(big, big);
let result = zabs(z);
let expected = big * 2.0_f64.sqrt();
assert!((result - expected).abs() / expected < 1e-15);
}
#[test]
fn zabs_tiny_values_no_underflow() {
let tiny = 1.0e-308;
let z = Complex64::new(tiny, tiny);
let result = zabs(z);
assert!(result > 0.0);
let expected = tiny * 2.0_f64.sqrt();
assert!((result - expected).abs() / expected < 1e-15);
}
#[test]
fn zabs_asymmetric() {
let z = Complex64::new(1.0, 1.0);
let expected = 2.0_f64.sqrt();
assert!((zabs(z) - expected).abs() < 1e-15);
}
#[test]
fn zdiv_simple() {
let a = Complex64::new(1.0, 0.0);
let b = Complex64::new(1.0, 0.0);
let c = zdiv(a, b);
assert!((c.re - 1.0).abs() < 1e-15);
assert!(c.im.abs() < 1e-15);
}
#[test]
fn zdiv_i_div_i() {
let a = Complex64::new(0.0, 1.0);
let b = Complex64::new(0.0, 1.0);
let c = zdiv(a, b);
assert!((c.re - 1.0).abs() < 1e-15);
assert!(c.im.abs() < 1e-15);
}
#[test]
fn zdiv_known_result() {
let a = Complex64::new(3.0, 4.0);
let b = Complex64::new(1.0, 2.0);
let c = zdiv(a, b);
assert!((c.re - 2.2).abs() < 1e-14);
assert!((c.im - (-0.4)).abs() < 1e-14);
}
#[test]
fn zdiv_inverse() {
let a = Complex64::new(1.23456789, -9.87654321);
let b = Complex64::new(0.314159, 2.71829);
let c = zdiv(a, b);
let recovered = Complex64::new(c.re * b.re - c.im * b.im, c.re * b.im + c.im * b.re);
assert!((recovered.re - a.re).abs() < 1e-13);
assert!((recovered.im - a.im).abs() < 1e-13);
}
#[test]
fn zdiv_large_denominator_no_overflow() {
let a = Complex64::new(1.0, 1.0);
let b = Complex64::new(1.0e200, 1.0e200);
let c = zdiv(a, b);
let expected_re = 1.0e-200;
assert!((c.re - expected_re).abs() / expected_re < 1e-14);
assert!(c.im.abs() < 1e-214);
}
#[test]
fn zabs_f32() {
use num_complex::Complex32;
let z = Complex32::new(3.0, 4.0);
assert!((zabs(z) - 5.0).abs() < 1e-6);
}
#[test]
fn zdiv_f32() {
use num_complex::Complex32;
let a = Complex32::new(3.0, 4.0);
let b = Complex32::new(1.0, 2.0);
let c = zdiv(a, b);
assert!((c.re - 2.2).abs() < 1e-5);
assert!((c.im - (-0.4)).abs() < 1e-5);
}
#[test]
fn sinpi_integers_are_zero() {
for n in -5..=5 {
let x = n as f64;
assert_eq!(sinpi(x), 0.0, "sinpi({x}) should be exactly 0");
}
}
#[test]
fn sinpi_half_integers() {
assert_eq!(sinpi(0.5_f64), 1.0);
assert_eq!(sinpi(1.5_f64), -1.0);
assert_eq!(sinpi(2.5_f64), 1.0);
assert_eq!(sinpi(-0.5_f64), -1.0);
assert_eq!(sinpi(-1.5_f64), 1.0);
}
#[test]
fn sinpi_quarter() {
let val = sinpi(0.25_f64);
let expected = core::f64::consts::FRAC_1_SQRT_2;
assert!((val - expected).abs() < 1e-15);
}
#[test]
fn sinpi_general_values() {
let val = sinpi(1.0_f64 / 6.0);
assert!((val - 0.5).abs() < 1e-15);
let val = sinpi(1.0_f64 / 3.0);
assert!((val - 3.0_f64.sqrt() / 2.0).abs() < 1e-15);
}
#[test]
fn sinpi_large_argument() {
assert_eq!(sinpi(1e15_f64), 0.0);
assert!(sinpi(1e15_f64 + 0.5).abs() == 1.0);
}
#[test]
fn sinpi_f32() {
assert_eq!(sinpi(0.0_f32), 0.0);
assert_eq!(sinpi(0.5_f32), 1.0);
assert_eq!(sinpi(1.0_f32), 0.0);
assert_eq!(sinpi(1.5_f32), -1.0);
}
#[test]
fn cospi_integers() {
assert_eq!(cospi(0.0_f64), 1.0);
assert_eq!(cospi(1.0_f64), -1.0);
assert_eq!(cospi(2.0_f64), 1.0);
assert_eq!(cospi(-1.0_f64), -1.0);
assert_eq!(cospi(-2.0_f64), 1.0);
}
#[test]
fn cospi_half_integers_are_zero() {
for n in -5..=5 {
let x = n as f64 + 0.5;
assert_eq!(cospi(x), 0.0, "cospi({x}) should be exactly 0");
}
}
#[test]
fn cospi_quarter() {
let val = cospi(0.25_f64);
let expected = core::f64::consts::FRAC_1_SQRT_2;
assert!((val - expected).abs() < 1e-15);
}
#[test]
fn cospi_general_values() {
let val = cospi(1.0_f64 / 3.0);
assert!((val - 0.5).abs() < 1e-15);
let val = cospi(1.0_f64 / 6.0);
assert!((val - 3.0_f64.sqrt() / 2.0).abs() < 1e-15);
}
#[test]
fn cospi_large_argument() {
assert!(cospi(1e15_f64).abs() == 1.0);
assert_eq!(cospi(1e15_f64 + 0.5), 0.0);
}
#[test]
fn cospi_f32() {
assert_eq!(cospi(0.0_f32), 1.0);
assert_eq!(cospi(0.5_f32), 0.0);
assert_eq!(cospi(1.0_f32), -1.0);
assert_eq!(cospi(1.5_f32), 0.0);
}
}