use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="fp_quantized",
subop="fp4_quant_dequant",
class=FpQuantized,
n=1048576,
tpg=32,
tol=0.5,
mlx="nvfp4_quantize_dequantize_float_gs_16_b_4",
metal_file="fp_quantized.metal",
dtypes=crate::spec::F32_ONLY,
)]
#[kernel]
pub fn mt_fp4_quant_dequant(inp: Tensor<f32>, out: Tensor<f32>, #[constexpr] n: u32) {
let gid = program_id::<0>();
let x = load(inp[gid]);
let ax = abs(x);
let group_max = simd_max(ax);
let inv_scale = select(group_max > 0.0f32, 6.0f32 / group_max, 0.0f32);
let norm = ax * inv_scale;
let q = select(
norm < 0.25f32,
0.0f32,
select(
norm < 0.75f32,
0.5f32,
select(
norm < 1.25f32,
1.0f32,
select(
norm < 1.75f32,
1.5f32,
select(
norm < 2.5f32,
2.0f32,
select(norm < 3.5f32, 3.0f32, select(norm < 5.0f32, 4.0f32, 6.0f32)),
),
),
),
),
);
let sign = select(x < 0.0f32, -1.0f32, 1.0f32);
let result = sign * q * (group_max / 6.0f32);
store(out[gid], result);
}
macro_rules! fp8_kernel {
($name:ident, $subop:literal, $mant:literal, $emin:literal, $emax:literal, $fp8max:literal) => {
#[bench_kernel(op = "fp_quantized", subop = $subop, class = FpQuantized, n = 1048576, tpg = 32, tol = 0.05, dtypes = crate::spec::F32_ONLY)]
#[kernel]
pub fn $name(inp: Tensor<f32>, out: Tensor<f32>, #[constexpr] n: u32) {
let gid = program_id::<0>();
let x = load(inp[gid]);
let ax = abs(x);
let group_max = simd_max(ax);
let inv_scale = select(group_max > 0.0f32, $fp8max / group_max, 0.0f32);
let norm = ax * inv_scale;
let raw_e = floor(log2(norm));
let e_lo = select(raw_e < $emin, $emin, raw_e);
let e = select(e_lo > $emax, $emax, e_lo);
let quantum = exp2(e - $mant);
let snapped = round(norm / quantum) * quantum;
let q = select(norm > 0.0f32, snapped, 0.0f32);
let q_clamped = select(q > $fp8max, $fp8max, q);
let sign = select(x < 0.0f32, -1.0f32, 1.0f32);
let result = sign * q_clamped * (group_max / $fp8max);
store(out[gid], result);
}
};
}
fp8_kernel!(mt_fp8_e4m3_quant_dequant, "fp8_e4m3", 3.0f32, -6.0f32, 8.0f32, 448.0f32);
fp8_kernel!(mt_fp8_e5m2_quant_dequant, "fp8_e5m2", 2.0f32, -14.0f32, 15.0f32, 57344.0f32);