Skip to main content

mlx_native/ops/
hadamard.rs

1//! Fast Walsh-Hadamard Transform (FWHT) GPU kernel dispatch.
2//!
3//! Applies an in-place, normalized FWHT to a flat buffer shaped
4//! `[num_heads, head_dim]`.  One Metal threadgroup is dispatched per head;
5//! each threadgroup has `head_dim` threads that cooperate through shared
6//! memory using the standard butterfly pattern.
7//!
8//! The transform is normalized so that H·H = I (applying it twice returns
9//! the original vector), which is required for the random-feature / scrambled
10//! Hadamard use-case in Gemma-4 attention.
11
12use metal::MTLSize;
13
14use crate::buffer::MlxBuffer;
15use crate::encoder::CommandEncoder;
16use crate::error::{MlxError, Result};
17use crate::kernel_registry::KernelRegistry;
18
19use super::encode_helpers::{encode_threadgroups_with_args_and_shared, KernelArg};
20
21/// MSL source for the Hadamard transform kernel (embedded at compile time).
22pub static HADAMARD_SHADER_SOURCE: &str = include_str!("../shaders/hadamard.metal");
23
24/// Register the Hadamard transform shader source with the given kernel registry.
25pub fn register(registry: &mut KernelRegistry) {
26    registry.register_source("hadamard_transform", HADAMARD_SHADER_SOURCE);
27}
28
29/// Dispatch an in-place normalized Fast Walsh-Hadamard Transform on the GPU.
30///
31/// Transforms `data` in-place.  After this call the GPU contains
32/// `H(data)` normalized by `1/sqrt(head_dim)`.
33///
34/// # Arguments
35///
36/// * `encoder`   — Command encoder to record the dispatch into.
37/// * `registry`  — Kernel registry (must have `hadamard_transform` registered).
38/// * `device`    — Metal device for pipeline compilation.
39/// * `data`      — F32 buffer of shape `[num_heads, head_dim]`, modified in-place.
40/// * `head_dim`  — Number of elements per head.  Must be a power of two and ≤ 8192.
41/// * `num_heads` — Number of heads (threadgroups dispatched).
42///
43/// # Errors
44///
45/// Returns `MlxError::InvalidArgument` if `head_dim` is not a power of two,
46/// if the buffer is too small, or if `head_dim > 8192` (exceeds Metal's 32 KB
47/// threadgroup memory limit for f32).
48pub fn dispatch_hadamard_transform(
49    encoder: &mut CommandEncoder,
50    registry: &mut KernelRegistry,
51    device: &metal::DeviceRef,
52    data: &MlxBuffer,
53    head_dim: u32,
54    num_heads: u32,
55) -> Result<()> {
56    if num_heads == 0 || head_dim == 0 {
57        return Ok(());
58    }
59
60    // head_dim must be a power of two (butterfly pattern requirement).
61    if !head_dim.is_power_of_two() {
62        return Err(MlxError::InvalidArgument(format!(
63            "hadamard_transform: head_dim must be a power of two, got {}",
64            head_dim
65        )));
66    }
67
68    // 32 KB threadgroup memory limit: head_dim * 4 bytes ≤ 32768 bytes → head_dim ≤ 8192
69    if head_dim > 8192 {
70        return Err(MlxError::InvalidArgument(format!(
71            "hadamard_transform: head_dim {} exceeds Metal 32 KB threadgroup memory limit \
72             (max 8192 for f32)",
73            head_dim
74        )));
75    }
76
77    let required_elements = (num_heads as u64) * (head_dim as u64);
78    if (data.element_count() as u64) < required_elements {
79        return Err(MlxError::InvalidArgument(format!(
80            "hadamard_transform: data has {} elements but need {} \
81             (num_heads={} * head_dim={})",
82            data.element_count(),
83            required_elements,
84            num_heads,
85            head_dim,
86        )));
87    }
88
89    let pipeline = registry.get_pipeline("hadamard_transform", device)?;
90
91    let head_dim_bytes = head_dim.to_ne_bytes();
92    let num_heads_bytes = num_heads.to_ne_bytes();
93
94    // Shared memory: head_dim floats (4 bytes each) per threadgroup.
95    let shared_mem_bytes = (head_dim as u64) * 4;
96
97    encode_threadgroups_with_args_and_shared(
98        encoder,
99        pipeline,
100        &[
101            (0, KernelArg::Buffer(data)),
102            (1, KernelArg::Bytes(&head_dim_bytes)),
103            (2, KernelArg::Bytes(&num_heads_bytes)),
104        ],
105        &[(0, shared_mem_bytes)],
106        MTLSize::new(num_heads as u64, 1, 1),
107        MTLSize::new(head_dim as u64, 1, 1),
108    );
109
110    Ok(())
111}