use metaltile::kernel;
use metaltile_core::ir::KernelMode;
use crate::{
bench_types::DType,
spec::{BenchDispatch, BenchSpec},
};
macro_rules! fft_kernel {
($name:ident, $n:literal, $log_n:literal, $inv_n:literal, $subop:literal) => {
#[kernel]
pub fn $name<T>(
in_re: Tensor<T>,
in_im: Tensor<T>,
mut out_re: Tensor<T>,
mut out_im: Tensor<T>,
#[constexpr] inv: u32,
) {
let row = program_id::<0>();
let base = row * $n;
threadgroup_alloc("re", $n, "f32");
threadgroup_alloc("im", $n, "f32");
let mut src = 0u32;
let mut rem = tid;
for _b in range(0u32, $log_n, 1u32) {
src = (src << 1u32) | (rem & 1u32);
rem = rem >> 1u32;
}
threadgroup_store("re", tid, load(in_re[base + src]).cast::<f32>());
threadgroup_store("im", tid, load(in_im[base + src]).cast::<f32>());
threadgroup_barrier();
let pi = 3.141592653589793f32;
let angle_sign = select(inv == 0u32, -1.0f32, 1.0f32);
for s in range(0u32, $log_n, 1u32) {
let h = 1u32 << s;
if (tid & h) == 0u32 {
let k = tid & (h - 1u32);
let h_f = h.cast::<f32>();
let angle = angle_sign * pi * k.cast::<f32>() / h_f;
let wr = cos(angle);
let wi = sin(angle);
let ar = threadgroup_load("re", tid);
let ai = threadgroup_load("im", tid);
let br = threadgroup_load("re", tid + h);
let bi = threadgroup_load("im", tid + h);
let tr = wr * br - wi * bi;
let ti = wr * bi + wi * br;
threadgroup_store("re", tid, ar + tr);
threadgroup_store("im", tid, ai + ti);
threadgroup_store("re", tid + h, ar - tr);
threadgroup_store("im", tid + h, ai - ti);
}
threadgroup_barrier();
}
let scale = select(inv == 0u32, 1.0f32, $inv_n);
let res_re = threadgroup_load("re", tid) * scale;
let res_im = threadgroup_load("im", tid) * scale;
store(out_re[base + tid], res_re.cast::<T>());
store(out_im[base + tid], res_im.cast::<T>());
}
inventory::submit! {
BenchSpec {
op: "fft",
subop: $subop,
kernel_name: stringify!($name),
kernel_ir: $name::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}
};
}
fft_kernel!(mt_fft_n32, 32u32, 5u32, 0.031_25f32, "n32");
fft_kernel!(mt_fft_n64, 64u32, 6u32, 0.015_625f32, "n64");
fft_kernel!(mt_fft_n128, 128u32, 7u32, 0.007_812_5f32, "n128");
fft_kernel!(mt_fft_n256, 256u32, 8u32, 0.003_906_25f32, "n256");
fft_kernel!(mt_fft_n512, 512u32, 9u32, 0.001_953_125f32, "n512");
fft_kernel!(mt_fft_n1024, 1024u32, 10u32, 0.000_976_562_5f32, "n1024");
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn mt_fft_bluestein_preprocess<T>(
in_re: Tensor<T>,
in_im: Tensor<T>,
mut out_re: Tensor<T>,
mut out_im: Tensor<T>,
#[constexpr] n_len: u32,
#[constexpr] m_len: u32,
#[constexpr] rows: u32,
#[constexpr] inv: u32,
) {
let idx = program_id::<0>();
if idx < rows * m_len {
let col = idx % m_len;
let row = idx / m_len;
let pi = 3.141592653589793f32;
let angle_sign = select(inv == 0u32, -1.0f32, 1.0f32);
if col >= n_len {
store(out_re[row * m_len + col], 0.0f32.cast::<T>());
store(out_im[row * m_len + col], 0.0f32.cast::<T>());
} else {
let n_f = col.cast::<f32>();
let n_len_f = n_len.cast::<f32>();
let angle = angle_sign * pi * n_f * n_f / n_len_f;
let wr = cos(angle);
let wi = sin(angle);
let xr = load(in_re[row * n_len + col]).cast::<f32>();
let xi = load(in_im[row * n_len + col]).cast::<f32>();
let pr = xr * wr - xi * wi;
let pi_v = xr * wi + xi * wr;
store(out_re[row * m_len + col], pr.cast::<T>());
store(out_im[row * m_len + col], pi_v.cast::<T>());
}
}
}
inventory::submit! {
BenchSpec {
op: "fft",
subop: "bluestein_preprocess",
kernel_name: "mt_fft_bluestein_preprocess",
kernel_ir: mt_fft_bluestein_preprocess::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Grid3D),
}
}
#[kernel]
pub fn mt_fft_bluestein_chirp_filter(
mut filter_re: Tensor<f32>,
mut filter_im: Tensor<f32>,
#[constexpr] n_len: u32,
#[constexpr] m_len: u32,
) {
let m = program_id::<0>(); let pi = 3.141592653589793f32;
let m_minus = m_len - m;
let n_tap = select(m < n_len, m, select(m_minus < n_len, m_minus, n_len));
let in_range = (m < n_len) | ((m_minus < n_len) & (m > 0u32));
if in_range {
let n_f = n_tap.cast::<f32>();
let n_len_f = n_len.cast::<f32>();
let angle = pi * n_f * n_f / n_len_f;
let wr = cos(angle);
let wi = sin(angle);
store(filter_re[m], wr);
store(filter_im[m], wi);
} else {
store(filter_re[m], 0.0f32);
store(filter_im[m], 0.0f32);
}
}
fn mt_fft_bluestein_chirp_filter_ir(_: DType) -> metaltile_core::ir::Kernel {
mt_fft_bluestein_chirp_filter::kernel_ir_for()
}
inventory::submit! {
BenchSpec {
op: "fft",
subop: "bluestein_chirp_filter",
kernel_name: "mt_fft_bluestein_chirp_filter",
kernel_ir: mt_fft_bluestein_chirp_filter_ir,
dtypes: &[DType::F32],
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Grid3D),
}
}
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn mt_fft_bluestein_cmul<T>(
y_re: Tensor<T>,
y_im: Tensor<T>,
filter_re: Tensor<f32>,
filter_im: Tensor<f32>,
mut out_re: Tensor<T>,
mut out_im: Tensor<T>,
#[constexpr] m_len: u32,
#[constexpr] rows: u32,
) {
let idx = program_id::<0>();
if idx < rows * m_len {
let col = idx % m_len;
let yr = load(y_re[idx]).cast::<f32>();
let yi = load(y_im[idx]).cast::<f32>();
let fr = load(filter_re[col]);
let fi = load(filter_im[col]);
let pr = yr * fr - yi * fi;
let pi_v = yr * fi + yi * fr;
store(out_re[idx], pr.cast::<T>());
store(out_im[idx], pi_v.cast::<T>());
}
}
inventory::submit! {
BenchSpec {
op: "fft",
subop: "bluestein_cmul",
kernel_name: "mt_fft_bluestein_cmul",
kernel_ir: mt_fft_bluestein_cmul::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Grid3D),
}
}
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn mt_fft_bluestein_postprocess<T>(
conv_re: Tensor<T>,
conv_im: Tensor<T>,
mut out_re: Tensor<T>,
mut out_im: Tensor<T>,
#[constexpr] n_len: u32,
#[constexpr] m_len: u32,
#[constexpr] rows: u32,
#[constexpr] inv: u32,
) {
let idx = program_id::<0>();
if idx < rows * n_len {
let k = idx % n_len;
let row = idx / n_len;
let pi = 3.141592653589793f32;
let angle_sign = select(inv == 0u32, -1.0f32, 1.0f32);
let k_f = k.cast::<f32>();
let n_len_f = n_len.cast::<f32>();
let angle = angle_sign * pi * k_f * k_f / n_len_f;
let wr = cos(angle);
let wi = sin(angle);
let cr = load(conv_re[row * m_len + k]).cast::<f32>();
let ci = load(conv_im[row * m_len + k]).cast::<f32>();
let pr = cr * wr - ci * wi;
let pi_v = cr * wi + ci * wr;
let scale = select(inv == 0u32, 1.0f32, 1.0f32 / n_len_f);
store(out_re[idx], (pr * scale).cast::<T>());
store(out_im[idx], (pi_v * scale).cast::<T>());
}
}
inventory::submit! {
BenchSpec {
op: "fft",
subop: "bluestein_postprocess",
kernel_name: "mt_fft_bluestein_postprocess",
kernel_ir: mt_fft_bluestein_postprocess::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Grid3D),
}
}