use realfft::ComplexToReal;
use realfft::RealToComplex;
use realfft::num_traits::NumAssign;
use realfft::num_traits::Zero;
use rustfft::FftNum;
use rustfft::num_complex::Complex;
#[cfg(feature = "serde")]
use serde::Deserialize;
#[cfg(feature = "serde")]
use serde::Serialize;
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::Debug;
#[cfg(feature = "fallible")]
use std::ops::Add;
#[cfg(feature = "fallible")]
use std::ops::AddAssign;
use std::ops::Deref;
use std::ops::Mul;
use std::ops::MulAssign;
use crate::with_inverse_real_fft_algorithm;
use crate::with_real_fft_algorithm;
pub trait DynRealFft<T> {
fn real_fft(&self) -> DynRealDft<T>;
#[cfg(feature = "fallible")]
fn real_fft_using(&self, output: &mut DynRealDft<T>);
}
pub trait DynRealIfft<T> {
fn real_ifft(&self) -> Box<[T]>;
#[cfg(feature = "fallible")]
fn real_ifft_using(&self, output: &mut [T]);
}
trait StaticScratchComplexToReal<T: FftNum>: ComplexToReal<T> {
unsafe fn process_with_static_scratch(&self, input: &[Complex<T>], output: &mut [T]);
}
#[cfg(not(feature = "fallible"))]
trait PrivateRealFftUsing<T> {
fn real_fft_using(&self, output: &mut DynRealDft<T>);
}
impl<T: FftNum + Default, U: ?Sized + ComplexToReal<T>> StaticScratchComplexToReal<T> for U {
unsafe fn process_with_static_scratch(&self, input: &[Complex<T>], output: &mut [T]) {
debug_assert_eq!(input.len(), output.len() / 2 + 1);
generic_singleton::get_or_init_thread_local!(
|| RefCell::new(HashMap::<usize, Box<[Complex<T>]>>::new()),
|input_clone_map| {
generic_singleton::get_or_init_thread_local!(
|| RefCell::new(HashMap::<usize, Box<[Complex<T>]>>::new()),
|scratch_buffer_map| {
let scratch_buffer_len = self.get_scratch_len();
let mut scratch_buffer_map = scratch_buffer_map.borrow_mut();
let scratch =
scratch_buffer_map
.entry(scratch_buffer_len)
.or_insert_with(|| {
vec![Complex::default(); scratch_buffer_len].into_boxed_slice()
});
let mut input_clone_map = input_clone_map.borrow_mut();
let input_clone = input_clone_map.entry(input.len()).or_insert_with(|| {
vec![Complex::default(); input.len()].into_boxed_slice()
});
input_clone.copy_from_slice(input);
unsafe {
self.process_with_scratch(input_clone, output, scratch)
.unwrap_unchecked();
}
}
);
}
);
}
}
trait StaticScratchRealToComplex<T: FftNum>: RealToComplex<T> {
unsafe fn process_with_static_scratch(&self, input: &[T], output: &mut [Complex<T>]);
}
impl<T: FftNum + Default, U: ?Sized + RealToComplex<T>> StaticScratchRealToComplex<T> for U {
unsafe fn process_with_static_scratch(&self, input: &[T], output: &mut [Complex<T>]) {
debug_assert_eq!(input.len() / 2 + 1, output.len());
generic_singleton::get_or_init_thread_local!(
|| RefCell::new(HashMap::<usize, Box<[T]>>::new()),
|input_clone_map| {
generic_singleton::get_or_init_thread_local!(
|| RefCell::new(HashMap::<usize, Box<[Complex<T>]>>::new()),
|scratch_buffer_map| {
let scratch_buffer_len = self.get_scratch_len();
let mut scratch_buffer_map = scratch_buffer_map.borrow_mut();
let scratch =
scratch_buffer_map
.entry(scratch_buffer_len)
.or_insert_with(|| {
vec![Complex::default(); scratch_buffer_len].into_boxed_slice()
});
let mut input_clone_map = input_clone_map.borrow_mut();
let input_clone = input_clone_map
.entry(input.len())
.or_insert_with(|| vec![T::default(); input.len()].into_boxed_slice());
input_clone.copy_from_slice(input);
unsafe {
self.process_with_scratch(input_clone, output, scratch)
.unwrap_unchecked();
}
}
);
}
);
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct DynRealDft<T> {
original_length: usize,
inner: Box<[Complex<T>]>,
}
#[cfg(feature = "fallible")]
impl<T> DynRealDft<T>
where
T: Zero + Copy + PartialEq + Debug,
{
pub fn new(zeroth_bin: T, frequency_bins: &[Complex<T>], original_length: usize) -> Self {
assert_eq!(original_length / 2 + 1, frequency_bins.len() + 1);
if original_length % 2 == 0 {
assert_eq!(frequency_bins[frequency_bins.len() - 1].im, T::zero());
}
let inner = [&[Complex::new(zeroth_bin, T::zero())], frequency_bins].concat();
Self {
original_length,
inner: inner.into_boxed_slice(),
}
}
}
impl<T> Deref for DynRealDft<T> {
type Target = [Complex<T>];
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[cfg(feature = "fallible")]
impl<T: Default + FftNum> Add for &DynRealDft<T> {
type Output = DynRealDft<T>;
fn add(self, rhs: Self) -> Self::Output {
assert_eq!(self.len(), rhs.len());
let mut inner = self.inner.clone();
for (i, r) in inner.iter_mut().zip(rhs.iter()) {
*i = *i + r;
}
DynRealDft {
original_length: self.original_length,
inner,
}
}
}
#[cfg(feature = "fallible")]
impl<T: Default + FftNum> AddAssign<&Self> for DynRealDft<T> {
fn add_assign(&mut self, rhs: &Self) {
assert_eq!(self.len(), rhs.len());
for (i, r) in self.inner.iter_mut().zip(rhs.iter()) {
*i = *i + r;
}
}
}
#[cfg(feature = "fallible")]
impl<T: Default + FftNum> Mul for &DynRealDft<T> {
type Output = DynRealDft<T>;
fn mul(self, rhs: Self) -> Self::Output {
assert_eq!(self.len(), rhs.len());
let mut inner = Vec::with_capacity(self.len());
for index in 0..self.len() {
inner.push(self[index] * rhs[index]);
}
DynRealDft {
inner: inner.into_boxed_slice(),
original_length: self.original_length,
}
}
}
#[cfg(feature = "fallible")]
impl<T: Default + FftNum + NumAssign> MulAssign<&Self> for DynRealDft<T> {
fn mul_assign(&mut self, rhs: &Self) {
assert_eq!(self.len(), rhs.len());
for (bin_self, bin_rhs) in self.inner.iter_mut().zip(rhs.iter()) {
*bin_self *= bin_rhs;
}
}
}
impl<T: Default + FftNum> Mul<T> for &DynRealDft<T> {
type Output = DynRealDft<T>;
fn mul(self, rhs: T) -> Self::Output {
let mut inner = Vec::with_capacity(self.len());
for index in 0..self.len() {
inner.push(self[index] * rhs);
}
DynRealDft {
inner: inner.into_boxed_slice(),
original_length: self.original_length,
}
}
}
impl<T: Default + FftNum + NumAssign> MulAssign<T> for DynRealDft<T> {
fn mul_assign(&mut self, rhs: T) {
for bin_self in &mut *self.inner {
*bin_self *= rhs;
}
}
}
#[cfg(feature = "fallible")]
impl<T: Default + FftNum> Mul<&[T]> for &DynRealDft<T> {
type Output = DynRealDft<T>;
fn mul(self, rhs: &[T]) -> Self::Output {
assert_eq!(self.len(), rhs.len());
let mut inner = Vec::with_capacity(self.len());
for index in 0..self.len() {
inner.push(self[index] * rhs[index]);
}
DynRealDft {
inner: inner.into_boxed_slice(),
original_length: self.original_length,
}
}
}
#[cfg(feature = "fallible")]
impl<T: Default + FftNum + NumAssign> MulAssign<&[T]> for DynRealDft<T> {
fn mul_assign(&mut self, rhs: &[T]) {
assert_eq!(self.len(), rhs.len());
for (bin_self, bin_rhs) in self.inner.iter_mut().zip(rhs.iter()) {
*bin_self *= bin_rhs;
}
}
}
impl<T> DynRealDft<T> {
#[must_use]
pub fn get_frequency_bins(&self) -> &[Complex<T>] {
let wanted_len = self.original_length - 1;
&self.inner[1..wanted_len / 2]
}
pub fn get_frequency_bins_mut(&mut self) -> &mut [Complex<T>] {
let wanted_len = self.original_length - 1;
&mut self.inner[1..wanted_len / 2]
}
#[must_use]
pub fn get_offset(&self) -> &T {
&self.inner[0].re
}
pub fn get_offset_mut(&mut self) -> &mut T {
&mut self.inner[0].re
}
}
impl<T: FftNum + Default> DynRealDft<T> {
#[cfg(feature = "fallible")]
#[must_use]
pub fn default(original_length: usize) -> Self {
let inner = vec![Complex::default(); original_length / 2 + 1].into_boxed_slice();
Self {
original_length,
inner,
}
}
}
impl<T: FftNum + Zero> DynRealDft<T> {
#[cfg(feature = "fallible")]
pub fn copy_from_slice(&mut self, slice: &[Complex<T>]) {
assert_eq!(self.original_length / 2 + 1, slice.len());
assert_eq!(slice[0].im, T::from_f32(0.0).unwrap());
if self.original_length % 2 == 0 {
assert_eq!(slice[slice.len() - 1].im, T::zero());
}
self.inner.copy_from_slice(slice);
}
}
impl<T> From<DynRealDft<T>> for Box<[Complex<T>]> {
fn from(dyn_real_dft: DynRealDft<T>) -> Self {
dyn_real_dft.inner
}
}
impl<T: FftNum + Default> DynRealFft<T> for [T] {
fn real_fft(&self) -> DynRealDft<T> {
let output = vec![Complex::default(); self.len() / 2 + 1];
let mut output = DynRealDft {
inner: output.into_boxed_slice(),
original_length: self.len(),
};
self.real_fft_using(&mut output);
output
}
#[cfg(feature = "fallible")]
fn real_fft_using(&self, output: &mut DynRealDft<T>) {
assert_eq!(self.len(), output.original_length);
with_real_fft_algorithm::<T>(self.len(), |r2c| {
unsafe {
r2c.process_with_static_scratch(self, &mut output.inner);
}
});
}
}
impl<T: FftNum + Default> DynRealIfft<T> for DynRealDft<T> {
fn real_ifft(&self) -> Box<[T]> {
let mut output = vec![T::default(); self.original_length];
self.real_ifft_using(&mut output);
output.into_boxed_slice()
}
#[cfg(feature = "fallible")]
fn real_ifft_using(&self, output: &mut [T]) {
assert_eq!(self.original_length, output.len());
with_inverse_real_fft_algorithm::<T>(self.original_length, |c2r| {
unsafe {
c2r.process_with_static_scratch(self, output);
}
});
}
}
#[cfg(not(feature = "fallible"))]
impl<T: FftNum + Default> PrivateRealFftUsing<T> for [T] {
fn real_fft_using(&self, output: &mut DynRealDft<T>) {
debug_assert_eq!(self.len(), output.original_length);
let r2c = with_real_fft_algorithm::<T>(self.len());
unsafe {
r2c.process_with_static_scratch(self, &mut output.inner);
}
}
}
#[cfg(not(feature = "fallible"))]
impl<T: FftNum + Default> DynRealDft<T> {
fn real_ifft_using(&self, output: &mut [T]) {
debug_assert_eq!(self.original_length, output.len());
let c2r = with_inverse_real_fft_algorithm::<T>(self.original_length);
unsafe {
c2r.process_with_static_scratch(self, output);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const ARBITRARY_EVEN_TEST_ARRAY: [f64; 6] = [1.5, 3.0, 2.1, 3.2, 2.2, 3.1];
const ARBITRARY_ODD_TEST_ARRAY: [f64; 7] = [1.5, 3.0, 2.1, 3.2, 2.2, 3.1, 1.2];
const ACCEPTABLE_ERROR: f64 = 0.000_000_000_000_01;
fn real_fft_and_real_ifft_are_inverse_operations(array: &[f64]) {
let converted: Vec<_> = array
.real_fft()
.real_ifft()
.iter_mut()
.map(|sample| *sample / array.len() as f64)
.collect();
assert_eq!(array.len(), converted.len());
for (converted, original) in converted.iter().zip(array.iter()) {
approx::assert_ulps_eq!(converted, original, epsilon = ACCEPTABLE_ERROR);
}
}
#[test]
fn real_fft_and_real_ifft_are_inverse_operations_even() {
real_fft_and_real_ifft_are_inverse_operations(&ARBITRARY_EVEN_TEST_ARRAY);
}
#[test]
fn real_fft_and_real_ifft_are_inverse_operations_odd() {
real_fft_and_real_ifft_are_inverse_operations(&ARBITRARY_ODD_TEST_ARRAY);
}
}