use std::{f64::consts::E, fmt, ops::*};
use bytemuck::{Pod, Zeroable};
use serde::*;
#[derive(Debug, Clone, Copy, PartialOrd, Default, Serialize, Deserialize, Pod, Zeroable)]
#[serde(from = "(f64, f64)", into = "(f64, f64)")]
#[repr(C)]
pub struct Complex {
pub re: f64,
pub im: f64,
}
impl From<(f64, f64)> for Complex {
fn from((re, im): (f64, f64)) -> Self {
Self { re, im }
}
}
impl From<Complex> for (f64, f64) {
fn from(c: Complex) -> Self {
(c.re, c.im)
}
}
impl PartialEq for Complex {
fn eq(&self, other: &Self) -> bool {
self.re == other.re && self.im == other.im
}
}
impl Eq for Complex {}
impl Complex {
pub const ZERO: Self = Self { re: 0.0, im: 0.0 };
pub const ONE: Self = Self { re: 1.0, im: 0.0 };
pub const I: Self = Self { re: 0.0, im: 1.0 };
pub fn new(re: f64, im: f64) -> Self {
Self { re, im }
}
pub fn min(self, rhs: impl Into<Self>) -> Self {
let rhs = rhs.into();
Self {
re: self.re.min(rhs.re),
im: self.im.min(rhs.im),
}
}
pub fn max(self, rhs: impl Into<Self>) -> Self {
let rhs = rhs.into();
Self {
re: self.re.max(rhs.re),
im: self.im.max(rhs.im),
}
}
pub fn floor(self) -> Self {
Self {
re: self.re.floor(),
im: self.im.floor(),
}
}
pub fn ceil(self) -> Self {
Self {
re: self.re.ceil(),
im: self.im.ceil(),
}
}
pub fn round(self) -> Self {
Self {
re: self.re.round(),
im: self.im.round(),
}
}
pub fn abs(self) -> f64 {
(self.re * self.re + self.im * self.im).sqrt()
}
pub fn atan2(self, x: impl Into<Self>) -> Complex {
let y = self;
let x = x.into();
-Complex::I * ((x + Complex::I * y) / (y * y + x * x).sqrt()).ln()
}
pub fn normalize(self) -> Self {
let len = self.abs();
if len == 0.0 {
Self::ZERO
} else {
self / len
}
}
pub fn arg(self) -> f64 {
self.im.atan2(self.re)
}
pub fn to_polar(self) -> (f64, f64) {
(self.abs(), self.arg())
}
pub fn from_polar(r: f64, theta: f64) -> Self {
r * Self::new(theta.cos(), theta.sin())
}
pub fn powc(self, power: impl Into<Self>) -> Self {
let power = power.into();
if power.im == 0.0 {
return self.powf(power.re);
}
let (r, theta) = self.to_polar();
((r.ln() + Self::I * theta) * power).exp()
}
pub fn powf(self, power: f64) -> Self {
if power == 0.0 {
return Self::ONE;
}
if power.fract() == 0.0 && self.im == 0.0 {
return self.re.powf(power).into();
}
let (r, theta) = self.to_polar();
Self::from_polar(r.powf(power), theta * power)
}
pub fn exp(self) -> Self {
Self::from_polar(E.powf(self.re), self.im)
}
pub fn ln(self) -> Self {
let (r, theta) = self.to_polar();
Self::new(r.ln(), theta)
}
pub fn log(self, base: impl Into<Self>) -> Self {
let base = base.into();
Self::new(self.abs().ln(), self.arg()) / (Self::new(base.abs().ln(), base.arg()))
}
pub fn sqrt(self) -> Self {
if self.im == 0.0 {
return if self.re >= 0.0 {
Self::new(self.re.sqrt(), 0.0)
} else {
Self::new(0.0, self.re.abs().sqrt())
};
}
let (r, theta) = self.to_polar();
Self::from_polar(r.sqrt(), theta / 2.0)
}
pub fn sin(self) -> Self {
Self::new(
self.re.sin() * self.im.cosh(),
self.re.cos() * self.im.sinh(),
)
}
pub fn cos(self) -> Self {
Self::new(
self.re.cos() * self.im.cosh(),
-self.re.sin() * self.im.sinh(),
)
}
pub fn asin(self) -> Self {
-Self::I * ((Self::ONE - self * self).sqrt() + Self::I * self).ln()
}
pub fn acos(self) -> Self {
-Self::I * (Self::I * (Self::ONE - self * self).sqrt() + self).ln()
}
pub fn is_nan(&self) -> bool {
self.re.is_nan() || self.im.is_nan()
}
pub fn into_real(self) -> Option<f64> {
if self.im.abs() < f64::EPSILON {
Some(self.re)
} else {
None
}
}
pub fn safe_mul(self, rhs: impl Into<Self>) -> Self {
let rhs = rhs.into();
Self {
re: safe_mul(self.re, rhs.re) - safe_mul(self.im, rhs.im),
im: safe_mul(self.re, rhs.im) + safe_mul(self.im, rhs.re),
}
}
pub fn recip(self) -> Self {
Self::ONE / self
}
}
fn safe_mul(a: f64, b: f64) -> f64 {
if a.is_infinite() && b == 0.0 || a == 0.0 && b.is_infinite() {
0.0
} else {
a * b
}
}
impl From<f64> for Complex {
fn from(re: f64) -> Self {
Self { re, im: 0.0 }
}
}
impl From<u8> for Complex {
fn from(value: u8) -> Self {
f64::from(value).into()
}
}
impl fmt::Display for Complex {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.im == 0.0 {
self.re.fmt(f)
} else {
write!(f, "{}r{}i", self.re, self.im)
}
}
}
impl Add for Complex {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self {
re: self.re + rhs.re,
im: self.im + rhs.im,
}
}
}
impl Add<f64> for Complex {
type Output = Self;
fn add(self, rhs: f64) -> Self::Output {
Self {
re: self.re + rhs,
im: self.im,
}
}
}
impl Add<Complex> for f64 {
type Output = Complex;
fn add(self, rhs: Complex) -> Self::Output {
Complex {
re: self + rhs.re,
im: rhs.im,
}
}
}
impl Sub for Complex {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self {
re: self.re - rhs.re,
im: self.im - rhs.im,
}
}
}
impl Sub<f64> for Complex {
type Output = Self;
fn sub(self, rhs: f64) -> Self::Output {
Self {
re: self.re - rhs,
im: self.im,
}
}
}
impl Sub<Complex> for f64 {
type Output = Complex;
fn sub(self, rhs: Complex) -> Self::Output {
Complex {
re: self - rhs.re,
im: -rhs.im,
}
}
}
impl Mul for Complex {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
Self {
re: self.re * rhs.re - self.im * rhs.im,
im: self.re * rhs.im + self.im * rhs.re,
}
}
}
impl Mul<f64> for Complex {
type Output = Self;
fn mul(self, rhs: f64) -> Self::Output {
Self {
re: self.re * rhs,
im: self.im * rhs,
}
}
}
impl Mul<Complex> for f64 {
type Output = Complex;
fn mul(self, rhs: Complex) -> Self::Output {
Complex {
re: self * rhs.re,
im: self * rhs.im,
}
}
}
impl Div for Complex {
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
let denom = rhs.re * rhs.re + rhs.im * rhs.im;
Self {
re: (self.re * rhs.re + self.im * rhs.im) / denom,
im: (self.im * rhs.re - self.re * rhs.im) / denom,
}
}
}
impl Div<f64> for Complex {
type Output = Self;
fn div(self, rhs: f64) -> Self::Output {
Self {
re: self.re / rhs,
im: self.im / rhs,
}
}
}
impl Div<Complex> for f64 {
type Output = Complex;
fn div(self, rhs: Complex) -> Self::Output {
let denom = rhs.re * rhs.re + rhs.im * rhs.im;
Complex {
re: self * rhs.re / denom,
im: -self * rhs.im / denom,
}
}
}
impl Rem for Complex {
type Output = Self;
fn rem(self, rhs: Self) -> Self::Output {
self - (self / rhs).floor() * rhs
}
}
impl Rem<f64> for Complex {
type Output = Self;
fn rem(self, rhs: f64) -> Self::Output {
Self {
re: self.re.rem_euclid(rhs),
im: self.im.rem_euclid(rhs),
}
}
}
impl Rem<Complex> for f64 {
type Output = Complex;
fn rem(self, rhs: Complex) -> Self::Output {
Complex {
re: self % rhs.re,
im: self % rhs.im,
}
}
}
impl Neg for Complex {
type Output = Self;
fn neg(self) -> Self::Output {
Self {
re: -self.re,
im: -self.im,
}
}
}
impl<T> AddAssign<T> for Complex
where
Complex: Add<T, Output = Complex>,
{
fn add_assign(&mut self, rhs: T) {
*self = *self + rhs;
}
}
impl<T> SubAssign<T> for Complex
where
Complex: Sub<T, Output = Complex>,
{
fn sub_assign(&mut self, rhs: T) {
*self = *self - rhs;
}
}
impl<T> MulAssign<T> for Complex
where
Complex: Mul<T, Output = Complex>,
{
fn mul_assign(&mut self, rhs: T) {
*self = *self * rhs;
}
}
impl<T> DivAssign<T> for Complex
where
Complex: Div<T, Output = Complex>,
{
fn div_assign(&mut self, rhs: T) {
*self = *self / rhs;
}
}