Expand description
Dense F16 matrix multiply for the lm_head vocabulary projection.
Computes C = A * B^T where A is [M, K] f16, B is [N, K] f16,
and C is [M, N] f16.
Two GPU kernels:
-
dense_matvec_f16— specialised M=1 mat-vec (decode hot path). Uses vectorised half4 loads + simd_sum, modelled after the llama.cppkernel_mul_mv_f16_f32pattern. -
dense_gemm_f16— tiled GEMM for M>1 with simdgroup_matrix MMA.
Structs§
- Dense
Gemm F16Params - Parameters for a dense GEMM operation.
Statics§
- DENSE_
GEMM_ SHADER_ SOURCE - MSL source for the dense GEMM kernel (embedded at compile time).
Functions§
- dispatch_
dense_ gemm_ f16 - Dispatch a dense F16 matrix multiply on the GPU:
C = A * B^T. - dispatch_
dense_ matvec_ f16w_ f32io - Dispatch a mixed-precision mat-vec: F32 input × F16 weights → F32 output.
- register
- Register dense GEMM shader source with the given kernel registry.