use metaltile::kernel;
use metaltile_core::{dtype::DType, ir::Kernel};
use crate::bench_types::DType as BenchDType;
#[cfg(test)]
const H12_SIGNS: [u32; 12] = [4093, 1364, 3127, 1681, 223, 2629, 883, 2329, 3523, 1129, 1807, 421];
#[cfg(test)]
const H20_SIGNS: [u32; 20] = [
445473, 859202, 702596, 389384, 747024, 641086, 234589, 469147, 938263, 828943, 984492, 953176,
889521, 762211, 508614, 34194, 68357, 135722, 270452, 540873,
];
#[cfg(test)]
const H28_SIGNS: [u32; 28] = [
53043585, 106070914, 210061060, 153783816, 41229328, 80377888, 160739520, 79265980, 156451192,
44483185, 88966243, 177932359, 87445519, 172810270, 125848794, 251697461, 237056618, 207758549,
149162411, 31986518, 63972909, 3206502, 4315853, 8631579, 17246902, 34477548, 68954969,
135812787,
];
#[kernel]
pub fn mt_hadamard_m12<T>(inp: Tensor<T>, mut out: Tensor<T>, #[constexpr] scale: f32) {
threadgroup_alloc("buf", 12u32, f32);
stack_alloc("signs", 12u32, "u32");
stack_store("signs", 0u32, 4093u32);
stack_store("signs", 1u32, 1364u32);
stack_store("signs", 2u32, 3127u32);
stack_store("signs", 3u32, 1681u32);
stack_store("signs", 4u32, 223u32);
stack_store("signs", 5u32, 2629u32);
stack_store("signs", 6u32, 883u32);
stack_store("signs", 7u32, 2329u32);
stack_store("signs", 8u32, 3523u32);
stack_store("signs", 9u32, 1129u32);
stack_store("signs", 10u32, 1807u32);
stack_store("signs", 11u32, 421u32);
let t = simd_lane;
let row = tgid_x;
let base = row * 12u32;
let tg = base + t;
let inp_f = load(inp[tg]).cast::<f32>();
threadgroup_store("buf", t, inp_f);
threadgroup_barrier();
let signs_t = stack_load("signs", t);
let mut acc = 0.0f32;
for j in range(0u32, 12u32, 1u32) {
let bit = ((signs_t >> j) & 1u32).cast::<f32>();
let sign = 2.0f32 * bit - 1.0f32; let buf_j = threadgroup_load("buf", j);
acc = acc + sign * buf_j;
}
let scaled = acc * scale;
store(out[tg], scaled.cast::<T>());
}
#[kernel]
pub fn mt_hadamard_m20<T>(inp: Tensor<T>, mut out: Tensor<T>, #[constexpr] scale: f32) {
threadgroup_alloc("buf", 20u32, f32);
stack_alloc("signs", 20u32, "u32");
stack_store("signs", 0u32, 445473u32);
stack_store("signs", 1u32, 859202u32);
stack_store("signs", 2u32, 702596u32);
stack_store("signs", 3u32, 389384u32);
stack_store("signs", 4u32, 747024u32);
stack_store("signs", 5u32, 641086u32);
stack_store("signs", 6u32, 234589u32);
stack_store("signs", 7u32, 469147u32);
stack_store("signs", 8u32, 938263u32);
stack_store("signs", 9u32, 828943u32);
stack_store("signs", 10u32, 984492u32);
stack_store("signs", 11u32, 953176u32);
stack_store("signs", 12u32, 889521u32);
stack_store("signs", 13u32, 762211u32);
stack_store("signs", 14u32, 508614u32);
stack_store("signs", 15u32, 34194u32);
stack_store("signs", 16u32, 68357u32);
stack_store("signs", 17u32, 135722u32);
stack_store("signs", 18u32, 270452u32);
stack_store("signs", 19u32, 540873u32);
let t = simd_lane;
let row = tgid_x;
let base = row * 20u32;
let tg = base + t;
let inp_f = load(inp[tg]).cast::<f32>();
threadgroup_store("buf", t, inp_f);
threadgroup_barrier();
let signs_t = stack_load("signs", t);
let mut acc = 0.0f32;
for j in range(0u32, 20u32, 1u32) {
let bit = ((signs_t >> j) & 1u32).cast::<f32>();
let sign = 2.0f32 * bit - 1.0f32;
let buf_j = threadgroup_load("buf", j);
acc = acc + sign * buf_j;
}
let scaled = acc * scale;
store(out[tg], scaled.cast::<T>());
}
#[kernel]
pub fn mt_hadamard_m28<T>(inp: Tensor<T>, mut out: Tensor<T>, #[constexpr] scale: f32) {
threadgroup_alloc("buf", 28u32, f32);
stack_alloc("signs", 28u32, "u32");
stack_store("signs", 0u32, 53043585u32);
stack_store("signs", 1u32, 106070914u32);
stack_store("signs", 2u32, 210061060u32);
stack_store("signs", 3u32, 153783816u32);
stack_store("signs", 4u32, 41229328u32);
stack_store("signs", 5u32, 80377888u32);
stack_store("signs", 6u32, 160739520u32);
stack_store("signs", 7u32, 79265980u32);
stack_store("signs", 8u32, 156451192u32);
stack_store("signs", 9u32, 44483185u32);
stack_store("signs", 10u32, 88966243u32);
stack_store("signs", 11u32, 177932359u32);
stack_store("signs", 12u32, 87445519u32);
stack_store("signs", 13u32, 172810270u32);
stack_store("signs", 14u32, 125848794u32);
stack_store("signs", 15u32, 251697461u32);
stack_store("signs", 16u32, 237056618u32);
stack_store("signs", 17u32, 207758549u32);
stack_store("signs", 18u32, 149162411u32);
stack_store("signs", 19u32, 31986518u32);
stack_store("signs", 20u32, 63972909u32);
stack_store("signs", 21u32, 3206502u32);
stack_store("signs", 22u32, 4315853u32);
stack_store("signs", 23u32, 8631579u32);
stack_store("signs", 24u32, 17246902u32);
stack_store("signs", 25u32, 34477548u32);
stack_store("signs", 26u32, 68954969u32);
stack_store("signs", 27u32, 135812787u32);
let t = simd_lane;
let row = tgid_x;
let base = row * 28u32;
let tg = base + t;
let inp_f = load(inp[tg]).cast::<f32>();
threadgroup_store("buf", t, inp_f);
threadgroup_barrier();
let signs_t = stack_load("signs", t);
let mut acc = 0.0f32;
for j in range(0u32, 28u32, 1u32) {
let bit = ((signs_t >> j) & 1u32).cast::<f32>();
let sign = 2.0f32 * bit - 1.0f32;
let buf_j = threadgroup_load("buf", j);
acc = acc + sign * buf_j;
}
let scaled = acc * scale;
store(out[tg], scaled.cast::<T>());
}
pub fn kernel_ir_for(m: u32, dt: DType) -> Kernel {
assert!(matches!(m, 12 | 20 | 28), "mt_hadamard_m only supports M ∈ {{12, 20, 28}}, got {m}");
assert!(
matches!(dt, DType::F32 | DType::F16 | DType::BF16),
"mt_hadamard_m only supports F32/F16/BF16, got {dt:?}"
);
match m {
12 => mt_hadamard_m12::kernel_ir_for(dt),
20 => mt_hadamard_m20::kernel_ir_for(dt),
28 => mt_hadamard_m28::kernel_ir_for(dt),
_ => unreachable!(),
}
}
const _: &[BenchDType] = &[BenchDType::F32, BenchDType::F16, BenchDType::BF16];
#[cfg(test)]
#[allow(clippy::needless_range_loop)] mod tests {
use metaltile_codegen::msl::MslGenerator;
use metaltile_core::ir::Op;
use super::*;
#[test]
fn kernel_ir_constructs_for_all_m_and_dtypes() {
for m in [12u32, 20, 28] {
for dt in [DType::F32, DType::F16, DType::BF16] {
let k = kernel_ir_for(m, dt);
assert_eq!(k.name, format!("mt_hadamard_m{m}"));
assert_eq!(k.params.len(), 2);
assert_eq!(k.params[0].name, "inp");
assert!(!k.params[0].is_output);
assert_eq!(k.params[1].name, "out");
assert!(k.params[1].is_output);
assert_eq!(k.constexprs.len(), 1);
assert_eq!(k.constexprs[0].name.name(), "scale");
let all_ops =
|| std::iter::once(&k.body).chain(k.blocks.values()).flat_map(|b| b.ops.iter());
assert!(!all_ops().any(|op| matches!(op, Op::InlineMsl { .. })));
assert!(all_ops().any(|op| matches!(op, Op::StackAlloc { .. })));
assert!(all_ops().any(|op| matches!(op, Op::StackLoad { .. })));
}
}
}
#[test]
#[should_panic(expected = "only supports M")]
fn kernel_ir_rejects_invalid_m() { let _ = kernel_ir_for(16, DType::F32); }
#[test]
fn codegen_emits_kernel_decl() {
for m in [12u32, 20, 28] {
for dt in [DType::F32, DType::F16, DType::BF16] {
let mut k = kernel_ir_for(m, dt);
let suffix = match dt {
DType::F32 => "f32",
DType::F16 => "f16",
DType::BF16 => "bf16",
_ => unreachable!(),
};
k.name = format!("mt_hadamard_m{m}_{suffix}");
let msl = MslGenerator::default().generate(&k).expect("codegen");
assert!(msl.contains(&format!("kernel void mt_hadamard_m{m}_{suffix}")));
assert!(!msl.contains("InlineMsl"));
}
}
}
#[test]
fn h12_is_orthogonal() {
let m = 12usize;
for i in 0..m {
for j in 0..m {
let dot: i32 = (0..m)
.map(|k| {
let si = if (H12_SIGNS[i] >> k) & 1 == 1 { 1i32 } else { -1 };
let sj = if (H12_SIGNS[j] >> k) & 1 == 1 { 1i32 } else { -1 };
si * sj
})
.sum();
let expected = if i == j { m as i32 } else { 0 };
assert_eq!(dot, expected, "H12[{i}]·H12[{j}] = {dot}, expected {expected}");
}
}
}
#[test]
fn h20_is_orthogonal() {
let m = 20usize;
for i in 0..m {
for j in 0..m {
let dot: i32 = (0..m)
.map(|k| {
let si = if (H20_SIGNS[i] >> k) & 1 == 1 { 1i32 } else { -1 };
let sj = if (H20_SIGNS[j] >> k) & 1 == 1 { 1i32 } else { -1 };
si * sj
})
.sum();
let expected = if i == j { m as i32 } else { 0 };
assert_eq!(dot, expected, "H20[{i}]·H20[{j}] = {dot}, expected {expected}");
}
}
}
#[test]
fn h28_is_orthogonal() {
let m = 28usize;
for i in 0..m {
for j in 0..m {
let dot: i32 = (0..m)
.map(|k| {
let si = if (H28_SIGNS[i] >> k) & 1 == 1 { 1i32 } else { -1 };
let sj = if (H28_SIGNS[j] >> k) & 1 == 1 { 1i32 } else { -1 };
si * sj
})
.sum();
let expected = if i == j { m as i32 } else { 0 };
assert_eq!(dot, expected, "H28[{i}]·H28[{j}] = {dot}, expected {expected}");
}
}
}
}