use super::{atan2, cossin};
use core::num::Wrapping;
use core::ops::{Add, Deref, DerefMut, Div, Mul, Sub};
use dsp_fixedpoint::{Accu, Q, Shift};
use num_traits::AsPrimitive;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::float::Float as _;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
#[cfg_attr(feature = "bytemuck", derive(bytemuck::Zeroable, bytemuck::Pod))]
#[repr(transparent)]
#[cfg_attr(feature = "serde", serde(transparent))]
pub struct Complex<T>(
pub [T; 2],
);
impl<T> Deref for Complex<T> {
type Target = [T; 2];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Complex<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T: Copy> Complex<T> {
pub const fn new(re: T, im: T) -> Self {
Self([re, im])
}
pub fn re(&self) -> T {
self.0[0]
}
pub fn im(&self) -> T {
self.0[1]
}
}
impl<T: Copy + core::ops::Neg<Output = T>> Complex<T> {
pub fn conj(self) -> Self {
Self([self.0[0], -self.0[1]])
}
}
macro_rules! fwd_binop {
($tr:ident::$meth:ident) => {
impl<T: Copy + core::ops::$tr<Output = T>> core::ops::$tr for Complex<T> {
type Output = Self;
fn $meth(self, rhs: Self) -> Self {
Self([self.0[0].$meth(rhs.0[0]), self.0[1].$meth(rhs.0[1])])
}
}
};
}
fwd_binop!(Add::add);
fwd_binop!(Sub::sub);
fwd_binop!(BitAnd::bitand);
fwd_binop!(BitOr::bitor);
fwd_binop!(BitXor::bitxor);
macro_rules! fwd_binop_inner {
($tr:ident::$meth:ident) => {
impl<T: Copy + core::ops::$tr<Output = T>> core::ops::$tr<T> for Complex<T> {
type Output = Self;
fn $meth(self, rhs: T) -> Self {
Self([self.0[0].$meth(rhs), self.0[1].$meth(rhs)])
}
}
};
}
fwd_binop_inner!(Mul::mul);
fwd_binop_inner!(Div::div);
fwd_binop_inner!(Rem::rem);
fwd_binop_inner!(BitAnd::bitand);
fwd_binop_inner!(BitOr::bitor);
fwd_binop_inner!(BitXor::bitxor);
macro_rules! fwd_unop {
($tr:ident::$meth:ident) => {
impl<T: Copy + core::ops::$tr<Output = T>> core::ops::$tr for Complex<T> {
type Output = Self;
fn $meth(self) -> Self {
Self([self.0[0].$meth(), self.0[1].$meth()])
}
}
};
}
fwd_unop!(Not::not);
fwd_unop!(Neg::neg);
impl<T: Copy + Mul<Output = T> + Add<Output = T> + Sub<Output = T>> Mul for Complex<T> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self([
self.0[0] * rhs.0[0] - self.0[1] * rhs.0[1],
self.0[0] * rhs.0[1] + self.0[1] * rhs.0[0],
])
}
}
impl<
T: 'static + Copy,
A: Copy + Add<Output = A> + Sub<Output = A> + AsPrimitive<T>,
B,
const F: i8,
> Mul<Complex<T>> for Complex<Q<T, B, F>>
where
Q<T, B, F>: Mul<T, Output = A>,
{
type Output = Complex<T>;
fn mul(self, rhs: Complex<T>) -> Complex<T> {
Complex([
(self.0[0] * rhs.0[0] - self.0[1] * rhs.0[1]).as_(),
(self.0[0] * rhs.0[1] + self.0[1] * rhs.0[0]).as_(),
])
}
}
impl<
T: 'static + Copy,
A: Copy + Add<Output = A> + Sub<Output = A> + AsPrimitive<T>,
B,
const F: i8,
> Mul<Complex<Q<T, B, F>>> for Complex<T>
where
T: Mul<Q<T, B, F>, Output = A>,
{
type Output = Complex<T>;
fn mul(self, rhs: Complex<Q<T, B, F>>) -> Complex<T> {
Complex([
(self.0[0] * rhs.0[0] - self.0[1] * rhs.0[1]).as_(),
(self.0[0] * rhs.0[1] + self.0[1] * rhs.0[0]).as_(),
])
}
}
impl<T> core::iter::Sum for Complex<T>
where
Self: Default + Add<Output = Self>,
{
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Default::default(), |c, i| c + i)
}
}
impl<T> core::iter::Product for Complex<T>
where
Self: Default + Mul<Output = Self>,
{
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Default::default(), |c, i| c * i)
}
}
macro_rules! impl_float {
($t:ty) => {
impl Complex<$t> {
pub fn from_angle(angle: $t) -> Self {
let (s, c) = angle.sin_cos();
Self::new(c, s)
}
pub fn arg(&self) -> $t {
self.re().atan2(self.im())
}
pub fn norm_sqr(&self) -> $t {
self.re() * self.re() + self.im() * self.im()
}
}
impl Div for Complex<$t> {
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
(self * rhs.conj()) / rhs.norm_sqr()
}
}
};
}
impl_float!(f32);
impl_float!(f64);
impl Complex<i32> {
pub fn norm_sqr(&self) -> i64 {
let [x, y] = self.0.map(|x| x as i64 * x as i64);
x + y
}
pub fn ilog2(&self) -> u32 {
self.norm_sqr().ilog2()
}
pub fn from_angle(angle: Wrapping<i32>) -> Self {
let (c, s) = cossin(angle.0);
Self::new(c, s)
}
pub fn arg(&self) -> Wrapping<i32> {
Wrapping(atan2(self.im(), self.re()))
}
}
impl<A: Shift, T: Accu<A> + Copy, const F: i8> Complex<Q<A, T, F>> {
pub fn quantize(self) -> Complex<T> {
Complex::new(self.re().quantize(), self.im().quantize())
}
}
impl<T: Copy, A, const F: i8> Complex<Q<T, A, F>> {
pub fn into_bits(self) -> Complex<T> {
Complex::new(self.re().into_bits(), self.im().into_bits())
}
pub fn from_bits(value: Complex<T>) -> Self {
Self::new(Q::from_bits(value.re()), Q::from_bits(value.im()))
}
}
#[cfg(test)]
mod test {
use super::*;
use dsp_fixedpoint::Q32;
#[test]
fn fixedpoint_into_bits_exposes_raw_representation() {
let x = Complex::new(
Q32::<32>::from_bits(0x1234_5678),
Q32::<32>::from_bits(-0x2345_6789),
);
assert_eq!(x.into_bits(), Complex::new(0x1234_5678, -0x2345_6789));
}
#[test]
fn fixedpoint_arg_matches_i32_arg() {
let z = Complex::new(
Q::<i64, i32, 32>::from_bits(1 << 34),
Q::<i64, i32, 32>::from_bits(1 << 34),
);
assert_eq!(z.quantize().arg(), Complex::new(4, 4).arg());
}
}