use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::region::wrap_anonymous;
const OP_ID: &str = "vyre-libs::math::fft::fft4_complex";
#[must_use]
pub fn fft4_complex(input: &str, output: &str) -> Program {
let body = vec![
Node::let_bind("x0r", Expr::load(input, Expr::u32(0))),
Node::let_bind("x0i", Expr::load(input, Expr::u32(1))),
Node::let_bind("x1r", Expr::load(input, Expr::u32(2))),
Node::let_bind("x1i", Expr::load(input, Expr::u32(3))),
Node::let_bind("x2r", Expr::load(input, Expr::u32(4))),
Node::let_bind("x2i", Expr::load(input, Expr::u32(5))),
Node::let_bind("x3r", Expr::load(input, Expr::u32(6))),
Node::let_bind("x3i", Expr::load(input, Expr::u32(7))),
Node::Store {
buffer: output.into(),
index: Expr::u32(0),
value: Expr::add(
Expr::add(Expr::var("x0r"), Expr::var("x1r")),
Expr::add(Expr::var("x2r"), Expr::var("x3r")),
),
},
Node::Store {
buffer: output.into(),
index: Expr::u32(1),
value: Expr::add(
Expr::add(Expr::var("x0i"), Expr::var("x1i")),
Expr::add(Expr::var("x2i"), Expr::var("x3i")),
),
},
Node::Store {
buffer: output.into(),
index: Expr::u32(2),
value: Expr::sub(
Expr::sub(
Expr::add(Expr::var("x0r"), Expr::var("x1i")),
Expr::var("x2r"),
),
Expr::var("x3i"),
),
},
Node::Store {
buffer: output.into(),
index: Expr::u32(3),
value: Expr::add(
Expr::sub(
Expr::sub(Expr::var("x0i"), Expr::var("x1r")),
Expr::var("x2i"),
),
Expr::var("x3r"),
),
},
Node::Store {
buffer: output.into(),
index: Expr::u32(4),
value: Expr::add(
Expr::sub(Expr::var("x0r"), Expr::var("x1r")),
Expr::sub(Expr::var("x2r"), Expr::var("x3r")),
),
},
Node::Store {
buffer: output.into(),
index: Expr::u32(5),
value: Expr::add(
Expr::sub(Expr::var("x0i"), Expr::var("x1i")),
Expr::sub(Expr::var("x2i"), Expr::var("x3i")),
),
},
Node::Store {
buffer: output.into(),
index: Expr::u32(6),
value: Expr::add(
Expr::sub(
Expr::sub(Expr::var("x0r"), Expr::var("x1i")),
Expr::var("x2r"),
),
Expr::var("x3i"),
),
},
Node::Store {
buffer: output.into(),
index: Expr::u32(7),
value: Expr::sub(
Expr::sub(
Expr::add(Expr::var("x0i"), Expr::var("x1r")),
Expr::var("x2i"),
),
Expr::var("x3r"),
),
},
];
Program::wrapped(
vec![
BufferDecl::storage(input, 0, BufferAccess::ReadOnly, DataType::F32).with_count(8),
BufferDecl::output(output, 1, DataType::F32).with_count(8),
],
[1, 1, 1],
vec![wrap_anonymous(OP_ID, body)],
)
}
inventory::submit! {
crate::harness::OpEntry {
id: OP_ID,
build: || fft4_complex("input", "output"),
test_inputs: Some(|| {
let input = crate::test_support::byte_pack::f32_bytes(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
vec![vec![input]]
}),
expected_output: Some(|| {
vec![vec![crate::test_support::byte_pack::f32_bytes(&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0])]]
}),
category: Some("math"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::byte_pack::f32_bytes;
use vyre_reference::value::Value;
fn decode(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect()
}
fn naive_dft4(input: &[f32]) -> Vec<f32> {
let n = 4usize;
let mut out = vec![0.0f32; 8];
for k in 0..n {
let mut re = 0.0_f32;
let mut im = 0.0_f32;
for nn in 0..n {
let xr = input[2 * nn];
let xi = input[2 * nn + 1];
let theta = -2.0 * std::f32::consts::PI * (nn as f32) * (k as f32) / (n as f32);
let cos_t = theta.cos();
let sin_t = theta.sin();
re += xr * cos_t - xi * sin_t;
im += xr * sin_t + xi * cos_t;
}
out[2 * k] = re;
out[2 * k + 1] = im;
}
out
}
fn run(input: &[f32]) -> Vec<f32> {
let prog = fft4_complex("input", "output");
let outputs = vyre_reference::reference_eval(
&prog,
&[Value::from(f32_bytes(input)), Value::from(vec![0u8; 32])],
)
.expect("Fix: fft4_complex must execute in the reference interpreter.");
decode(&outputs[0].to_bytes())
}
#[test]
fn fft4_impulse_yields_uniform_bins() {
let input = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let actual = run(&input);
let expected = [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
for (a, e) in actual.iter().zip(expected.iter()) {
assert!((a - e).abs() <= 1.0e-5, "{a} != {e}");
}
}
#[test]
fn fft4_dc_signal_concentrates_in_dc_bin() {
let input = [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
let actual = run(&input);
let expected = [4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
for (a, e) in actual.iter().zip(expected.iter()) {
assert!((a - e).abs() <= 1.0e-5, "{a} != {e}");
}
}
#[test]
fn fft4_freq1_cosine_concentrates_in_bins_1_and_3() {
let input = [1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0];
let actual = run(&input);
let expected = [0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 2.0, 0.0];
for (a, e) in actual.iter().zip(expected.iter()) {
assert!((a - e).abs() <= 1.0e-5, "{a} != {e}");
}
}
#[test]
fn fft4_matches_naive_dft_on_random_fuzz() {
let mut state = 0xCAFEBABE_u64;
let mut next = || {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f32 / (u32::MAX as f32 / 2.0)) - 1.0
};
for _ in 0..50 {
let input: Vec<f32> = (0..8).map(|_| next()).collect();
let actual = run(&input);
let expected = naive_dft4(&input);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() <= 1.0e-4,
"lane {i}: fft={a} naive={e} diff={}",
(a - e).abs()
);
}
}
}
#[test]
fn fft4_nan_input_propagates_to_real_parts() {
let input = [f32::NAN, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let actual = run(&input);
for k in 0..4 {
assert!(
actual[2 * k].is_nan(),
"FFT bin {k} real part must be NaN when x0r is NaN"
);
}
}
#[test]
fn fft4_nan_both_components_propagates_everywhere() {
let input = [f32::NAN, f32::NAN, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let actual = run(&input);
for (i, &v) in actual.iter().enumerate() {
assert!(
v.is_nan(),
"FFT lane {i} must be NaN when both re/im inputs are NaN"
);
}
}
#[test]
fn fft4_inf_input_propagates_to_real_parts() {
let input = [f32::INFINITY, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let actual = run(&input);
for k in 0..4 {
assert!(
actual[2 * k].is_infinite(),
"FFT bin {k} real part must be Inf when x0r is Inf"
);
}
}
}