1use metal::MTLSize;
15
16use crate::buffer::MlxBuffer;
17use crate::encoder::CommandEncoder;
18use crate::error::{MlxError, Result};
19use crate::kernel_registry::KernelRegistry;
20
21use super::encode_helpers::{as_bytes, encode_threadgroups_with_args, KernelArg};
22
23pub static DENSE_GEMM_SHADER_SOURCE: &str = include_str!("../shaders/dense_gemm.metal");
25
26pub fn register(registry: &mut KernelRegistry) {
28 registry.register_source("dense_gemm_f16", DENSE_GEMM_SHADER_SOURCE);
29 registry.register_source("dense_matvec_f16", DENSE_GEMM_SHADER_SOURCE);
30 registry.register_source("dense_matvec_f16w_f32io", DENSE_GEMM_SHADER_SOURCE);
31}
32
33#[repr(C)]
37#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
38struct GpuDenseGemmParams {
39 m: u32,
40 n: u32,
41 k: u32,
42}
43
44pub struct DenseGemmF16Params {
46 pub m: u32,
48 pub n: u32,
50 pub k: u32,
52}
53
54pub fn dispatch_dense_gemm_f16(
77 encoder: &mut CommandEncoder,
78 registry: &mut KernelRegistry,
79 device: &metal::DeviceRef,
80 a: &MlxBuffer,
81 b: &MlxBuffer,
82 output: &MlxBuffer,
83 params: &DenseGemmF16Params,
84) -> Result<()> {
85 if params.m == 0 || params.n == 0 || params.k == 0 {
86 return Err(MlxError::InvalidArgument(
87 "dense_gemm_f16: M, N, and K must all be > 0".into(),
88 ));
89 }
90
91 let a_bytes = params.m as usize * params.k as usize * 2; if a.byte_len() < a_bytes {
93 return Err(MlxError::InvalidArgument(format!(
94 "dense_gemm_f16: A buffer too small: need {} bytes, have {}",
95 a_bytes,
96 a.byte_len()
97 )));
98 }
99 let b_bytes = params.n as usize * params.k as usize * 2;
100 if b.byte_len() < b_bytes {
101 return Err(MlxError::InvalidArgument(format!(
102 "dense_gemm_f16: B buffer too small: need {} bytes, have {}",
103 b_bytes,
104 b.byte_len()
105 )));
106 }
107 let c_bytes = params.m as usize * params.n as usize * 2;
108 if output.byte_len() < c_bytes {
109 return Err(MlxError::InvalidArgument(format!(
110 "dense_gemm_f16: output buffer too small: need {} bytes, have {}",
111 c_bytes,
112 output.byte_len()
113 )));
114 }
115
116 if params.m == 1 {
117 dispatch_matvec_f16(encoder, registry, device, a, b, output, params)
118 } else {
119 dispatch_gemm_tiled_f16(encoder, registry, device, a, b, output, params)
120 }
121}
122
123fn dispatch_matvec_f16(
134 encoder: &mut CommandEncoder,
135 registry: &mut KernelRegistry,
136 device: &metal::DeviceRef,
137 a: &MlxBuffer,
138 b: &MlxBuffer,
139 output: &MlxBuffer,
140 params: &DenseGemmF16Params,
141) -> Result<()> {
142 let pipeline = registry.get_pipeline("dense_matvec_f16", device)?;
143
144 let gpu_params = GpuDenseGemmParams {
145 m: params.m,
146 n: params.n,
147 k: params.k,
148 };
149
150 let n_dst: u64 = 4;
151 let n_simdgroup: u64 = 2;
152 let rows_per_tg = n_dst * n_simdgroup; let threadgroups = MTLSize::new(
155 (params.n as u64 + rows_per_tg - 1) / rows_per_tg,
156 1,
157 1,
158 );
159 let threads_per_tg = MTLSize::new(32, n_simdgroup, 1);
160
161 encode_threadgroups_with_args(
162 encoder,
163 pipeline,
164 &[
165 (0, KernelArg::Buffer(a)),
166 (1, KernelArg::Buffer(b)),
167 (2, KernelArg::Buffer(output)),
168 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
169 ],
170 threadgroups,
171 threads_per_tg,
172 );
173
174 Ok(())
175}
176
177pub fn dispatch_dense_matvec_f16w_f32io(
186 encoder: &mut CommandEncoder,
187 registry: &mut KernelRegistry,
188 device: &metal::DeviceRef,
189 a: &MlxBuffer,
190 b: &MlxBuffer,
191 output: &MlxBuffer,
192 params: &DenseGemmF16Params,
193) -> Result<()> {
194 if params.m != 1 {
195 return Err(MlxError::InvalidArgument(
196 "dense_matvec_f16w_f32io: M must be 1 (decode only)".into(),
197 ));
198 }
199 let pipeline = registry.get_pipeline("dense_matvec_f16w_f32io", device)?;
200
201 let gpu_params = GpuDenseGemmParams {
202 m: params.m,
203 n: params.n,
204 k: params.k,
205 };
206
207 let n_dst: u64 = 4;
208 let n_simdgroup: u64 = 2;
209 let rows_per_tg = n_dst * n_simdgroup;
210
211 let threadgroups = MTLSize::new(
212 (params.n as u64 + rows_per_tg - 1) / rows_per_tg,
213 1,
214 1,
215 );
216 let threads_per_tg = MTLSize::new(32, n_simdgroup, 1);
217
218 encode_threadgroups_with_args(
219 encoder,
220 pipeline,
221 &[
222 (0, KernelArg::Buffer(a)),
223 (1, KernelArg::Buffer(b)),
224 (2, KernelArg::Buffer(output)),
225 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
226 ],
227 threadgroups,
228 threads_per_tg,
229 );
230
231 Ok(())
232}
233
234fn dispatch_gemm_tiled_f16(
238 encoder: &mut CommandEncoder,
239 registry: &mut KernelRegistry,
240 device: &metal::DeviceRef,
241 a: &MlxBuffer,
242 b: &MlxBuffer,
243 output: &MlxBuffer,
244 params: &DenseGemmF16Params,
245) -> Result<()> {
246 let pipeline = registry.get_pipeline("dense_gemm_f16", device)?;
247
248 let gpu_params = GpuDenseGemmParams {
249 m: params.m,
250 n: params.n,
251 k: params.k,
252 };
253
254 let bm: u64 = 32;
255 let bn: u64 = 32;
256 let tgp_size: u64 = 128; let threadgroups = MTLSize::new(
259 (params.n as u64 + bn - 1) / bn,
260 (params.m as u64 + bm - 1) / bm,
261 1,
262 );
263 let threads_per_tg = MTLSize::new(tgp_size, 1, 1);
264
265 encode_threadgroups_with_args(
266 encoder,
267 pipeline,
268 &[
269 (0, KernelArg::Buffer(a)),
270 (1, KernelArg::Buffer(b)),
271 (2, KernelArg::Buffer(output)),
272 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
273 ],
274 threadgroups,
275 threads_per_tg,
276 );
277
278 Ok(())
279}