use ariadnetor_tensor::{DenseTensorData, MemoryOrder, Scalar};
use num_complex::Complex;
#[test]
fn test_complex_f64_scalar_trait() {
let z = Complex::new(3.0, 4.0);
assert_eq!(z.abs(), 5.0);
let scaled = z.scale_real(2.0);
assert_eq!(scaled, Complex::new(6.0, 8.0));
let conjugate = z.conj();
assert_eq!(conjugate, Complex::new(3.0, -4.0));
}
#[test]
fn test_complex_f64_norm() {
let data = vec![Complex::new(3.0, 4.0), Complex::new(0.0, 0.0)];
let tensor = DenseTensorData::from_raw_parts(data, vec![2], MemoryOrder::ColumnMajor);
let norm = tensor.norm();
assert_eq!(norm, 5.0);
assert!(std::any::type_name_of_val(&norm).contains("f64"));
}
#[test]
fn test_complex_f64_normalize() {
let data = vec![Complex::new(1.0, 0.0), Complex::new(0.0, 1.0)];
let mut tensor = DenseTensorData::from_raw_parts(data, vec![2], MemoryOrder::ColumnMajor);
let norm: f64 = tensor.normalize();
let expected_norm: f64 = 2.0f64.sqrt();
assert!((norm - expected_norm).abs() < 1e-10);
{
let data = tensor.data();
let expected = 1.0 / expected_norm;
assert!((data[0].re - expected).abs() < 1e-10);
assert!(data[0].im.abs() < 1e-10);
assert!(data[1].re.abs() < 1e-10);
assert!((data[1].im - expected).abs() < 1e-10);
}
let new_norm = tensor.norm();
assert!((new_norm - 1.0).abs() < 1e-10);
}
#[test]
fn test_complex_f32_norm() {
let data = vec![Complex::new(1.0f32, 1.0f32), Complex::new(1.0f32, -1.0f32)];
let tensor = DenseTensorData::from_raw_parts(data, vec![2], MemoryOrder::ColumnMajor);
let norm = tensor.norm();
assert!((norm - 2.0f32).abs() < 1e-6);
assert!(std::any::type_name_of_val(&norm).contains("f32"));
}
#[test]
fn test_complex_scale_real_in_normalize() {
let z = Complex::new(2.0, 2.0);
let inv_norm = 0.5;
let scaled = z.scale_real(inv_norm);
assert_eq!(scaled, Complex::new(1.0, 1.0));
}
#[test]
fn test_norm_returns_real_type() {
let complex_data = vec![Complex::new(3.0, 4.0)];
let complex_tensor =
DenseTensorData::from_raw_parts(complex_data, vec![1], MemoryOrder::ColumnMajor);
let norm: f64 = complex_tensor.norm(); assert_eq!(norm, 5.0);
}
#[test]
fn test_generic_function_with_scalar() {
fn compute_norm<T: Scalar>(tensor: &DenseTensorData<T>) -> T::Real {
tensor.norm()
}
let real_tensor = DenseTensorData::<f64>::ones_in_order(vec![2, 2], MemoryOrder::ColumnMajor);
let real_norm: f64 = compute_norm(&real_tensor);
assert_eq!(real_norm, 2.0);
let complex_data = vec![Complex::new(3.0, 4.0)];
let complex_tensor =
DenseTensorData::from_raw_parts(complex_data, vec![1], MemoryOrder::ColumnMajor);
let complex_norm: f64 = compute_norm(&complex_tensor);
assert_eq!(complex_norm, 5.0);
}