use crate::cwt_executor::CommonCwtExecutor;
use crate::err::try_vec;
use crate::sample::CwtSample;
use crate::scale_bounds::find_min_max_scales;
use crate::scales::{linear_scales, log_piecewise_scales};
use crate::{CwtExecutor, CwtOptions, CwtWavelet, ScaleType, ScaletError};
use num_traits::AsPrimitive;
use std::sync::Arc;
use zaft::FftDirection;
pub(crate) fn gen_psi<T: CwtSample>(points: usize) -> Result<Vec<T>, ScaletError>
where
usize: AsPrimitive<T>,
f64: AsPrimitive<T>,
isize: AsPrimitive<T>,
{
let mut psih = try_vec![T::zero(); points];
let recip_points = 1f64.as_() / points.as_();
for (i, v) in psih.iter_mut().enumerate() {
let idx = if i < points / 2 {
i as isize
} else {
i as isize - points as isize
};
let w = idx.as_() * T::TWO_PI * recip_points;
*v = w;
}
Ok(psih)
}
pub(crate) fn create_cwt<T: CwtSample>(
wavelet: Arc<dyn CwtWavelet<T> + Send + Sync>,
filter_size: usize,
scale_type: ScaleType,
options: CwtOptions,
) -> Result<Arc<dyn CwtExecutor<T> + Send + Sync>, ScaletError>
where
usize: AsPrimitive<T>,
f64: AsPrimitive<T>,
isize: AsPrimitive<T>,
{
if filter_size == 0 {
return Err(ScaletError::ZeroBaseSized);
}
let scale_bounds = find_min_max_scales(wavelet.clone(), -0.5f64.as_())?;
let scales = match scale_type {
ScaleType::Log => {
log_piecewise_scales(scale_bounds.min, scale_bounds.max, options.nv.as_())?
}
ScaleType::Linear => linear_scales(scale_bounds.min, scale_bounds.max, options.nv.as_())?,
};
let fft_forward = T::make_fft(filter_size, FftDirection::Forward)?;
let fft_inverse = T::make_fft(filter_size, FftDirection::Inverse)?;
let psi = gen_psi(filter_size)?;
let scratch_length = fft_inverse
.scratch_length()
.max(fft_forward.scratch_length());
Ok(Arc::new(CommonCwtExecutor {
wavelet,
fft_forward,
fft_inverse,
scales,
psi,
execution_length: filter_size,
l1_norm: options.l1_norm,
spectrum_arithmetic: T::spectrum_arithmetic(),
scratch_length,
}))
}