Skip to main content

dispatch_rope

Function dispatch_rope 

Source
pub fn dispatch_rope(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    input: &MlxBuffer,
    output: &MlxBuffer,
    params_buf: &MlxBuffer,
    positions_buf: &MlxBuffer,
    seq_len: u32,
    head_dim: u32,
) -> Result<()>
Expand description

Dispatch a RoPE operation on the GPU.

§Arguments

  • encoder - Command encoder to record the dispatch into.
  • registry - Kernel registry (must have RoPE sources registered).
  • device - Metal device for pipeline compilation.
  • input - Input buffer of shape [seq_len, head_dim] (f32 or f16).
  • output - Output buffer (same dtype and shape as input).
  • params_buf - Params buffer containing [theta, head_dim, 0, 0] as f32.
  • positions_buf - Positions buffer containing [pos_0, pos_1, ...] as u32.
  • seq_len - Number of sequence positions.
  • head_dim - Dimension of each head (must be even).

§Errors

Returns MlxError::InvalidArgument if:

  • Input dtype is not f32 or f16.
  • head_dim is not even.
  • Input and output element counts do not match.