use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::error::FerrayResult;
use num_complex::Complex;
use num_traits::Float;
pub fn real<T, D>(input: &Array<Complex<T>, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float,
Complex<T>: Element,
D: Dimension,
{
let data: Vec<T> = input.iter().map(|c| c.re).collect();
Array::from_vec(input.dim().clone(), data)
}
pub fn imag<T, D>(input: &Array<Complex<T>, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float,
Complex<T>: Element,
D: Dimension,
{
let data: Vec<T> = input.iter().map(|c| c.im).collect();
Array::from_vec(input.dim().clone(), data)
}
pub fn conj<T, D>(input: &Array<Complex<T>, D>) -> FerrayResult<Array<Complex<T>, D>>
where
T: Element + Float,
Complex<T>: Element,
D: Dimension,
{
let data: Vec<Complex<T>> = input.iter().map(num_complex::Complex::conj).collect();
Array::from_vec(input.dim().clone(), data)
}
pub fn conjugate<T, D>(input: &Array<Complex<T>, D>) -> FerrayResult<Array<Complex<T>, D>>
where
T: Element + Float,
Complex<T>: Element,
D: Dimension,
{
conj(input)
}
pub fn angle<T, D>(input: &Array<Complex<T>, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float,
Complex<T>: Element,
D: Dimension,
{
let data: Vec<T> = input.iter().map(|c| c.im.atan2(c.re)).collect();
Array::from_vec(input.dim().clone(), data)
}
pub fn abs<T, D>(input: &Array<Complex<T>, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float,
Complex<T>: Element,
D: Dimension,
{
let data: Vec<T> = input
.iter()
.map(|c| (c.re * c.re + c.im * c.im).sqrt())
.collect();
Array::from_vec(input.dim().clone(), data)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
use num_complex::Complex64;
fn arr1_c64(data: Vec<Complex64>) -> Array<Complex64, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
#[test]
fn test_real() {
let a = arr1_c64(vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)]);
let r = real(&a).unwrap();
assert_eq!(r.as_slice().unwrap(), &[1.0, 3.0]);
}
#[test]
fn test_imag() {
let a = arr1_c64(vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)]);
let r = imag(&a).unwrap();
assert_eq!(r.as_slice().unwrap(), &[2.0, 4.0]);
}
#[test]
fn test_conj() {
let a = arr1_c64(vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, -4.0)]);
let r = conj(&a).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s[0], Complex64::new(1.0, -2.0));
assert_eq!(s[1], Complex64::new(3.0, 4.0));
}
#[test]
fn test_conjugate_alias() {
let a = arr1_c64(vec![Complex64::new(1.0, 2.0)]);
let r = conjugate(&a).unwrap();
assert_eq!(r.as_slice().unwrap()[0], Complex64::new(1.0, -2.0));
}
#[test]
fn test_angle() {
let a = arr1_c64(vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 1.0),
Complex64::new(-1.0, 0.0),
]);
let r = angle(&a).unwrap();
let s = r.as_slice().unwrap();
assert!((s[0] - 0.0).abs() < 1e-12);
assert!((s[1] - std::f64::consts::FRAC_PI_2).abs() < 1e-12);
assert!((s[2] - std::f64::consts::PI).abs() < 1e-12);
}
#[test]
fn test_abs() {
let a = arr1_c64(vec![Complex64::new(3.0, 4.0), Complex64::new(0.0, 1.0)]);
let r = abs(&a).unwrap();
let s = r.as_slice().unwrap();
assert!((s[0] - 5.0).abs() < 1e-12);
assert!((s[1] - 1.0).abs() < 1e-12);
}
}