use std::ops::{Add, Div, Mul, Neg, Sub};
use crate::{Abs, Cos, Inv, One, Sin, Zero};
#[derive(Debug, Copy, PartialEq, PartialOrd, Clone)]
pub(crate) struct Wrapper<T> {
pub(crate) x: T,
}
impl<T> Neg for Wrapper<T>
where
T: Neg<Output = T>,
{
type Output = Self;
fn neg(self) -> Self::Output {
Self::Output { x: -self.x }
}
}
impl<T> Add for Wrapper<T>
where
T: Add<Output = T>,
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self::Output { x: self.x + rhs.x }
}
}
impl<T> Mul for Wrapper<T>
where
T: Mul<Output = T>,
{
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
Self::Output { x: self.x * rhs.x }
}
}
impl<T> Div for Wrapper<T>
where
T: Div<Output = T>,
{
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
Self::Output { x: self.x / rhs.x }
}
}
impl<T> Sub for Wrapper<T>
where
T: Sub<Output = T>,
{
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self::Output { x: self.x - rhs.x }
}
}
impl<T> complex_division::Number for Wrapper<T>
where
T: Abs
+ Add<Output = T>
+ Clone
+ Div<Output = T>
+ Inv
+ Mul<Output = T>
+ Neg<Output = T>
+ PartialOrd
+ Sub<Output = T>
+ Zero,
{
fn abs(&self) -> Self {
Self { x: self.x.abs() }
}
fn inv(&self) -> Self {
Self { x: self.x.inv() }
}
fn is_zero(&self) -> bool {
self.x.is_zero()
}
}
#[derive(Debug, Copy, PartialEq, Clone)]
pub(crate) struct Complex<T> {
pub(crate) re: Wrapper<T>,
pub(crate) im: Wrapper<T>,
}
impl<T> Complex<T>
where
Wrapper<T>: complex_division::Number,
{
pub(crate) fn new(re: T, im: T) -> Self {
Self {
re: Wrapper { x: re },
im: Wrapper { x: im },
}
}
pub(crate) fn new_internal(re: Wrapper<T>, im: Wrapper<T>) -> Self {
Self { re, im }
}
pub(crate) fn inv(&self) -> Self {
let (re, im) = complex_division::compinv(self.re.clone(), self.im.clone());
Self::new(re.x, im.x)
}
}
impl<T> Complex<T>
where
Wrapper<T>: complex_division::Number,
T: Clone + Cos + Mul<Output = T> + Sin,
{
pub(crate) fn from_polar(r: T, theta: T) -> Self {
Self::new(r.clone() * theta.cos(), r * theta.sin())
}
}
impl<T> Neg for Complex<T>
where
T: Neg<Output = T>,
Wrapper<T>: complex_division::Number,
{
type Output = Self;
fn neg(self) -> Self::Output {
Self::Output::new_internal(-self.re, -self.im)
}
}
impl<T> Add for Complex<T>
where
T: Add<Output = T>,
Wrapper<T>: complex_division::Number,
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self::Output::new_internal(self.re + rhs.re, self.im + rhs.im)
}
}
impl<T> Add<&T> for Complex<T>
where
T: Add<Output = T> + Clone,
Wrapper<T>: complex_division::Number,
{
type Output = Self;
fn add(self, rhs: &T) -> Self::Output {
Self::Output::new(self.re.x + rhs.clone(), self.im.x)
}
}
impl<T> Sub for Complex<T>
where
T: Sub<Output = T>,
Wrapper<T>: complex_division::Number,
{
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self::Output::new_internal(self.re - rhs.re, self.im - rhs.im)
}
}
impl<T> Sub<&Self> for Complex<T>
where
T: Sub<Output = T>,
Wrapper<T>: complex_division::Number,
{
type Output = Self;
fn sub(self, rhs: &Self) -> Self::Output {
Self::Output::new_internal(self.re - rhs.re.clone(), self.im - rhs.im.clone())
}
}
impl<T> Mul for Complex<T>
where
T: Add<Output = T> + Mul<Output = T> + Sub<Output = T>,
Wrapper<T>: complex_division::Number,
{
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
Self::Output::new_internal(
self.re.clone() * rhs.re.clone() - self.im.clone() * rhs.im.clone(),
self.re * rhs.im + self.im * rhs.re,
)
}
}
impl<T> Mul<&Self> for Complex<T>
where
T: Add<Output = T> + Mul<Output = T> + Sub<Output = T>,
Wrapper<T>: complex_division::Number,
{
type Output = Self;
fn mul(self, rhs: &Self) -> Self::Output {
Self::Output::new_internal(
self.re.clone() * rhs.re.clone() - self.im.clone() * rhs.im.clone(),
self.re * rhs.im.clone() + self.im * rhs.re.clone(),
)
}
}
impl<T> Div for Complex<T>
where
T: Div<Output = T>,
Wrapper<T>: complex_division::Number,
{
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
let (re, im) = complex_division::compdiv(self.re, self.im, rhs.re, rhs.im);
Self::Output::new_internal(re, im)
}
}
impl<T> Div<T> for &Complex<T>
where
T: Clone + Div<Output = T>,
Wrapper<T>: complex_division::Number,
{
type Output = Complex<T>;
fn div(self, rhs: T) -> Self::Output {
Self::Output::new(self.re.x.clone() / rhs.clone(), self.im.x.clone() / rhs)
}
}
impl<T> Zero for Complex<T>
where
T: Zero,
{
fn zero() -> Self {
Self {
re: Wrapper { x: T::zero() },
im: Wrapper { x: T::zero() },
}
}
fn is_zero(&self) -> bool {
self.re.x.is_zero() && self.im.x.is_zero()
}
}
impl<T> One for Complex<T>
where
T: One + Zero,
{
fn one() -> Self {
Self {
re: Wrapper { x: T::one() },
im: Wrapper { x: T::zero() },
}
}
fn is_one(&self) -> bool {
self.re.x.is_one() && self.im.x.is_zero()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn complex_neg() {
let c = Complex::new(2., -0.452);
assert_eq!(c, c.neg().neg());
}
#[test]
fn complex_zero() {
assert!(Complex::<f32>::zero().is_zero());
}
#[test]
fn complex_one() {
assert!(Complex::<f64>::one().is_one());
}
}