use crate::sinc::make_sincs;
use crate::windows::WindowFunction;
use crate::Sample;
pub(crate) mod aligned_buf;
pub(crate) use aligned_buf::AlignedBuf;
macro_rules! interpolator {
(
#[cfg($($cond:tt)*)]
mod $mod:ident;
trait $trait:ident;
) => {
#[cfg($($cond)*)]
pub mod $mod;
#[cfg(not($($cond)*))]
pub mod $mod {
use crate::Sample;
pub trait $trait {
}
impl<T> $trait for T where T: Sample {
}
}
pub use self::$mod::$trait;
}
}
interpolator! {
#[cfg(target_arch = "x86_64")]
mod sinc_interpolator_avx;
trait AvxSample;
}
interpolator! {
#[cfg(target_arch = "x86_64")]
mod sinc_interpolator_sse;
trait SseSample;
}
interpolator! {
#[cfg(target_arch = "aarch64")]
mod sinc_interpolator_neon;
trait NeonSample;
}
#[cfg_attr(feature = "bench_asyncro", visibility::make(pub))]
pub(crate) trait SincInterpolator<T>: Send {
fn get_sinc_dot_product(&self, wave: &[T], index: usize, sinc: &[T]) -> T;
fn get_sincs(&self) -> &[AlignedBuf<T>];
fn nbr_points(&self) -> usize;
fn nbr_sincs(&self) -> usize;
#[inline]
fn prefetch_sinc(&self, _subindex: usize) {}
#[inline]
fn get_sinc_interpolated(&self, wave: &[T], index: usize, subindex: usize) -> T {
assert!(
(index + self.nbr_points()) < wave.len(),
"Tried to interpolate for index {}, max for the given input is {}",
index,
wave.len() - self.nbr_points() - 1
);
assert!(
subindex < self.nbr_sincs(),
"Tried to use sinc subindex {}, max is {}",
subindex,
self.nbr_sincs() - 1
);
self.get_sinc_dot_product(wave, index, &self.get_sincs()[subindex])
}
fn make_combined_sinc(
&self,
nearest: &[(isize, isize)],
weights: &[T],
combined: &mut [T],
) -> isize
where
T: Sample,
{
debug_assert_eq!(
combined.len(),
self.nbr_points() + 1,
"combined must be nbr_points()+1: the extra element holds any spillover \
from nearest points at the higher integer index"
);
let min_idx = nearest.iter().map(|n| n.0).min().unwrap();
unsafe {
std::ptr::write_bytes(combined.as_mut_ptr(), 0, combined.len());
}
for (n, &w) in nearest.iter().zip(weights.iter()) {
let shift = (n.0 - min_idx) as usize;
for (k, &s) in self.get_sincs()[n.1 as usize].iter().enumerate() {
combined[shift + k] += w * s;
}
}
min_idx
}
}
#[cfg_attr(feature = "bench_asyncro", visibility::make(pub))]
pub(crate) enum AnyInterpolator<T>
where
T: AvxSample + SseSample + NeonSample + Sample,
{
#[cfg(target_arch = "x86_64")]
Avx(sinc_interpolator_avx::AvxInterpolator<T>),
#[cfg(target_arch = "x86_64")]
Sse(sinc_interpolator_sse::SseInterpolator<T>),
#[cfg(target_arch = "aarch64")]
Neon(sinc_interpolator_neon::NeonInterpolator<T>),
Scalar(ScalarInterpolator<T>),
}
macro_rules! dispatch {
($self:expr, $method:ident ($($arg:expr),*)) => {
match $self {
#[cfg(target_arch = "x86_64")]
AnyInterpolator::Avx(i) => i.$method($($arg),*),
#[cfg(target_arch = "x86_64")]
AnyInterpolator::Sse(i) => i.$method($($arg),*),
#[cfg(target_arch = "aarch64")]
AnyInterpolator::Neon(i) => i.$method($($arg),*),
AnyInterpolator::Scalar(i) => i.$method($($arg),*),
}
};
}
impl<T> SincInterpolator<T> for AnyInterpolator<T>
where
T: AvxSample + SseSample + NeonSample + Sample,
{
#[inline]
fn get_sinc_dot_product(&self, wave: &[T], index: usize, sinc: &[T]) -> T {
dispatch!(self, get_sinc_dot_product(wave, index, sinc))
}
#[inline]
fn get_sincs(&self) -> &[AlignedBuf<T>] {
dispatch!(self, get_sincs())
}
#[inline]
fn nbr_points(&self) -> usize {
dispatch!(self, nbr_points())
}
#[inline]
fn nbr_sincs(&self) -> usize {
dispatch!(self, nbr_sincs())
}
#[inline]
fn prefetch_sinc(&self, subindex: usize) {
dispatch!(self, prefetch_sinc(subindex))
}
#[inline]
fn make_combined_sinc(
&self,
nearest: &[(isize, isize)],
weights: &[T],
combined: &mut [T],
) -> isize
where
T: Sample,
{
dispatch!(self, make_combined_sinc(nearest, weights, combined))
}
}
#[cfg_attr(feature = "bench_asyncro", visibility::make(pub))]
pub(crate) struct ScalarInterpolator<T> {
sincs: Vec<AlignedBuf<T>>,
length: usize,
nbr_sincs: usize,
}
impl<T> SincInterpolator<T> for ScalarInterpolator<T>
where
T: Sample,
{
fn get_sinc_dot_product(&self, wave: &[T], index: usize, sinc: &[T]) -> T {
let wave_cut = &wave[index..(index + sinc.len())];
unsafe {
let mut acc0 = T::zero();
let mut acc1 = T::zero();
let mut acc2 = T::zero();
let mut acc3 = T::zero();
let mut acc4 = T::zero();
let mut acc5 = T::zero();
let mut acc6 = T::zero();
let mut acc7 = T::zero();
let mut idx = 0;
for _ in 0..wave_cut.len() / 8 {
acc0 += *wave_cut.get_unchecked(idx) * *sinc.get_unchecked(idx);
acc1 += *wave_cut.get_unchecked(idx + 1) * *sinc.get_unchecked(idx + 1);
acc2 += *wave_cut.get_unchecked(idx + 2) * *sinc.get_unchecked(idx + 2);
acc3 += *wave_cut.get_unchecked(idx + 3) * *sinc.get_unchecked(idx + 3);
acc4 += *wave_cut.get_unchecked(idx + 4) * *sinc.get_unchecked(idx + 4);
acc5 += *wave_cut.get_unchecked(idx + 5) * *sinc.get_unchecked(idx + 5);
acc6 += *wave_cut.get_unchecked(idx + 6) * *sinc.get_unchecked(idx + 6);
acc7 += *wave_cut.get_unchecked(idx + 7) * *sinc.get_unchecked(idx + 7);
idx += 8;
}
acc0 + acc1 + acc2 + acc3 + acc4 + acc5 + acc6 + acc7
}
}
fn get_sincs(&self) -> &[AlignedBuf<T>] {
&self.sincs
}
fn nbr_points(&self) -> usize {
self.length
}
fn nbr_sincs(&self) -> usize {
self.nbr_sincs
}
}
impl<T> ScalarInterpolator<T>
where
T: Sample,
{
pub fn new(
sinc_len: usize,
oversampling_factor: usize,
f_cutoff: f32,
window: WindowFunction,
) -> Self {
assert!(sinc_len % 8 == 0, "Sinc length must be a multiple of 8");
let raw_sincs: Vec<Vec<T>> = make_sincs(sinc_len, oversampling_factor, f_cutoff, window);
let sincs = raw_sincs
.into_iter()
.map(|row| AlignedBuf::from_slice(&row))
.collect();
Self {
sincs,
length: sinc_len,
nbr_sincs: oversampling_factor,
}
}
}
#[cfg(test)]
mod tests {
use super::ScalarInterpolator;
use super::SincInterpolator;
use crate::WindowFunction;
use num_traits::Float;
use test_log::test;
fn get_sinc_interpolated<T: Float>(wave: &[T], index: usize, sinc: &[T]) -> T {
let wave_cut = &wave[index..(index + sinc.len())];
wave_cut
.iter()
.zip(sinc.iter())
.fold(T::zero(), |acc, (x, y)| acc + *x * *y)
}
#[test]
fn test_scalar_interpolator_64() {
let mut wave = Vec::new();
for _ in 0..2048 {
wave.push(rand::random::<f64>());
}
let sinc_len = 256;
let f_cutoff = 0.94733715;
let oversampling_factor = 256;
let window = WindowFunction::BlackmanHarris2;
let interpolator =
ScalarInterpolator::<f64>::new(sinc_len, oversampling_factor, f_cutoff, window);
let value = interpolator.get_sinc_interpolated(&wave, 333, 123);
let check = get_sinc_interpolated(&wave, 333, &interpolator.sincs[123]);
assert!((value - check).abs() < 1.0e-9);
}
#[test]
fn test_scalar_interpolator_32() {
let mut wave = Vec::new();
for _ in 0..2048 {
wave.push(rand::random::<f32>());
}
let sinc_len = 256;
let f_cutoff = 0.94733715;
let oversampling_factor = 256;
let window = WindowFunction::BlackmanHarris2;
let interpolator =
ScalarInterpolator::<f32>::new(sinc_len, oversampling_factor, f_cutoff, window);
let value = interpolator.get_sinc_interpolated(&wave, 333, 123);
let check = get_sinc_interpolated(&wave, 333, &interpolator.sincs[123]);
assert!((value - check).abs() < 1.0e-6);
}
}