Expand description
Fused scale-mask-softmax for non-flash-attention prefill.
Replaces three sequential dispatches (scale by 1/sqrt(hd), add bf16 mask, row-softmax) with one kernel. Intended for hf2q’s HF2Q_NO_FA path; not used elsewhere.
See src/shaders/scale_mask_softmax.metal for the kernel contract.
Structs§
- Scale
Mask Softmax Params - Host-side parameters for
scale_mask_softmax_f32.
Functions§
- dispatch_
scale_ mask_ softmax_ f32 - Dispatches
scale_mask_softmax_f32.