use crate::DType;
use crate::signal::impl_generic::helpers::{
DetrendMode, detrend_tensor_impl, extract_segments_impl, power_spectrum_impl,
};
use crate::signal::impl_generic::spectral::helpers::generate_window;
use crate::signal::traits::spectral::{Detrend, PsdScaling, WelchParams, WelchResult};
use numr::algorithm::fft::{FftAlgorithms, FftNormalization};
use numr::error::{Error, Result};
use numr::ops::{ComplexOps, ReduceOps, ScalarOps, ShapeOps, TensorOps, UtilityOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn welch_impl<R, C>(client: &C, x: &Tensor<R>, params: WelchParams<R>) -> Result<WelchResult<R>>
where
R: Runtime<DType = DType>,
C: FftAlgorithms<R>
+ ComplexOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ ReduceOps<R>
+ ShapeOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
let n = x.shape()[0];
if n == 0 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "Input signal cannot be empty".to_string(),
});
}
let nperseg = params.nperseg.unwrap_or(256.min(n));
let noverlap = params.noverlap.unwrap_or(nperseg / 2);
let nfft = params.nfft.unwrap_or(nperseg).max(nperseg);
let nfft = nfft.next_power_of_two();
if nperseg > n {
return Err(Error::InvalidArgument {
arg: "nperseg",
reason: format!(
"nperseg ({}) cannot be greater than signal length ({})",
nperseg, n
),
});
}
if noverlap >= nperseg {
return Err(Error::InvalidArgument {
arg: "noverlap",
reason: "noverlap must be less than nperseg".to_string(),
});
}
let window = generate_window(¶ms.window, nperseg, ¶ms.device);
let win_sq = client.mul(&window, &window)?;
let win_sum_sq_tensor = client.sum(&win_sq, &[0], false)?;
let win_sum_sq: f64 = win_sum_sq_tensor.item()?;
let segments = extract_segments_impl(client, x, nperseg, noverlap)?;
let num_segments = segments.shape()[0];
let detrend_mode = match params.detrend {
Detrend::None => DetrendMode::None,
Detrend::Constant => DetrendMode::Constant,
Detrend::Linear => DetrendMode::Linear,
};
let segments_detrended = detrend_tensor_impl(client, &segments, detrend_mode)?;
let window_broadcast = window.reshape(&[1, nperseg])?;
let segments_windowed = client.mul(&segments_detrended, &window_broadcast)?;
let segments_padded = if nfft > nperseg {
let pad_amount = nfft - nperseg;
client.pad(&segments_windowed, &[0, pad_amount], 0.0)?
} else {
segments_windowed
};
let fft_result = client.rfft(&segments_padded, FftNormalization::None)?;
let power = power_spectrum_impl(client, &fft_result)?;
let psd_sum = client.sum(&power, &[0], false)?;
let psd_avg = client.div_scalar(&psd_sum, num_segments as f64)?;
let n_freqs = nfft / 2 + 1;
let scale = match params.scaling {
PsdScaling::Density => 1.0 / (params.fs * win_sum_sq),
PsdScaling::Spectrum => 1.0 / win_sum_sq,
};
let psd_scaled = client.mul_scalar(&psd_avg, scale)?;
let psd_final = if params.onesided && n_freqs > 2 {
let mut scale_factors = vec![2.0f64; n_freqs];
scale_factors[0] = 1.0; if n_freqs > 1 {
scale_factors[n_freqs - 1] = 1.0; }
let scale_tensor = Tensor::from_slice(&scale_factors, &[n_freqs], ¶ms.device);
client.mul(&psd_scaled, &scale_tensor)?
} else {
psd_scaled
};
let freqs = client.rfftfreq(nfft, 1.0 / params.fs, psd_final.dtype(), ¶ms.device)?;
Ok(WelchResult {
freqs,
psd: psd_final,
})
}