use crate::rand;
use serde::{Deserialize, Serialize};
use std::{
f64::consts::PI,
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};
#[derive(Debug, Default, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct Complex {
pub re: f64,
pub im: f64,
}
impl Complex {
#[inline(always)]
pub fn new(re: f64, im: f64) -> Self {
Self { re, im }
}
#[inline(always)]
pub const fn zero() -> Self {
Self { re: 0.0, im: 0.0 }
}
#[inline(always)]
pub const fn one() -> Self {
Self { re: 1.0, im: 0.0 }
}
#[inline(always)]
pub const fn i() -> Self {
Self { re: 0.0, im: 1.0 }
}
#[inline(always)]
pub fn square(self) -> Self {
Self {
re: self.re * self.re - self.im * self.im,
im: 2.0 * self.re * self.im,
}
}
#[inline(always)]
pub fn cube(self) -> Self {
self.square() * self
}
#[inline(always)]
pub fn inv(self) -> Self {
let denom = self.abs_sq();
if denom == 0.0 {
Self::zero()
} else {
Self {
re: self.re / denom,
im: -self.im / denom,
}
}
}
#[inline(always)]
pub fn pow(self, p: u32) -> Self {
let mut base = self;
let mut exponent = p;
let mut res = Self::one();
while exponent > 0 {
if exponent % 2 == 1 {
res *= base;
}
base = base.square();
exponent /= 2;
}
res
}
#[inline(always)]
pub fn powi(self, exp: i32) -> Self {
if exp == 0 {
return Self::one();
}
let mut base = if exp < 0 { self.inv() } else { self };
let mut exponent = exp.unsigned_abs();
let mut res = Self::one();
while exponent > 0 {
if exponent % 2 == 1 {
res *= base;
}
base = base.square();
exponent /= 2;
}
res
}
#[inline(always)]
pub fn powf(self, exp: f64) -> Self {
if self.re == 0.0 && self.im == 0.0 {
if exp == 0.0 {
Self::one()
} else {
Self::zero()
}
} else {
(self.ln() * exp).exp()
}
}
#[inline(always)]
pub fn powc(self, exp: Self) -> Self {
if self.re == 0.0 && self.im == 0.0 {
if exp.re == 0.0 && exp.im == 0.0 {
Self::one()
} else {
Self::zero()
}
} else {
(exp * self.ln()).exp()
}
}
#[inline(always)]
pub fn abs_sq(self) -> f64 {
self.re * self.re + self.im * self.im
}
#[inline(always)]
pub fn abs(self) -> f64 {
self.abs_sq().sqrt()
}
#[inline(always)]
pub fn distance_sq(self, other: Self) -> f64 {
(self - other).abs_sq()
}
#[inline(always)]
pub fn arg(self) -> f64 {
self.im.atan2(self.re)
}
#[inline(always)]
pub fn conj(self) -> Self {
Self::new(self.re, -self.im)
}
#[inline(always)]
pub fn signum(self) -> Self {
let mag = self.abs();
if mag == 0.0 { Self::zero() } else { self / mag }
}
#[inline(always)]
pub fn from_polar(radius: f64, angle: f64) -> Self {
Self {
re: radius * angle.cos(),
im: radius * angle.sin(),
}
}
#[inline(always)]
pub fn to_polar(self) -> (f64, f64) {
(self.abs(), self.arg())
}
#[inline(always)]
pub fn cis(theta: f64) -> Self {
Self::from_polar(1.0, theta)
}
#[inline(always)]
pub fn rotate(self, phasor: Self) -> Self {
self * phasor
}
#[inline(always)]
pub fn sqrt(self) -> Self {
if self.re == 0.0 && self.im == 0.0 {
return Self::zero();
}
let r = self.abs();
let t = (0.5 * (r + self.re.abs())).sqrt();
let (re, im) = if self.re >= 0.0 {
(t, self.im / (2.0 * t))
} else {
let re = self.im.abs() / (2.0 * t);
let im = t.copysign(self.im);
(re, im)
};
Self::new(re, im)
}
#[inline(always)]
pub fn exp(self) -> Self {
let r = self.re.exp();
Self::new(r * self.im.cos(), r * self.im.sin())
}
#[inline(always)]
pub fn ln(self) -> Self {
Self::new(self.abs().ln(), self.arg())
}
#[inline(always)]
pub fn log(self, base: f64) -> Self {
self.ln() / base.ln()
}
#[inline(always)]
pub fn sin(self) -> Self {
Self::new(
self.re.sin() * self.im.cosh(),
self.re.cos() * self.im.sinh(),
)
}
#[inline(always)]
pub fn cos(self) -> Self {
Self::new(
self.re.cos() * self.im.cosh(),
-self.re.sin() * self.im.sinh(),
)
}
#[inline(always)]
pub fn tan(self) -> Self {
self.sin() / self.cos()
}
#[inline(always)]
pub fn csc(self) -> Self {
Self::one() / self.sin()
}
#[inline(always)]
pub fn sec(self) -> Self {
Self::one() / self.cos()
}
#[inline(always)]
pub fn cot(self) -> Self {
self.cos() / self.sin()
}
#[inline(always)]
pub fn sinh(self) -> Self {
Self::new(
self.re.sinh() * self.im.cos(),
self.re.cosh() * self.im.sin(),
)
}
#[inline(always)]
pub fn cosh(self) -> Self {
Self::new(
self.re.cosh() * self.im.cos(),
self.re.sinh() * self.im.sin(),
)
}
#[inline(always)]
pub fn tanh(self) -> Self {
self.sinh() / self.cosh()
}
#[inline(always)]
pub fn csch(self) -> Self {
Self::one() / self.sinh()
}
#[inline(always)]
pub fn sech(self) -> Self {
Self::one() / self.cosh()
}
#[inline(always)]
pub fn coth(self) -> Self {
self.cosh() / self.sinh()
}
#[inline(always)]
pub fn asin(self) -> Self {
let i_z = Self::new(-self.im, self.re);
let sqrt_term = (Self::one() - self.square()).sqrt();
let ln_term = (i_z + sqrt_term).ln();
Self::new(ln_term.im, -ln_term.re)
}
#[inline(always)]
pub fn acos(self) -> Self {
Self::new(std::f64::consts::FRAC_PI_2, 0.0) - self.asin()
}
#[inline(always)]
pub fn atan(self) -> Self {
let i = Self::i();
let num = i + self;
let den = i - self;
let ln_term = (num / den).ln();
Self::new(-0.5 * ln_term.im, 0.5 * ln_term.re)
}
#[inline(always)]
pub fn asinh(self) -> Self {
(self + (Self::one() + self.square()).sqrt()).ln()
}
#[inline(always)]
pub fn acosh(self) -> Self {
(self + (self.square() - Self::one()).sqrt()).ln()
}
#[inline(always)]
pub fn atanh(self) -> Self {
let num = Self::one() + self;
let den = Self::one() - self;
(num / den).ln() * 0.5
}
#[inline(always)]
pub fn newton_step_term(self, p: u32) -> Self {
if p == 0 {
return Self::zero();
}
let z_p_minus_1 = self.pow(p - 1);
let f_z = z_p_minus_1 * self - Self::one();
let f_prime_z = z_p_minus_1 * (p as f64);
f_z / f_prime_z
}
#[inline(always)]
pub fn circular_fade(self, max_radius: f64, flat_ratio: f64) -> f64 {
let dist = self.abs();
let r = dist / max_radius;
if r < flat_ratio {
1.0
} else if r < 1.0 {
let t = (r - flat_ratio) / (1.0 - flat_ratio);
1.0 - t * t * (3.0 - 2.0 * t)
} else {
0.0
}
}
#[inline]
pub fn sample_rotation() -> Self {
let radians = rand::<f64>() * 2.0 * PI;
Self::cis(radians)
}
#[inline]
pub fn rotation_phasors(rotations: usize) -> impl Iterator<Item = Self> {
(0..rotations).map(move |step| Self::cis((step as f64) * 2.0 * PI / (rotations as f64)))
}
}
impl Add for Complex {
type Output = Self;
#[inline(always)]
fn add(self, other: Self) -> Self::Output {
Self::new(self.re + other.re, self.im + other.im)
}
}
impl Add<f64> for Complex {
type Output = Self;
#[inline(always)]
fn add(self, scalar: f64) -> Self {
Self::new(self.re + scalar, self.im)
}
}
impl Add<Complex> for f64 {
type Output = Complex;
#[inline(always)]
fn add(self, other: Complex) -> Complex {
Complex::new(self + other.re, other.im)
}
}
impl Sub for Complex {
type Output = Self;
#[inline(always)]
fn sub(self, other: Self) -> Self::Output {
Self::new(self.re - other.re, self.im - other.im)
}
}
impl Sub<f64> for Complex {
type Output = Self;
#[inline(always)]
fn sub(self, scalar: f64) -> Self {
Self::new(self.re - scalar, self.im)
}
}
impl Sub<Complex> for f64 {
type Output = Complex;
#[inline(always)]
fn sub(self, other: Complex) -> Complex {
Complex::new(self - other.re, -other.im)
}
}
impl Mul for Complex {
type Output = Self;
#[inline(always)]
fn mul(self, other: Self) -> Self::Output {
Self::new(
self.re * other.re - self.im * other.im,
self.re * other.im + self.im * other.re,
)
}
}
impl Div for Complex {
type Output = Self;
#[inline(always)]
fn div(self, other: Self) -> Self::Output {
let denom = other.abs_sq();
if denom == 0.0 {
Self::new(0.0, 0.0)
} else {
Self::new(
(self.re * other.re + self.im * other.im) / denom,
(self.im * other.re - self.re * other.im) / denom,
)
}
}
}
impl Neg for Complex {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self::Output {
Self::new(-self.re, -self.im)
}
}
impl Mul<f64> for Complex {
type Output = Self;
#[inline(always)]
fn mul(self, scalar: f64) -> Self {
Self::new(self.re * scalar, self.im * scalar)
}
}
impl Mul<Complex> for f64 {
type Output = Complex;
#[inline(always)]
fn mul(self, other: Complex) -> Complex {
Complex::new(self * other.re, self * other.im)
}
}
impl Div<f64> for Complex {
type Output = Self;
#[inline(always)]
fn div(self, scalar: f64) -> Self {
if scalar == 0.0 {
Self::new(0.0, 0.0)
} else {
Self::new(self.re / scalar, self.im / scalar)
}
}
}
impl Div<Complex> for f64 {
type Output = Complex;
#[inline(always)]
fn div(self, other: Complex) -> Complex {
let denom = other.abs_sq();
if denom == 0.0 {
Complex::new(0.0, 0.0)
} else {
Complex::new((self * other.re) / denom, (-self * other.im) / denom)
}
}
}
impl AddAssign for Complex {
#[inline(always)]
fn add_assign(&mut self, other: Self) {
self.re += other.re;
self.im += other.im;
}
}
impl AddAssign<f64> for Complex {
#[inline(always)]
fn add_assign(&mut self, other: f64) {
self.re += other;
}
}
impl SubAssign for Complex {
#[inline(always)]
fn sub_assign(&mut self, other: Self) {
self.re -= other.re;
self.im -= other.im;
}
}
impl SubAssign<f64> for Complex {
#[inline(always)]
fn sub_assign(&mut self, other: f64) {
self.re -= other;
}
}
impl MulAssign for Complex {
#[inline(always)]
fn mul_assign(&mut self, other: Self) {
*self = *self * other;
}
}
impl MulAssign<f64> for Complex {
#[inline(always)]
fn mul_assign(&mut self, other: f64) {
self.re *= other;
self.im *= other;
}
}
impl DivAssign for Complex {
#[inline(always)]
fn div_assign(&mut self, other: Self) {
*self = *self / other;
}
}
impl DivAssign<f64> for Complex {
#[inline(always)]
fn div_assign(&mut self, other: f64) {
if other == 0.0 {
self.re = 0.0;
self.im = 0.0;
} else {
self.re /= other;
self.im /= other;
}
}
}
#[cfg(test)]
mod tests_complex {
use super::*;
fn assert_approx(c1: Complex, c2: Complex, tol: f64) {
assert!(
(c1.re - c2.re).abs() < tol,
"Real parts differ: expected {}, got {} (diff {})",
c2.re,
c1.re,
(c1.re - c2.re).abs()
);
assert!(
(c1.im - c2.im).abs() < tol,
"Imaginary parts differ: expected {}, got {} (diff {})",
c2.im,
c1.im,
(c1.im - c2.im).abs()
);
}
#[test]
fn test_addition() {
let c1 = Complex::new(1.0, 2.0);
let c2 = Complex::new(3.0, 4.0);
let result = c1 + c2;
assert_eq!(result, Complex::new(4.0, 6.0));
}
#[test]
fn test_subtraction() {
let c1 = Complex::new(5.0, 7.0);
let c2 = Complex::new(2.0, 3.0);
let result = c1 - c2;
assert_eq!(result, Complex::new(3.0, 4.0));
}
#[test]
fn test_multiplication() {
let c1 = Complex::new(1.0, 2.0);
let c2 = Complex::new(3.0, 4.0);
let result = c1 * c2;
assert_eq!(result, Complex::new(-5.0, 10.0));
}
#[test]
fn test_division() {
let c1 = Complex::new(1.0, 2.0);
let c2 = Complex::new(3.0, 4.0);
let result = c1 / c2;
assert!((result.re - 0.44).abs() < 1e-9);
assert!((result.im - 0.08).abs() < 1e-9);
}
#[test]
fn test_division_by_zero() {
let c1 = Complex::new(1.0, 2.0);
let c2 = Complex::new(0.0, 0.0);
let result = c1 / c2;
assert_eq!(result, Complex::new(0.0, 0.0));
}
#[test]
fn test_pow() {
let c = Complex::new(1.0, 1.0);
let result = c.pow(3);
assert_eq!(result, Complex::new(-2.0, 2.0));
}
#[test]
fn test_abs_sq() {
let c = Complex::new(3.0, -4.0);
assert_eq!(c.abs_sq(), 25.0);
}
#[test]
fn test_scalar_multiplication_complex_f64() {
let c = Complex::new(1.5, -2.0);
let scalar = 2.0;
let result = c * scalar;
assert_eq!(result, Complex::new(3.0, -4.0));
}
#[test]
fn test_scalar_multiplication_f64_complex() {
let c = Complex::new(1.5, -2.0);
let scalar = 2.0;
let result = scalar * c;
assert_eq!(result, Complex::new(3.0, -4.0));
}
#[test]
fn test_scalar_division_complex_f64() {
let c = Complex::new(3.0, -4.0);
let scalar = 2.0;
let result = c / scalar;
assert_eq!(result, Complex::new(1.5, -2.0));
}
#[test]
fn test_scalar_division_complex_f64_by_zero() {
let c = Complex::new(3.0, -4.0);
let result = c / 0.0;
assert_eq!(result, Complex::new(0.0, 0.0));
}
#[test]
fn test_scalar_division_f64_complex() {
let scalar = 5.0;
let c = Complex::new(3.0, 4.0);
let result = scalar / c;
assert_eq!(result, Complex::new(0.6, -0.8));
}
#[test]
fn test_scalar_division_f64_complex_by_zero() {
let scalar = 5.0;
let c = Complex::new(0.0, 0.0);
let result = scalar / c;
assert_eq!(result, Complex::new(0.0, 0.0));
}
#[test]
fn test_newton_step_term() {
let z = Complex::new(2.0, 0.0);
let term = z.newton_step_term(3);
assert!((term.re - 0.58333333333).abs() < 1e-9);
assert_eq!(term.im, 0.0);
}
#[test]
fn test_euler_identity() {
let pi = std::f64::consts::PI;
let exponent = Complex::i() * pi;
let result = exponent.exp();
assert_approx(result, -Complex::one(), 1e-9);
}
#[test]
fn test_sqrt() {
let c1 = Complex::new(-4.0, 0.0);
assert_approx(c1.sqrt(), Complex::new(0.0, 2.0), 1e-9);
let c2 = Complex::new(3.0, 4.0);
assert_approx(c2.sqrt(), Complex::new(2.0, 1.0), 1e-9);
}
#[test]
fn test_trig_identity() {
let z = Complex::new(1.0, 2.0);
let lhs = z.sin().square() + z.cos().square();
assert_approx(lhs, Complex::one(), 1e-9);
}
#[test]
fn test_hyperbolic_identity() {
let z = Complex::new(1.0, 2.0);
let lhs = z.cosh().square() - z.sinh().square();
assert_approx(lhs, Complex::one(), 1e-9);
}
#[test]
fn test_inverse_trig() {
let z = Complex::new(0.5, 0.5);
assert_approx(z.sin().asin(), z, 1e-9);
assert_approx(z.cos().acos(), z, 1e-9);
assert_approx(z.tan().atan(), z, 1e-9);
}
#[test]
fn test_logarithm() {
let z = Complex::new(10.0, 0.0);
assert_approx(z.log(10.0), Complex::one(), 1e-9);
}
#[test]
fn test_assignment_ops() {
let mut c = Complex::new(1.0, 2.0);
c += Complex::new(2.0, 3.0);
assert_eq!(c, Complex::new(3.0, 5.0));
c -= Complex::new(1.0, 1.0);
assert_eq!(c, Complex::new(2.0, 4.0));
c *= 2.0;
assert_eq!(c, Complex::new(4.0, 8.0));
c /= 4.0;
assert_eq!(c, Complex::new(1.0, 2.0));
}
}