use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
#[derive(Debug, Clone, Copy)]
pub struct DenseMmF16F32Params {
pub m: u32,
pub n: u32,
pub k: u32,
pub src0_batch: u32,
pub src1_batch: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct DenseMmF16F32TensorGpuParams {
ne00: i32, ne02: i32, nb01: u64, nb02: u64, nb03: u64, ne12: i32, _pad0: u32,
nb10: u64, nb11: u64, nb12: u64, nb13: u64, ne0: i32, ne1: i32, r2: i16, r3: i16,
_pad1: u32,
}
pub fn dense_matmul_f16_f32_tensor(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
src0: &MlxBuffer,
src1: &MlxBuffer,
dst: &mut MlxBuffer,
params: &DenseMmF16F32Params,
) -> Result<()> {
if params.m == 0 || params.n == 0 || params.k == 0 {
return Err(MlxError::InvalidArgument(
"dense_matmul_f16_f32_tensor: M, N, K must all be > 0".into(),
));
}
if params.k < 32 {
return Err(MlxError::InvalidArgument(format!(
"dense_matmul_f16_f32_tensor: K ({}) must be >= 32",
params.k
)));
}
if params.src0_batch == 0 || params.src1_batch == 0 {
return Err(MlxError::InvalidArgument(
"dense_matmul_f16_f32_tensor: batch counts must be > 0".into(),
));
}
if params.src1_batch % params.src0_batch != 0 {
return Err(MlxError::InvalidArgument(format!(
"dense_matmul_f16_f32_tensor: src1_batch ({}) must be a \
multiple of src0_batch ({}) for GQA broadcast",
params.src1_batch, params.src0_batch
)));
}
let f16_sz = DType::F16.size_of();
let f32_sz = DType::F32.size_of();
let expected_src0_bytes =
(params.src0_batch as usize) * (params.n as usize) * (params.k as usize) * f16_sz;
if src0.byte_len() < expected_src0_bytes {
return Err(MlxError::InvalidArgument(format!(
"dense_matmul_f16_f32_tensor: src0 too small: expected {} bytes for \
[{}x{}x{}] f16, got {}",
expected_src0_bytes, params.src0_batch, params.n, params.k, src0.byte_len()
)));
}
let expected_src1_bytes =
(params.src1_batch as usize) * (params.m as usize) * (params.k as usize) * f32_sz;
if src1.byte_len() < expected_src1_bytes {
return Err(MlxError::InvalidArgument(format!(
"dense_matmul_f16_f32_tensor: src1 too small: expected {} bytes for \
[{}x{}x{}] f32, got {}",
expected_src1_bytes, params.src1_batch, params.m, params.k, src1.byte_len()
)));
}
let expected_dst_bytes =
(params.src1_batch as usize) * (params.m as usize) * (params.n as usize) * f32_sz;
if dst.byte_len() < expected_dst_bytes {
return Err(MlxError::InvalidArgument(format!(
"dense_matmul_f16_f32_tensor: dst too small: expected {} bytes for \
[{}x{}x{}] f32, got {}",
expected_dst_bytes, params.src1_batch, params.m, params.n, dst.byte_len()
)));
}
let pipeline = registry
.get_pipeline("hf2q_dense_mm_f16_f32_tensor", device.metal_device())?;
let nb01 = (params.k as u64) * (f16_sz as u64); let nb02 = (params.n as u64) * nb01; let nb11 = (params.k as u64) * (f32_sz as u64); let nb12 = (params.m as u64) * nb11; let r2 = (params.src1_batch / params.src0_batch) as i16;
let gpu_params = DenseMmF16F32TensorGpuParams {
ne00: params.k as i32,
ne02: params.src0_batch as i32,
nb01,
nb02,
nb03: 0,
ne12: params.src1_batch as i32,
_pad0: 0,
nb10: f32_sz as u64,
nb11,
nb12,
nb13: 0,
ne0: params.n as i32,
ne1: params.m as i32,
r2,
r3: 1,
_pad1: 0,
};
const NR0: u64 = 64;
const NR1: u64 = 32;
const THREADS_PER_TG: u64 = 128;
const SHMEM_BYTES: u64 = 8192;
let threadgroups = metal::MTLSize::new(
(params.m as u64 + NR1 - 1) / NR1,
(params.n as u64 + NR0 - 1) / NR0,
params.src1_batch as u64,
);
let threads_per_tg = metal::MTLSize::new(THREADS_PER_TG, 1, 1);
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(src0)),
(2, KernelArg::Buffer(src1)),
(3, KernelArg::Buffer(dst)),
],
&[(0, SHMEM_BYTES)],
threadgroups,
threads_per_tg,
);
Ok(())
}