use numr::dtype::complex::Complex64;
use numr::prelude::*;
fn main() -> Result<()> {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let n = 64;
let signal: Vec<Complex64> = (0..n)
.map(|i| {
let t = i as f32 / n as f32;
let val = (2.0 * std::f32::consts::PI * 3.0 * t).sin()
+ 0.5 * (2.0 * std::f32::consts::PI * 10.0 * t).cos();
Complex64::new(val, 0.0)
})
.collect();
let input = Tensor::<CpuRuntime>::from_slice(&signal, &[n], &device);
let freq = client.fft(&input, FftDirection::Forward, FftNormalization::Backward)?;
let freq_data: Vec<Complex64> = freq.to_vec();
let mut magnitudes: Vec<(usize, f32)> = freq_data
.iter()
.enumerate()
.map(|(i, c)| (i, c.magnitude()))
.collect();
magnitudes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
println!("Top 5 frequency bins by magnitude:");
for &(bin, mag) in magnitudes.iter().take(5) {
println!(" bin {bin:>3}: {mag:.4}");
}
let recovered = client.fft(&freq, FftDirection::Inverse, FftNormalization::Backward)?;
let recovered_data: Vec<Complex64> = recovered.to_vec();
let max_err: f32 = signal
.iter()
.zip(recovered_data.iter())
.map(|(a, b)| {
let dr = a.re - b.re;
let di = a.im - b.im;
(dr * dr + di * di).sqrt()
})
.fold(0.0f32, f32::max);
println!("\nComplex FFT round-trip max error: {max_err:.2e}");
assert!(max_err < 1e-4, "Round-trip error should be small");
let real_signal: Vec<f32> = (0..n)
.map(|i| {
let t = i as f32 / n as f32;
(2.0 * std::f32::consts::PI * 5.0 * t).sin()
})
.collect();
let real_input = Tensor::<CpuRuntime>::from_slice(&real_signal, &[n], &device);
let real_freq = client.rfft(&real_input, FftNormalization::Backward)?;
println!(
"\nrfft: input length = {n}, output length = {} (N/2+1 complex)",
real_freq.shape()[0],
);
let real_recovered = client.irfft(&real_freq, Some(n), FftNormalization::Backward)?;
let real_recovered_data: Vec<f32> = real_recovered.to_vec();
let real_max_err: f32 = real_signal
.iter()
.zip(real_recovered_data.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
println!("Real FFT round-trip max error: {real_max_err:.2e}");
assert!(real_max_err < 1e-4, "Real round-trip error should be small");
println!("\nFFT round-trip example completed successfully!");
Ok(())
}