use mlxrs::{
Array, Dtype,
ops::fft::{self, FftNorm},
};
const TOL: f32 = 1e-4;
fn close(a: f32, b: f32) -> bool {
(a - b).abs() <= TOL
}
#[test]
fn fft_then_ifft_round_trips_real_signal() {
let data = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let a = Array::from_slice::<f32>(&data, &[8i32]).unwrap();
let f = fft::fft(&a, 8, 0, FftNorm::Backward).unwrap();
let back = fft::ifft(&f, 8, 0, FftNorm::Backward).unwrap();
assert_eq!(back.shape(), vec![8]);
assert_eq!(back.dtype().unwrap(), Dtype::Complex64);
let _ = back;
let rf = fft::rfft(&a, 8, 0, FftNorm::Backward).unwrap();
let mut rb = fft::irfft(&rf, 8, 0, FftNorm::Backward).unwrap();
assert_eq!(rb.shape(), vec![8]);
let v = rb.to_vec::<f32>().unwrap();
for (got, want) in v.iter().zip(data.iter()) {
assert!(close(*got, *want), "rfft round-trip got={got} want={want}");
}
}
#[test]
fn fft2_then_ifft2_round_trips_real_2d() {
let data: Vec<f32> = (0..16).map(|x| x as f32).collect();
let a = Array::from_slice::<f32>(&data, &(4, 4)).unwrap();
let f = fft::rfft2(&a, &[4, 4], &[0, 1], FftNorm::Backward).unwrap();
assert_eq!(f.shape(), vec![4, 3]);
let mut back = fft::irfft2(&f, &[4, 4], &[0, 1], FftNorm::Backward).unwrap();
assert_eq!(back.shape(), vec![4, 4]);
let v = back.to_vec::<f32>().unwrap();
for (got, want) in v.iter().zip(data.iter()) {
assert!(close(*got, *want), "rfft2 round-trip got={got} want={want}");
}
}
#[test]
fn fftn_empty_axes_expands_to_all_dims() {
let data: Vec<f32> = (0..24).map(|x| x as f32).collect();
let a = Array::from_slice::<f32>(&data, &(2, 3, 4)).unwrap();
let f = fft::rfftn(&a, &[], &[], FftNorm::Backward).unwrap();
assert_eq!(f.shape(), vec![2, 3, 3]);
assert_eq!(f.dtype().unwrap(), Dtype::Complex64);
let back = fft::irfftn(&f, &[], &[], FftNorm::Backward).unwrap();
assert_eq!(back.dtype().unwrap(), Dtype::F32);
}
#[test]
fn fftn_explicit_axes() {
let data: Vec<f32> = (0..12).map(|x| x as f32).collect();
let a = Array::from_slice::<f32>(&data, &(3, 4)).unwrap();
let f = fft::fftn(&a, &[4], &[1], FftNorm::Backward).unwrap();
assert_eq!(f.shape(), vec![3, 4]);
assert_eq!(f.dtype().unwrap(), Dtype::Complex64);
}
#[test]
fn fft2_complex_round_trips() {
let data: Vec<f32> = (0..16).map(|x| x as f32).collect();
let a = Array::from_slice::<f32>(&data, &(4, 4)).unwrap();
let f = fft::fft2(&a, &[4, 4], &[0, 1], FftNorm::Backward).unwrap();
let back = fft::ifft2(&f, &[4, 4], &[0, 1], FftNorm::Backward).unwrap();
assert_eq!(back.shape(), vec![4, 4]);
assert_eq!(back.dtype().unwrap(), Dtype::Complex64);
}
#[test]
fn fft_method_form_matches_freefn() {
let data = [1.0_f32, 0.0, 0.0, 0.0];
let a = Array::from_slice::<f32>(&data, &[4i32]).unwrap();
let f1 = fft::fft(&a, 4, 0, FftNorm::Backward).unwrap();
let f2 = a.fft(4, 0, FftNorm::Backward).unwrap();
assert_eq!(f1.shape(), f2.shape());
assert_eq!(f1.dtype().unwrap(), f2.dtype().unwrap());
}
#[test]
fn fft_norm_ortho_changes_magnitude() {
let mut data = vec![0.0_f32; 8];
data[0] = 1.0;
let a = Array::from_slice::<f32>(&data, &[8i32]).unwrap();
let mut back_f = fft::fft(&a, 8, 0, FftNorm::Backward).unwrap();
let mut ortho_f = fft::fft(&a, 8, 0, FftNorm::Ortho).unwrap();
assert_eq!(back_f.shape(), ortho_f.shape());
let _ = back_f.eval();
let _ = ortho_f.eval();
}
#[test]
fn fftfreq_yields_n_samples() {
let mut f = fft::fftfreq(8, 1.0).unwrap();
assert_eq!(f.shape(), vec![8]);
let v = f.to_vec::<f32>().unwrap();
let want = [0.0_f32, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125];
for (got, w) in v.iter().zip(want.iter()) {
assert!(close(*got, *w), "fftfreq got={got} want={w}");
}
}
#[test]
fn rfftfreq_yields_n_over_2_plus_one_samples() {
let mut f = fft::rfftfreq(8, 1.0).unwrap();
assert_eq!(f.shape(), vec![5]);
let v = f.to_vec::<f32>().unwrap();
let want = [0.0_f32, 0.125, 0.25, 0.375, 0.5];
for (got, w) in v.iter().zip(want.iter()) {
assert!(close(*got, *w), "rfftfreq got={got} want={w}");
}
}
#[test]
fn fftshift_then_ifftshift_round_trips() {
let mut a = Array::arange::<f32>(0.0, 8.0, 1.0).unwrap();
let want = a.to_vec::<f32>().unwrap();
let s = fft::fftshift(&a, &[]).unwrap();
let mut back = fft::ifftshift(&s, &[]).unwrap();
assert_eq!(back.shape(), vec![8]);
let v = back.to_vec::<f32>().unwrap();
assert_eq!(v, want);
}
#[test]
fn fftshift_axes_specific_axis() {
let a = Array::arange::<f32>(0.0, 8.0, 1.0).unwrap();
let s = a.fftshift(&[0]).unwrap();
assert_eq!(s.shape(), vec![8]);
}