#![allow(clippy::cast_precision_loss)] #![allow(clippy::similar_names)]
use oxifft::api::{fft2d, fft_nd, ifft2d, ifft_nd};
use oxifft::Complex;
fn main() {
println!("=== 2D FFT Example ===\n");
let rows = 4;
let cols = 4;
let input_2d: Vec<Complex<f64>> = (0..(rows * cols))
.map(|idx| {
let row = idx / cols;
let col = idx % cols;
if row == 1 && col == 1 {
Complex::new(1.0, 0.0)
} else {
Complex::new(0.0, 0.0)
}
})
.collect();
println!("2D Input ({rows}x{cols}):");
for row in 0..rows {
print!(" ");
for col in 0..cols {
print!("{:+.2} ", input_2d[row * cols + col].re);
}
println!();
}
let spectrum_2d = fft2d(&input_2d, rows, cols);
println!("\n2D FFT Output (magnitudes):");
for row in 0..rows {
print!(" ");
for col in 0..cols {
let c = spectrum_2d[row * cols + col];
let mag = c.re.hypot(c.im);
print!("{mag:+.2} ");
}
println!();
}
let recovered_2d = ifft2d(&spectrum_2d, rows, cols);
let max_error_2d: f64 = input_2d
.iter()
.zip(recovered_2d.iter())
.map(|(a, b)| (a.re - b.re).hypot(a.im - b.im))
.fold(0.0, f64::max);
println!("\n2D roundtrip error: {max_error_2d:.2e}");
println!("\n=== 3D FFT Example ===\n");
let dims = [2, 2, 2]; let total = dims.iter().product::<usize>();
let input_3d: Vec<Complex<f64>> = (0..total)
.map(|idx| Complex::new(idx as f64, 0.0))
.collect();
println!(
"3D Input values: {:?}",
input_3d.iter().map(|c| c.re).collect::<Vec<_>>()
);
let spectrum_nd = fft_nd(&input_3d, &dims);
println!(
"3D FFT output magnitudes: {:?}",
spectrum_nd
.iter()
.map(|c| c.re.hypot(c.im))
.collect::<Vec<_>>()
);
let recovered_nd = ifft_nd(&spectrum_nd, &dims);
let max_error_nd: f64 = input_3d
.iter()
.zip(recovered_nd.iter())
.map(|(a, b)| (a.re - b.re).hypot(a.im - b.im))
.fold(0.0, f64::max);
println!("3D roundtrip error: {max_error_nd:.2e}");
}