use scirs2_core::numeric::Float;
use std::sync::Arc;
use torsh_core::{
dtype::{ComplexElement, TensorElement},
error::{Result, TorshError},
};
use crate::core_ops::{Operation, Tensor};
impl<T: ComplexElement + Copy> Tensor<T> {
pub fn complex_conj(&self) -> Result<Self>
where
T: Copy,
{
let data = self.to_vec()?;
let conj_data: Vec<T> = data.iter().map(|&z| z.conj()).collect();
let mut result = Self::from_data(conj_data, self.shape().dims().to_vec(), self.device)?;
result.requires_grad = self.requires_grad;
if self.requires_grad {
result.operation = Operation::Custom(
"complex_conj".to_string(),
vec![Arc::downgrade(&Arc::new(self.clone()))],
);
}
Ok(result)
}
pub fn real(&self) -> Result<Tensor<T::Real>>
where
T::Real: TensorElement + Copy,
{
let data = self.to_vec()?;
let real_data: Vec<T::Real> = data.iter().map(|x| x.real()).collect();
Tensor::from_data(real_data, self.shape().dims().to_vec(), self.device)
}
pub fn imag(&self) -> Result<Tensor<T::Real>>
where
T::Real: TensorElement + Copy,
{
let data = self.to_vec()?;
let imag_data: Vec<T::Real> = data.iter().map(|x| x.imag()).collect();
Tensor::from_data(imag_data, self.shape().dims().to_vec(), self.device)
}
pub fn abs(&self) -> Result<Tensor<T::Real>>
where
T::Real: TensorElement + Copy + num_traits::Float,
{
let data = self.to_vec()?;
let abs_data: Vec<T::Real> = data.iter().map(|x| x.abs()).collect();
Tensor::from_data(abs_data, self.shape().dims().to_vec(), self.device)
}
pub fn angle(&self) -> Result<Tensor<T::Real>>
where
T::Real: TensorElement + Copy + num_traits::Float,
{
let data = self.to_vec()?;
let angle_data: Vec<T::Real> = data.iter().map(|x| x.arg()).collect();
Tensor::from_data(angle_data, self.shape().dims().to_vec(), self.device)
}
pub fn complex(real: &Tensor<T::Real>, imag: &Tensor<T::Real>) -> Result<Self>
where
T::Real: TensorElement + Copy,
{
if real.shape() != imag.shape() {
return Err(TorshError::ShapeMismatch {
expected: real.shape().dims().to_vec(),
got: imag.shape().dims().to_vec(),
});
}
let real_data = real.to_vec()?;
let imag_data = imag.to_vec()?;
let complex_data: Vec<T> = real_data
.iter()
.zip(imag_data.iter())
.map(|(&r, &i)| T::new(r, i))
.collect();
Self::from_data(complex_data, real.shape().dims().to_vec(), real.device)
}
pub fn polar(magnitude: &Tensor<T::Real>, phase: &Tensor<T::Real>) -> Result<Self>
where
T::Real: TensorElement + Copy + num_traits::Float,
{
if magnitude.shape() != phase.shape() {
return Err(TorshError::ShapeMismatch {
expected: magnitude.shape().dims().to_vec(),
got: phase.shape().dims().to_vec(),
});
}
let mag_data = magnitude.to_vec()?;
let phase_data = phase.to_vec()?;
let complex_data: Vec<T> = mag_data
.iter()
.zip(phase_data.iter())
.map(|(&mag, &phase)| {
let real = mag * phase.cos();
let imag = mag * phase.sin();
T::new(real, imag)
})
.collect();
Self::from_data(
complex_data,
magnitude.shape().dims().to_vec(),
magnitude.device,
)
}
pub fn backward_complex(&self) -> Result<()>
where
T: Copy
+ Default
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
if !self.requires_grad {
return Err(TorshError::AutogradError(
"Called backward on tensor that doesn't require grad".to_string(),
));
}
if self.shape().numel() != 1 {
return Err(TorshError::AutogradError(
"Gradient can only be computed for scalar outputs".to_string(),
));
}
let output_grad_data = vec![T::new(
<T::Real as TensorElement>::one(),
<T::Real as TensorElement>::zero(),
)];
let output_grad = Self::from_data(output_grad_data, vec![], self.device)?;
self.backward_complex_impl(&output_grad)?;
Ok(())
}
fn backward_complex_impl(&self, grad_output: &Self) -> Result<()>
where
T: Copy
+ Default
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
{
match &self.operation {
Operation::Leaf => {
let mut grad_lock = self.grad.write().expect("lock should not be poisoned");
if let Some(existing_grad) = grad_lock.as_ref() {
let new_grad = existing_grad.add_op(grad_output)?;
*grad_lock = Some(new_grad);
} else {
*grad_lock = Some(grad_output.clone());
}
}
Operation::Add { lhs, rhs } => {
if lhs.requires_grad {
lhs.backward_complex_impl(grad_output)?;
}
if rhs.requires_grad {
rhs.backward_complex_impl(grad_output)?;
}
}
Operation::Mul { lhs, rhs } => {
if lhs.requires_grad {
let lhs_grad = (**rhs).mul_op(grad_output)?;
lhs.backward_complex_impl(&lhs_grad)?;
}
if rhs.requires_grad {
let rhs_grad = (**lhs).mul_op(grad_output)?;
rhs.backward_complex_impl(&rhs_grad)?;
}
}
Operation::Custom(op_name, inputs) => {
match op_name.as_str() {
"complex_conj" => {
if let Some(weak_input) = inputs.first() {
if let Some(input) = weak_input.upgrade() {
if input.requires_grad {
let conj_grad = grad_output.complex_conj()?;
input.backward_complex_impl(&conj_grad)?;
}
}
}
}
"complex_abs" => {
if let Some(weak_input) = inputs.first() {
if let Some(input) = weak_input.upgrade() {
if input.requires_grad {
let input_data = input.to_vec()?;
let grad_data = grad_output.to_vec()?;
let input_grad_data: Vec<T> = input_data
.iter()
.zip(grad_data.iter())
.map(|(&z, &grad)| {
let abs_z = z.abs();
if abs_z > T::Real::zero() {
let z_normalized =
T::new(z.real() / abs_z, z.imag() / abs_z);
T::new(
z_normalized.real() * grad.real()
- z_normalized.imag() * grad.imag(),
z_normalized.real() * grad.imag()
+ z_normalized.imag() * grad.real(),
)
} else {
T::new(T::Real::zero(), T::Real::zero())
}
})
.collect();
let input_grad = Self::from_data(
input_grad_data,
input.shape().dims().to_vec(),
input.device,
)?;
input.backward_complex_impl(&input_grad)?;
}
}
}
}
_ => {
for weak_input in inputs {
if let Some(input) = weak_input.upgrade() {
if input.requires_grad {
input.backward_complex_impl(grad_output)?;
}
}
}
}
}
}
_ => {
}
}
Ok(())
}
pub fn complex_mul(&self, other: &Self) -> Result<Self>
where
T: std::ops::Mul<Output = T> + std::ops::Add<Output = T> + std::ops::Sub<Output = T>,
{
if self.shape() != other.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().dims().to_vec(),
got: other.shape().dims().to_vec(),
});
}
let self_data = self.to_vec()?;
let other_data = other.to_vec()?;
let result_data: Vec<T> = self_data
.iter()
.zip(other_data.iter())
.map(|(&a, &b)| {
T::new(
a.real() * b.real() - a.imag() * b.imag(),
a.real() * b.imag() + a.imag() * b.real(),
)
})
.collect();
let mut result = Self::from_data(result_data, self.shape().dims().to_vec(), self.device)?;
if self.requires_grad || other.requires_grad {
result.requires_grad = true;
result.operation = Operation::Mul {
lhs: Arc::new(self.clone()),
rhs: Arc::new(other.clone()),
};
}
Ok(result)
}
pub fn complex_add(&self, other: &Self) -> Result<Self>
where
T: std::ops::Add<Output = T>,
{
if self.shape() != other.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().dims().to_vec(),
got: other.shape().dims().to_vec(),
});
}
let self_data = self.to_vec()?;
let other_data = other.to_vec()?;
let result_data: Vec<T> = self_data
.iter()
.zip(other_data.iter())
.map(|(&a, &b)| T::new(a.real() + b.real(), a.imag() + b.imag()))
.collect();
let mut result = Self::from_data(result_data, self.shape().dims().to_vec(), self.device)?;
if self.requires_grad || other.requires_grad {
result.requires_grad = true;
result.operation = Operation::Add {
lhs: Arc::new(self.clone()),
rhs: Arc::new(other.clone()),
};
}
Ok(result)
}
pub fn is_real(&self) -> Result<bool>
where
T::Real: PartialEq + num_traits::Zero,
{
let data = self.to_vec()?;
Ok(data.iter().all(|&z| z.imag() == T::Real::zero()))
}
pub fn is_complex(&self) -> Result<bool>
where
T::Real: PartialEq + num_traits::Zero,
{
Ok(!self.is_real()?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use num_complex::Complex32;
use torsh_core::device::DeviceType;
type C32 = Complex32;
#[test]
fn test_complex_conjugate() {
let data = vec![C32::new(1.0, 2.0), C32::new(3.0, -4.0), C32::new(-1.0, 1.0)];
let tensor =
Tensor::from_data(data, vec![3], DeviceType::Cpu).expect("operation should succeed");
let conj_tensor = tensor
.complex_conj()
.expect("complex conjugate should succeed");
let conj_data = conj_tensor.to_vec().expect("to_vec should succeed");
assert_eq!(conj_data[0], C32::new(1.0, -2.0));
assert_eq!(conj_data[1], C32::new(3.0, 4.0));
assert_eq!(conj_data[2], C32::new(-1.0, -1.0));
}
#[test]
fn test_real_imag_extraction() {
let data = vec![C32::new(1.0, 2.0), C32::new(3.0, -4.0)];
let tensor =
Tensor::from_data(data, vec![2], DeviceType::Cpu).expect("operation should succeed");
let real_part = tensor.real().expect("real extraction should succeed");
let imag_part = tensor.imag().expect("imag extraction should succeed");
assert_eq!(
real_part.to_vec().expect("to_vec should succeed"),
vec![1.0, 3.0]
);
assert_eq!(
imag_part.to_vec().expect("to_vec should succeed"),
vec![2.0, -4.0]
);
}
#[test]
fn test_magnitude_and_phase() {
let data = vec![
C32::new(3.0, 4.0), C32::new(1.0, 0.0), ];
let tensor =
Tensor::from_data(data, vec![2], DeviceType::Cpu).expect("operation should succeed");
let magnitude = tensor.abs().expect("abs computation should succeed");
let phase = tensor.angle().expect("angle computation should succeed");
let mag_data = magnitude.to_vec().expect("to_vec should succeed");
let phase_data = phase.to_vec().expect("to_vec should succeed");
assert!((mag_data[0] - 5.0).abs() < 1e-6);
assert!((mag_data[1] - 1.0).abs() < 1e-6);
assert!((phase_data[1] - 0.0).abs() < 1e-6);
}
#[test]
fn test_complex_from_components() {
let real_data = vec![1.0f32, 2.0, 3.0];
let imag_data = vec![4.0f32, 5.0, 6.0];
let real_tensor = Tensor::from_data(real_data, vec![3], DeviceType::Cpu)
.expect("operation should succeed");
let imag_tensor = Tensor::from_data(imag_data, vec![3], DeviceType::Cpu)
.expect("operation should succeed");
let complex_tensor =
Tensor::<C32>::complex(&real_tensor, &imag_tensor).expect("operation should succeed");
let result_data = complex_tensor.to_vec().expect("to_vec should succeed");
assert_eq!(result_data[0], C32::new(1.0, 4.0));
assert_eq!(result_data[1], C32::new(2.0, 5.0));
assert_eq!(result_data[2], C32::new(3.0, 6.0));
}
#[test]
fn test_complex_arithmetic() {
let a_data = vec![C32::new(1.0, 2.0), C32::new(3.0, 4.0)];
let b_data = vec![C32::new(2.0, 1.0), C32::new(1.0, -1.0)];
let a =
Tensor::from_data(a_data, vec![2], DeviceType::Cpu).expect("operation should succeed");
let b =
Tensor::from_data(b_data, vec![2], DeviceType::Cpu).expect("operation should succeed");
let sum = a.complex_add(&b).expect("operation should succeed");
let sum_data = sum.to_vec().expect("to_vec should succeed");
assert_eq!(sum_data[0], C32::new(3.0, 3.0));
assert_eq!(sum_data[1], C32::new(4.0, 3.0));
let product = a.complex_mul(&b).expect("operation should succeed");
let prod_data = product.to_vec().expect("to_vec should succeed");
assert_eq!(prod_data[0], C32::new(0.0, 5.0));
assert_eq!(prod_data[1], C32::new(7.0, 1.0));
}
#[test]
fn test_polar_construction() {
let mag_data = vec![1.0f32, 2.0];
let phase_data = vec![0.0f32, std::f32::consts::PI / 2.0];
let mag_tensor = Tensor::from_data(mag_data, vec![2], DeviceType::Cpu)
.expect("operation should succeed");
let phase_tensor = Tensor::from_data(phase_data, vec![2], DeviceType::Cpu)
.expect("operation should succeed");
let complex_tensor =
Tensor::<C32>::polar(&mag_tensor, &phase_tensor).expect("operation should succeed");
let result_data = complex_tensor.to_vec().expect("to_vec should succeed");
assert!((result_data[0].re - 1.0).abs() < 1e-6);
assert!((result_data[0].im - 0.0).abs() < 1e-6);
assert!((result_data[1].re - 0.0).abs() < 1e-6);
assert!((result_data[1].im - 2.0).abs() < 1e-6);
}
#[test]
fn test_is_real_complex() {
let real_data = vec![C32::new(1.0, 0.0), C32::new(2.0, 0.0)];
let complex_data = vec![C32::new(1.0, 1.0), C32::new(2.0, 0.0)];
let real_tensor = Tensor::from_data(real_data, vec![2], DeviceType::Cpu)
.expect("operation should succeed");
let complex_tensor = Tensor::from_data(complex_data, vec![2], DeviceType::Cpu)
.expect("operation should succeed");
assert!(real_tensor.is_real().expect("is_real check should succeed"));
assert!(!real_tensor
.is_complex()
.expect("is_complex check should succeed"));
assert!(!complex_tensor
.is_real()
.expect("is_real check should succeed"));
assert!(complex_tensor
.is_complex()
.expect("is_complex check should succeed"));
}
#[test]
fn test_shape_mismatch_errors() {
let a = Tensor::<C32>::zeros(&[2], DeviceType::Cpu).expect("operation should succeed");
let b = Tensor::<C32>::zeros(&[3], DeviceType::Cpu).expect("operation should succeed");
assert!(a.complex_add(&b).is_err());
assert!(a.complex_mul(&b).is_err());
let real_2 = Tensor::<f32>::zeros(&[2], DeviceType::Cpu).expect("operation should succeed");
let imag_3 = Tensor::<f32>::zeros(&[3], DeviceType::Cpu).expect("operation should succeed");
assert!(Tensor::<C32>::complex(&real_2, &imag_3).is_err());
}
}