use num_complex::Complex;
use num_traits::{
Num,
Zero,
};
use std::f64;
pub trait Real: Num + std::ops::Neg<Output = Self>
+ Clone
+ PartialEq + PartialOrd
+ std::fmt::Debug
{
fn from_f64(v: f64) -> Self;
fn to_i32(&self) -> i32;
fn is_i32_compatible(&self) -> bool {
self.clone().fract().is_zero()
&& *self >= Self::from_f64(i32::MIN as f64)
&& *self <= Self::from_f64(i32::MAX as f64)
}
fn fract(self) -> Self;
fn trunc(self) -> Self;
fn e() -> Self;
fn frac_1_pi() -> Self;
fn frac_1_sqrt_2() -> Self;
fn frac_2_pi() -> Self;
fn frac_2_sqrt_pi() -> Self;
fn frac_pi_2() -> Self;
fn frac_pi_3() -> Self;
fn frac_pi_4() -> Self;
fn frac_pi_6() -> Self;
fn frac_pi_8() -> Self;
fn ln_2() -> Self;
fn ln_10() -> Self;
fn log2_10() -> Self;
fn log2_e() -> Self;
fn log10_2() -> Self;
fn log10_e() -> Self;
fn pi() -> Self;
fn sqrt_2() -> Self;
fn tau() -> Self;
fn sin(self) -> Self;
fn cos(self) -> Self;
fn tan(self) -> Self;
fn asin(self) -> Self;
fn acos(self) -> Self;
fn atan(self) -> Self;
fn atan2(self, other: Self) -> Self;
fn sin_cos(self) -> (Self, Self);
fn sinh(self) -> Self;
fn cosh(self) -> Self;
fn tanh(self) -> Self;
fn asinh(self) -> Self;
fn acosh(self) -> Self;
fn atanh(self) -> Self;
fn exp(self) -> Self;
fn ln(self) -> Self;
fn log10(self) -> Self;
fn sqrt(self) -> Self;
fn abs(self) -> Self;
fn hypot(self, other: Self) -> Self;
fn pow(self, rhs: Self) -> Self;
fn powi(self, n: i32) -> Self;
}
impl Real for f64 {
fn from_f64(v: f64) -> Self { v }
fn to_i32(&self) -> i32
{
if !self.is_finite() {
return 0;
}
let truncated = self.trunc();
if truncated > i32::MAX as Self {
i32::MAX
} else if truncated < i32::MIN as Self {
i32::MIN
} else {
truncated as i32
}
}
fn is_i32_compatible(&self) -> bool {
const MAX: f64 = i32::MAX as f64;
const MIN: f64 = i32::MIN as f64;
self.fract().is_zero() && MIN <= *self && *self <= MAX
}
fn fract(self) -> Self { self.fract() }
fn trunc(self) -> Self { self.trunc() }
fn e() -> Self { f64::consts::E }
fn frac_1_pi() -> Self { f64::consts::FRAC_1_PI }
fn frac_1_sqrt_2() -> Self { f64::consts::FRAC_1_SQRT_2 }
fn frac_2_pi() -> Self { f64::consts::FRAC_2_PI }
fn frac_2_sqrt_pi() -> Self { f64::consts::FRAC_2_SQRT_PI }
fn frac_pi_2() -> Self { f64::consts::FRAC_PI_2 }
fn frac_pi_3() -> Self { f64::consts::FRAC_PI_3 }
fn frac_pi_4() -> Self { f64::consts::FRAC_PI_4 }
fn frac_pi_6() -> Self { f64::consts::FRAC_PI_6 }
fn frac_pi_8() -> Self { f64::consts::FRAC_PI_8 }
fn ln_2() -> Self { f64::consts::LN_2 }
fn ln_10() -> Self { f64::consts::LN_10 }
fn log2_10() -> Self { f64::consts::LOG2_10 }
fn log2_e() -> Self { f64::consts::LOG2_E }
fn log10_2() -> Self { f64::consts::LOG10_2 }
fn log10_e() -> Self { f64::consts::LOG10_E }
fn pi() -> Self { f64::consts::PI }
fn sqrt_2() -> Self { f64::consts::SQRT_2 }
fn tau() -> Self { f64::consts::TAU }
fn sin(self) -> Self { self.sin() }
fn cos(self) -> Self { self.cos() }
fn tan(self) -> Self { self.tan() }
fn asin(self) -> Self { self.asin() }
fn acos(self) -> Self { self.acos() }
fn atan(self) -> Self { self.atan() }
fn atan2(self, other: Self) -> Self { self.atan2(other) }
fn sin_cos(self) -> (Self, Self) { self.sin_cos() }
fn sinh(self) -> Self { self.sinh() }
fn cosh(self) -> Self { self.cosh() }
fn tanh(self) -> Self { self.tanh() }
fn asinh(self) -> Self { self.asinh() }
fn acosh(self) -> Self { self.acosh() }
fn atanh(self) -> Self { self.atanh() }
fn exp(self) -> Self { self.exp() }
fn ln(self) -> Self { self.ln() }
fn log10(self) -> Self { self.log10() }
fn sqrt(self) -> Self { self.sqrt() }
fn abs(self) -> Self { self.abs() }
fn hypot(self, other: Self) -> Self { self.hypot(other) }
fn pow(self, rhs: Self) -> Self { self.powf(rhs) }
fn powi(self, n: i32) -> Self { self.powi(n) }
}
pub trait ComplexMath {
fn sin(self) -> Self;
fn cos(self) -> Self;
fn tan(self) -> Self;
fn asin(self) -> Self;
fn acos(self) -> Self;
fn atan(self) -> Self;
fn sinh(self) -> Self;
fn cosh(self) -> Self;
fn tanh(self) -> Self;
fn asinh(self) -> Self;
fn acosh(self) -> Self;
fn atanh(self) -> Self;
fn exp(self) -> Self;
fn ln(self) -> Self;
fn log10(self) -> Self;
fn sqrt(self) -> Self;
fn abs(self) -> Self;
fn conj(self) -> Self;
fn powc(self, rhs: Self) -> Self;
fn powi(self, n: i32) -> Self;
}
impl<T: Real> ComplexMath for Complex<T> {
fn sin(self) -> Self
{
let (a, b) = (self.re, self.im);
let (sin_a, cos_a) = a.sin_cos();
Self {
re: sin_a * b.clone().cosh(),
im: cos_a * b.sinh(),
}
}
fn cos(self) -> Self
{
let (a, b) = (self.re, self.im);
let (sin_a, cos_a) = a.sin_cos();
Self {
re: cos_a * b.clone().cosh(),
im: -(sin_a * b.sinh()),
}
}
fn tan(self) -> Self {
let (a, b) = (self.re, self.im);
let (sin_a, cos_a) = a.sin_cos();
let (sinh_b, cosh_b) = (b.clone().sinh(), b.cosh());
let sin = Self {
re: sin_a.clone() * cosh_b.clone(),
im: cos_a.clone() * sinh_b.clone(),
};
let cos = Self {
re: cos_a * cosh_b,
im: -sin_a * sinh_b,
};
sin / cos
}
fn asin(self) -> Self {
let z = self;
let z2 = z.clone() * z.clone(); let iz = Complex::new(-z.im, z.re);
let i = Complex::new(T::zero(), T::one());
let one = Complex::new(T::one(), T::zero());
-i * (iz + (one - z2).sqrt()).ln()
}
fn acos(self) -> Self {
let z = self;
let z2 = z.clone() * z.clone(); let i = Complex::new(T::zero(), T::one());
let one = Complex::new(T::one(), T::zero());
-(i.clone()) * (z + i * (one - z2).sqrt()).ln()
}
fn atan(self) -> Self {
let i = Complex::new(T::zero(), T::one());
let one = Complex::new(T::one(), T::zero());
let iz = i * self;
let num = one.clone() - iz.clone();
let den = one + iz;
let half_i = Complex::new(T::zero(), T::one() * T::from_f64(0.5));
half_i * (num / den).ln()
}
fn sinh(self) -> Self {
let (x, y) = (self.re, self.im);
let (sin_y, cos_y) = y.sin_cos();
Self {
re: x.clone().sinh() * cos_y,
im: x.cosh() * sin_y,
}
}
fn cosh(self) -> Self {
let (x, y) = (self.re, self.im);
let (sin_y, cos_y) = y.sin_cos();
Self {
re: x.clone().cosh() * cos_y,
im: x.sinh() * sin_y,
}
}
fn tanh(self) -> Self {
let e2 = (self.clone() + self).exp();
let one = T::one();
(e2.clone() - one.clone()) / (e2 + one)
}
fn asinh(self) -> Self {
let one = Complex::new(T::one(), T::zero());
(self.clone() + (self.clone() * self + one).sqrt()).ln()
}
fn acosh(self) -> Self {
let one = Complex::new(T::one(), T::zero());
(self.clone() + (self.clone() - one.clone()).sqrt() * (self + one).sqrt()).ln()
}
fn atanh(self) -> Self {
let one = Complex::new(T::one(), T::zero());
((one.clone() + self.clone()) / (one - self)).ln() * T::from_f64(0.5)
}
fn exp(self) -> Self
{
let re = self.re;
let im = self.im;
let exp = re.exp();
Self {
re: exp.clone() * im.clone().cos(),
im: exp * im.sin(),
}
}
fn ln(self) -> Self
{
let r = self.re.clone().hypot(self.im.clone());
let theta = self.im.atan2(self.re);
Self {
re: r.ln(),
im: theta,
}
}
fn log10(self) -> Self
{
self.ln() * T::log10_e()
}
fn sqrt(self) -> Self {
let r = self.re.clone().hypot(self.im.clone());
let half = T::from_f64(0.5);
let re = ((r.clone() + self.re.clone()) * half.clone()).sqrt();
let im = ((r - self.re) * half).sqrt();
let im = if self.im >= T::zero() { im } else { -im };
Complex::new(re, im)
}
fn abs(self) -> Self { Complex::new(self.re.hypot(self.im), T::zero()) }
fn conj(self) -> Self { Complex::new(self.re, -self.im) }
fn powc(self, rhs: Self) -> Self
{
(rhs * self.ln()).exp()
}
fn powi(self, n: i32) -> Self {
if n == 0 {
return Complex::new(T::one(), T::zero());
}
if n < 0 {
return Complex::new(T::one(), T::zero()) / self.powi(-n);
}
let mut result = Complex::new(T::one(), T::zero());
let mut base = self;
let mut exp = n;
while exp > 1 {
if exp & 1 == 1 {
result = result * base.clone();
}
base = base.clone() * base;
exp >>=1;
}
result * base
}
}
#[cfg(test)]
mod tests {
use super::*;
use num_complex::Complex;
const EPS: f64 = 1e-10;
fn approx_eq(a: Complex<f64>, b: Complex<f64>) {
assert!(
(a.re - b.re).abs() < EPS,
"re mismatch: {} vs {}",
a.re,
b.re
);
assert!(
(a.im - b.im).abs() < EPS,
"im mismatch: {} vs {}",
a.im,
b.im
);
}
#[test]
fn test_sin_cos_identity() {
let z = Complex::new(1.2, -0.7);
let sin = z.clone().sin();
let cos = z.clone().cos();
let lhs = sin.clone() * sin + cos.clone() * cos;
let rhs = Complex::new(1.0, 0.0);
approx_eq(lhs, rhs);
}
#[test]
fn test_exp_ln_identity() {
let z = Complex::new(0.5, -1.3);
let result = z.clone().ln().exp();
approx_eq(result, z);
}
#[test]
fn test_exp_i_pi() {
let pi = std::f64::consts::PI;
let z = Complex::new(0.0, pi);
let result = z.exp();
approx_eq(result, Complex::new(-1.0, 0.0));
}
#[test]
fn test_sin_i() {
let z = Complex::new(0.0, 1.0);
let result = z.sin();
let expected = Complex::new(0.0, 1.0_f64.sinh());
approx_eq(result, expected);
}
#[test]
fn test_real_consistency() {
let x = 0.7;
let z = Complex::new(x, 0.0);
approx_eq(Complex::new(x.sin(), 0.0), z.clone().sin());
approx_eq(Complex::new(x.cos(), 0.0), z.clone().cos());
approx_eq(Complex::new(x.exp(), 0.0), z.clone().exp());
approx_eq(Complex::new(x.ln(), 0.0), z.clone().ln());
}
#[test]
fn test_sqrt() {
let z = Complex::new(3.0, 4.0);
let sqrt = z.clone().sqrt();
let back = sqrt.clone() * sqrt;
approx_eq(back, z);
}
#[test]
fn test_powc() {
let z = Complex::new(1.2, 0.7);
let w = Complex::new(-0.3, 0.5);
let result = z.clone().powc(w.clone());
let expected = (w * z.ln()).exp();
approx_eq(result, expected);
}
#[test]
fn test_powi() {
let z = Complex::new(1.1, -0.4);
let result = z.clone().powi(5);
let expected = z.clone() * z.clone() * z.clone() * z.clone() * z;
approx_eq(result, expected);
}
#[test]
fn test_tan_identity() {
let z = Complex::new(0.8, -0.3);
let tan = z.clone().tan();
let expected = z.clone().sin() / z.cos();
approx_eq(tan, expected);
}
#[test]
fn test_log10() {
let z = Complex::new(1.3, 0.4);
let result = z.clone().log10();
let expected = z.ln() / std::f64::consts::LN_10;
approx_eq(result, expected);
}
}