1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! Walsh–Hadamard transform along the last axis (size N = 2^k) —
//! port of MLX's `hadamard_n`.
//!
//! Computes `y = H_N · x` where `H_N` is the order-N Hadamard matrix,
//! then scales by `scale`. Used by the Walsh–Hadamard quantization /
//! rotation path (relevant to AURA's rotation matrix).
//!
//! Expressed as the fast Walsh–Hadamard transform: `log2(N)` in-place
//! butterfly passes over a threadgroup buffer. The MLX kernel uses a
//! radix-decomposed multi-step form for register efficiency; this port
//! keeps the plain butterfly — the codegen handles the rest, and one
//! threadgroup per row covers any `N ≤ 1024`. The non-power-of-2
//! `hadamard_m` factor (M ∈ {12,20,28}) is a follow-up.
//!
//! ## DISPATCH INVARIANTS
//!
//! - **Reduction mode**, `grid = [rows, 1, 1]`, `tg = [N, 1, 1]`.
//! - `N` a power of two, `32 ≤ N ≤ 1024`; one thread per element.
//!
//! Codegen-only; correctness pinned by
//! `tests/hadamard_gpu_correctness.rs`.
use metaltile::kernel;
use metaltile_core::ir::KernelMode;
use crate::{
bench_types::DType,
spec::{BenchDispatch, BenchSpec},
};
macro_rules! hadamard_kernel {
($name:ident, $n:literal, $log_n:literal, $subop:literal) => {
#[kernel]
pub fn $name<T>(inp: Tensor<T>, out: Tensor<T>, #[constexpr] scale: f32) {
let row = program_id::<0>();
let base = row * $n;
threadgroup_alloc("buf", $n, "f32");
threadgroup_store("buf", tid, load(inp[base + tid]).cast::<f32>());
threadgroup_barrier();
// log2(N) butterfly passes; stride h doubles each pass.
for s in range(0u32, $log_n, 1u32) {
let h = 1u32 << s;
if (tid & h) == 0u32 {
let a = threadgroup_load("buf", tid);
let b = threadgroup_load("buf", tid + h);
threadgroup_store("buf", tid, a + b);
threadgroup_store("buf", tid + h, a - b);
}
threadgroup_barrier();
}
store(out[base + tid], (threadgroup_load("buf", tid) * scale).cast::<T>());
}
inventory::submit! {
BenchSpec {
op: "hadamard",
subop: $subop,
kernel_name: stringify!($name),
kernel_ir: $name::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}
};
}
hadamard_kernel!(mt_hadamard_n64, 64u32, 6u32, "n64");
hadamard_kernel!(mt_hadamard_n128, 128u32, 7u32, "n128");
hadamard_kernel!(mt_hadamard_n256, 256u32, 8u32, "n256");
hadamard_kernel!(mt_hadamard_n512, 512u32, 9u32, "n512");
hadamard_kernel!(mt_hadamard_n1024, 1024u32, 10u32, "n1024");