use crate::complex_arith::ComplexArithmetic;
use crate::err::try_vec;
use crate::sample::CwtSample;
use crate::{BufferStoreMut, CwtExecutor, CwtWavelet, ScaletError, ScaletFrameMut};
use num_complex::Complex;
use num_traits::{AsPrimitive, Zero};
use std::sync::Arc;
use zaft::FftExecutor;
pub(crate) struct CommonCwtExecutor<T> {
pub(crate) wavelet: Arc<dyn CwtWavelet<T> + Send + Sync>,
pub(crate) fft_forward: Arc<dyn FftExecutor<T> + Send + Sync>,
pub(crate) fft_inverse: Arc<dyn FftExecutor<T> + Send + Sync>,
pub(crate) complex_arithmetic: Arc<dyn ComplexArithmetic<T> + Send + Sync>,
pub(crate) scales: Vec<T>,
pub(crate) psi: Vec<T>,
pub(crate) execution_length: usize,
pub(crate) l1_norm: bool,
pub(crate) scratch_length: usize,
pub(crate) built_wavelets: Vec<Complex<T>>,
}
impl<T: CwtSample> CommonCwtExecutor<T>
where
f64: AsPrimitive<T>,
usize: AsPrimitive<T>,
{
fn execute_impl(
&self,
into: &mut ScaletFrameMut<'_, Complex<T>>,
signal_fft: &mut [Complex<T>],
scratch: &mut [Complex<T>],
) -> Result<(), ScaletError> {
if self.execution_length != signal_fft.len() {
return Err(ScaletError::InvalidInputSize(
self.execution_length,
signal_fft.len(),
));
}
if scratch.len() < self.scratch_length {
return Err(ScaletError::InvalidScratchSize(
self.scratch_length,
scratch.len(),
));
}
let (scratch, _) = scratch.split_at_mut(self.scratch_length);
self.fft_forward
.execute_with_scratch(signal_fft, scratch)
.map_err(|x| ScaletError::FftError(x.to_string()))?;
let scales = self.view_scales();
let mut current_psi = try_vec![T::zero(); self.execution_length];
if self.execution_length != into.width {
return Err(ScaletError::InvalidFrame(
format_args!(
"Invalid frame width, expected {} but it was {}",
self.execution_length, into.width
)
.to_string(),
));
}
if scales.len() != into.height {
return Err(ScaletError::InvalidFrame(
format_args!(
"Invalid frame height, expected {} but it was {}",
scales.len(),
into.height
)
.to_string(),
));
}
if into.data.borrow().len() != scales.len() * self.execution_length {
return Err(ScaletError::InvalidFrame(
format_args!(
"Invalid frame size, expected {} but it was {}",
scales.len() * self.execution_length,
into.data.borrow().len()
)
.to_string(),
));
}
if self.built_wavelets.is_empty() {
for (&scale, v_dst) in scales.iter().zip(
into.data
.borrow_mut()
.chunks_exact_mut(self.execution_length),
) {
for (dst, &psi) in current_psi.iter_mut().zip(self.psi.iter()) {
*dst = psi * scale;
}
let wavelet_fft = self.wavelet.make_wavelet(¤t_psi)?;
if wavelet_fft.len() != self.execution_length {
return Err(ScaletError::WaveletInvalidSize(
self.execution_length,
wavelet_fft.len(),
));
}
let norm_factor = if self.l1_norm {
1.0f64.as_() / v_dst.len().as_()
} else {
1.0f64.as_() / (v_dst.len().as_() * scale.sqrt())
};
self.complex_arithmetic.mul_by_b_conj_normalize(
v_dst,
signal_fft,
&wavelet_fft,
norm_factor,
);
self.fft_inverse
.execute_with_scratch(v_dst, scratch)
.map_err(|x| ScaletError::FftError(x.to_string()))?;
}
} else {
for ((&scale, wavelet_fft), v_dst) in self
.scales
.iter()
.zip(self.built_wavelets.chunks_exact(self.execution_length))
.zip(
into.data
.borrow_mut()
.chunks_exact_mut(self.execution_length),
)
{
let norm_factor = if self.l1_norm {
1.0f64.as_() / v_dst.len().as_()
} else {
1.0f64.as_() / (v_dst.len().as_() * scale.sqrt())
};
self.complex_arithmetic.mul_by_b_conj_normalize(
v_dst,
signal_fft,
wavelet_fft,
norm_factor,
);
self.fft_inverse
.execute_with_scratch(v_dst, scratch)
.map_err(|x| ScaletError::FftError(x.to_string()))?;
}
}
Ok(())
}
}
impl<T: CwtSample> CwtExecutor<T> for CommonCwtExecutor<T>
where
f64: AsPrimitive<T>,
usize: AsPrimitive<T>,
{
fn execute(&self, input: &[T]) -> Result<ScaletFrameMut<'_, Complex<T>>, ScaletError> {
let mut frame = ScaletFrameMut {
data: BufferStoreMut::Owned(
try_vec![Complex::zero(); self.view_scales().len() * self.execution_length],
),
height: self.view_scales().len(),
width: self.execution_length,
};
let mut scratch = try_vec![Complex::zero(); self.scratch_length];
self.execute_with_scratch(input, &mut frame, &mut scratch)?;
Ok(frame)
}
fn execute_with_scratch(
&self,
input: &[T],
into_frame: &mut ScaletFrameMut<'_, Complex<T>>,
scratch: &mut [Complex<T>],
) -> Result<(), ScaletError> {
if self.execution_length != input.len() {
return Err(ScaletError::InvalidInputSize(
self.execution_length,
input.len(),
));
}
let mut signal_fft: Vec<Complex<T>> = try_vec![Complex::<T>::default(); input.len()];
for (dst, &src) in signal_fft.iter_mut().zip(input.iter()) {
*dst = Complex::new(src, Zero::zero());
}
self.execute_impl(into_frame, &mut signal_fft, scratch)?;
Ok(())
}
fn execute_complex(
&self,
input: &[Complex<T>],
) -> Result<ScaletFrameMut<'_, Complex<T>>, ScaletError> {
if self.execution_length != input.len() {
return Err(ScaletError::InvalidInputSize(
self.execution_length,
input.len(),
));
}
let mut frame = ScaletFrameMut {
data: BufferStoreMut::Owned(
try_vec![Complex::zero(); self.view_scales().len() * self.execution_length],
),
height: self.view_scales().len(),
width: self.execution_length,
};
let mut scratch = try_vec![Complex::zero(); self.scratch_length];
let mut signal_fft = input.to_vec();
self.execute_impl(&mut frame, &mut signal_fft, &mut scratch)?;
Ok(frame)
}
fn execute_complex_with_scratch(
&self,
input: &[Complex<T>],
into: &mut ScaletFrameMut<'_, Complex<T>>,
scratch: &mut [Complex<T>],
) -> Result<(), ScaletError> {
if self.execution_length != input.len() {
return Err(ScaletError::InvalidInputSize(
self.execution_length,
input.len(),
));
}
let mut signal_fft = input.to_vec();
self.execute_impl(into, &mut signal_fft, scratch)?;
Ok(())
}
fn length(&self) -> usize {
self.execution_length
}
fn view_scales(&self) -> &[T] {
&self.scales
}
fn scratch_length(&self) -> usize {
self.scratch_length
}
}