use scirs2_core::ndarray::{Array, Array2, Dimension, IxDyn};
use scirs2_core::numeric::Complex64;
use std::ops::Not;
#[allow(dead_code)]
pub fn enforce_hermitian_symmetry(array: &mut Array2<Complex64>) {
let (rows, cols) = array.dim();
if rows > 0 && cols > 0 {
array[[0, 0]] = Complex64::new(array[[0, 0]].re, 0.0);
}
if rows > 1 && cols > 1 {
for j in 1..cols / 2 + (cols % 2).not() {
let conj_val = array[[0, cols - j]].conj();
array[[0, j]] = conj_val;
}
for i in 1..rows / 2 + (rows % 2).not() {
let conj_val = array[[rows - i, 0]].conj();
array[[i, 0]] = conj_val;
}
if rows % 2 == 0 && rows > 0 {
array[[rows / 2, 0]] = Complex64::new(array[[rows / 2, 0]].re, 0.0);
}
if cols % 2 == 0 && cols > 0 {
array[[0, cols / 2]] = Complex64::new(array[[0, cols / 2]].re, 0.0);
}
}
}
#[allow(dead_code)]
pub fn enforce_hermitian_symmetry_nd(array: &mut Array<Complex64, IxDyn>) {
let shape = array.shape().to_vec();
let ndim = shape.len();
if ndim == 0 || array.is_empty() {
return;
}
if let Some(slice) = array.as_slice_mut() {
if !slice.is_empty() {
slice[0] = Complex64::new(slice[0].re, 0.0);
}
}
match ndim {
1 => {
if let Some(slice) = array.as_slice_mut() {
let n = slice.len();
for i in 1..n / 2 + 1 {
if i < n && (n - i) < n {
let avg = (slice[i] + slice[n - i].conj()) * 0.5;
slice[i] = avg;
slice[n - i] = avg.conj();
}
}
if n % 2 == 0 && n >= 2 {
slice[n / 2] = Complex64::new(slice[n / 2].re, 0.0);
}
}
}
2 => {
if let Ok(mut array2) = array
.clone()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
{
enforce_hermitian_symmetry(&mut array2);
let view2 = array2.view();
let flat = view2.as_slice().expect("Operation failed");
if let Some(target) = array.as_slice_mut() {
target.copy_from_slice(flat);
}
}
}
_ => {
if let Ok(mut view) = array
.view_mut()
.into_dimensionality::<scirs2_core::ndarray::Ix3>()
{
let (dim1, dim2, _) = view.dim();
for k in 0..view.dim().2 {
let mut slice = view.slice_mut(scirs2_core::ndarray::s![.., .., k]);
let mut array2 = Array2::zeros((dim1, dim2));
for i in 0..dim1 {
for j in 0..dim2 {
array2[[i, j]] = slice[[i, j]];
}
}
enforce_hermitian_symmetry(&mut array2);
for i in 0..dim1 {
for j in 0..dim2 {
slice[[i, j]] = array2[[i, j]];
}
}
}
}
}
}
}
#[allow(dead_code)]
pub fn is_hermitian_symmetric<D>(array: &Array<Complex64, D>, tolerance: Option<f64>) -> bool
where
D: Dimension,
{
let tol = tolerance.unwrap_or(1e-10);
let shape = array.shape();
if !shape.is_empty() && !array.is_empty() {
let dc_val = &array.as_slice().expect("Operation failed")[0];
if dc_val.im.abs() > tol {
return false;
}
}
if shape.len() == 1 && shape[0] > 1 {
let n = shape[0];
let data = array.as_slice().expect("Operation failed");
for i in 1..n / 2 + 1 {
if i < n && (n - i) < n {
let a = &data[i];
let b = data[n - i].conj();
if (a.re - b.re).abs() > tol || (a.im - b.im).abs() > tol {
return false;
}
}
}
return true;
}
if shape.len() == 2 {
let (rows, cols) = (shape[0], shape[1]);
let array2 = array
.to_owned()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.expect("Operation failed");
for j in 1..cols / 2 + 1 {
if j < cols && (cols - j) < cols {
let a = &array2[[0, j]];
let b = array2[[0, cols - j]].conj();
if (a.re - b.re).abs() > tol || (a.im - b.im).abs() > tol {
return false;
}
}
}
for i in 1..rows / 2 + 1 {
if i < rows && (rows - i) < rows {
let a = &array2[[i, 0]];
let b = array2[[rows - i, 0]].conj();
if (a.re - b.re).abs() > tol || (a.im - b.im).abs() > tol {
return false;
}
}
}
}
true
}
#[allow(dead_code)]
pub fn create_hermitian_symmetric_signal(
amplitudes: &[f64],
randomize_phases: bool,
) -> Vec<Complex64> {
use scirs2_core::random::{Rng, RngExt};
let n = amplitudes.len();
let mut result = Vec::with_capacity(n);
result.push(Complex64::new(amplitudes[0], 0.0));
let mut rng = scirs2_core::random::rng();
for (_i, &) in amplitudes.iter().enumerate().skip(1).take(n / 2 - 1) {
let phase = if randomize_phases {
2.0 * std::f64::consts::PI * rng.random::<f64>()
} else {
0.0
};
let value = Complex64::from_polar(amp, phase);
result.push(value);
result.push(value.conj());
}
if n.is_multiple_of(2) && n > 0 {
result.push(Complex64::new(amplitudes[n / 2], 0.0));
}
result
}