#[cfg(not(feature = "std"))]
use alloc::sync::Arc;
#[cfg(feature = "std")]
use std::sync::Arc;
use num_traits::Float;
#[cfg(feature = "rustfft-backend")]
pub use rustfft::num_complex::Complex;
#[cfg(feature = "microfft-backend")]
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub struct Complex<T> {
pub re: T,
pub im: T,
}
#[cfg(feature = "microfft-backend")]
impl<T: Float> Complex<T> {
pub fn new(re: T, im: T) -> Self {
Self { re, im }
}
pub fn conj(&self) -> Self {
Self {
re: self.re,
im: -self.im,
}
}
pub fn norm_sqr(&self) -> T {
self.re * self.re + self.im * self.im
}
pub fn norm(&self) -> T {
self.norm_sqr().sqrt()
}
}
#[cfg(feature = "microfft-backend")]
impl<T: Float> core::ops::Mul<T> for Complex<T> {
type Output = Self;
fn mul(self, rhs: T) -> Self::Output {
Self {
re: self.re * rhs,
im: self.im * rhs,
}
}
}
#[cfg(feature = "microfft-backend")]
impl<T: Float> core::ops::Add for Complex<T> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self {
re: self.re + rhs.re,
im: self.im + rhs.im,
}
}
}
#[cfg(feature = "microfft-backend")]
impl<T: Float> core::ops::Sub for Complex<T> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self {
re: self.re - rhs.re,
im: self.im - rhs.im,
}
}
}
#[cfg(feature = "microfft-backend")]
impl<T: Float> core::ops::Mul for Complex<T> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
Self {
re: self.re * rhs.re - self.im * rhs.im,
im: self.re * rhs.im + self.im * rhs.re,
}
}
}
#[cfg(feature = "microfft-backend")]
impl<T: Float> core::ops::Div for Complex<T> {
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
let norm_sqr = rhs.norm_sqr();
Self {
re: (self.re * rhs.re + self.im * rhs.im) / norm_sqr,
im: (self.im * rhs.re - self.re * rhs.im) / norm_sqr,
}
}
}
#[cfg(feature = "rustfft-backend")]
pub trait FftNum: Float + rustfft::FftNum + Send + Sync + 'static {}
#[cfg(not(feature = "rustfft-backend"))]
pub trait FftNum: Float + Send + Sync + 'static {}
impl FftNum for f32 {}
impl FftNum for f64 {}
pub trait FftBackend<T: FftNum>: Send + Sync + core::fmt::Debug {
fn process(&self, buffer: &mut [Complex<T>]);
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub trait FftPlannerTrait<T: FftNum> {
fn new() -> Self;
fn plan_fft_forward(&mut self, size: usize) -> Arc<dyn FftBackend<T>>;
fn plan_fft_inverse(&mut self, size: usize) -> Arc<dyn FftBackend<T>>;
}
#[cfg(feature = "rustfft-backend")]
mod rustfft_impl {
use super::*;
use rustfft::{Fft, FftPlanner as RustFftPlanner};
struct RustFftWrapper<T: rustfft::FftNum> {
fft: Arc<dyn Fft<T>>,
}
impl<T: rustfft::FftNum> core::fmt::Debug for RustFftWrapper<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("RustFftWrapper")
.field("fft_size", &self.fft.len())
.finish()
}
}
impl<T: FftNum> FftBackend<T> for RustFftWrapper<T> {
fn process(&self, buffer: &mut [Complex<T>]) {
let buffer_ptr = buffer.as_mut_ptr() as *mut rustfft::num_complex::Complex<T>;
let buffer_slice = unsafe { core::slice::from_raw_parts_mut(buffer_ptr, buffer.len()) };
self.fft.process(buffer_slice);
}
fn len(&self) -> usize {
self.fft.len()
}
}
pub struct FftPlanner<T: rustfft::FftNum> {
planner: RustFftPlanner<T>,
}
impl<T: FftNum> FftPlannerTrait<T> for FftPlanner<T> {
fn new() -> Self {
Self {
planner: RustFftPlanner::new(),
}
}
fn plan_fft_forward(&mut self, size: usize) -> Arc<dyn FftBackend<T>> {
Arc::new(RustFftWrapper {
fft: self.planner.plan_fft_forward(size),
})
}
fn plan_fft_inverse(&mut self, size: usize) -> Arc<dyn FftBackend<T>> {
Arc::new(RustFftWrapper {
fft: self.planner.plan_fft_inverse(size),
})
}
}
}
#[cfg(feature = "rustfft-backend")]
pub use rustfft_impl::FftPlanner;
#[cfg(feature = "microfft-backend")]
mod microfft_impl {
use super::*;
macro_rules! slice_to_array {
($slice:expr, $size:expr) => {
unsafe { &mut *($slice.as_mut_ptr() as *mut [microfft::Complex32; $size]) }
};
}
#[derive(Debug, Clone)]
struct MicroFftForward {
size: usize,
}
#[derive(Debug, Clone)]
struct MicroFftInverse {
size: usize,
}
impl FftBackend<f32> for MicroFftForward {
fn process(&self, buffer: &mut [Complex<f32>]) {
let buffer_ptr = buffer.as_mut_ptr() as *mut microfft::Complex32;
let microfft_buffer =
unsafe { core::slice::from_raw_parts_mut(buffer_ptr, buffer.len()) };
match self.size {
2 => {
let _ = microfft::complex::cfft_2(slice_to_array!(microfft_buffer, 2));
}
4 => {
let _ = microfft::complex::cfft_4(slice_to_array!(microfft_buffer, 4));
}
8 => {
let _ = microfft::complex::cfft_8(slice_to_array!(microfft_buffer, 8));
}
16 => {
let _ = microfft::complex::cfft_16(slice_to_array!(microfft_buffer, 16));
}
32 => {
let _ = microfft::complex::cfft_32(slice_to_array!(microfft_buffer, 32));
}
64 => {
let _ = microfft::complex::cfft_64(slice_to_array!(microfft_buffer, 64));
}
128 => {
let _ = microfft::complex::cfft_128(slice_to_array!(microfft_buffer, 128));
}
256 => {
let _ = microfft::complex::cfft_256(slice_to_array!(microfft_buffer, 256));
}
512 => {
let _ = microfft::complex::cfft_512(slice_to_array!(microfft_buffer, 512));
}
1024 => {
let _ = microfft::complex::cfft_1024(slice_to_array!(microfft_buffer, 1024));
}
2048 => {
let _ = microfft::complex::cfft_2048(slice_to_array!(microfft_buffer, 2048));
}
4096 => {
let _ = microfft::complex::cfft_4096(slice_to_array!(microfft_buffer, 4096));
}
_ => panic!("microfft only supports power-of-2 sizes from 2 to 4096"),
}
}
fn len(&self) -> usize {
self.size
}
}
impl FftBackend<f32> for MicroFftInverse {
fn process(&self, buffer: &mut [Complex<f32>]) {
for val in buffer.iter_mut() {
val.im = -val.im;
}
let buffer_ptr = buffer.as_mut_ptr() as *mut microfft::Complex32;
let microfft_buffer =
unsafe { core::slice::from_raw_parts_mut(buffer_ptr, buffer.len()) };
match self.size {
2 => {
let _ = microfft::complex::cfft_2(slice_to_array!(microfft_buffer, 2));
}
4 => {
let _ = microfft::complex::cfft_4(slice_to_array!(microfft_buffer, 4));
}
8 => {
let _ = microfft::complex::cfft_8(slice_to_array!(microfft_buffer, 8));
}
16 => {
let _ = microfft::complex::cfft_16(slice_to_array!(microfft_buffer, 16));
}
32 => {
let _ = microfft::complex::cfft_32(slice_to_array!(microfft_buffer, 32));
}
64 => {
let _ = microfft::complex::cfft_64(slice_to_array!(microfft_buffer, 64));
}
128 => {
let _ = microfft::complex::cfft_128(slice_to_array!(microfft_buffer, 128));
}
256 => {
let _ = microfft::complex::cfft_256(slice_to_array!(microfft_buffer, 256));
}
512 => {
let _ = microfft::complex::cfft_512(slice_to_array!(microfft_buffer, 512));
}
1024 => {
let _ = microfft::complex::cfft_1024(slice_to_array!(microfft_buffer, 1024));
}
2048 => {
let _ = microfft::complex::cfft_2048(slice_to_array!(microfft_buffer, 2048));
}
4096 => {
let _ = microfft::complex::cfft_4096(slice_to_array!(microfft_buffer, 4096));
}
_ => panic!("microfft only supports power-of-2 sizes from 2 to 4096"),
}
for val in buffer.iter_mut() {
val.im = -val.im;
}
}
fn len(&self) -> usize {
self.size
}
}
pub struct FftPlanner<T: FftNum> {
_phantom: core::marker::PhantomData<T>,
}
impl FftPlannerTrait<f32> for FftPlanner<f32> {
fn new() -> Self {
Self {
_phantom: core::marker::PhantomData,
}
}
fn plan_fft_forward(&mut self, size: usize) -> Arc<dyn FftBackend<f32>> {
if !size.is_power_of_two() || size < 2 || size > 4096 {
panic!(
"microfft only supports power-of-2 sizes from 2 to 4096, got {}",
size
);
}
Arc::new(MicroFftForward { size })
}
fn plan_fft_inverse(&mut self, size: usize) -> Arc<dyn FftBackend<f32>> {
if !size.is_power_of_two() || size < 2 || size > 4096 {
panic!(
"microfft only supports power-of-2 sizes from 2 to 4096, got {}",
size
);
}
Arc::new(MicroFftInverse { size })
}
}
impl FftPlannerTrait<f64> for FftPlanner<f64> {
fn new() -> Self {
panic!("microfft backend does not support f64, only f32");
}
fn plan_fft_forward(&mut self, _size: usize) -> Arc<dyn FftBackend<f64>> {
panic!("microfft backend does not support f64, only f32");
}
fn plan_fft_inverse(&mut self, _size: usize) -> Arc<dyn FftBackend<f64>> {
panic!("microfft backend does not support f64, only f32");
}
}
}
#[cfg(feature = "microfft-backend")]
pub use microfft_impl::FftPlanner;
#[cfg(not(any(feature = "rustfft-backend", feature = "microfft-backend")))]
compile_error!("At least one FFT backend must be enabled: 'rustfft-backend' or 'microfft-backend'");
#[cfg(all(feature = "rustfft-backend", feature = "microfft-backend"))]
compile_error!(
"Cannot enable both 'rustfft-backend' and 'microfft-backend' at the same time. Choose one."
);