use ndarray::{Array2, ArrayView2, ArrayViewMut2};
use rustfft::{num_complex::Complex, Fft};
use std::sync::Arc;
use crate::float_trait::Bm3dFloat;
#[inline(always)]
fn transpose_square_in_place<T>(data: &mut [T], n: usize) {
for r in 0..n {
let row_base = r * n;
for c in (r + 1)..n {
data.swap(row_base + c, c * n + r);
}
}
}
#[inline(always)]
fn transpose_square_copy<T: Copy>(src: &[T], dst: &mut [T], n: usize) {
for r in 0..n {
let src_base = r * n;
for c in 0..n {
dst[c * n + r] = src[src_base + c];
}
}
}
pub fn fft2d<F: Bm3dFloat>(
input: ArrayView2<F>,
fft_row_plan: &Arc<dyn Fft<F>>,
fft_col_plan: &Arc<dyn Fft<F>>,
) -> Array2<Complex<F>> {
let (rows, cols) = input.dim();
let mut intermediate = Array2::<Complex<F>>::zeros((rows, cols));
let mut output = Array2::<Complex<F>>::zeros((rows, cols));
let mut scratch = vec![Complex::new(F::zero(), F::zero()); rows.max(cols)];
fft2d_into(
input,
fft_row_plan,
fft_col_plan,
&mut intermediate,
output.view_mut(),
&mut scratch,
);
output
}
pub fn fft2d_into<F: Bm3dFloat>(
input: ArrayView2<F>,
fft_row_plan: &Arc<dyn Fft<F>>,
fft_col_plan: &Arc<dyn Fft<F>>,
work_complex: &mut Array2<Complex<F>>,
mut output: ArrayViewMut2<Complex<F>>,
scratch: &mut [Complex<F>],
) {
let row_scratch_len = fft_row_plan.get_inplace_scratch_len();
let col_scratch_len = fft_col_plan.get_inplace_scratch_len();
let mut row_fft_scratch = if row_scratch_len > 0 {
vec![Complex::new(F::zero(), F::zero()); row_scratch_len]
} else {
Vec::new()
};
let mut col_fft_scratch = if col_scratch_len > 0 {
vec![Complex::new(F::zero(), F::zero()); col_scratch_len]
} else {
Vec::new()
};
fft2d_into_with_plan_scratch(
input,
fft_row_plan,
fft_col_plan,
work_complex,
output.view_mut(),
scratch,
&mut row_fft_scratch,
&mut col_fft_scratch,
);
}
pub fn fft2d_into_with_plan_scratch<F: Bm3dFloat>(
input: ArrayView2<F>,
fft_row_plan: &Arc<dyn Fft<F>>,
fft_col_plan: &Arc<dyn Fft<F>>,
work_complex: &mut Array2<Complex<F>>,
mut output: ArrayViewMut2<Complex<F>>,
scratch: &mut [Complex<F>],
row_fft_scratch: &mut [Complex<F>],
col_fft_scratch: &mut [Complex<F>],
) {
let (rows, cols) = input.dim();
debug_assert_eq!(work_complex.dim(), (rows, cols));
debug_assert_eq!(output.dim(), (rows, cols));
debug_assert!(scratch.len() >= rows.max(cols));
debug_assert!(row_fft_scratch.len() >= fft_row_plan.get_inplace_scratch_len());
debug_assert!(col_fft_scratch.len() >= fft_col_plan.get_inplace_scratch_len());
let row_scratch_len = fft_row_plan.get_inplace_scratch_len();
let col_scratch_len = fft_col_plan.get_inplace_scratch_len();
if let (Some(input_data), Some(work_data), Some(output_data)) = (
input.as_slice_memory_order(),
work_complex.as_slice_memory_order_mut(),
output.as_slice_memory_order_mut(),
) {
for r in 0..rows {
let row_base = r * cols;
let row = &mut work_data[row_base..row_base + cols];
for c in 0..cols {
row[c] = Complex::new(input_data[row_base + c], F::zero());
}
}
if row_scratch_len == 0 {
fft_row_plan.process_with_scratch(work_data, &mut []);
} else {
fft_row_plan.process_with_scratch(work_data, row_fft_scratch);
}
if rows == cols {
transpose_square_copy(work_data, output_data, rows);
if col_scratch_len == 0 {
fft_col_plan.process_with_scratch(output_data, &mut []);
} else {
fft_col_plan.process_with_scratch(output_data, col_fft_scratch);
}
transpose_square_in_place(output_data, rows);
} else if col_scratch_len == 0 {
for c in 0..cols {
for r in 0..rows {
scratch[r] = work_data[r * cols + c];
}
fft_col_plan.process_with_scratch(&mut scratch[..rows], &mut []);
for r in 0..rows {
output_data[r * cols + c] = scratch[r];
}
}
} else {
for c in 0..cols {
for r in 0..rows {
scratch[r] = work_data[r * cols + c];
}
fft_col_plan.process_with_scratch(&mut scratch[..rows], col_fft_scratch);
for r in 0..rows {
output_data[r * cols + c] = scratch[r];
}
}
}
} else {
if row_scratch_len == 0 {
for r in 0..rows {
for (c, &v) in input.row(r).iter().enumerate() {
scratch[c] = Complex::new(v, F::zero());
}
fft_row_plan.process_with_scratch(&mut scratch[..cols], &mut []);
for c in 0..cols {
work_complex[[r, c]] = scratch[c];
}
}
} else {
for r in 0..rows {
for (c, &v) in input.row(r).iter().enumerate() {
scratch[c] = Complex::new(v, F::zero());
}
fft_row_plan.process_with_scratch(&mut scratch[..cols], row_fft_scratch);
for c in 0..cols {
work_complex[[r, c]] = scratch[c];
}
}
}
if col_scratch_len == 0 {
for c in 0..cols {
for r in 0..rows {
scratch[r] = work_complex[[r, c]];
}
fft_col_plan.process_with_scratch(&mut scratch[..rows], &mut []);
for r in 0..rows {
output[[r, c]] = scratch[r];
}
}
} else {
for c in 0..cols {
for r in 0..rows {
scratch[r] = work_complex[[r, c]];
}
fft_col_plan.process_with_scratch(&mut scratch[..rows], col_fft_scratch);
for r in 0..rows {
output[[r, c]] = scratch[r];
}
}
}
}
}
pub fn ifft2d<F: Bm3dFloat>(
input: &Array2<Complex<F>>,
ifft_row_plan: &Arc<dyn Fft<F>>,
ifft_col_plan: &Arc<dyn Fft<F>>,
) -> Array2<F> {
ifft2d_view(input.view(), ifft_row_plan, ifft_col_plan)
}
pub fn ifft2d_view<F: Bm3dFloat>(
input: ArrayView2<Complex<F>>,
ifft_row_plan: &Arc<dyn Fft<F>>,
ifft_col_plan: &Arc<dyn Fft<F>>,
) -> Array2<F> {
let (rows, cols) = input.dim();
let col_scratch_len = ifft_col_plan.get_inplace_scratch_len();
let row_scratch_len = ifft_row_plan.get_inplace_scratch_len();
let mut col_fft_scratch = if col_scratch_len > 0 {
vec![Complex::new(F::zero(), F::zero()); col_scratch_len]
} else {
Vec::new()
};
let mut row_fft_scratch = if row_scratch_len > 0 {
vec![Complex::new(F::zero(), F::zero()); row_scratch_len]
} else {
Vec::new()
};
let mut intermediate = input.to_owned();
let mut col_vec = vec![Complex::new(F::zero(), F::zero()); rows];
for c in 0..cols {
for r in 0..rows {
col_vec[r] = intermediate[[r, c]];
}
if col_scratch_len == 0 {
ifft_col_plan.process_with_scratch(&mut col_vec, &mut []);
} else {
ifft_col_plan.process_with_scratch(&mut col_vec, &mut col_fft_scratch);
}
for r in 0..rows {
intermediate[[r, c]] = col_vec[r];
}
}
let mut output = Array2::<F>::zeros((rows, cols));
let norm_factor = F::one() / F::usize_as(rows * cols);
let mut row_vec = vec![Complex::new(F::zero(), F::zero()); cols];
for r in 0..rows {
for c in 0..cols {
row_vec[c] = intermediate[[r, c]];
}
if row_scratch_len == 0 {
ifft_row_plan.process_with_scratch(&mut row_vec, &mut []);
} else {
ifft_row_plan.process_with_scratch(&mut row_vec, &mut row_fft_scratch);
}
for c in 0..cols {
output[[r, c]] = row_vec[c].re * norm_factor;
}
}
output
}
pub fn ifft2d_into<F: Bm3dFloat>(
input: ArrayView2<Complex<F>>,
ifft_row_plan: &Arc<dyn Fft<F>>,
ifft_col_plan: &Arc<dyn Fft<F>>,
work_complex: &mut Array2<Complex<F>>,
output: &mut Array2<F>,
scratch: &mut [Complex<F>],
) {
let col_scratch_len = ifft_col_plan.get_inplace_scratch_len();
let row_scratch_len = ifft_row_plan.get_inplace_scratch_len();
let mut col_fft_scratch = if col_scratch_len > 0 {
vec![Complex::new(F::zero(), F::zero()); col_scratch_len]
} else {
Vec::new()
};
let mut row_fft_scratch = if row_scratch_len > 0 {
vec![Complex::new(F::zero(), F::zero()); row_scratch_len]
} else {
Vec::new()
};
ifft2d_into_with_plan_scratch(
input,
ifft_row_plan,
ifft_col_plan,
work_complex,
output,
scratch,
&mut row_fft_scratch,
&mut col_fft_scratch,
);
}
pub fn ifft2d_into_with_plan_scratch<F: Bm3dFloat>(
input: ArrayView2<Complex<F>>,
ifft_row_plan: &Arc<dyn Fft<F>>,
ifft_col_plan: &Arc<dyn Fft<F>>,
work_complex: &mut Array2<Complex<F>>,
output: &mut Array2<F>,
scratch: &mut [Complex<F>],
row_fft_scratch: &mut [Complex<F>],
col_fft_scratch: &mut [Complex<F>],
) {
let (rows, cols) = input.dim();
debug_assert_eq!(work_complex.dim(), (rows, cols));
debug_assert_eq!(output.dim(), (rows, cols));
debug_assert!(scratch.len() >= rows.max(cols));
debug_assert!(row_fft_scratch.len() >= ifft_row_plan.get_inplace_scratch_len());
debug_assert!(col_fft_scratch.len() >= ifft_col_plan.get_inplace_scratch_len());
let col_scratch_len = ifft_col_plan.get_inplace_scratch_len();
let row_scratch_len = ifft_row_plan.get_inplace_scratch_len();
if let (Some(input_data), Some(work_data), Some(output_data)) = (
input.as_slice_memory_order(),
work_complex.as_slice_memory_order_mut(),
output.as_slice_memory_order_mut(),
) {
work_data.copy_from_slice(input_data);
if rows == cols {
transpose_square_in_place(work_data, rows);
if col_scratch_len == 0 {
ifft_col_plan.process_with_scratch(work_data, &mut []);
} else {
ifft_col_plan.process_with_scratch(work_data, col_fft_scratch);
}
transpose_square_in_place(work_data, rows);
} else if col_scratch_len == 0 {
for c in 0..cols {
for r in 0..rows {
scratch[r] = work_data[r * cols + c];
}
ifft_col_plan.process_with_scratch(&mut scratch[..rows], &mut []);
for r in 0..rows {
work_data[r * cols + c] = scratch[r];
}
}
} else {
for c in 0..cols {
for r in 0..rows {
scratch[r] = work_data[r * cols + c];
}
ifft_col_plan.process_with_scratch(&mut scratch[..rows], col_fft_scratch);
for r in 0..rows {
work_data[r * cols + c] = scratch[r];
}
}
}
let norm_factor = F::one() / F::usize_as(rows * cols);
if row_scratch_len == 0 {
ifft_row_plan.process_with_scratch(work_data, &mut []);
for idx in 0..rows * cols {
output_data[idx] = work_data[idx].re * norm_factor;
}
} else {
ifft_row_plan.process_with_scratch(work_data, row_fft_scratch);
for idx in 0..rows * cols {
output_data[idx] = work_data[idx].re * norm_factor;
}
}
} else {
work_complex.assign(&input);
if col_scratch_len == 0 {
for c in 0..cols {
for r in 0..rows {
scratch[r] = work_complex[[r, c]];
}
ifft_col_plan.process_with_scratch(&mut scratch[..rows], &mut []);
for r in 0..rows {
work_complex[[r, c]] = scratch[r];
}
}
} else {
for c in 0..cols {
for r in 0..rows {
scratch[r] = work_complex[[r, c]];
}
ifft_col_plan.process_with_scratch(&mut scratch[..rows], col_fft_scratch);
for r in 0..rows {
work_complex[[r, c]] = scratch[r];
}
}
}
let norm_factor = F::one() / F::usize_as(rows * cols);
if row_scratch_len == 0 {
for r in 0..rows {
for c in 0..cols {
scratch[c] = work_complex[[r, c]];
}
ifft_row_plan.process_with_scratch(&mut scratch[..cols], &mut []);
for c in 0..cols {
output[[r, c]] = scratch[c].re * norm_factor;
}
}
} else {
for r in 0..rows {
for c in 0..cols {
scratch[c] = work_complex[[r, c]];
}
ifft_row_plan.process_with_scratch(&mut scratch[..cols], row_fft_scratch);
for c in 0..cols {
output[[r, c]] = scratch[c].re * norm_factor;
}
}
}
}
}
pub fn ifft2d_inplace_to_real_with_plan_scratch<F: Bm3dFloat>(
mut input_output: ArrayViewMut2<Complex<F>>,
ifft_row_plan: &Arc<dyn Fft<F>>,
ifft_col_plan: &Arc<dyn Fft<F>>,
output: &mut Array2<F>,
scratch: &mut [Complex<F>],
row_fft_scratch: &mut [Complex<F>],
col_fft_scratch: &mut [Complex<F>],
) {
let (rows, cols) = input_output.dim();
debug_assert_eq!(output.dim(), (rows, cols));
debug_assert!(scratch.len() >= rows.max(cols));
debug_assert!(row_fft_scratch.len() >= ifft_row_plan.get_inplace_scratch_len());
debug_assert!(col_fft_scratch.len() >= ifft_col_plan.get_inplace_scratch_len());
let col_scratch_len = ifft_col_plan.get_inplace_scratch_len();
let row_scratch_len = ifft_row_plan.get_inplace_scratch_len();
if let (Some(data), Some(output_data)) = (
input_output.as_slice_memory_order_mut(),
output.as_slice_memory_order_mut(),
) {
if rows == cols {
transpose_square_in_place(data, rows);
if col_scratch_len == 0 {
ifft_col_plan.process_with_scratch(data, &mut []);
} else {
ifft_col_plan.process_with_scratch(data, col_fft_scratch);
}
transpose_square_in_place(data, rows);
} else if col_scratch_len == 0 {
for c in 0..cols {
for r in 0..rows {
scratch[r] = data[r * cols + c];
}
ifft_col_plan.process_with_scratch(&mut scratch[..rows], &mut []);
for r in 0..rows {
data[r * cols + c] = scratch[r];
}
}
} else {
for c in 0..cols {
for r in 0..rows {
scratch[r] = data[r * cols + c];
}
ifft_col_plan.process_with_scratch(&mut scratch[..rows], col_fft_scratch);
for r in 0..rows {
data[r * cols + c] = scratch[r];
}
}
}
let norm_factor = F::one() / F::usize_as(rows * cols);
if row_scratch_len == 0 {
ifft_row_plan.process_with_scratch(data, &mut []);
for idx in 0..rows * cols {
output_data[idx] = data[idx].re * norm_factor;
}
} else {
ifft_row_plan.process_with_scratch(data, row_fft_scratch);
for idx in 0..rows * cols {
output_data[idx] = data[idx].re * norm_factor;
}
}
} else {
if col_scratch_len == 0 {
for c in 0..cols {
for r in 0..rows {
scratch[r] = input_output[[r, c]];
}
ifft_col_plan.process_with_scratch(&mut scratch[..rows], &mut []);
for r in 0..rows {
input_output[[r, c]] = scratch[r];
}
}
} else {
for c in 0..cols {
for r in 0..rows {
scratch[r] = input_output[[r, c]];
}
ifft_col_plan.process_with_scratch(&mut scratch[..rows], col_fft_scratch);
for r in 0..rows {
input_output[[r, c]] = scratch[r];
}
}
}
let norm_factor = F::one() / F::usize_as(rows * cols);
if row_scratch_len == 0 {
for r in 0..rows {
for c in 0..cols {
scratch[c] = input_output[[r, c]];
}
ifft_row_plan.process_with_scratch(&mut scratch[..cols], &mut []);
for c in 0..cols {
output[[r, c]] = scratch[c].re * norm_factor;
}
}
} else {
for r in 0..rows {
for c in 0..cols {
scratch[c] = input_output[[r, c]];
}
ifft_row_plan.process_with_scratch(&mut scratch[..cols], row_fft_scratch);
for c in 0..cols {
output[[r, c]] = scratch[c].re * norm_factor;
}
}
}
}
}
#[inline(always)]
fn fwht8<F: Bm3dFloat>(buf: &mut [F; 8]) {
let t0 = buf[0] + buf[1];
buf[1] = buf[0] - buf[1];
buf[0] = t0;
let t2 = buf[2] + buf[3];
buf[3] = buf[2] - buf[3];
buf[2] = t2;
let t4 = buf[4] + buf[5];
buf[5] = buf[4] - buf[5];
buf[4] = t4;
let t6 = buf[6] + buf[7];
buf[7] = buf[6] - buf[7];
buf[6] = t6;
let t0 = buf[0] + buf[2];
buf[2] = buf[0] - buf[2];
buf[0] = t0;
let t1 = buf[1] + buf[3];
buf[3] = buf[1] - buf[3];
buf[1] = t1;
let t4 = buf[4] + buf[6];
buf[6] = buf[4] - buf[6];
buf[4] = t4;
let t5 = buf[5] + buf[7];
buf[7] = buf[5] - buf[7];
buf[5] = t5;
let t0 = buf[0] + buf[4];
buf[4] = buf[0] - buf[4];
buf[0] = t0;
let t1 = buf[1] + buf[5];
buf[5] = buf[1] - buf[5];
buf[1] = t1;
let t2 = buf[2] + buf[6];
buf[6] = buf[2] - buf[6];
buf[2] = t2;
let t3 = buf[3] + buf[7];
buf[7] = buf[3] - buf[7];
buf[3] = t3;
}
pub fn wht2d_8x8_forward<F: Bm3dFloat>(input: ArrayView2<F>) -> Array2<Complex<F>> {
let mut output = Array2::<Complex<F>>::zeros((8, 8));
wht2d_8x8_forward_into_view(input, output.view_mut());
output
}
pub fn wht2d_8x8_forward_into_view<F: Bm3dFloat>(
input: ArrayView2<F>,
mut output: ArrayViewMut2<Complex<F>>,
) {
assert_eq!(output.dim(), (8, 8));
let mut data = [F::zero(); 64];
let mut idx = 0;
for r in 0..8 {
for c in 0..8 {
data[idx] = input[[r, c]];
idx += 1;
}
}
for r in 0..8 {
let mut row_buf = [F::zero(); 8];
let offset = r * 8;
row_buf.copy_from_slice(&data[offset..offset + 8]);
fwht8(&mut row_buf);
data[offset..offset + 8].copy_from_slice(&row_buf);
}
for c in 0..8 {
let mut col_buf = [F::zero(); 8];
for r in 0..8 {
col_buf[r] = data[r * 8 + c];
}
fwht8(&mut col_buf);
for r in 0..8 {
data[r * 8 + c] = col_buf[r];
}
}
idx = 0;
for r in 0..8 {
for c in 0..8 {
output[[r, c]] = Complex::new(data[idx], F::zero());
idx += 1;
}
}
}
pub fn wht2d_8x8_inverse<F: Bm3dFloat>(input: &Array2<Complex<F>>) -> Array2<F> {
wht2d_8x8_inverse_view(input.view())
}
pub fn wht2d_8x8_inverse_view<F: Bm3dFloat>(input: ArrayView2<Complex<F>>) -> Array2<F> {
let mut output = Array2::<F>::zeros((8, 8));
wht2d_8x8_inverse_into_view(input, &mut output);
output
}
pub fn wht2d_8x8_inverse_into_view<F: Bm3dFloat>(
input: ArrayView2<Complex<F>>,
output: &mut Array2<F>,
) {
assert_eq!(output.dim(), (8, 8));
let mut data = [F::zero(); 64];
let mut idx = 0;
for r in 0..8 {
for c in 0..8 {
data[idx] = input[[r, c]].re;
idx += 1;
}
}
for r in 0..8 {
let mut row_buf = [F::zero(); 8];
let offset = r * 8;
row_buf.copy_from_slice(&data[offset..offset + 8]);
fwht8(&mut row_buf);
data[offset..offset + 8].copy_from_slice(&row_buf);
}
for c in 0..8 {
let mut col_buf = [F::zero(); 8];
for r in 0..8 {
col_buf[r] = data[r * 8 + c];
}
fwht8(&mut col_buf);
for r in 0..8 {
data[r * 8 + c] = col_buf[r];
}
}
let norm_scale = F::one() / F::usize_as(64);
idx = 0;
for r in 0..8 {
for c in 0..8 {
output[[r, c]] = data[idx] * norm_scale;
idx += 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
use rustfft::FftPlanner;
use std::sync::Arc;
struct SimpleLcg {
state: u64,
}
impl SimpleLcg {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
self.state
}
fn next_f32(&mut self) -> f32 {
let u = self.next_u64();
((u >> 40) as f32 / (1u64 << 24) as f32) * 2.0 - 1.0
}
fn next_f64(&mut self) -> f64 {
let u = self.next_u64();
((u >> 11) as f64 / (1u64 << 53) as f64) * 2.0 - 1.0
}
}
type FftPlanQuad32 = (
Arc<dyn Fft<f32>>,
Arc<dyn Fft<f32>>,
Arc<dyn Fft<f32>>,
Arc<dyn Fft<f32>>,
);
fn create_fft_plans_f32(rows: usize, cols: usize) -> FftPlanQuad32 {
let mut planner = FftPlanner::<f32>::new();
let fft_row = planner.plan_fft_forward(cols);
let fft_col = planner.plan_fft_forward(rows);
let ifft_row = planner.plan_fft_inverse(cols);
let ifft_col = planner.plan_fft_inverse(rows);
(fft_row, fft_col, ifft_row, ifft_col)
}
type FftPlanQuad64 = (
Arc<dyn Fft<f64>>,
Arc<dyn Fft<f64>>,
Arc<dyn Fft<f64>>,
Arc<dyn Fft<f64>>,
);
fn create_fft_plans_f64(rows: usize, cols: usize) -> FftPlanQuad64 {
let mut planner = FftPlanner::<f64>::new();
let fft_row = planner.plan_fft_forward(cols);
let fft_col = planner.plan_fft_forward(rows);
let ifft_row = planner.plan_fft_inverse(cols);
let ifft_col = planner.plan_fft_inverse(rows);
(fft_row, fft_col, ifft_row, ifft_col)
}
fn arrays_approx_equal_f32(a: &Array2<f32>, b: &Array2<f32>, epsilon: f32) -> bool {
if a.dim() != b.dim() {
return false;
}
a.iter().zip(b.iter()).all(|(x, y)| (x - y).abs() < epsilon)
}
fn arrays_approx_equal_f64(a: &Array2<f64>, b: &Array2<f64>, epsilon: f64) -> bool {
if a.dim() != b.dim() {
return false;
}
a.iter().zip(b.iter()).all(|(x, y)| (x - y).abs() < epsilon)
}
fn random_matrix_f32(rows: usize, cols: usize, seed: u64) -> Array2<f32> {
let mut rng = SimpleLcg::new(seed);
Array2::from_shape_fn((rows, cols), |_| rng.next_f32())
}
fn random_matrix_f64(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
let mut rng = SimpleLcg::new(seed);
Array2::from_shape_fn((rows, cols), |_| rng.next_f64())
}
#[test]
fn test_fft2d_roundtrip_8x8() {
let input = random_matrix_f32(8, 8, 12345);
let (fft_row, fft_col, ifft_row, ifft_col) = create_fft_plans_f32(8, 8);
let freq = fft2d(input.view(), &fft_row, &fft_col);
let output = ifft2d(&freq, &ifft_row, &ifft_col);
assert!(
arrays_approx_equal_f32(&input, &output, 1e-5),
"FFT roundtrip failed: max diff = {}",
input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max)
);
}
#[test]
fn test_fft2d_roundtrip_various_sizes() {
let sizes = [(4, 4), (8, 8), (16, 16), (32, 32), (4, 8), (8, 16)];
for (rows, cols) in sizes {
let input = random_matrix_f32(rows, cols, (rows * 1000 + cols) as u64);
let (fft_row, fft_col, ifft_row, ifft_col) = create_fft_plans_f32(rows, cols);
let freq = fft2d(input.view(), &fft_row, &fft_col);
let output = ifft2d(&freq, &ifft_row, &ifft_col);
let max_diff = input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
arrays_approx_equal_f32(&input, &output, 1e-5),
"FFT roundtrip failed for {}x{}: max diff = {}",
rows,
cols,
max_diff
);
}
}
#[test]
fn test_fft2d_roundtrip_multiple_seeds() {
for seed in 0..10u64 {
let input = random_matrix_f32(8, 8, seed * 7919); let (fft_row, fft_col, ifft_row, ifft_col) = create_fft_plans_f32(8, 8);
let freq = fft2d(input.view(), &fft_row, &fft_col);
let output = ifft2d(&freq, &ifft_row, &ifft_col);
assert!(
arrays_approx_equal_f32(&input, &output, 1e-5),
"FFT roundtrip failed for seed {}",
seed
);
}
}
#[test]
fn test_fft2d_roundtrip_8x8_f64() {
let input = random_matrix_f64(8, 8, 12345);
let (fft_row, fft_col, ifft_row, ifft_col) = create_fft_plans_f64(8, 8);
let freq = fft2d(input.view(), &fft_row, &fft_col);
let output = ifft2d(&freq, &ifft_row, &ifft_col);
assert!(
arrays_approx_equal_f64(&input, &output, 1e-12),
"FFT f64 roundtrip failed: max diff = {}",
input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f64, f64::max)
);
}
#[test]
fn test_fft2d_roundtrip_various_sizes_f64() {
let sizes = [(4, 4), (8, 8), (16, 16), (32, 32)];
for (rows, cols) in sizes {
let input = random_matrix_f64(rows, cols, (rows * 1000 + cols) as u64);
let (fft_row, fft_col, ifft_row, ifft_col) = create_fft_plans_f64(rows, cols);
let freq = fft2d(input.view(), &fft_row, &fft_col);
let output = ifft2d(&freq, &ifft_row, &ifft_col);
let max_diff = input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f64, f64::max);
assert!(
arrays_approx_equal_f64(&input, &output, 1e-12),
"FFT f64 roundtrip failed for {}x{}: max diff = {}",
rows,
cols,
max_diff
);
}
}
#[test]
fn test_fft2d_zeros() {
let input = Array2::<f32>::zeros((8, 8));
let (fft_row, fft_col, _, _) = create_fft_plans_f32(8, 8);
let output = fft2d(input.view(), &fft_row, &fft_col);
for val in output.iter() {
assert!(
val.norm() < 1e-10,
"FFT of zeros should be zeros, got magnitude {}",
val.norm()
);
}
}
#[test]
fn test_fft2d_constant() {
let input = Array2::<f32>::ones((8, 8));
let (fft_row, fft_col, _, _) = create_fft_plans_f32(8, 8);
let output = fft2d(input.view(), &fft_row, &fft_col);
let dc = output[[0, 0]];
assert!(
(dc.re - 64.0).abs() < 1e-5 && dc.im.abs() < 1e-5,
"DC component should be 64+0i, got {:?}",
dc
);
for r in 0..8 {
for c in 0..8 {
if r != 0 || c != 0 {
let val = output[[r, c]];
assert!(
val.norm() < 1e-5,
"Non-DC component [{},{}] should be ~0, got magnitude {}",
r,
c,
val.norm()
);
}
}
}
}
#[test]
fn test_fft2d_impulse() {
let mut input = Array2::<f32>::zeros((8, 8));
input[[0, 0]] = 1.0;
let (fft_row, fft_col, _, _) = create_fft_plans_f32(8, 8);
let output = fft2d(input.view(), &fft_row, &fft_col);
for r in 0..8 {
for c in 0..8 {
let mag = output[[r, c]].norm();
assert!(
(mag - 1.0).abs() < 1e-5,
"Impulse FFT at [{},{}] should have magnitude 1, got {}",
r,
c,
mag
);
}
}
}
#[test]
fn test_fft2d_parseval() {
let input = random_matrix_f32(8, 8, 42);
let (fft_row, fft_col, _, _) = create_fft_plans_f32(8, 8);
let output = fft2d(input.view(), &fft_row, &fft_col);
let energy_spatial: f32 = input.iter().map(|x| x * x).sum();
let energy_freq: f32 = output.iter().map(|x| x.norm_sqr()).sum();
let expected_freq_energy = energy_spatial * 64.0;
assert!(
(energy_freq - expected_freq_energy).abs() / expected_freq_energy < 1e-4,
"Parseval's theorem violated: spatial={}, freq={}, expected={}",
energy_spatial,
energy_freq,
expected_freq_energy
);
}
#[test]
fn test_fft2d_single_element() {
let mut input = Array2::<f32>::zeros((1, 1));
input[[0, 0]] = 2.71; let (fft_row, fft_col, ifft_row, ifft_col) = create_fft_plans_f32(1, 1);
let freq = fft2d(input.view(), &fft_row, &fft_col);
assert!(
(freq[[0, 0]].re - 2.71).abs() < 1e-5,
"1x1 FFT should preserve value"
);
let output = ifft2d(&freq, &ifft_row, &ifft_col);
assert!((output[[0, 0]] - 2.71).abs() < 1e-5, "1x1 roundtrip failed");
}
#[test]
fn test_fft2d_non_square() {
let sizes = [(4, 8), (8, 4), (2, 16), (16, 2)];
for (rows, cols) in sizes {
let input = random_matrix_f32(rows, cols, (rows * 100 + cols) as u64);
let (fft_row, fft_col, ifft_row, ifft_col) = create_fft_plans_f32(rows, cols);
let freq = fft2d(input.view(), &fft_row, &fft_col);
let output = ifft2d(&freq, &ifft_row, &ifft_col);
assert!(
arrays_approx_equal_f32(&input, &output, 1e-5),
"Non-square {}x{} roundtrip failed",
rows,
cols
);
}
}
#[test]
fn test_fft2d_large_values() {
let mut input = Array2::<f32>::zeros((8, 8));
for r in 0..8 {
for c in 0..8 {
input[[r, c]] = ((r * 8 + c) as f32) * 1e5;
}
}
let (fft_row, fft_col, ifft_row, ifft_col) = create_fft_plans_f32(8, 8);
let freq = fft2d(input.view(), &fft_row, &fft_col);
let output = ifft2d(&freq, &ifft_row, &ifft_col);
for r in 0..8 {
for c in 0..8 {
let diff = (input[[r, c]] - output[[r, c]]).abs();
let rel_err = diff / (input[[r, c]].abs() + 1e-10);
assert!(
rel_err < 1e-4,
"Large value roundtrip failed at [{},{}]: input={}, output={}, rel_err={}",
r,
c,
input[[r, c]],
output[[r, c]],
rel_err
);
}
}
}
#[test]
fn test_fft2d_small_values() {
let mut input = Array2::<f32>::zeros((8, 8));
for r in 0..8 {
for c in 0..8 {
input[[r, c]] = ((r * 8 + c) as f32 + 1.0) * 1e-6;
}
}
let (fft_row, fft_col, ifft_row, ifft_col) = create_fft_plans_f32(8, 8);
let freq = fft2d(input.view(), &fft_row, &fft_col);
let output = ifft2d(&freq, &ifft_row, &ifft_col);
assert!(
arrays_approx_equal_f32(&input, &output, 1e-10),
"Small value roundtrip failed"
);
}
#[test]
fn test_wht_8x8_roundtrip() {
let input = random_matrix_f32(8, 8, 54321);
let freq = wht2d_8x8_forward(input.view());
let output = wht2d_8x8_inverse(&freq);
assert!(
arrays_approx_equal_f32(&input, &output, 1e-6),
"WHT roundtrip failed: max diff = {}",
input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max)
);
}
#[test]
fn test_wht_8x8_roundtrip_multiple() {
for seed in 0..10u64 {
let input = random_matrix_f32(8, 8, seed * 13331);
let freq = wht2d_8x8_forward(input.view());
let output = wht2d_8x8_inverse(&freq);
let max_diff = input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
arrays_approx_equal_f32(&input, &output, 1e-6),
"WHT roundtrip failed for seed {}: max diff = {}",
seed,
max_diff
);
}
}
#[test]
fn test_wht_8x8_roundtrip_f64() {
let input = random_matrix_f64(8, 8, 54321);
let freq = wht2d_8x8_forward(input.view());
let output = wht2d_8x8_inverse(&freq);
assert!(
arrays_approx_equal_f64(&input, &output, 1e-14),
"WHT f64 roundtrip failed: max diff = {}",
input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f64, f64::max)
);
}
#[test]
fn test_wht_8x8_roundtrip_multiple_f64() {
for seed in 0..10u64 {
let input = random_matrix_f64(8, 8, seed * 13331);
let freq = wht2d_8x8_forward(input.view());
let output = wht2d_8x8_inverse(&freq);
let max_diff = input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f64, f64::max);
assert!(
arrays_approx_equal_f64(&input, &output, 1e-14),
"WHT f64 roundtrip failed for seed {}: max diff = {}",
seed,
max_diff
);
}
}
#[test]
fn test_wht_zeros() {
let input = Array2::<f32>::zeros((8, 8));
let output = wht2d_8x8_forward(input.view());
for val in output.iter() {
assert!(
val.norm() < 1e-10,
"WHT of zeros should be zeros, got magnitude {}",
val.norm()
);
}
}
#[test]
fn test_wht_constant() {
let input = Array2::<f32>::ones((8, 8));
let output = wht2d_8x8_forward(input.view());
let dc = output[[0, 0]];
assert!(
(dc.re - 64.0).abs() < 1e-6,
"WHT DC component should be 64, got {}",
dc.re
);
for r in 0..8 {
for c in 0..8 {
if r != 0 || c != 0 {
let val = output[[r, c]];
assert!(
val.norm() < 1e-6,
"Non-DC component [{},{}] should be 0, got {}",
r,
c,
val.norm()
);
}
}
}
}
#[test]
fn test_wht_impulse() {
let mut input = Array2::<f32>::zeros((8, 8));
input[[0, 0]] = 1.0;
let output = wht2d_8x8_forward(input.view());
let expected = output[[0, 0]].re;
for r in 0..8 {
for c in 0..8 {
assert!(
(output[[r, c]].re - expected).abs() < 1e-6,
"WHT of impulse should have uniform coefficients, got [{},{}]={}",
r,
c,
output[[r, c]].re
);
}
}
}
#[test]
fn test_wht_symmetry() {
let input = random_matrix_f32(8, 8, 99999);
let once = wht2d_8x8_forward(input.view());
let mut once_real = Array2::<f32>::zeros((8, 8));
for r in 0..8 {
for c in 0..8 {
once_real[[r, c]] = once[[r, c]].re;
}
}
let twice = wht2d_8x8_forward(once_real.view());
for r in 0..8 {
for c in 0..8 {
let expected = input[[r, c]] * 64.0;
let actual = twice[[r, c]].re;
assert!(
(expected - actual).abs() < 1e-4,
"WHT symmetry failed at [{},{}]: expected {}, got {}",
r,
c,
expected,
actual
);
}
}
}
#[test]
fn test_wht_near_zero() {
let mut input = Array2::<f32>::zeros((8, 8));
let mut rng = SimpleLcg::new(11111);
for r in 0..8 {
for c in 0..8 {
input[[r, c]] = rng.next_f32() * 1e-10;
}
}
let freq = wht2d_8x8_forward(input.view());
let output = wht2d_8x8_inverse(&freq);
assert!(
arrays_approx_equal_f32(&input, &output, 1e-14),
"WHT near-zero roundtrip failed"
);
}
#[test]
fn test_wht_large_values() {
let mut input = Array2::<f32>::zeros((8, 8));
for r in 0..8 {
for c in 0..8 {
input[[r, c]] = ((r * 8 + c) as f32 + 1.0) * 1e5;
}
}
let freq = wht2d_8x8_forward(input.view());
let output = wht2d_8x8_inverse(&freq);
for r in 0..8 {
for c in 0..8 {
let diff = (input[[r, c]] - output[[r, c]]).abs();
let rel_err = diff / (input[[r, c]].abs() + 1e-10);
assert!(
rel_err < 1e-5,
"WHT large value roundtrip failed at [{},{}]: rel_err={}",
r,
c,
rel_err
);
}
}
}
#[test]
fn test_wht_alternating_pattern() {
let mut input = Array2::<f32>::zeros((8, 8));
for r in 0..8 {
for c in 0..8 {
input[[r, c]] = if (r + c) % 2 == 0 { 1.0 } else { -1.0 };
}
}
let freq = wht2d_8x8_forward(input.view());
let output = wht2d_8x8_inverse(&freq);
assert!(
arrays_approx_equal_f32(&input, &output, 1e-6),
"WHT alternating pattern roundtrip failed"
);
}
#[test]
fn test_wht_imaginary_part_ignored() {
let input = random_matrix_f32(8, 8, 77777);
let freq = wht2d_8x8_forward(input.view());
let mut freq_with_imag = freq.clone();
for r in 0..8 {
for c in 0..8 {
freq_with_imag[[r, c]] = Complex::new(freq[[r, c]].re, 999.0);
}
}
let output_clean = wht2d_8x8_inverse(&freq);
let output_imag = wht2d_8x8_inverse(&freq_with_imag);
assert!(
arrays_approx_equal_f32(&output_clean, &output_imag, 1e-10),
"WHT inverse should ignore imaginary part"
);
}
#[test]
fn test_ifft2d_into_matches_ifft2d() {
let input = random_matrix_f32(8, 8, 88888);
let mut planner = FftPlanner::new();
let fft_row = planner.plan_fft_forward(8);
let fft_col = planner.plan_fft_forward(8);
let ifft_row = planner.plan_fft_inverse(8);
let ifft_col = planner.plan_fft_inverse(8);
let freq = fft2d(input.view(), &fft_row, &fft_col);
let expected = ifft2d(&freq, &ifft_row, &ifft_col);
let mut work = Array2::<Complex<f32>>::zeros((8, 8));
let mut output = Array2::<f32>::zeros((8, 8));
let mut scratch = vec![Complex::new(0.0f32, 0.0f32); 8];
ifft2d_into(
freq.view(),
&ifft_row,
&ifft_col,
&mut work,
&mut output,
&mut scratch,
);
assert!(
arrays_approx_equal_f32(&expected, &output, 1e-6),
"ifft2d_into should match ifft2d"
);
}
#[test]
fn test_ifft2d_inplace_to_real_matches_ifft2d() {
let input = random_matrix_f32(8, 8, 987654);
let mut planner = FftPlanner::new();
let fft_row = planner.plan_fft_forward(8);
let fft_col = planner.plan_fft_forward(8);
let ifft_row = planner.plan_fft_inverse(8);
let ifft_col = planner.plan_fft_inverse(8);
let mut freq = fft2d(input.view(), &fft_row, &fft_col);
let expected = ifft2d(&freq, &ifft_row, &ifft_col);
let mut output = Array2::<f32>::zeros((8, 8));
let mut scratch = vec![Complex::new(0.0f32, 0.0f32); 8];
let row_scratch_len = ifft_row.get_inplace_scratch_len();
let col_scratch_len = ifft_col.get_inplace_scratch_len();
let mut row_fft_scratch = vec![Complex::new(0.0f32, 0.0f32); row_scratch_len];
let mut col_fft_scratch = vec![Complex::new(0.0f32, 0.0f32); col_scratch_len];
ifft2d_inplace_to_real_with_plan_scratch(
freq.view_mut(),
&ifft_row,
&ifft_col,
&mut output,
&mut scratch,
&mut row_fft_scratch,
&mut col_fft_scratch,
);
assert!(
arrays_approx_equal_f32(&expected, &output, 1e-6),
"ifft2d_inplace_to_real_with_plan_scratch should match ifft2d"
);
}
#[test]
fn test_fft2d_into_matches_fft2d() {
let input = random_matrix_f32(8, 8, 123456);
let mut planner = FftPlanner::new();
let fft_row = planner.plan_fft_forward(8);
let fft_col = planner.plan_fft_forward(8);
let expected = fft2d(input.view(), &fft_row, &fft_col);
let mut work = Array2::<Complex<f32>>::zeros((8, 8));
let mut output = Array2::<Complex<f32>>::zeros((8, 8));
let mut scratch = vec![Complex::new(0.0f32, 0.0f32); 8];
fft2d_into(
input.view(),
&fft_row,
&fft_col,
&mut work,
output.view_mut(),
&mut scratch,
);
for r in 0..8 {
for c in 0..8 {
let diff = (expected[[r, c]] - output[[r, c]]).norm();
assert!(diff < 1e-6, "fft2d_into mismatch at ({}, {})", r, c);
}
}
}
#[test]
fn test_wht_inverse_into_matches_inverse() {
let input = random_matrix_f32(8, 8, 99999);
let freq = wht2d_8x8_forward(input.view());
let expected = wht2d_8x8_inverse(&freq);
let mut output = Array2::<f32>::zeros((8, 8));
wht2d_8x8_inverse_into_view(freq.view(), &mut output);
assert!(
arrays_approx_equal_f32(&expected, &output, 1e-6),
"wht2d_8x8_inverse_into_view should match wht2d_8x8_inverse"
);
}
#[test]
fn test_wht_forward_into_matches_forward() {
let input = random_matrix_f32(8, 8, 424242);
let expected = wht2d_8x8_forward(input.view());
let mut output = Array2::<Complex<f32>>::zeros((8, 8));
wht2d_8x8_forward_into_view(input.view(), output.view_mut());
for r in 0..8 {
for c in 0..8 {
let diff = (expected[[r, c]] - output[[r, c]]).norm();
assert!(
diff < 1e-6,
"wht2d_8x8_forward_into_view mismatch at ({}, {})",
r,
c
);
}
}
}
}