Skip to main content

Module scale_mask_softmax

Module scale_mask_softmax 

Source
Expand description

Fused scale-mask-softmax for non-flash-attention prefill.

Replaces three sequential dispatches (scale by 1/sqrt(hd), add bf16 mask, row-softmax) with one kernel. Intended for hf2q’s HF2Q_NO_FA path; not used elsewhere.

See src/shaders/scale_mask_softmax.metal for the kernel contract.

Structs§

ScaleMaskSoftmaxParams
Host-side parameters for scale_mask_softmax_f32.

Functions§

dispatch_scale_mask_softmax_f32
Dispatches scale_mask_softmax_f32.