use ndarray::{ArrayBase, AsArray, Ix1, ViewRepr};
use rustfft::{FftPlanner, num_complex::Complex, num_traits::Zero};
use crate::traits::numeric::AsNumeric;
pub fn fft_convolve_1d<'a, T, A>(data_a: A, data_b: A) -> Vec<f64>
where
A: AsArray<'a, T, Ix1>,
T: 'a + AsNumeric,
{
let view_a: ArrayBase<ViewRepr<&'a T>, Ix1> = data_a.into();
let view_b: ArrayBase<ViewRepr<&'a T>, Ix1> = data_b.into();
let n_a = view_a.len();
let n_b = view_b.len();
let n_fft = n_a + n_b - 1;
let fft_size = n_fft.next_power_of_two();
let mut a_fft_buf = vec![Complex::zero(); fft_size];
let mut b_fft_buf = vec![Complex::zero(); fft_size];
a_fft_buf[..n_a].iter_mut().enumerate().for_each(|(i, v)| {
*v = Complex::new(view_a[i].to_f64(), 0.0);
});
b_fft_buf[..n_b].iter_mut().enumerate().for_each(|(i, v)| {
*v = Complex::new(view_b[i].to_f64(), 0.0);
});
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(fft_size);
let ifft = planner.plan_fft_inverse(fft_size);
fft.process(&mut a_fft_buf);
fft.process(&mut b_fft_buf);
a_fft_buf.iter_mut().enumerate().for_each(|(i, v)| {
*v *= b_fft_buf[i];
});
ifft.process(&mut a_fft_buf);
let scale = 1.0 / fft_size as f64;
let mut result = vec![0.0; n_a];
result.iter_mut().enumerate().for_each(|(i, v)| {
*v = a_fft_buf[i].re * scale;
});
result
}
pub fn fft_deconvolve_1d<'a, T, A>(data_a: A, data_b: A, epsilon: Option<f64>) -> Vec<f64>
where
A: AsArray<'a, T, Ix1>,
T: 'a + AsNumeric,
{
let view_a: ArrayBase<ViewRepr<&'a T>, Ix1> = data_a.into();
let view_b: ArrayBase<ViewRepr<&'a T>, Ix1> = data_b.into();
let epsilon = epsilon.unwrap_or(1e-8);
let n_a = view_a.len();
let n_b = view_b.len();
let n_fft = n_a + n_b - 1;
let fft_size = n_fft.next_power_of_two();
let mut a_fft_buf = vec![Complex::zero(); fft_size];
let mut b_fft_buf = vec![Complex::zero(); fft_size];
a_fft_buf[..n_a].iter_mut().enumerate().for_each(|(i, v)| {
*v = Complex::new(view_a[i].to_f64(), 0.0);
});
b_fft_buf[..n_b].iter_mut().enumerate().for_each(|(i, v)| {
*v = Complex::new(view_b[i].to_f64(), 0.0);
});
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(fft_size);
let ifft = planner.plan_fft_inverse(fft_size);
fft.process(&mut a_fft_buf);
fft.process(&mut b_fft_buf);
a_fft_buf.iter_mut().enumerate().for_each(|(i, v)| {
if v.norm_sqr() > epsilon {
*v /= b_fft_buf[i]
} else {
*v = Complex::zero();
}
});
ifft.process(&mut a_fft_buf);
let scale = 1.0 / fft_size as f64;
let mut result = vec![0.0; n_a];
result.iter_mut().enumerate().for_each(|(i, v)| {
*v = a_fft_buf[i].re * scale;
});
result
}