use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_threadgroups_with_args, KernelArg};
pub static DENSE_GEMM_SHADER_SOURCE: &str = include_str!("../shaders/dense_gemm.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("dense_gemm_f16", DENSE_GEMM_SHADER_SOURCE);
registry.register_source("dense_matvec_f16", DENSE_GEMM_SHADER_SOURCE);
registry.register_source("dense_matvec_f16w_f32io", DENSE_GEMM_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuDenseGemmParams {
m: u32,
n: u32,
k: u32,
}
pub struct DenseGemmF16Params {
pub m: u32,
pub n: u32,
pub k: u32,
}
pub fn dispatch_dense_gemm_f16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
a: &MlxBuffer,
b: &MlxBuffer,
output: &MlxBuffer,
params: &DenseGemmF16Params,
) -> Result<()> {
if params.m == 0 || params.n == 0 || params.k == 0 {
return Err(MlxError::InvalidArgument(
"dense_gemm_f16: M, N, and K must all be > 0".into(),
));
}
let a_bytes = params.m as usize * params.k as usize * 2; if a.byte_len() < a_bytes {
return Err(MlxError::InvalidArgument(format!(
"dense_gemm_f16: A buffer too small: need {} bytes, have {}",
a_bytes,
a.byte_len()
)));
}
let b_bytes = params.n as usize * params.k as usize * 2;
if b.byte_len() < b_bytes {
return Err(MlxError::InvalidArgument(format!(
"dense_gemm_f16: B buffer too small: need {} bytes, have {}",
b_bytes,
b.byte_len()
)));
}
let c_bytes = params.m as usize * params.n as usize * 2;
if output.byte_len() < c_bytes {
return Err(MlxError::InvalidArgument(format!(
"dense_gemm_f16: output buffer too small: need {} bytes, have {}",
c_bytes,
output.byte_len()
)));
}
if params.m == 1 {
dispatch_matvec_f16(encoder, registry, device, a, b, output, params)
} else {
dispatch_gemm_tiled_f16(encoder, registry, device, a, b, output, params)
}
}
fn dispatch_matvec_f16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
a: &MlxBuffer,
b: &MlxBuffer,
output: &MlxBuffer,
params: &DenseGemmF16Params,
) -> Result<()> {
let pipeline = registry.get_pipeline("dense_matvec_f16", device)?;
let gpu_params = GpuDenseGemmParams {
m: params.m,
n: params.n,
k: params.k,
};
let n_dst: u64 = 4;
let n_simdgroup: u64 = 2;
let rows_per_tg = n_dst * n_simdgroup;
let threadgroups = MTLSize::new(
(params.n as u64 + rows_per_tg - 1) / rows_per_tg,
1,
1,
);
let threads_per_tg = MTLSize::new(32, n_simdgroup, 1);
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(a)),
(1, KernelArg::Buffer(b)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
threadgroups,
threads_per_tg,
);
Ok(())
}
pub fn dispatch_dense_matvec_f16w_f32io(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
a: &MlxBuffer,
b: &MlxBuffer,
output: &MlxBuffer,
params: &DenseGemmF16Params,
) -> Result<()> {
if params.m != 1 {
return Err(MlxError::InvalidArgument(
"dense_matvec_f16w_f32io: M must be 1 (decode only)".into(),
));
}
let pipeline = registry.get_pipeline("dense_matvec_f16w_f32io", device)?;
let gpu_params = GpuDenseGemmParams {
m: params.m,
n: params.n,
k: params.k,
};
let n_dst: u64 = 4;
let n_simdgroup: u64 = 2;
let rows_per_tg = n_dst * n_simdgroup;
let threadgroups = MTLSize::new(
(params.n as u64 + rows_per_tg - 1) / rows_per_tg,
1,
1,
);
let threads_per_tg = MTLSize::new(32, n_simdgroup, 1);
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(a)),
(1, KernelArg::Buffer(b)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
threadgroups,
threads_per_tg,
);
Ok(())
}
fn dispatch_gemm_tiled_f16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
a: &MlxBuffer,
b: &MlxBuffer,
output: &MlxBuffer,
params: &DenseGemmF16Params,
) -> Result<()> {
let pipeline = registry.get_pipeline("dense_gemm_f16", device)?;
let gpu_params = GpuDenseGemmParams {
m: params.m,
n: params.n,
k: params.k,
};
let bm: u64 = 32;
let bn: u64 = 32;
let tgp_size: u64 = 128;
let threadgroups = MTLSize::new(
(params.n as u64 + bn - 1) / bn,
(params.m as u64 + bm - 1) / bm,
1,
);
let threads_per_tg = MTLSize::new(tgp_size, 1, 1);
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(a)),
(1, KernelArg::Buffer(b)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
threadgroups,
threads_per_tg,
);
Ok(())
}