use serde::{Deserialize, Serialize};
use std::ops::{Add, Div, Mul, Sub};
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct Complex {
pub re: f64,
pub im: f64,
}
impl Complex {
#[inline(always)]
pub fn new(re: f64, im: f64) -> Self {
Self { re, im }
}
#[inline(always)]
pub fn pow(self, p: u32) -> Self {
let mut res = Self::new(1.0, 0.0);
for _ in 0..p {
res = res * self;
}
res
}
#[inline(always)]
pub fn norm_sq(self) -> f64 {
self.re * self.re + self.im * self.im
}
#[inline(always)]
pub fn norm(self) -> f64 {
self.norm_sq().sqrt()
}
#[inline(always)]
pub fn arg(self) -> f64 {
self.im.atan2(self.re)
}
#[inline(always)]
pub fn conj(self) -> Self {
Self::new(self.re, -self.im)
}
}
impl Add for Complex {
type Output = Self;
#[inline(always)]
fn add(self, other: Self) -> Self::Output {
Self::new(self.re + other.re, self.im + other.im)
}
}
impl Sub for Complex {
type Output = Self;
#[inline(always)]
fn sub(self, other: Self) -> Self::Output {
Self::new(self.re - other.re, self.im - other.im)
}
}
impl Mul for Complex {
type Output = Self;
#[inline(always)]
fn mul(self, other: Self) -> Self::Output {
Self::new(
self.re * other.re - self.im * other.im,
self.re * other.im + self.im * other.re,
)
}
}
impl Div for Complex {
type Output = Self;
#[inline(always)]
fn div(self, other: Self) -> Self::Output {
let denom = other.re * other.re + other.im * other.im;
if denom == 0.0 {
Self::new(0.0, 0.0)
} else {
Self::new(
(self.re * other.re + self.im * other.im) / denom,
(self.im * other.re - self.re * other.im) / denom,
)
}
}
}
impl Mul<f64> for Complex {
type Output = Self;
#[inline(always)]
fn mul(self, scalar: f64) -> Self {
Self::new(self.re * scalar, self.im * scalar)
}
}
impl Mul<Complex> for f64 {
type Output = Complex;
#[inline(always)]
fn mul(self, other: Complex) -> Complex {
Complex::new(self * other.re, self * other.im)
}
}
impl Div<f64> for Complex {
type Output = Self;
#[inline(always)]
fn div(self, scalar: f64) -> Self {
if scalar == 0.0 {
Self::new(0.0, 0.0)
} else {
Self::new(self.re / scalar, self.im / scalar)
}
}
}
impl Div<Complex> for f64 {
type Output = Complex;
#[inline(always)]
fn div(self, other: Complex) -> Complex {
let denom = other.re * other.re + other.im * other.im;
if denom == 0.0 {
Complex::new(0.0, 0.0)
} else {
Complex::new((self * other.re) / denom, (-self * other.im) / denom)
}
}
}
#[cfg(test)]
mod tests_complex {
use super::*;
#[test]
fn test_addition() {
let c1 = Complex::new(1.0, 2.0);
let c2 = Complex::new(3.0, 4.0);
let result = c1 + c2;
assert_eq!(result, Complex::new(4.0, 6.0));
}
#[test]
fn test_subtraction() {
let c1 = Complex::new(5.0, 7.0);
let c2 = Complex::new(2.0, 3.0);
let result = c1 - c2;
assert_eq!(result, Complex::new(3.0, 4.0));
}
#[test]
fn test_multiplication() {
let c1 = Complex::new(1.0, 2.0);
let c2 = Complex::new(3.0, 4.0);
let result = c1 * c2;
assert_eq!(result, Complex::new(-5.0, 10.0));
}
#[test]
fn test_division() {
let c1 = Complex::new(1.0, 2.0);
let c2 = Complex::new(3.0, 4.0);
let result = c1 / c2;
assert!((result.re - 0.44).abs() < 1e-9);
assert!((result.im - 0.08).abs() < 1e-9);
}
#[test]
fn test_division_by_zero() {
let c1 = Complex::new(1.0, 2.0);
let c2 = Complex::new(0.0, 0.0);
let result = c1 / c2;
assert_eq!(result, Complex::new(0.0, 0.0));
}
#[test]
fn test_pow() {
let c = Complex::new(1.0, 1.0);
let result = c.pow(3);
assert_eq!(result, Complex::new(-2.0, 2.0));
}
#[test]
fn test_norm_sq() {
let c = Complex::new(3.0, -4.0);
assert_eq!(c.norm_sq(), 25.0);
}
#[test]
fn test_scalar_multiplication_complex_f64() {
let c = Complex::new(1.5, -2.0);
let scalar = 2.0;
let result = c * scalar;
assert_eq!(result, Complex::new(3.0, -4.0));
}
#[test]
fn test_scalar_multiplication_f64_complex() {
let c = Complex::new(1.5, -2.0);
let scalar = 2.0;
let result = scalar * c;
assert_eq!(result, Complex::new(3.0, -4.0));
}
#[test]
fn test_scalar_division_complex_f64() {
let c = Complex::new(3.0, -4.0);
let scalar = 2.0;
let result = c / scalar;
assert_eq!(result, Complex::new(1.5, -2.0));
}
#[test]
fn test_scalar_division_complex_f64_by_zero() {
let c = Complex::new(3.0, -4.0);
let result = c / 0.0;
assert_eq!(result, Complex::new(0.0, 0.0));
}
#[test]
fn test_scalar_division_f64_complex() {
let scalar = 5.0;
let c = Complex::new(3.0, 4.0);
let result = scalar / c;
assert_eq!(result, Complex::new(0.6, -0.8));
}
#[test]
fn test_scalar_division_f64_complex_by_zero() {
let scalar = 5.0;
let c = Complex::new(0.0, 0.0);
let result = scalar / c;
assert_eq!(result, Complex::new(0.0, 0.0));
}
}