use crate::DType;
use crate::signal::impl_generic::helpers::{
DetrendMode, detrend_tensor_impl, extract_segments_impl, power_spectrum_impl,
};
use crate::signal::impl_generic::spectral::helpers::generate_window;
use crate::signal::traits::spectral::{CoherenceResult, Detrend, WelchParams};
use numr::algorithm::fft::{FftAlgorithms, FftNormalization};
use numr::error::{Error, Result};
use numr::ops::{ComplexOps, ReduceOps, ScalarOps, ShapeOps, TensorOps, UtilityOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn coherence_impl<R, C>(
client: &C,
x: &Tensor<R>,
y: &Tensor<R>,
params: WelchParams<R>,
) -> Result<CoherenceResult<R>>
where
R: Runtime<DType = DType>,
C: FftAlgorithms<R>
+ ComplexOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ ReduceOps<R>
+ ShapeOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
let nx = x.shape()[0];
let ny = y.shape()[0];
if nx != ny {
return Err(Error::InvalidArgument {
arg: "y",
reason: "x and y must have the same length".to_string(),
});
}
let n = nx;
if n == 0 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "Input signals cannot be empty".to_string(),
});
}
let nperseg = params.nperseg.unwrap_or(256.min(n));
let noverlap = params.noverlap.unwrap_or(nperseg / 2);
let nfft = params.nfft.unwrap_or(nperseg).max(nperseg);
let nfft = nfft.next_power_of_two();
if nperseg > n || noverlap >= nperseg {
return Err(Error::InvalidArgument {
arg: "nperseg",
reason: "Invalid segment parameters".to_string(),
});
}
let window = generate_window(¶ms.window, nperseg, ¶ms.device);
let x_segments = extract_segments_impl(client, x, nperseg, noverlap)?;
let y_segments = extract_segments_impl(client, y, nperseg, noverlap)?;
let detrend_mode = match params.detrend {
Detrend::None => DetrendMode::None,
Detrend::Constant => DetrendMode::Constant,
Detrend::Linear => DetrendMode::Linear,
};
let x_detrended = detrend_tensor_impl(client, &x_segments, detrend_mode)?;
let y_detrended = detrend_tensor_impl(client, &y_segments, detrend_mode)?;
let window_broadcast = window.reshape(&[1, nperseg])?;
let x_windowed = client.mul(&x_detrended, &window_broadcast)?;
let y_windowed = client.mul(&y_detrended, &window_broadcast)?;
let x_padded = if nfft > nperseg {
let pad_amount = nfft - nperseg;
client.pad(&x_windowed, &[0, pad_amount], 0.0)?
} else {
x_windowed
};
let y_padded = if nfft > nperseg {
let pad_amount = nfft - nperseg;
client.pad(&y_windowed, &[0, pad_amount], 0.0)?
} else {
y_windowed
};
let x_fft = client.rfft(&x_padded, FftNormalization::None)?;
let y_fft = client.rfft(&y_padded, FftNormalization::None)?;
let pxx = power_spectrum_impl(client, &x_fft)?;
let pxx_sum = client.sum(&pxx, &[0], false)?;
let pyy = power_spectrum_impl(client, &y_fft)?;
let pyy_sum = client.sum(&pyy, &[0], false)?;
let x_conj = client.conj(&x_fft)?;
let pxy_complex = client.mul(&x_conj, &y_fft)?;
let pxy_sum = client.sum(&pxy_complex, &[0], false)?;
let pxy_conj = client.conj(&pxy_sum)?;
let pxy_mag_sq_complex = client.mul(&pxy_conj, &pxy_sum)?;
let pxy_mag_sq = client.real(&pxy_mag_sq_complex)?;
let pxx_pyy = client.mul(&pxx_sum, &pyy_sum)?;
let epsilon = 1e-30;
let pxx_pyy_safe = client.add_scalar(&pxx_pyy, epsilon)?;
let cxy = client.div(&pxy_mag_sq, &pxx_pyy_safe)?;
let zeros = Tensor::zeros(cxy.shape(), cxy.dtype(), ¶ms.device);
let ones = Tensor::ones(cxy.shape(), cxy.dtype(), ¶ms.device);
let cxy_clamped = client.maximum(&zeros, &cxy)?;
let cxy_final = client.minimum(&ones, &cxy_clamped)?;
let n_freqs = nfft / 2 + 1;
let freqs = client.rfftfreq(
n_freqs * 2 - 2,
1.0 / params.fs,
cxy_final.dtype(),
¶ms.device,
)?;
let freqs_final = freqs.narrow(0, 0, n_freqs)?;
Ok(CoherenceResult {
freqs: freqs_final,
cxy: cxy_final,
})
}