use crate::tensor::Tensor;
use approx::{AbsDiffEq, RelativeEq, UlpsEq};
use num_traits::{Float, Num, NumCast, One, ToPrimitive, Zero};
use std::fmt::{Debug, Display};
use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct Complex<T: Float> {
pub re: T,
pub im: T,
}
impl<T: Float> Complex<T> {
pub fn new(re: T, im: T) -> Self {
Self { re, im }
}
pub fn from_real(re: T) -> Self {
Self::new(re, T::zero())
}
pub fn from_imag(im: T) -> Self {
Self::new(T::zero(), im)
}
pub fn real(&self) -> T {
self.re
}
pub fn imag(&self) -> T {
self.im
}
pub fn conj(&self) -> Self {
Self::new(self.re, -self.im)
}
pub fn abs(&self) -> T {
(self.re * self.re + self.im * self.im).sqrt()
}
pub fn abs_sq(&self) -> T {
self.re * self.re + self.im * self.im
}
pub fn arg(&self) -> T {
self.im.atan2(self.re)
}
pub fn to_polar(&self) -> (T, T) {
(self.abs(), self.arg())
}
pub fn from_polar(r: T, theta: T) -> Self {
Self::new(r * theta.cos(), r * theta.sin())
}
pub fn is_finite(&self) -> bool {
self.re.is_finite() && self.im.is_finite()
}
pub fn is_infinite(&self) -> bool {
self.re.is_infinite() || self.im.is_infinite()
}
pub fn is_nan(&self) -> bool {
self.re.is_nan() || self.im.is_nan()
}
pub fn is_real(&self) -> bool {
self.im == T::zero()
}
pub fn is_imag(&self) -> bool {
self.re == T::zero()
}
}
impl<T: Float> Add for Complex<T> {
type Output = Self;
fn add(self, other: Self) -> Self::Output {
Self::new(self.re + other.re, self.im + other.im)
}
}
impl<T: Float> Sub for Complex<T> {
type Output = Self;
fn sub(self, other: Self) -> Self::Output {
Self::new(self.re - other.re, self.im - other.im)
}
}
impl<T: Float> Mul for Complex<T> {
type Output = Self;
fn mul(self, other: Self) -> Self::Output {
Self::new(
self.re * other.re - self.im * other.im,
self.re * other.im + self.im * other.re,
)
}
}
impl<T: Float> Div for Complex<T> {
type Output = Self;
fn div(self, other: Self) -> Self::Output {
let denom = other.abs_sq();
if denom == T::zero() {
Self::new(T::infinity(), T::infinity())
} else {
let conj = other.conj();
let num = self * conj;
Self::new(num.re / denom, num.im / denom)
}
}
}
impl<T: Float> Neg for Complex<T> {
type Output = Self;
fn neg(self) -> Self::Output {
Self::new(-self.re, -self.im)
}
}
impl<T: Float> Add<T> for Complex<T> {
type Output = Self;
fn add(self, scalar: T) -> Self::Output {
Self::new(self.re + scalar, self.im)
}
}
impl<T: Float> Sub<T> for Complex<T> {
type Output = Self;
fn sub(self, scalar: T) -> Self::Output {
Self::new(self.re - scalar, self.im)
}
}
impl<T: Float> Mul<T> for Complex<T> {
type Output = Self;
fn mul(self, scalar: T) -> Self::Output {
Self::new(self.re * scalar, self.im * scalar)
}
}
impl<T: Float> Div<T> for Complex<T> {
type Output = Self;
fn div(self, scalar: T) -> Self::Output {
if scalar == T::zero() {
Self::new(T::infinity(), T::infinity())
} else {
Self::new(self.re / scalar, self.im / scalar)
}
}
}
impl<T: Float> Rem for Complex<T> {
type Output = Self;
fn rem(self, other: Self) -> Self::Output {
Self::new(self.re % other.re, self.im % other.im)
}
}
impl<T: Float> Rem<T> for Complex<T> {
type Output = Self;
fn rem(self, scalar: T) -> Self::Output {
Self::new(self.re % scalar, self.im % scalar)
}
}
impl<T: Float> Zero for Complex<T> {
fn zero() -> Self {
Self::new(T::zero(), T::zero())
}
fn is_zero(&self) -> bool {
self.re.is_zero() && self.im.is_zero()
}
}
impl<T: Float> One for Complex<T> {
fn one() -> Self {
Self::new(T::one(), T::zero())
}
}
impl<T: Float> PartialOrd for Complex<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
let self_mag = self.re * self.re + self.im * self.im;
let other_mag = other.re * other.re + other.im * other.im;
self_mag.partial_cmp(&other_mag)
}
}
impl<T: Float> Num for Complex<T> {
type FromStrRadixErr = <T as Num>::FromStrRadixErr;
fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
let val = T::from_str_radix(str, radix)?;
Ok(Self::from_real(val))
}
}
impl<T: Float + ToPrimitive> ToPrimitive for Complex<T> {
fn to_i64(&self) -> Option<i64> {
self.abs().to_i64()
}
fn to_u64(&self) -> Option<u64> {
self.abs().to_u64()
}
fn to_f64(&self) -> Option<f64> {
self.abs().to_f64()
}
fn to_f32(&self) -> Option<f32> {
self.abs().to_f32()
}
}
impl<T: Float + NumCast> NumCast for Complex<T> {
fn from<N: ToPrimitive>(n: N) -> Option<Self> {
T::from(n).map(|t| Self::from_real(t))
}
}
impl<T: Float> Float for Complex<T> {
fn nan() -> Self {
Self::new(T::nan(), T::nan())
}
fn infinity() -> Self {
Self::new(T::infinity(), T::zero())
}
fn neg_infinity() -> Self {
Self::new(T::neg_infinity(), T::zero())
}
fn neg_zero() -> Self {
Self::new(T::neg_zero(), T::zero())
}
fn min_value() -> Self {
Self::new(T::min_value(), T::zero())
}
fn min_positive_value() -> Self {
Self::new(T::min_positive_value(), T::zero())
}
fn max_value() -> Self {
Self::new(T::max_value(), T::zero())
}
fn is_nan(self) -> bool {
self.re.is_nan() || self.im.is_nan()
}
fn is_infinite(self) -> bool {
self.re.is_infinite() || self.im.is_infinite()
}
fn is_finite(self) -> bool {
self.re.is_finite() && self.im.is_finite()
}
fn is_normal(self) -> bool {
self.re.is_normal() && self.im.is_normal()
}
fn classify(self) -> std::num::FpCategory {
self.re.classify()
}
fn floor(self) -> Self {
Self::new(self.re.floor(), self.im.floor())
}
fn ceil(self) -> Self {
Self::new(self.re.ceil(), self.im.ceil())
}
fn round(self) -> Self {
Self::new(self.re.round(), self.im.round())
}
fn trunc(self) -> Self {
Self::new(self.re.trunc(), self.im.trunc())
}
fn fract(self) -> Self {
Self::new(self.re.fract(), self.im.fract())
}
fn abs(self) -> Self {
let magnitude = (self.re * self.re + self.im * self.im).sqrt();
Self::new(magnitude, T::zero())
}
fn signum(self) -> Self {
let magnitude = (self.re * self.re + self.im * self.im).sqrt();
if magnitude == T::zero() {
Self::zero()
} else {
Self::new(self.re / magnitude, self.im / magnitude)
}
}
fn is_sign_positive(self) -> bool {
self.re.is_sign_positive()
}
fn is_sign_negative(self) -> bool {
self.re.is_sign_negative()
}
fn mul_add(self, a: Self, b: Self) -> Self {
self * a + b
}
fn recip(self) -> Self {
Self::one() / self
}
fn powi(self, n: i32) -> Self {
if self.is_zero() {
if n == 0 {
Self::one() } else {
Self::zero()
}
} else {
let exp = Complex::from_real(T::from(n).unwrap());
(self.ln() * exp).exp()
}
}
fn powf(self, n: Self) -> Self {
if self.is_zero() {
if n.is_zero() {
Self::one() } else {
Self::zero()
}
} else {
(self.ln() * n).exp()
}
}
fn sqrt(self) -> Self {
let r = (self.re * self.re + self.im * self.im).sqrt();
let theta = self.im.atan2(self.re);
let sqrt_r = r.sqrt();
let half_theta = theta / (T::one() + T::one());
Self::new(sqrt_r * half_theta.cos(), sqrt_r * half_theta.sin())
}
fn exp(self) -> Self {
let exp_re = self.re.exp();
Self::new(exp_re * self.im.cos(), exp_re * self.im.sin())
}
fn exp2(self) -> Self {
let ln2 = T::from(std::f64::consts::LN_2).unwrap();
(self * ln2).exp()
}
fn ln(self) -> Self {
let magnitude = (self.re * self.re + self.im * self.im).sqrt();
let phase = self.im.atan2(self.re);
Self::new(magnitude.ln(), phase)
}
fn log(self, base: Self) -> Self {
self.ln() / base.ln()
}
fn log2(self) -> Self {
let ln2 = T::from(std::f64::consts::LN_2).unwrap();
self.ln() / Complex::from_real(ln2)
}
fn log10(self) -> Self {
let ln10 = T::from(std::f64::consts::LN_10).unwrap();
self.ln() / Complex::from_real(ln10)
}
fn max(self, other: Self) -> Self {
if self.abs() >= other.abs() {
self
} else {
other
}
}
fn min(self, other: Self) -> Self {
if self.abs() <= other.abs() {
self
} else {
other
}
}
fn abs_sub(self, other: Self) -> Self {
if self.abs() >= other.abs() {
self - other
} else {
Self::zero()
}
}
fn cbrt(self) -> Self {
let one_third = Complex::from_real(T::one() / T::from(3.0).unwrap());
self.powf(one_third)
}
fn hypot(self, other: Self) -> Self {
(self * self + other * other).sqrt()
}
fn sin(self) -> Self {
Self::new(
self.re.sin() * self.im.cosh(),
self.re.cos() * self.im.sinh(),
)
}
fn cos(self) -> Self {
Self::new(
self.re.cos() * self.im.cosh(),
-self.re.sin() * self.im.sinh(),
)
}
fn tan(self) -> Self {
let sin_val = self.sin();
let cos_val = self.cos();
sin_val / cos_val
}
fn asin(self) -> Self {
let i: Self = Complex::i();
-i * (i * self + (Self::one() - self * self).sqrt()).ln()
}
fn acos(self) -> Self {
let i: Self = Complex::i();
-i * (self + i * (Self::one() - self * self).sqrt()).ln()
}
fn atan(self) -> Self {
let i: Self = Complex::i();
let two = T::from(2.0).unwrap();
(i / two) * ((i + self) / (i - self)).ln()
}
fn atan2(self, other: Self) -> Self {
(self / other).atan()
}
fn sin_cos(self) -> (Self, Self) {
(self.sin(), self.cos())
}
fn exp_m1(self) -> Self {
self.exp() - Self::one()
}
fn ln_1p(self) -> Self {
(self + Self::one()).ln()
}
fn sinh(self) -> Self {
Self::new(
self.re.sinh() * self.im.cos(),
self.re.cosh() * self.im.sin(),
)
}
fn cosh(self) -> Self {
Self::new(
self.re.cosh() * self.im.cos(),
self.re.sinh() * self.im.sin(),
)
}
fn tanh(self) -> Self {
let sinh_val = self.sinh();
let cosh_val = self.cosh();
sinh_val / cosh_val
}
fn asinh(self) -> Self {
(self + (self * self + Self::one()).sqrt()).ln()
}
fn acosh(self) -> Self {
(self + (self * self - Self::one()).sqrt()).ln()
}
fn atanh(self) -> Self {
let two = T::from(2.0).unwrap();
((Self::one() + self) / (Self::one() - self)).ln() / two
}
fn integer_decode(self) -> (u64, i16, i8) {
self.re.integer_decode()
}
}
impl<T: Float + AbsDiffEq> AbsDiffEq for Complex<T>
where
T::Epsilon: Clone,
{
type Epsilon = T::Epsilon;
fn default_epsilon() -> Self::Epsilon {
T::default_epsilon()
}
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
self.re.abs_diff_eq(&other.re, epsilon.clone()) && self.im.abs_diff_eq(&other.im, epsilon)
}
}
impl<T: Float + RelativeEq> RelativeEq for Complex<T>
where
T::Epsilon: Clone,
{
fn default_max_relative() -> Self::Epsilon {
T::default_max_relative()
}
fn relative_eq(
&self,
other: &Self,
epsilon: Self::Epsilon,
max_relative: Self::Epsilon,
) -> bool {
self.re
.relative_eq(&other.re, epsilon.clone(), max_relative.clone())
&& self.im.relative_eq(&other.im, epsilon, max_relative)
}
}
impl<T: Float + UlpsEq> UlpsEq for Complex<T>
where
T::Epsilon: Clone,
{
fn default_max_ulps() -> u32 {
T::default_max_ulps()
}
fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
self.re.ulps_eq(&other.re, epsilon.clone(), max_ulps)
&& self.im.ulps_eq(&other.im, epsilon, max_ulps)
}
}
impl<T: Float + Display> Display for Complex<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.im >= T::zero() {
write!(f, "{}+{}i", self.re, self.im)
} else {
write!(f, "{}{}i", self.re, self.im)
}
}
}
impl<T: Float> Complex<T> {
pub fn zero_const() -> Self {
Self {
re: T::zero(),
im: T::zero(),
}
}
pub fn one_const() -> Self {
Self {
re: T::one(),
im: T::zero(),
}
}
pub fn i() -> Self {
Self {
re: T::zero(),
im: T::one(),
}
}
}
impl<T: Float> Complex<T> {
pub fn exp(&self) -> Self {
let exp_re = self.re.exp();
Self::new(exp_re * self.im.cos(), exp_re * self.im.sin())
}
pub fn ln(&self) -> Self {
Self::new(self.abs().ln(), self.arg())
}
pub fn pow(&self, exp: Self) -> Self {
if self.is_zero() {
if exp.is_zero() {
Self::one() } else {
Self::zero()
}
} else {
(self.ln() * exp).exp()
}
}
pub fn sqrt(&self) -> Self {
let r = self.abs();
let theta = self.arg();
Self::from_polar(r.sqrt(), theta / (T::one() + T::one()))
}
pub fn sin(&self) -> Self {
Self::new(
self.re.sin() * self.im.cosh(),
self.re.cos() * self.im.sinh(),
)
}
pub fn cos(&self) -> Self {
Self::new(
self.re.cos() * self.im.cosh(),
-self.re.sin() * self.im.sinh(),
)
}
pub fn tan(&self) -> Self {
self.sin() / self.cos()
}
pub fn sinh(&self) -> Self {
Self::new(
self.re.sinh() * self.im.cos(),
self.re.cosh() * self.im.sin(),
)
}
pub fn cosh(&self) -> Self {
Self::new(
self.re.cosh() * self.im.cos(),
self.re.sinh() * self.im.sin(),
)
}
pub fn tanh(&self) -> Self {
self.sinh() / self.cosh()
}
}
impl<T: Float> From<T> for Complex<T> {
fn from(re: T) -> Self {
Self::from_real(re)
}
}
impl<T: Float> From<(T, T)> for Complex<T> {
fn from((re, im): (T, T)) -> Self {
Self::new(re, im)
}
}
impl<T: Float> From<Complex<T>> for (T, T) {
fn from(z: Complex<T>) -> Self {
(z.re, z.im)
}
}
impl<T: Float + 'static> Complex<T> {
pub fn from_tensors(real: &Tensor<T>, imag: &Tensor<T>) -> Result<Tensor<Complex<T>>, String> {
if real.shape() != imag.shape() {
return Err("Real and imaginary tensors must have the same shape".to_string());
}
let mut complex_data = Vec::with_capacity(real.data.len());
for (r, i) in real.data.iter().zip(imag.data.iter()) {
complex_data.push(Complex::new(*r, *i));
}
Ok(Tensor::from_vec(complex_data, real.shape().to_vec()))
}
pub fn tensor_real_part(tensor: &Tensor<Complex<T>>) -> Tensor<T> {
let real_data: Vec<T> = tensor.data.iter().map(|z| z.real()).collect();
Tensor::from_vec(real_data, tensor.shape().to_vec())
}
pub fn tensor_imag_part(tensor: &Tensor<Complex<T>>) -> Tensor<T> {
let imag_data: Vec<T> = tensor.data.iter().map(|z| z.imag()).collect();
Tensor::from_vec(imag_data, tensor.shape().to_vec())
}
pub fn tensor_abs(tensor: &Tensor<Complex<T>>) -> Tensor<T> {
let abs_data: Vec<T> = tensor.data.iter().map(|z| z.abs()).collect();
Tensor::from_vec(abs_data, tensor.shape().to_vec())
}
pub fn tensor_arg(tensor: &Tensor<Complex<T>>) -> Tensor<T> {
let arg_data: Vec<T> = tensor.data.iter().map(|z| z.arg()).collect();
Tensor::from_vec(arg_data, tensor.shape().to_vec())
}
pub fn tensor_conj(tensor: &Tensor<Complex<T>>) -> Tensor<Complex<T>> {
let conj_data: Vec<Complex<T>> = tensor.data.iter().map(|z| z.conj()).collect();
Tensor::from_vec(conj_data, tensor.shape().to_vec())
}
}
impl<T: Float + 'static> Tensor<Complex<T>> {
pub fn complex_zeros(shape: &[usize]) -> Self {
let total_size = shape.iter().product();
let data = vec![Complex::zero(); total_size];
Tensor::from_vec(data, shape.to_vec())
}
pub fn complex_ones(shape: &[usize]) -> Self {
let total_size = shape.iter().product();
let data = vec![Complex::one(); total_size];
Tensor::from_vec(data, shape.to_vec())
}
pub fn complex_i(shape: &[usize]) -> Self {
let total_size = shape.iter().product();
let data = vec![Complex::i(); total_size];
Tensor::from_vec(data, shape.to_vec())
}
pub fn from_polar(magnitude: &Tensor<T>, phase: &Tensor<T>) -> Result<Self, String> {
if magnitude.shape() != phase.shape() {
return Err("Magnitude and phase tensors must have the same shape".to_string());
}
let mut complex_data = Vec::with_capacity(magnitude.data.len());
for (mag, ph) in magnitude.data.iter().zip(phase.data.iter()) {
complex_data.push(Complex::from_polar(*mag, *ph));
}
Ok(Tensor::from_vec(complex_data, magnitude.shape().to_vec()))
}
pub fn from_real(real: &Tensor<T>) -> Self {
let complex_data: Vec<Complex<T>> =
real.data.iter().map(|&r| Complex::from_real(r)).collect();
Tensor::from_vec(complex_data, real.shape().to_vec())
}
pub fn from_imag(imag: &Tensor<T>) -> Self {
let complex_data: Vec<Complex<T>> =
imag.data.iter().map(|&i| Complex::from_imag(i)).collect();
Tensor::from_vec(complex_data, imag.shape().to_vec())
}
}
impl<T: Float + 'static> Tensor<Complex<T>> {
pub fn exp(&self) -> Self {
let exp_data: Vec<Complex<T>> = self.data.iter().map(|z| z.exp()).collect();
Tensor::from_vec(exp_data, self.shape().to_vec())
}
pub fn ln(&self) -> Self {
let ln_data: Vec<Complex<T>> = self.data.iter().map(|z| z.ln()).collect();
Tensor::from_vec(ln_data, self.shape().to_vec())
}
pub fn pow(&self, exponent: &Self) -> Result<Self, String> {
if self.shape() != exponent.shape() {
return Err("Shape mismatch for power operation".to_string());
}
let pow_data: Vec<Complex<T>> = self
.data
.iter()
.zip(exponent.data.iter())
.map(|(z, exp)| z.pow(*exp))
.collect();
Ok(Tensor::from_vec(pow_data, self.shape().to_vec()))
}
pub fn pow_scalar(&self, exponent: Complex<T>) -> Self {
let pow_data: Vec<Complex<T>> = self.data.iter().map(|z| z.pow(exponent)).collect();
Tensor::from_vec(pow_data, self.shape().to_vec())
}
pub fn sqrt(&self) -> Self {
let sqrt_data: Vec<Complex<T>> = self.data.iter().map(|z| z.sqrt()).collect();
Tensor::from_vec(sqrt_data, self.shape().to_vec())
}
pub fn sin(&self) -> Self {
let sin_data: Vec<Complex<T>> = self.data.iter().map(|z| z.sin()).collect();
Tensor::from_vec(sin_data, self.shape().to_vec())
}
pub fn cos(&self) -> Self {
let cos_data: Vec<Complex<T>> = self.data.iter().map(|z| z.cos()).collect();
Tensor::from_vec(cos_data, self.shape().to_vec())
}
pub fn tan(&self) -> Self {
let tan_data: Vec<Complex<T>> = self.data.iter().map(|z| z.tan()).collect();
Tensor::from_vec(tan_data, self.shape().to_vec())
}
pub fn sinh(&self) -> Self {
let sinh_data: Vec<Complex<T>> = self.data.iter().map(|z| z.sinh()).collect();
Tensor::from_vec(sinh_data, self.shape().to_vec())
}
pub fn cosh(&self) -> Self {
let cosh_data: Vec<Complex<T>> = self.data.iter().map(|z| z.cosh()).collect();
Tensor::from_vec(cosh_data, self.shape().to_vec())
}
pub fn tanh(&self) -> Self {
let tanh_data: Vec<Complex<T>> = self.data.iter().map(|z| z.tanh()).collect();
Tensor::from_vec(tanh_data, self.shape().to_vec())
}
pub fn matmul(&self, other: &Self) -> Result<Self, String> {
if self.ndim() != 2 || other.ndim() != 2 {
return Err(format!(
"Complex matmul currently supports only 2D matrices, got {}D @ {}D",
self.ndim(),
other.ndim()
));
}
let self_shape = self.shape();
let other_shape = other.shape();
if self_shape[1] != other_shape[0] {
return Err(format!(
"Complex matrix dimension mismatch: {}x{} @ {}x{}",
self_shape[0], self_shape[1], other_shape[0], other_shape[1]
));
}
let m = self_shape[0];
let n = other_shape[1];
let k = self_shape[1];
let mut result = vec![Complex::zero(); m * n];
for i in 0..m {
for j in 0..n {
let mut sum = Complex::zero();
for l in 0..k {
sum = sum + self.data[[i, l]] * other.data[[l, j]];
}
result[i * n + j] = sum;
}
}
Ok(Tensor::from_vec(result, vec![m, n]))
}
pub fn transpose(&self) -> Result<Self, String> {
if self.ndim() != 2 {
return Err("Transpose currently supports only 2D matrices".to_string());
}
let shape = self.shape();
let rows = shape[0];
let cols = shape[1];
let mut result = vec![Complex::zero(); rows * cols];
for i in 0..rows {
for j in 0..cols {
let dst_idx = j * rows + i;
result[dst_idx] = self.data[[i, j]];
}
}
Ok(Tensor::from_vec(result, vec![cols, rows]))
}
pub fn conj_transpose(&self) -> Result<Self, String> {
if self.ndim() != 2 {
return Err("Conjugate transpose currently supports only 2D matrices".to_string());
}
let shape = self.shape();
let rows = shape[0];
let cols = shape[1];
let mut result = vec![Complex::zero(); rows * cols];
for i in 0..rows {
for j in 0..cols {
let dst_idx = j * rows + i;
result[dst_idx] = self.data[[i, j]].conj();
}
}
Ok(Tensor::from_vec(result, vec![cols, rows]))
}
pub fn trace(&self) -> Result<Complex<T>, String> {
if self.ndim() != 2 {
return Err("Trace requires a 2D matrix".to_string());
}
let shape = self.shape();
let rows = shape[0];
let cols = shape[1];
let min_dim = rows.min(cols);
let mut sum = Complex::zero();
for i in 0..min_dim {
sum = sum + self.data[[i, i]];
}
Ok(sum)
}
pub fn determinant(&self) -> Result<Complex<T>, String> {
if self.ndim() != 2 {
return Err("Determinant requires a 2D matrix".to_string());
}
let shape = self.shape();
if shape[0] != shape[1] {
return Err("Determinant requires a square matrix".to_string());
}
if shape[0] == 1 {
return Ok(self.data[[0, 0]]);
} else if shape[0] == 2 {
let a = self.data[[0, 0]]; let b = self.data[[0, 1]]; let c = self.data[[1, 0]]; let d = self.data[[1, 1]]; return Ok(a * d - b * c);
} else {
return Err(
"Determinant only implemented for 1x1 and 2x2 matrices currently".to_string(),
);
}
}
pub fn fft(
&self,
n: Option<usize>,
_dim: Option<isize>,
norm: Option<&str>,
) -> Result<Self, String>
where
T: Float + 'static + Default + Clone + std::fmt::Debug + num_traits::FromPrimitive,
{
if self.ndim() != 1 {
return Err("Complex FFT currently supports only 1D tensors".to_string());
}
let input_len = self.shape()[0];
let fft_len = n.unwrap_or(input_len);
let mut fft_data: Vec<Complex<T>> = self.data.iter().cloned().collect();
if fft_data.len() != fft_len {
fft_data.resize(fft_len, Complex::zero());
}
let result = if fft_len.is_power_of_two() {
self.cooley_tukey_complex(&mut fft_data, false)?
} else {
self.dft_complex(&fft_data, false)?
};
let normalized = self.apply_complex_normalization(result, fft_len, norm, false)?;
Ok(Tensor::from_vec(normalized, vec![fft_len]))
}
pub fn ifft(
&self,
n: Option<usize>,
_dim: Option<isize>,
norm: Option<&str>,
) -> Result<Self, String>
where
T: Float + 'static + Default + Clone + std::fmt::Debug + num_traits::FromPrimitive,
{
if self.ndim() != 1 {
return Err("Complex IFFT currently supports only 1D tensors".to_string());
}
let input_len = self.shape()[0];
let fft_len = n.unwrap_or(input_len);
let mut fft_data: Vec<Complex<T>> = self.data.iter().cloned().collect();
if fft_data.len() != fft_len {
fft_data.resize(fft_len, Complex::zero());
}
let result = if fft_len.is_power_of_two() {
self.cooley_tukey_complex(&mut fft_data, true)?
} else {
self.dft_complex(&fft_data, true)?
};
let normalized = self.apply_complex_normalization(result, fft_len, norm, true)?;
Ok(Tensor::from_vec(normalized, vec![fft_len]))
}
pub fn fftshift(&self, _dim: Option<&[isize]>) -> Result<Self, String> {
if self.ndim() != 1 {
return Err("Complex fftshift currently supports only 1D tensors".to_string());
}
let input_data: Vec<Complex<T>> = self.data.iter().cloned().collect();
let input_len = input_data.len();
let mid = input_len.div_ceil(2);
let mut new_data = Vec::with_capacity(input_len);
new_data.extend_from_slice(&input_data[mid..]);
new_data.extend_from_slice(&input_data[..mid]);
Ok(Tensor::from_vec(new_data, self.shape().to_vec()))
}
pub fn ifftshift(&self, _dim: Option<&[isize]>) -> Result<Self, String> {
if self.ndim() != 1 {
return Err("Complex ifftshift currently supports only 1D tensors".to_string());
}
let input_data: Vec<Complex<T>> = self.data.iter().cloned().collect();
let input_len = input_data.len();
let mid = input_len / 2;
let mut new_data = Vec::with_capacity(input_len);
new_data.extend_from_slice(&input_data[mid..]);
new_data.extend_from_slice(&input_data[..mid]);
Ok(Tensor::from_vec(new_data, self.shape().to_vec()))
}
fn cooley_tukey_complex(
&self,
data: &mut [Complex<T>],
inverse: bool,
) -> Result<Vec<Complex<T>>, String>
where
T: Float + 'static + Default + Clone + std::fmt::Debug + num_traits::FromPrimitive,
{
let n = data.len();
if !n.is_power_of_two() {
return Err("Cooley-Tukey algorithm requires power of two length".to_string());
}
let mut j = 0;
for i in 1..n {
let mut bit = n >> 1;
while j & bit != 0 {
j ^= bit;
bit >>= 1;
}
j ^= bit;
if i < j {
data.swap(i, j);
}
}
let mut length = 2;
while length <= n {
let half_len = length / 2;
let angle = if inverse {
T::from(2.0).unwrap() * T::from(std::f64::consts::PI).unwrap()
/ T::from(length).unwrap()
} else {
-T::from(2.0).unwrap() * T::from(std::f64::consts::PI).unwrap()
/ T::from(length).unwrap()
};
let wlen = Complex::new(angle.cos(), angle.sin());
for i in (0..n).step_by(length) {
let mut w = Complex::one();
for j in 0..half_len {
let u = data[i + j];
let v = data[i + j + half_len] * w;
data[i + j] = u + v;
data[i + j + half_len] = u - v;
w = w * wlen;
}
}
length *= 2;
}
Ok(data.to_vec())
}
fn dft_complex(&self, data: &[Complex<T>], inverse: bool) -> Result<Vec<Complex<T>>, String>
where
T: Float + 'static + Default + Clone + std::fmt::Debug + num_traits::FromPrimitive,
{
let n = data.len();
let mut result = vec![Complex::zero(); n];
let sign = if inverse { T::one() } else { -T::one() };
let pi2 = T::from(2.0).unwrap() * T::from(std::f64::consts::PI).unwrap();
for k in 0..n {
let mut sum = Complex::zero();
for j in 0..n {
let angle =
sign * pi2 * T::from(k).unwrap() * T::from(j).unwrap() / T::from(n).unwrap();
let w = Complex::new(angle.cos(), angle.sin());
sum = sum + data[j] * w;
}
result[k] = sum;
}
Ok(result)
}
fn apply_complex_normalization(
&self,
mut data: Vec<Complex<T>>,
n: usize,
norm: Option<&str>,
inverse: bool,
) -> Result<Vec<Complex<T>>, String>
where
T: Float + 'static + Default + Clone + std::fmt::Debug + num_traits::FromPrimitive,
{
match norm {
Some("forward") => {
if !inverse {
let scale = T::one() / T::from(n).unwrap();
for x in &mut data {
*x = *x * scale;
}
}
}
Some("backward") => {
if inverse {
let scale = T::one() / T::from(n).unwrap();
for x in &mut data {
*x = *x * scale;
}
}
}
Some("ortho") => {
let scale = T::one() / T::from(n).unwrap().sqrt();
for x in &mut data {
*x = *x * scale;
}
}
_ => {
if inverse {
let scale = T::one() / T::from(n).unwrap();
for x in &mut data {
*x = *x * scale;
}
}
}
}
Ok(data)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_complex_creation() {
let z = Complex::new(3.0, 4.0);
assert_eq!(z.real(), 3.0);
assert_eq!(z.imag(), 4.0);
let real = Complex::from_real(5.0);
assert_eq!(real, Complex::new(5.0, 0.0));
let imag = Complex::from_imag(2.0);
assert_eq!(imag, Complex::new(0.0, 2.0));
}
#[test]
fn test_complex_arithmetic() {
let z1 = Complex::new(3.0, 4.0);
let z2 = Complex::new(1.0, 2.0);
assert_eq!(z1 + z2, Complex::new(4.0, 6.0));
assert_eq!(z1 - z2, Complex::new(2.0, 2.0));
assert_eq!(z1 * z2, Complex::new(-5.0, 10.0));
let div = z1 / z2;
assert_relative_eq!(div.real(), 2.2, epsilon = 1e-10);
assert_relative_eq!(div.imag(), -0.4, epsilon = 1e-10);
}
#[test]
fn test_complex_properties() {
let z = Complex::new(3.0, 4.0);
assert_relative_eq!(Complex::abs(&z), 5.0, epsilon = 1e-10);
assert_relative_eq!(z.abs_sq(), 25.0, epsilon = 1e-10);
assert_eq!(z.conj(), Complex::new(3.0, -4.0));
let expected_phase = 4.0_f64.atan2(3.0);
assert_relative_eq!(z.arg(), expected_phase, epsilon = 1e-10);
}
#[test]
fn test_complex_functions() {
let z = Complex::new(1.0, 1.0);
let exp_z = z.exp();
let expected_real = 1.0_f64.exp() * 1.0_f64.cos();
let expected_imag = 1.0_f64.exp() * 1.0_f64.sin();
assert_relative_eq!(exp_z.real(), expected_real, epsilon = 1e-10);
assert_relative_eq!(exp_z.imag(), expected_imag, epsilon = 1e-10);
let sqrt_z = z.sqrt();
assert_relative_eq!((sqrt_z * sqrt_z).real(), z.real(), epsilon = 1e-10);
assert_relative_eq!((sqrt_z * sqrt_z).imag(), z.imag(), epsilon = 1e-10);
}
#[test]
fn test_polar_conversion() {
let z = Complex::new(3.0, 4.0);
let (r, theta) = z.to_polar();
let z_converted = Complex::from_polar(r, theta);
assert_relative_eq!(z_converted.real(), z.real(), epsilon = 1e-10);
assert_relative_eq!(z_converted.imag(), z.imag(), epsilon = 1e-10);
}
#[test]
fn test_trigonometric_functions() {
let z = Complex::new(0.5, 0.3);
let sin_z = z.sin();
let cos_z = z.cos();
let identity = sin_z * sin_z + cos_z * cos_z;
assert_relative_eq!(identity.real(), 1.0, epsilon = 1e-10);
assert_relative_eq!(identity.imag(), 0.0, epsilon = 1e-10);
}
#[test]
fn test_constants() {
let zero = Complex::<f64>::zero_const();
assert_eq!(zero.real(), 0.0);
assert_eq!(zero.imag(), 0.0);
let one = Complex::<f64>::one_const();
assert_eq!(one.real(), 1.0);
assert_eq!(one.imag(), 0.0);
let i = Complex::<f64>::i();
assert_eq!(i.real(), 0.0);
assert_eq!(i.imag(), 1.0);
}
#[test]
fn test_complex_tensor_creation() {
let real = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let imag = Tensor::from_vec(vec![4.0, 5.0, 6.0], vec![3]);
let complex_tensor = Complex::from_tensors(&real, &imag).unwrap();
assert_eq!(complex_tensor.shape(), &[3]);
assert_eq!(complex_tensor.data[0].real(), 1.0);
assert_eq!(complex_tensor.data[0].imag(), 4.0);
assert_eq!(complex_tensor.data[1].real(), 2.0);
assert_eq!(complex_tensor.data[1].imag(), 5.0);
assert_eq!(complex_tensor.data[2].real(), 3.0);
assert_eq!(complex_tensor.data[2].imag(), 6.0);
}
#[test]
fn test_complex_tensor_extraction() {
let complex_data = vec![
Complex::new(1.0, 2.0),
Complex::new(3.0, 4.0),
Complex::new(5.0, 6.0),
];
let complex_tensor = Tensor::from_vec(complex_data, vec![3]);
let real_part = Complex::tensor_real_part(&complex_tensor);
assert_eq!(real_part.data.as_slice().unwrap(), &[1.0, 3.0, 5.0]);
let imag_part = Complex::tensor_imag_part(&complex_tensor);
assert_eq!(imag_part.data.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
let abs_part = Complex::tensor_abs(&complex_tensor);
assert_relative_eq!(abs_part.data[0], 5.0_f64.sqrt(), epsilon = 1e-10);
assert_relative_eq!(abs_part.data[1], 25.0_f64.sqrt(), epsilon = 1e-10);
assert_relative_eq!(abs_part.data[2], 61.0_f64.sqrt(), epsilon = 1e-10);
}
#[test]
fn test_complex_tensor_factory_functions() {
let zeros = Tensor::<Complex<f64>>::complex_zeros(&[2, 3]);
assert_eq!(zeros.shape(), &[2, 3]);
for z in zeros.data.iter() {
assert_eq!(z.real(), 0.0);
assert_eq!(z.imag(), 0.0);
}
let ones = Tensor::<Complex<f64>>::complex_ones(&[2, 2]);
assert_eq!(ones.shape(), &[2, 2]);
for z in ones.data.iter() {
assert_eq!(z.real(), 1.0);
assert_eq!(z.imag(), 0.0);
}
let i_tensor = Tensor::<Complex<f64>>::complex_i(&[1, 4]);
assert_eq!(i_tensor.shape(), &[1, 4]);
for z in i_tensor.data.iter() {
assert_eq!(z.real(), 0.0);
assert_eq!(z.imag(), 1.0);
}
}
#[test]
fn test_complex_tensor_polar_conversion() {
let magnitude = Tensor::from_vec(vec![1.0, 2.0], vec![2]);
let phase = Tensor::from_vec(vec![0.0, std::f64::consts::PI / 2.0], vec![2]);
let complex_tensor = Tensor::from_polar(&magnitude, &phase).unwrap();
assert_eq!(complex_tensor.shape(), &[2]);
assert_relative_eq!(complex_tensor.data[0].real(), 1.0, epsilon = 1e-10);
assert_relative_eq!(complex_tensor.data[0].imag(), 0.0, epsilon = 1e-10);
assert_relative_eq!(complex_tensor.data[1].real(), 0.0, epsilon = 1e-10);
assert_relative_eq!(complex_tensor.data[1].imag(), 2.0, epsilon = 1e-10);
}
#[test]
fn test_complex_tensor_conjugate() {
let complex_data = vec![Complex::new(1.0, 2.0), Complex::new(-3.0, 4.0)];
let complex_tensor = Tensor::from_vec(complex_data, vec![2]);
let conj_tensor = Complex::tensor_conj(&complex_tensor);
assert_eq!(conj_tensor.data[0].real(), 1.0);
assert_eq!(conj_tensor.data[0].imag(), -2.0);
assert_eq!(conj_tensor.data[1].real(), -3.0);
assert_eq!(conj_tensor.data[1].imag(), -4.0);
}
#[test]
fn test_complex_mathematical_functions() {
let complex_data = vec![
Complex::new(1.0, 0.0),
Complex::new(0.0, 1.0),
Complex::new(1.0, 1.0),
];
let complex_tensor = Tensor::from_vec(complex_data, vec![3]);
let exp_result = complex_tensor.exp();
assert_relative_eq!(
exp_result.data[0].real(),
std::f64::consts::E,
epsilon = 1e-10
);
assert_relative_eq!(exp_result.data[0].imag(), 0.0, epsilon = 1e-10);
let ln_result = complex_tensor.ln();
assert_relative_eq!(ln_result.data[0].real(), 0.0, epsilon = 1e-10);
assert_relative_eq!(ln_result.data[0].imag(), 0.0, epsilon = 1e-10);
let sqrt_result = complex_tensor.sqrt();
let sqrt_1_1 = sqrt_result.data[2];
assert_relative_eq!((sqrt_1_1 * sqrt_1_1).real(), 1.0, epsilon = 1e-10);
assert_relative_eq!((sqrt_1_1 * sqrt_1_1).imag(), 1.0, epsilon = 1e-10);
let sin_result = complex_tensor.sin();
let cos_result = complex_tensor.cos();
for i in 0..3 {
let sin_val = sin_result.data[i];
let cos_val = cos_result.data[i];
let identity = sin_val * sin_val + cos_val * cos_val;
assert_relative_eq!(identity.real(), 1.0, epsilon = 1e-10);
assert_relative_eq!(identity.imag(), 0.0, epsilon = 1e-10);
}
}
#[test]
fn test_complex_matrix_multiplication() {
let a_data = vec![
Complex::new(1.0, 1.0),
Complex::new(2.0, 0.0), Complex::new(0.0, 1.0),
Complex::new(1.0, -1.0), ];
let a = Tensor::from_vec(a_data, vec![2, 2]);
let b_data = vec![
Complex::new(1.0, 0.0),
Complex::new(0.0, 1.0), Complex::new(1.0, 1.0),
Complex::new(1.0, 0.0), ];
let b = Tensor::from_vec(b_data, vec![2, 2]);
let result = a.matmul(&b).unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result.data[[0, 0]].real(), 3.0, epsilon = 1e-10);
assert_relative_eq!(result.data[[0, 0]].imag(), 3.0, epsilon = 1e-10);
assert_relative_eq!(result.data[[0, 1]].real(), 1.0, epsilon = 1e-10);
assert_relative_eq!(result.data[[0, 1]].imag(), 1.0, epsilon = 1e-10);
}
#[test]
fn test_complex_matrix_transpose() {
let data = vec![
Complex::new(1.0, 2.0),
Complex::new(3.0, 4.0),
Complex::new(5.0, 6.0),
Complex::new(7.0, 8.0),
];
let matrix = Tensor::from_vec(data, vec![2, 2]);
let transposed = matrix.transpose().unwrap();
assert_eq!(transposed.shape(), &[2, 2]);
assert_eq!(transposed.data[[0, 0]], Complex::new(1.0, 2.0)); assert_eq!(transposed.data[[0, 1]], Complex::new(5.0, 6.0)); assert_eq!(transposed.data[[1, 0]], Complex::new(3.0, 4.0)); assert_eq!(transposed.data[[1, 1]], Complex::new(7.0, 8.0)); }
#[test]
fn test_complex_matrix_conjugate_transpose() {
let data = vec![
Complex::new(1.0, 2.0),
Complex::new(3.0, 4.0),
Complex::new(5.0, 6.0),
Complex::new(7.0, 8.0),
];
let matrix = Tensor::from_vec(data, vec![2, 2]);
let conj_transposed = matrix.conj_transpose().unwrap();
assert_eq!(conj_transposed.shape(), &[2, 2]);
assert_eq!(conj_transposed.data[[0, 0]], Complex::new(1.0, -2.0)); assert_eq!(conj_transposed.data[[0, 1]], Complex::new(5.0, -6.0)); assert_eq!(conj_transposed.data[[1, 0]], Complex::new(3.0, -4.0)); assert_eq!(conj_transposed.data[[1, 1]], Complex::new(7.0, -8.0)); }
#[test]
fn test_complex_matrix_trace() {
let data = vec![
Complex::new(1.0, 1.0),
Complex::new(2.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(4.0, 2.0),
];
let matrix = Tensor::from_vec(data, vec![2, 2]);
let trace = matrix.trace().unwrap();
assert_eq!(trace.real(), 5.0);
assert_eq!(trace.imag(), 3.0);
}
#[test]
fn test_complex_matrix_determinant() {
let data = vec![
Complex::new(1.0, 1.0),
Complex::new(2.0, 0.0),
Complex::new(0.0, 1.0),
Complex::new(1.0, -1.0),
];
let matrix = Tensor::from_vec(data, vec![2, 2]);
let det = matrix.determinant().unwrap();
assert_eq!(det.real(), 2.0);
assert_eq!(det.imag(), -2.0);
}
#[test]
fn test_complex_fft_basic() {
let signal_data = vec![
Complex::new(1.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, 0.0),
Complex::new(0.0, 0.0),
];
let signal = Tensor::from_vec(signal_data, vec![4]);
let fft_result = signal.fft(None, None, None);
assert!(
fft_result.is_ok(),
"Complex FFT should work on basic signal"
);
let fft_tensor = fft_result.unwrap();
assert_eq!(fft_tensor.shape(), &[4]);
let ifft_result = fft_tensor.ifft(None, None, None).unwrap();
for i in 0..4 {
assert_relative_eq!(
ifft_result.data[i].real(),
signal.data[i].real(),
epsilon = 1e-6
);
assert_relative_eq!(
ifft_result.data[i].imag(),
signal.data[i].imag(),
epsilon = 1e-6
);
}
}
#[test]
fn test_complex_fft_shift() {
let data = vec![
Complex::new(1.0, 0.0),
Complex::new(2.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(4.0, 0.0),
];
let tensor = Tensor::from_vec(data, vec![4]);
let shifted = tensor.fftshift(None).unwrap();
assert_eq!(shifted.data[0].real(), 3.0);
assert_eq!(shifted.data[1].real(), 4.0);
assert_eq!(shifted.data[2].real(), 1.0);
assert_eq!(shifted.data[3].real(), 2.0);
let unshifted = shifted.ifftshift(None).unwrap();
for i in 0..4 {
assert_relative_eq!(
unshifted.data[i].real(),
tensor.data[i].real(),
epsilon = 1e-10
);
assert_relative_eq!(
unshifted.data[i].imag(),
tensor.data[i].imag(),
epsilon = 1e-10
);
}
}
#[test]
fn test_complex_power_operations() {
let base_data = vec![
Complex::new(2.0, 0.0),
Complex::new(0.0, 1.0),
Complex::new(1.0, 1.0),
];
let base = Tensor::from_vec(base_data, vec![3]);
let squared = base.pow_scalar(Complex::new(2.0, 0.0));
assert_relative_eq!(squared.data[0].real(), 4.0, epsilon = 1e-10);
assert_relative_eq!(squared.data[0].imag(), 0.0, epsilon = 1e-10);
assert_relative_eq!(squared.data[1].real(), -1.0, epsilon = 1e-10); assert_relative_eq!(squared.data[1].imag(), 0.0, epsilon = 1e-10);
let exp_data = vec![
Complex::new(0.5, 0.0),
Complex::new(2.0, 0.0),
Complex::new(1.0, 0.0),
];
let exponent = Tensor::from_vec(exp_data, vec![3]);
let powered = base.pow(&exponent).unwrap();
assert_relative_eq!(powered.data[0].real(), 2.0_f64.sqrt(), epsilon = 1e-10); assert_relative_eq!(powered.data[0].imag(), 0.0, epsilon = 1e-10);
assert_relative_eq!(powered.data[1].real(), -1.0, epsilon = 1e-10); assert_relative_eq!(powered.data[1].imag(), 0.0, epsilon = 1e-10);
assert_relative_eq!(powered.data[2].real(), 1.0, epsilon = 1e-10); assert_relative_eq!(powered.data[2].imag(), 1.0, epsilon = 1e-10); }
}