use ndarray::{Array, ArrayBase, Data, DataMut, Ix1, Ix2, IxDyn};
use crate::kernel::{Complex, Float};
use crate::{Direction, Flags, Plan, RealPlan};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NdarrayFftError {
PlanCreationFailed,
EmptyArray,
ShapeError,
}
impl core::fmt::Display for NdarrayFftError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::PlanCreationFailed => write!(f, "FFT plan creation failed"),
Self::EmptyArray => write!(f, "Cannot FFT an empty array"),
Self::ShapeError => write!(f, "Failed to construct ndarray with given shape"),
}
}
}
fn make_plan<T: Float>(n: usize, dir: Direction) -> Result<Plan<T>, NdarrayFftError> {
if n == 0 {
return Err(NdarrayFftError::EmptyArray);
}
Plan::<T>::dft_1d(n, dir, Flags::ESTIMATE).ok_or(NdarrayFftError::PlanCreationFailed)
}
fn make_real_plan<T: Float>(n: usize) -> Result<RealPlan<T>, NdarrayFftError> {
if n == 0 {
return Err(NdarrayFftError::EmptyArray);
}
RealPlan::<T>::r2c_1d(n, Flags::ESTIMATE).ok_or(NdarrayFftError::PlanCreationFailed)
}
pub trait FftExt<T: Float> {
fn fft(&self) -> Result<Array<Complex<T>, IxDyn>, NdarrayFftError>;
fn ifft(&self) -> Result<Array<Complex<T>, IxDyn>, NdarrayFftError>;
fn fft_inplace(&mut self) -> Result<(), NdarrayFftError>;
fn ifft_inplace(&mut self) -> Result<(), NdarrayFftError>;
}
impl<T: Float, S: DataMut<Elem = Complex<T>>> FftExt<T> for ArrayBase<S, Ix1> {
fn fft(&self) -> Result<Array<Complex<T>, IxDyn>, NdarrayFftError> {
let n = self.len();
let plan = make_plan::<T>(n, Direction::Forward)?;
let input: Vec<Complex<T>> = self.iter().copied().collect();
let mut output = vec![Complex::<T>::zero(); n];
plan.execute(&input, &mut output);
Array::from_shape_vec(IxDyn(&[n]), output).map_err(|_| NdarrayFftError::ShapeError)
}
fn ifft(&self) -> Result<Array<Complex<T>, IxDyn>, NdarrayFftError> {
let n = self.len();
let plan = make_plan::<T>(n, Direction::Backward)?;
let input: Vec<Complex<T>> = self.iter().copied().collect();
let mut output = vec![Complex::<T>::zero(); n];
plan.execute(&input, &mut output);
Array::from_shape_vec(IxDyn(&[n]), output).map_err(|_| NdarrayFftError::ShapeError)
}
fn fft_inplace(&mut self) -> Result<(), NdarrayFftError> {
let n = self.len();
let plan = make_plan::<T>(n, Direction::Forward)?;
let input: Vec<Complex<T>> = self.iter().copied().collect();
let mut output = vec![Complex::<T>::zero(); n];
plan.execute(&input, &mut output);
for (dst, src) in self.iter_mut().zip(output.iter()) {
*dst = *src;
}
Ok(())
}
fn ifft_inplace(&mut self) -> Result<(), NdarrayFftError> {
let n = self.len();
let plan = make_plan::<T>(n, Direction::Backward)?;
let input: Vec<Complex<T>> = self.iter().copied().collect();
let mut output = vec![Complex::<T>::zero(); n];
plan.execute(&input, &mut output);
for (dst, src) in self.iter_mut().zip(output.iter()) {
*dst = *src;
}
Ok(())
}
}
impl<T: Float, S: DataMut<Elem = Complex<T>>> FftExt<T> for ArrayBase<S, Ix2> {
fn fft(&self) -> Result<Array<Complex<T>, IxDyn>, NdarrayFftError> {
let (rows, cols) = self.dim();
let row_plan = make_plan::<T>(cols, Direction::Forward)?;
let col_plan = make_plan::<T>(rows, Direction::Forward)?;
let mut buf: Vec<Complex<T>> = self.iter().copied().collect();
let mut row_out = vec![Complex::<T>::zero(); cols];
for row_idx in 0..rows {
let start = row_idx * cols;
row_plan.execute(&buf[start..start + cols], &mut row_out);
buf[start..start + cols].copy_from_slice(&row_out);
}
let mut col_in = vec![Complex::<T>::zero(); rows];
let mut col_out = vec![Complex::<T>::zero(); rows];
for col_idx in 0..cols {
for (r, val) in col_in.iter_mut().enumerate() {
*val = buf[r * cols + col_idx];
}
col_plan.execute(&col_in, &mut col_out);
for (r, val) in col_out.iter().enumerate() {
buf[r * cols + col_idx] = *val;
}
}
Array::from_shape_vec(IxDyn(&[rows, cols]), buf).map_err(|_| NdarrayFftError::ShapeError)
}
fn ifft(&self) -> Result<Array<Complex<T>, IxDyn>, NdarrayFftError> {
let (rows, cols) = self.dim();
let row_plan = make_plan::<T>(cols, Direction::Backward)?;
let col_plan = make_plan::<T>(rows, Direction::Backward)?;
let mut buf: Vec<Complex<T>> = self.iter().copied().collect();
let mut row_out = vec![Complex::<T>::zero(); cols];
for row_idx in 0..rows {
let start = row_idx * cols;
row_plan.execute(&buf[start..start + cols], &mut row_out);
buf[start..start + cols].copy_from_slice(&row_out);
}
let mut col_in = vec![Complex::<T>::zero(); rows];
let mut col_out = vec![Complex::<T>::zero(); rows];
for col_idx in 0..cols {
for (r, val) in col_in.iter_mut().enumerate() {
*val = buf[r * cols + col_idx];
}
col_plan.execute(&col_in, &mut col_out);
for (r, val) in col_out.iter().enumerate() {
buf[r * cols + col_idx] = *val;
}
}
Array::from_shape_vec(IxDyn(&[rows, cols]), buf).map_err(|_| NdarrayFftError::ShapeError)
}
fn fft_inplace(&mut self) -> Result<(), NdarrayFftError> {
let (rows, cols) = self.dim();
let row_plan = make_plan::<T>(cols, Direction::Forward)?;
let col_plan = make_plan::<T>(rows, Direction::Forward)?;
let mut row_in = vec![Complex::<T>::zero(); cols];
let mut row_out = vec![Complex::<T>::zero(); cols];
for row_idx in 0..rows {
for (c, val) in row_in.iter_mut().enumerate() {
*val = self[[row_idx, c]];
}
row_plan.execute(&row_in, &mut row_out);
for (c, val) in row_out.iter().enumerate() {
self[[row_idx, c]] = *val;
}
}
let mut col_in = vec![Complex::<T>::zero(); rows];
let mut col_out = vec![Complex::<T>::zero(); rows];
for col_idx in 0..cols {
for (r, val) in col_in.iter_mut().enumerate() {
*val = self[[r, col_idx]];
}
col_plan.execute(&col_in, &mut col_out);
for (r, val) in col_out.iter().enumerate() {
self[[r, col_idx]] = *val;
}
}
Ok(())
}
fn ifft_inplace(&mut self) -> Result<(), NdarrayFftError> {
let (rows, cols) = self.dim();
let row_plan = make_plan::<T>(cols, Direction::Backward)?;
let col_plan = make_plan::<T>(rows, Direction::Backward)?;
let mut row_in = vec![Complex::<T>::zero(); cols];
let mut row_out = vec![Complex::<T>::zero(); cols];
for row_idx in 0..rows {
for (c, val) in row_in.iter_mut().enumerate() {
*val = self[[row_idx, c]];
}
row_plan.execute(&row_in, &mut row_out);
for (c, val) in row_out.iter().enumerate() {
self[[row_idx, c]] = *val;
}
}
let mut col_in = vec![Complex::<T>::zero(); rows];
let mut col_out = vec![Complex::<T>::zero(); rows];
for col_idx in 0..cols {
for (r, val) in col_in.iter_mut().enumerate() {
*val = self[[r, col_idx]];
}
col_plan.execute(&col_in, &mut col_out);
for (r, val) in col_out.iter().enumerate() {
self[[r, col_idx]] = *val;
}
}
Ok(())
}
}
pub trait RealFftExt<T: Float> {
fn fft_real(&self) -> Result<Array<Complex<T>, IxDyn>, NdarrayFftError>;
}
impl<T: Float, S: Data<Elem = T>> RealFftExt<T> for ArrayBase<S, Ix1> {
fn fft_real(&self) -> Result<Array<Complex<T>, IxDyn>, NdarrayFftError> {
let n = self.len();
let plan = make_real_plan::<T>(n)?;
let input: Vec<T> = self.iter().copied().collect();
let out_len = plan.complex_size();
let mut output = vec![Complex::<T>::zero(); out_len];
plan.execute_r2c(&input, &mut output);
Array::from_shape_vec(IxDyn(&[out_len]), output).map_err(|_| NdarrayFftError::ShapeError)
}
}