Skip to main content

mlx_native/ops/
dense_gemm.rs

1//! Dense F16 matrix multiply for the lm_head vocabulary projection.
2//!
3//! Computes `C = A * B^T` where A is [M, K] f16, B is [N, K] f16,
4//! and C is [M, N] f16.
5//!
6//! Two GPU kernels:
7//!
8//! - `dense_matvec_f16` — specialised M=1 mat-vec (decode hot path).
9//!   Uses vectorised half4 loads + simd_sum, modelled after the llama.cpp
10//!   `kernel_mul_mv_f16_f32` pattern.
11//!
12//! - `dense_gemm_f16` — tiled GEMM for M>1 with simdgroup_matrix MMA.
13
14use metal::MTLSize;
15
16use crate::buffer::MlxBuffer;
17use crate::encoder::CommandEncoder;
18use crate::error::{MlxError, Result};
19use crate::kernel_registry::KernelRegistry;
20
21use super::encode_helpers::{as_bytes, encode_threadgroups_with_args, KernelArg};
22
23/// MSL source for the dense GEMM kernel (embedded at compile time).
24pub static DENSE_GEMM_SHADER_SOURCE: &str = include_str!("../shaders/dense_gemm.metal");
25
26/// Register dense GEMM shader source with the given kernel registry.
27pub fn register(registry: &mut KernelRegistry) {
28    registry.register_source("dense_gemm_f16", DENSE_GEMM_SHADER_SOURCE);
29    registry.register_source("dense_matvec_f16", DENSE_GEMM_SHADER_SOURCE);
30    registry.register_source("dense_matvec_f16w_f32io", DENSE_GEMM_SHADER_SOURCE);
31}
32
33/// MSL-compatible params struct for dense GEMM.
34///
35/// Must match `DenseGemmParams` in `dense_gemm.metal`.
36#[repr(C)]
37#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
38struct GpuDenseGemmParams {
39    m: u32,
40    n: u32,
41    k: u32,
42}
43
44/// Parameters for a dense GEMM operation.
45pub struct DenseGemmF16Params {
46    /// Number of rows in A (and C).
47    pub m: u32,
48    /// Number of rows in B (columns in C).  C = A * B^T is [M, N].
49    pub n: u32,
50    /// Inner dimension (columns of A and B).
51    pub k: u32,
52}
53
54/// Dispatch a dense F16 matrix multiply on the GPU: `C = A * B^T`.
55///
56/// A is `[M, K]` f16, B is `[N, K]` f16, C is `[M, N]` f16.
57///
58/// For M=1 (decode path), dispatches the specialised `dense_matvec_f16`
59/// kernel which uses vectorised loads + SIMD reduction — typically 10-20x
60/// faster than the tiled GEMM for single-row inputs.
61///
62/// # Arguments
63///
64/// * `encoder`  - Command encoder to record the dispatch into.
65/// * `registry` - Kernel registry (must have `dense_gemm_f16` registered).
66/// * `device`   - Metal device for pipeline compilation.
67/// * `a`        - Matrix A buffer `[M, K]` (f16).
68/// * `b`        - Matrix B buffer `[N, K]` (f16).
69/// * `output`   - Output buffer C `[M, N]` (f16).
70/// * `params`   - GEMM dimensions.
71///
72/// # Errors
73///
74/// Returns `MlxError::InvalidArgument` if dimensions are 0 or buffers are
75/// too small.
76pub fn dispatch_dense_gemm_f16(
77    encoder: &mut CommandEncoder,
78    registry: &mut KernelRegistry,
79    device: &metal::DeviceRef,
80    a: &MlxBuffer,
81    b: &MlxBuffer,
82    output: &MlxBuffer,
83    params: &DenseGemmF16Params,
84) -> Result<()> {
85    if params.m == 0 || params.n == 0 || params.k == 0 {
86        return Err(MlxError::InvalidArgument(
87            "dense_gemm_f16: M, N, and K must all be > 0".into(),
88        ));
89    }
90
91    let a_bytes = params.m as usize * params.k as usize * 2; // f16 = 2 bytes
92    if a.byte_len() < a_bytes {
93        return Err(MlxError::InvalidArgument(format!(
94            "dense_gemm_f16: A buffer too small: need {} bytes, have {}",
95            a_bytes,
96            a.byte_len()
97        )));
98    }
99    let b_bytes = params.n as usize * params.k as usize * 2;
100    if b.byte_len() < b_bytes {
101        return Err(MlxError::InvalidArgument(format!(
102            "dense_gemm_f16: B buffer too small: need {} bytes, have {}",
103            b_bytes,
104            b.byte_len()
105        )));
106    }
107    let c_bytes = params.m as usize * params.n as usize * 2;
108    if output.byte_len() < c_bytes {
109        return Err(MlxError::InvalidArgument(format!(
110            "dense_gemm_f16: output buffer too small: need {} bytes, have {}",
111            c_bytes,
112            output.byte_len()
113        )));
114    }
115
116    if params.m == 1 {
117        dispatch_matvec_f16(encoder, registry, device, a, b, output, params)
118    } else {
119        dispatch_gemm_tiled_f16(encoder, registry, device, a, b, output, params)
120    }
121}
122
123/// Specialised M=1 mat-vec kernel dispatch.
124///
125/// Kernel constants (must match `dense_gemm.metal`):
126///   N_DST       = 4  (rows per simdgroup)
127///   N_SIMDGROUP = 2  (simdgroups per threadgroup)
128///   N_SIMDWIDTH = 32 (Apple SIMD width)
129///
130/// Dispatch geometry:
131///   threadgroups:     (ceil(N / 8), 1, 1)
132///   threads_per_tg:   (32, N_SIMDGROUP, 1)   — 32 lanes × 2 simdgroups = 64 threads
133fn dispatch_matvec_f16(
134    encoder: &mut CommandEncoder,
135    registry: &mut KernelRegistry,
136    device: &metal::DeviceRef,
137    a: &MlxBuffer,
138    b: &MlxBuffer,
139    output: &MlxBuffer,
140    params: &DenseGemmF16Params,
141) -> Result<()> {
142    let pipeline = registry.get_pipeline("dense_matvec_f16", device)?;
143
144    let gpu_params = GpuDenseGemmParams {
145        m: params.m,
146        n: params.n,
147        k: params.k,
148    };
149
150    let n_dst: u64 = 4;
151    let n_simdgroup: u64 = 2;
152    let rows_per_tg = n_dst * n_simdgroup; // 8
153
154    let threadgroups = MTLSize::new(
155        (params.n as u64 + rows_per_tg - 1) / rows_per_tg,
156        1,
157        1,
158    );
159    let threads_per_tg = MTLSize::new(32, n_simdgroup, 1);
160
161    encode_threadgroups_with_args(
162        encoder,
163        pipeline,
164        &[
165            (0, KernelArg::Buffer(a)),
166            (1, KernelArg::Buffer(b)),
167            (2, KernelArg::Buffer(output)),
168            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
169        ],
170        threadgroups,
171        threads_per_tg,
172    );
173
174    Ok(())
175}
176
177/// Dispatch a mixed-precision mat-vec: F32 input × F16 weights → F32 output.
178///
179/// Eliminates the F32→F16 cast on input and F16→F32 cast on output compared
180/// to the pure-F16 path. M must be 1 (decode path only).
181///
182/// * `a`      - Input buffer `[1, K]` (f32)
183/// * `b`      - Weight buffer `[N, K]` (f16)
184/// * `output` - Output buffer `[1, N]` (f32)
185pub fn dispatch_dense_matvec_f16w_f32io(
186    encoder: &mut CommandEncoder,
187    registry: &mut KernelRegistry,
188    device: &metal::DeviceRef,
189    a: &MlxBuffer,
190    b: &MlxBuffer,
191    output: &MlxBuffer,
192    params: &DenseGemmF16Params,
193) -> Result<()> {
194    if params.m != 1 {
195        return Err(MlxError::InvalidArgument(
196            "dense_matvec_f16w_f32io: M must be 1 (decode only)".into(),
197        ));
198    }
199    let pipeline = registry.get_pipeline("dense_matvec_f16w_f32io", device)?;
200
201    let gpu_params = GpuDenseGemmParams {
202        m: params.m,
203        n: params.n,
204        k: params.k,
205    };
206
207    let n_dst: u64 = 4;
208    let n_simdgroup: u64 = 2;
209    let rows_per_tg = n_dst * n_simdgroup;
210
211    let threadgroups = MTLSize::new(
212        (params.n as u64 + rows_per_tg - 1) / rows_per_tg,
213        1,
214        1,
215    );
216    let threads_per_tg = MTLSize::new(32, n_simdgroup, 1);
217
218    encode_threadgroups_with_args(
219        encoder,
220        pipeline,
221        &[
222            (0, KernelArg::Buffer(a)),
223            (1, KernelArg::Buffer(b)),
224            (2, KernelArg::Buffer(output)),
225            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
226        ],
227        threadgroups,
228        threads_per_tg,
229    );
230
231    Ok(())
232}
233
234/// Tiled GEMM dispatch for M>1 using simdgroup_matrix MMA.
235///
236/// Tile: BM=32, BN=32, BK=16, WM=2, WN=2 → 128 threads per threadgroup.
237fn dispatch_gemm_tiled_f16(
238    encoder: &mut CommandEncoder,
239    registry: &mut KernelRegistry,
240    device: &metal::DeviceRef,
241    a: &MlxBuffer,
242    b: &MlxBuffer,
243    output: &MlxBuffer,
244    params: &DenseGemmF16Params,
245) -> Result<()> {
246    let pipeline = registry.get_pipeline("dense_gemm_f16", device)?;
247
248    let gpu_params = GpuDenseGemmParams {
249        m: params.m,
250        n: params.n,
251        k: params.k,
252    };
253
254    let bm: u64 = 32;
255    let bn: u64 = 32;
256    let tgp_size: u64 = 128; // WM * WN * 32 = 2*2*32
257
258    let threadgroups = MTLSize::new(
259        (params.n as u64 + bn - 1) / bn,
260        (params.m as u64 + bm - 1) / bm,
261        1,
262    );
263    let threads_per_tg = MTLSize::new(tgp_size, 1, 1);
264
265    encode_threadgroups_with_args(
266        encoder,
267        pipeline,
268        &[
269            (0, KernelArg::Buffer(a)),
270            (1, KernelArg::Buffer(b)),
271            (2, KernelArg::Buffer(output)),
272            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
273        ],
274        threadgroups,
275        threads_per_tg,
276    );
277
278    Ok(())
279}