Skip to main content

dispatch_repeat_tiled_f32

Function dispatch_repeat_tiled_f32 

Source
pub fn dispatch_repeat_tiled_f32(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    src: &MlxBuffer,
    dst: &MlxBuffer,
    params: &RepeatTiledParams,
) -> Result<()>
Expand description

Dispatch a tiled-GQA broadcast on the GPU.

Expands a [seq, hg, k] f32 input to a [seq, h, k] f32 output via dst[t, h, k] = src[t, h % hg, k] in a single dispatch — no compute, no host round-trip.

§Arguments

  • encoder - Command encoder to record the dispatch into.
  • registry - Kernel registry (repeat_tiled_f32 is auto-registered).
  • device - Metal device for pipeline compilation.
  • src - Input buffer, f32, contiguous, ≥ seq*hg*k elements.
  • dst - Output buffer, f32, contiguous, ≥ seq*h*k elements.
  • params - Shape parameters.

§Errors

Returns MlxError::InvalidArgument if any dimension is zero, if h % hg != 0, or if either buffer is too small for the declared shapes.