use ndarray::{Array2, ArrayView2};
use num_complex::Complex;
use crate::fft_backend::{C2rPlan2d, C2rPlanner2d, R2cPlan2d, R2cPlanner2d, r2c_output_size_2d};
use crate::{SpectrogramError, SpectrogramResult};
#[cfg(feature = "fftw")]
use crate::fft_backend::fftw_backend::FftwPlanner;
#[cfg(feature = "realfft")]
use crate::fft_backend::realfft_backend::RealFftPlanner;
#[inline]
pub fn fft2d(data: &ArrayView2<f64>) -> SpectrogramResult<Array2<Complex<f64>>> {
let (nrows, ncols) = (data.nrows(), data.ncols());
if nrows == 0 || ncols == 0 {
return Err(SpectrogramError::invalid_input(
"array dimensions must be > 0",
));
}
if !data.is_standard_layout() {
return Err(SpectrogramError::invalid_input(
"array must be contiguous and row-major (standard layout)",
));
}
let out_shape = r2c_output_size_2d(nrows, ncols);
#[cfg(feature = "fftw")]
let mut planner = FftwPlanner::new();
#[cfg(feature = "realfft")]
let mut planner = RealFftPlanner::new();
let mut plan = planner.plan_r2c_2d(nrows, ncols)?;
let input_slice = data
.as_slice()
.ok_or_else(|| SpectrogramError::invalid_input("array must be contiguous"))?;
let mut output = vec![Complex::new(0.0, 0.0); out_shape.0 * out_shape.1];
plan.process(input_slice, &mut output)?;
Array2::from_shape_vec(out_shape, output)
.map_err(|e| SpectrogramError::invalid_input(format!("failed to reshape output: {e}")))
}
#[inline]
pub fn ifft2d(
spectrum: &Array2<Complex<f64>>,
output_ncols: usize,
) -> SpectrogramResult<Array2<f64>> {
let nrows = spectrum.nrows();
if nrows == 0 || output_ncols == 0 {
return Err(SpectrogramError::invalid_input("dimensions must be > 0"));
}
let expected_ncols = output_ncols / 2 + 1;
if spectrum.ncols() != expected_ncols {
return Err(SpectrogramError::dimension_mismatch(
expected_ncols,
spectrum.ncols(),
));
}
if !spectrum.is_standard_layout() {
return Err(SpectrogramError::invalid_input(
"array must be contiguous and row-major (standard layout)",
));
}
#[cfg(feature = "fftw")]
let mut planner = FftwPlanner::new();
#[cfg(feature = "realfft")]
let mut planner = RealFftPlanner::new();
let mut plan = planner.plan_c2r_2d(nrows, output_ncols)?;
let input_slice = spectrum
.as_slice()
.ok_or_else(|| SpectrogramError::invalid_input("array must be contiguous"))?;
let mut output = vec![0.0; nrows * output_ncols];
plan.process(input_slice, &mut output)?;
Array2::from_shape_vec((nrows, output_ncols), output)
.map_err(|e| SpectrogramError::invalid_input(format!("failed to reshape output: {e}")))
}
#[inline]
pub fn power_spectrum_2d(data: &ArrayView2<f64>) -> SpectrogramResult<Array2<f64>> {
let spectrum = fft2d(data)?;
let power = spectrum.mapv(|c| c.norm_sqr());
Ok(power)
}
#[inline]
pub fn magnitude_spectrum_2d(data: &ArrayView2<f64>) -> SpectrogramResult<Array2<f64>> {
let spectrum = fft2d(data)?;
let magnitude = spectrum.mapv(num_complex::Complex::norm);
Ok(magnitude)
}
#[inline]
#[must_use]
pub fn fftshift<T: Clone>(arr: Array2<T>) -> Array2<T> {
let (nrows, ncols) = arr.dim();
if nrows == 0 || ncols == 0 {
return arr;
}
let row_shift = nrows / 2;
let col_shift = ncols / 2;
shift_2d(arr, row_shift, col_shift)
}
#[inline]
#[must_use]
pub fn ifftshift<T: Clone>(arr: Array2<T>) -> Array2<T> {
let (nrows, ncols) = arr.dim();
if nrows == 0 || ncols == 0 {
return arr;
}
let row_shift = nrows.div_ceil(2);
let col_shift = ncols.div_ceil(2);
shift_2d(arr, row_shift, col_shift)
}
fn shift_2d<T: Clone>(arr: Array2<T>, row_shift: usize, col_shift: usize) -> Array2<T> {
let (nrows, ncols) = arr.dim();
if nrows == 0 || ncols == 0 {
return arr;
}
let row_shift = row_shift % nrows;
let col_shift = col_shift % ncols;
if row_shift == 0 && col_shift == 0 {
return arr;
}
Array2::from_shape_fn((nrows, ncols), |(i, j)| {
let src_i = (i + row_shift) % nrows;
let src_j = (j + col_shift) % ncols;
arr[[src_i, src_j]].clone()
})
}
#[inline]
#[must_use]
pub fn fftshift_1d<T: Clone>(arr: Vec<T>) -> Vec<T> {
let len = arr.len();
rotate_left(arr, len / 2)
}
#[inline]
#[must_use]
pub fn ifftshift_1d<T: Clone>(arr: Vec<T>) -> Vec<T> {
let len = arr.len();
rotate_left(arr, len.div_ceil(2))
}
fn rotate_left<T: Clone>(arr: Vec<T>, shift: usize) -> Vec<T> {
let n = arr.len();
if n == 0 {
return arr;
}
let k = shift % n;
if k == 0 {
return arr;
}
let mut result = Vec::with_capacity(n);
result.extend_from_slice(&arr[k..]);
result.extend_from_slice(&arr[..k]);
result
}
#[inline]
#[must_use]
pub fn fftfreq(n: usize, d: f64) -> Vec<f64> {
let mut freqs = Vec::with_capacity(n);
let n_f64 = n as f64;
let n_half = n.div_ceil(2);
for i in 0..n_half {
freqs.push(i as f64 / (n_f64 * d));
}
for i in n_half..n {
freqs.push((i as f64 - n_f64) / (n_f64 * d));
}
freqs
}
#[inline]
#[must_use]
pub fn rfftfreq(n: usize, d: f64) -> Vec<f64> {
let n_out = n / 2 + 1;
let mut freqs = Vec::with_capacity(n_out);
let n_f64 = n as f64;
for i in 0..n_out {
freqs.push(i as f64 / (n_f64 * d));
}
freqs
}
pub struct Fft2dPlanner {
#[cfg(feature = "fftw")]
inner: FftwPlanner,
#[cfg(feature = "realfft")]
inner: RealFftPlanner,
}
impl Fft2dPlanner {
#[inline]
#[must_use]
pub fn new() -> Self {
Self {
#[cfg(feature = "fftw")]
inner: FftwPlanner::new(),
#[cfg(feature = "realfft")]
inner: RealFftPlanner::new(),
}
}
#[inline]
pub fn fft2d(&mut self, data: &ArrayView2<f64>) -> SpectrogramResult<Array2<Complex<f64>>> {
let (nrows, ncols) = (data.nrows(), data.ncols());
if nrows == 0 || ncols == 0 {
return Err(SpectrogramError::invalid_input(
"array dimensions must be > 0",
));
}
if !data.is_standard_layout() {
return Err(SpectrogramError::invalid_input(
"array must be contiguous and row-major (standard layout)",
));
}
let out_shape = r2c_output_size_2d(nrows, ncols);
let mut plan = self.inner.plan_r2c_2d(nrows, ncols)?;
let input_slice = data
.as_slice()
.ok_or_else(|| SpectrogramError::invalid_input("array must be contiguous"))?;
let mut output = vec![Complex::new(0.0, 0.0); out_shape.0 * out_shape.1];
plan.process(input_slice, &mut output)?;
Array2::from_shape_vec(out_shape, output)
.map_err(|e| SpectrogramError::invalid_input(format!("failed to reshape output: {e}")))
}
#[inline]
pub fn ifft2d(
&mut self,
spectrum: &ArrayView2<Complex<f64>>,
output_ncols: usize,
) -> SpectrogramResult<Array2<f64>> {
let nrows = spectrum.nrows();
if nrows == 0 || output_ncols == 0 {
return Err(SpectrogramError::invalid_input("dimensions must be > 0"));
}
let expected_ncols = output_ncols / 2 + 1;
if spectrum.ncols() != expected_ncols {
return Err(SpectrogramError::dimension_mismatch(
expected_ncols,
spectrum.ncols(),
));
}
if !spectrum.is_standard_layout() {
return Err(SpectrogramError::invalid_input(
"array must be contiguous and row-major (standard layout)",
));
}
let mut plan = self.inner.plan_c2r_2d(nrows, output_ncols)?;
let input_slice = spectrum
.as_slice()
.ok_or_else(|| SpectrogramError::invalid_input("array must be contiguous"))?;
let mut output = vec![0.0; nrows * output_ncols];
plan.process(input_slice, &mut output)?;
Array2::from_shape_vec((nrows, output_ncols), output)
.map_err(|e| SpectrogramError::invalid_input(format!("failed to reshape output: {e}")))
}
#[inline]
pub fn power_spectrum_2d(&mut self, data: &ArrayView2<f64>) -> SpectrogramResult<Array2<f64>> {
let spectrum = self.fft2d(data)?;
let power = spectrum.mapv(|c| c.norm_sqr());
Ok(power)
}
#[inline]
pub fn magnitude_spectrum_2d(
&mut self,
data: &ArrayView2<f64>,
) -> SpectrogramResult<Array2<f64>> {
let spectrum = self.fft2d(data)?;
let magnitude = spectrum.mapv(num_complex::Complex::norm);
Ok(magnitude)
}
}
impl Default for Fft2dPlanner {
#[inline]
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fft2d_zeros() {
let data = Array2::<f64>::zeros((32, 32));
let result = fft2d(&data.view());
assert!(result.is_ok());
let spectrum = result.unwrap();
assert_eq!(spectrum.shape(), &[32, 17]);
for val in spectrum.iter() {
assert!(val.norm() < 1e-10);
}
}
#[test]
fn test_fft2d_ones() {
let data = Array2::<f64>::ones((32, 32));
let spectrum = fft2d(&data.view()).unwrap();
assert!(spectrum[[0, 0]].norm() > 1000.0);
for i in 1..32 {
assert!(spectrum[[i, 0]].norm() < 1e-10);
}
}
#[test]
fn test_roundtrip() {
let original = Array2::<f64>::from_shape_fn((64, 64), |(i, j)| {
(i as f64 * 0.1).sin() + (j as f64 * 0.2).cos()
});
let spectrum = fft2d(&original.view()).unwrap();
let reconstructed = ifft2d(&spectrum, 64).unwrap();
for ((i, j), &val) in original.indexed_iter() {
assert!((reconstructed[[i, j]] - val).abs() < 1e-10);
}
}
#[test]
fn test_planner_reuse() {
let mut planner = Fft2dPlanner::new();
for _ in 0..5 {
let data = Array2::<f64>::zeros((32, 32));
let result = planner.fft2d(&data.view());
assert!(result.is_ok());
}
}
}