Skip to main content

Module dense_gemm

Module dense_gemm 

Source
Expand description

Dense F16 matrix multiply for the lm_head vocabulary projection.

Computes C = A * B^T where A is [M, K] f16, B is [N, K] f16, and C is [M, N] f16.

Two GPU kernels:

  • dense_matvec_f16 — specialised M=1 mat-vec (decode hot path). Uses vectorised half4 loads + simd_sum, modelled after the llama.cpp kernel_mul_mv_f16_f32 pattern.

  • dense_gemm_f16 — tiled GEMM for M>1 with simdgroup_matrix MMA.

Structs§

DenseGemmF16Params
Parameters for a dense GEMM operation.

Statics§

DENSE_GEMM_SHADER_SOURCE
MSL source for the dense GEMM kernel (embedded at compile time).

Functions§

dispatch_dense_gemm_f16
Dispatch a dense F16 matrix multiply on the GPU: C = A * B^T.
dispatch_dense_matvec_f16w_f32io
Dispatch a mixed-precision mat-vec: F32 input × F16 weights → F32 output.
register
Register dense GEMM shader source with the given kernel registry.