use crate::err::try_vec;
use crate::sample::CwtSample;
use crate::spetrum_arith::SpectrumArithmetic;
use crate::{CwtExecutor, CwtWavelet, ScaletError};
use num_complex::Complex;
use num_traits::{AsPrimitive, Zero};
use std::sync::Arc;
use zaft::FftExecutor;
pub(crate) struct CommonCwtExecutor<T> {
pub(crate) wavelet: Arc<dyn CwtWavelet<T> + Send + Sync>,
pub(crate) fft_forward: Arc<dyn FftExecutor<T> + Send + Sync>,
pub(crate) fft_inverse: Arc<dyn FftExecutor<T> + Send + Sync>,
pub(crate) spectrum_arithmetic: Arc<dyn SpectrumArithmetic<T> + Send + Sync>,
pub(crate) scales: Vec<T>,
pub(crate) psi: Vec<T>,
pub(crate) execution_length: usize,
pub(crate) l1_norm: bool,
}
impl<T: CwtSample> CommonCwtExecutor<T>
where
f64: AsPrimitive<T>,
usize: AsPrimitive<T>,
{
fn execute_impl(
&self,
signal_fft: &mut [Complex<T>],
) -> Result<Vec<Vec<Complex<T>>>, ScaletError> {
if self.execution_length != signal_fft.len() {
return Err(ScaletError::InvalidInputSize(
self.execution_length,
signal_fft.len(),
));
}
self.fft_forward
.execute(signal_fft)
.map_err(|x| ScaletError::FftError(x.to_string()))?;
let scales = self.view_scales();
let mut current_psi = try_vec![T::zero(); self.execution_length];
let mut result = try_vec![try_vec![Complex::new(T::zero(), T::zero()); self.execution_length]; scales.len()];
for (&scale, v_dst) in scales.iter().zip(result.iter_mut()) {
for (dst, &psi) in current_psi.iter_mut().zip(self.psi.iter()) {
*dst = psi * scale;
}
let wavelet_fft = self.wavelet.make_wavelet(¤t_psi)?;
if wavelet_fft.len() != self.execution_length {
return Err(ScaletError::WaveletInvalidSize(
self.execution_length,
wavelet_fft.len(),
));
}
let norm_factor = if self.l1_norm {
1.0f64.as_() / v_dst.len().as_()
} else {
1.0f64.as_() / (v_dst.len().as_() * scale.sqrt())
};
self.spectrum_arithmetic.mul_by_b_conj_normalize(
v_dst,
signal_fft,
&wavelet_fft,
norm_factor,
);
self.fft_inverse
.execute(v_dst)
.map_err(|x| ScaletError::FftError(x.to_string()))?;
}
Ok(result)
}
}
impl<T: CwtSample> CwtExecutor<T> for CommonCwtExecutor<T>
where
f64: AsPrimitive<T>,
usize: AsPrimitive<T>,
{
fn execute(&self, input: &[T]) -> Result<Vec<Vec<Complex<T>>>, ScaletError> {
if self.execution_length != input.len() {
return Err(ScaletError::InvalidInputSize(
self.execution_length,
input.len(),
));
}
let mut signal_fft: Vec<Complex<T>> = try_vec![Complex::<T>::default(); input.len()];
for (dst, &src) in signal_fft.iter_mut().zip(input.iter()) {
*dst = Complex::new(src, Zero::zero());
}
self.execute_impl(&mut signal_fft)
}
fn execute_complex(&self, input: &[Complex<T>]) -> Result<Vec<Vec<Complex<T>>>, ScaletError> {
if self.execution_length != input.len() {
return Err(ScaletError::InvalidInputSize(
self.execution_length,
input.len(),
));
}
let mut signal_fft = input.to_vec();
self.execute_impl(&mut signal_fft)
}
fn length(&self) -> usize {
self.execution_length
}
fn view_scales(&self) -> &[T] {
&self.scales
}
}