mlx_native/ops/hadamard.rs
1//! Fast Walsh-Hadamard Transform (FWHT) GPU kernel dispatch.
2//!
3//! Applies an in-place, normalized FWHT to a flat buffer shaped
4//! `[num_heads, head_dim]`. One Metal threadgroup is dispatched per head;
5//! each threadgroup has `head_dim` threads that cooperate through shared
6//! memory using the standard butterfly pattern.
7//!
8//! The transform is normalized so that H·H = I (applying it twice returns
9//! the original vector), which is required for the random-feature / scrambled
10//! Hadamard use-case in Gemma-4 attention.
11
12use metal::MTLSize;
13
14use crate::buffer::MlxBuffer;
15use crate::encoder::CommandEncoder;
16use crate::error::{MlxError, Result};
17use crate::kernel_registry::KernelRegistry;
18
19use super::encode_helpers::{encode_threadgroups_with_args_and_shared, KernelArg};
20
21/// MSL source for the Hadamard transform kernel (embedded at compile time).
22pub static HADAMARD_SHADER_SOURCE: &str = include_str!("../shaders/hadamard.metal");
23
24/// Register the Hadamard transform shader source with the given kernel registry.
25pub fn register(registry: &mut KernelRegistry) {
26 registry.register_source("hadamard_transform", HADAMARD_SHADER_SOURCE);
27}
28
29/// Dispatch an in-place normalized Fast Walsh-Hadamard Transform on the GPU.
30///
31/// Transforms `data` in-place. After this call the GPU contains
32/// `H(data)` normalized by `1/sqrt(head_dim)`.
33///
34/// # Arguments
35///
36/// * `encoder` — Command encoder to record the dispatch into.
37/// * `registry` — Kernel registry (must have `hadamard_transform` registered).
38/// * `device` — Metal device for pipeline compilation.
39/// * `data` — F32 buffer of shape `[num_heads, head_dim]`, modified in-place.
40/// * `head_dim` — Number of elements per head. Must be a power of two and ≤ 8192.
41/// * `num_heads` — Number of heads (threadgroups dispatched).
42///
43/// # Errors
44///
45/// Returns `MlxError::InvalidArgument` if `head_dim` is not a power of two,
46/// if the buffer is too small, or if `head_dim > 8192` (exceeds Metal's 32 KB
47/// threadgroup memory limit for f32).
48pub fn dispatch_hadamard_transform(
49 encoder: &mut CommandEncoder,
50 registry: &mut KernelRegistry,
51 device: &metal::DeviceRef,
52 data: &MlxBuffer,
53 head_dim: u32,
54 num_heads: u32,
55) -> Result<()> {
56 if num_heads == 0 || head_dim == 0 {
57 return Ok(());
58 }
59
60 // head_dim must be a power of two (butterfly pattern requirement).
61 if !head_dim.is_power_of_two() {
62 return Err(MlxError::InvalidArgument(format!(
63 "hadamard_transform: head_dim must be a power of two, got {}",
64 head_dim
65 )));
66 }
67
68 // 32 KB threadgroup memory limit: head_dim * 4 bytes ≤ 32768 bytes → head_dim ≤ 8192
69 if head_dim > 8192 {
70 return Err(MlxError::InvalidArgument(format!(
71 "hadamard_transform: head_dim {} exceeds Metal 32 KB threadgroup memory limit \
72 (max 8192 for f32)",
73 head_dim
74 )));
75 }
76
77 let required_elements = (num_heads as u64) * (head_dim as u64);
78 if (data.element_count() as u64) < required_elements {
79 return Err(MlxError::InvalidArgument(format!(
80 "hadamard_transform: data has {} elements but need {} \
81 (num_heads={} * head_dim={})",
82 data.element_count(),
83 required_elements,
84 num_heads,
85 head_dim,
86 )));
87 }
88
89 let pipeline = registry.get_pipeline("hadamard_transform", device)?;
90
91 let head_dim_bytes = head_dim.to_ne_bytes();
92 let num_heads_bytes = num_heads.to_ne_bytes();
93
94 // Shared memory: head_dim floats (4 bytes each) per threadgroup.
95 let shared_mem_bytes = (head_dim as u64) * 4;
96
97 encode_threadgroups_with_args_and_shared(
98 encoder,
99 pipeline,
100 &[
101 (0, KernelArg::Buffer(data)),
102 (1, KernelArg::Bytes(&head_dim_bytes)),
103 (2, KernelArg::Bytes(&num_heads_bytes)),
104 ],
105 &[(0, shared_mem_bytes)],
106 MTLSize::new(num_heads as u64, 1, 1),
107 MTLSize::new(head_dim as u64, 1, 1),
108 );
109
110 Ok(())
111}