use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static FWHT_STANDALONE_SHADER_SOURCE: &str =
include_str!("../shaders/fwht_standalone.metal");
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFwhtParams {
head_dim: u32,
num_heads: u32,
}
pub fn dispatch_fwht_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
data: &MlxBuffer,
num_heads: u32,
head_dim: u32,
) -> Result<()> {
let kernel_name = match head_dim {
256 => "fwht_standalone_f32_d256",
512 => "fwht_standalone_f32_d512",
_ => return Err(MlxError::InvalidArgument(
format!("fwht_standalone: unsupported head_dim={}", head_dim),
)),
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let params = GpuFwhtParams { head_dim, num_heads };
let threadgroups = MTLSize::new(num_heads as u64, 1, 1);
let threads_per_tg = MTLSize::new(32, 1, 1);
use crate::ops::encode_helpers::{as_bytes, KernelArg};
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(data)),
(1, KernelArg::Bytes(as_bytes(¶ms))),
],
threadgroups,
threads_per_tg,
);
Ok(())
}