pub fn dispatch_repeat_tiled_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
src: &MlxBuffer,
dst: &MlxBuffer,
params: &RepeatTiledParams,
) -> Result<()>Expand description
Dispatch a tiled-GQA broadcast on the GPU.
Expands a [seq, hg, k] f32 input to a [seq, h, k] f32 output via
dst[t, h, k] = src[t, h % hg, k] in a single dispatch — no compute,
no host round-trip.
§Arguments
encoder- Command encoder to record the dispatch into.registry- Kernel registry (repeat_tiled_f32is auto-registered).device- Metal device for pipeline compilation.src- Input buffer, f32, contiguous, ≥seq*hg*kelements.dst- Output buffer, f32, contiguous, ≥seq*h*kelements.params- Shape parameters.
§Errors
Returns MlxError::InvalidArgument if any dimension is zero, if
h % hg != 0, or if either buffer is too small for the declared shapes.