use crate::sinc::make_sincs;
use crate::windows::WindowFunction;
use crate::Sample;
macro_rules! interpolator {
(
#[cfg($($cond:tt)*)]
mod $mod:ident;
trait $trait:ident;
) => {
#[cfg($($cond)*)]
pub mod $mod;
#[cfg(not($($cond)*))]
pub mod $mod {
use crate::Sample;
pub trait $trait {
}
impl<T> $trait for T where T: Sample {
}
}
pub use self::$mod::$trait;
}
}
interpolator! {
#[cfg(target_arch = "x86_64")]
mod sinc_interpolator_avx;
trait AvxSample;
}
interpolator! {
#[cfg(target_arch = "x86_64")]
mod sinc_interpolator_sse;
trait SseSample;
}
interpolator! {
#[cfg(target_arch = "aarch64")]
mod sinc_interpolator_neon;
trait NeonSample;
}
pub trait SincInterpolator<T>: Send {
fn get_sinc_interpolated(&self, wave: &[T], index: usize, subindex: usize) -> T;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn nbr_sincs(&self) -> usize;
}
pub struct ScalarInterpolator<T> {
sincs: Vec<Vec<T>>,
length: usize,
nbr_sincs: usize,
}
impl<T> SincInterpolator<T> for ScalarInterpolator<T>
where
T: Sample,
{
fn get_sinc_interpolated(&self, wave: &[T], index: usize, subindex: usize) -> T {
assert!(
(index + self.length) < wave.len(),
"Tried to interpolate for index {}, max for the given input is {}",
index,
wave.len() - self.length - 1
);
assert!(
subindex < self.nbr_sincs,
"Tried to use sinc subindex {}, max is {}",
subindex,
self.nbr_sincs - 1
);
let wave_cut = &wave[index..(index + self.sincs[subindex].len())];
let sinc = &self.sincs[subindex];
unsafe {
let mut acc0 = T::zero();
let mut acc1 = T::zero();
let mut acc2 = T::zero();
let mut acc3 = T::zero();
let mut acc4 = T::zero();
let mut acc5 = T::zero();
let mut acc6 = T::zero();
let mut acc7 = T::zero();
let mut idx = 0;
for _ in 0..wave_cut.len() / 8 {
acc0 += *wave_cut.get_unchecked(idx) * *sinc.get_unchecked(idx);
acc1 += *wave_cut.get_unchecked(idx + 1) * *sinc.get_unchecked(idx + 1);
acc2 += *wave_cut.get_unchecked(idx + 2) * *sinc.get_unchecked(idx + 2);
acc3 += *wave_cut.get_unchecked(idx + 3) * *sinc.get_unchecked(idx + 3);
acc4 += *wave_cut.get_unchecked(idx + 4) * *sinc.get_unchecked(idx + 4);
acc5 += *wave_cut.get_unchecked(idx + 5) * *sinc.get_unchecked(idx + 5);
acc6 += *wave_cut.get_unchecked(idx + 6) * *sinc.get_unchecked(idx + 6);
acc7 += *wave_cut.get_unchecked(idx + 7) * *sinc.get_unchecked(idx + 7);
idx += 8;
}
acc0 + acc1 + acc2 + acc3 + acc4 + acc5 + acc6 + acc7
}
}
fn len(&self) -> usize {
self.length
}
fn nbr_sincs(&self) -> usize {
self.nbr_sincs
}
}
impl<T> ScalarInterpolator<T>
where
T: Sample,
{
pub fn new(
sinc_len: usize,
oversampling_factor: usize,
f_cutoff: f32,
window: WindowFunction,
) -> Self {
assert!(sinc_len % 8 == 0, "Sinc length must be a multiple of 8");
let sincs = make_sincs(sinc_len, oversampling_factor, f_cutoff, window);
Self {
sincs,
length: sinc_len,
nbr_sincs: oversampling_factor,
}
}
}
#[cfg(test)]
mod tests {
use super::ScalarInterpolator;
use super::SincInterpolator;
use crate::WindowFunction;
use num_traits::Float;
use rand::Rng;
fn get_sinc_interpolated<T: Float>(wave: &[T], index: usize, sinc: &[T]) -> T {
let wave_cut = &wave[index..(index + sinc.len())];
wave_cut
.iter()
.zip(sinc.iter())
.fold(T::zero(), |acc, (x, y)| acc + *x * *y)
}
#[test]
fn test_scalar_interpolator_64() {
let mut rng = rand::thread_rng();
let mut wave = Vec::new();
for _ in 0..2048 {
wave.push(rng.gen::<f64>());
}
let sinc_len = 256;
let f_cutoff = 0.9473371669037001;
let oversampling_factor = 256;
let window = WindowFunction::BlackmanHarris2;
let interpolator =
ScalarInterpolator::<f64>::new(sinc_len, oversampling_factor, f_cutoff, window);
let value = interpolator.get_sinc_interpolated(&wave, 333, 123);
let check = get_sinc_interpolated(&wave, 333, &interpolator.sincs[123]);
assert!((value - check).abs() < 1.0e-9);
}
#[test]
fn test_scalar_interpolator_32() {
let mut rng = rand::thread_rng();
let mut wave = Vec::new();
for _ in 0..2048 {
wave.push(rng.gen::<f32>());
}
let sinc_len = 256;
let f_cutoff = 0.9473371669037001;
let oversampling_factor = 256;
let window = WindowFunction::BlackmanHarris2;
let interpolator =
ScalarInterpolator::<f32>::new(sinc_len, oversampling_factor, f_cutoff, window);
let value = interpolator.get_sinc_interpolated(&wave, 333, 123);
let check = get_sinc_interpolated(&wave, 333, &interpolator.sincs[123]);
assert!((value - check).abs() < 1.0e-6);
}
}