Skip to main content

dispatch_rope_multi

Function dispatch_rope_multi 

Source
pub fn dispatch_rope_multi(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    input: &MlxBuffer,
    output: &MlxBuffer,
    positions: &MlxBuffer,
    params_buf: &MlxBuffer,
    rope_params_buf: &MlxBuffer,
    sections_buf: &MlxBuffer,
    p: RopeMultiParams,
) -> Result<()>
Expand description

Dispatch a rope_multi operation.

The caller must upload:

  • params_buf: float4 [freq_base, head_dim, rope_dim, 0].
  • rope_params_buf: uint4 [n_heads, mode_code, seq_len, 0]. The mode_code is the u32 underlying RopeMultiMode.
  • sections_buf: uint4 [s0, s1, s2, s3].
  • positions: int32 array of length 4 * seq_len.

The helper build_rope_multi_buffers constructs all three small buffers in one call for callers that do not already keep them pooled.