use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{encode_threadgroups_with_args_and_shared, KernelArg};
pub static HADAMARD_SHADER_SOURCE: &str = include_str!("../shaders/hadamard.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("hadamard_transform", HADAMARD_SHADER_SOURCE);
}
pub fn dispatch_hadamard_transform(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
data: &MlxBuffer,
head_dim: u32,
num_heads: u32,
) -> Result<()> {
if num_heads == 0 || head_dim == 0 {
return Ok(());
}
if !head_dim.is_power_of_two() {
return Err(MlxError::InvalidArgument(format!(
"hadamard_transform: head_dim must be a power of two, got {}",
head_dim
)));
}
if head_dim > 8192 {
return Err(MlxError::InvalidArgument(format!(
"hadamard_transform: head_dim {} exceeds Metal 32 KB threadgroup memory limit \
(max 8192 for f32)",
head_dim
)));
}
let required_elements = (num_heads as u64) * (head_dim as u64);
if (data.element_count() as u64) < required_elements {
return Err(MlxError::InvalidArgument(format!(
"hadamard_transform: data has {} elements but need {} \
(num_heads={} * head_dim={})",
data.element_count(),
required_elements,
num_heads,
head_dim,
)));
}
let pipeline = registry.get_pipeline("hadamard_transform", device)?;
let head_dim_bytes = head_dim.to_ne_bytes();
let num_heads_bytes = num_heads.to_ne_bytes();
let shared_mem_bytes = (head_dim as u64) * 4;
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(data)),
(1, KernelArg::Bytes(&head_dim_bytes)),
(2, KernelArg::Bytes(&num_heads_bytes)),
],
&[(0, shared_mem_bytes)],
MTLSize::new(num_heads as u64, 1, 1),
MTLSize::new(head_dim as u64, 1, 1),
);
Ok(())
}