pub fn dispatch_rope(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
positions_buf: &MlxBuffer,
seq_len: u32,
head_dim: u32,
) -> Result<()>Expand description
Dispatch a RoPE operation on the GPU.
§Arguments
encoder- Command encoder to record the dispatch into.registry- Kernel registry (must have RoPE sources registered).device- Metal device for pipeline compilation.input- Input buffer of shape[seq_len, head_dim](f32 or f16).output- Output buffer (same dtype and shape as input).params_buf- Params buffer containing[theta, head_dim, 0, 0]as f32.positions_buf- Positions buffer containing[pos_0, pos_1, ...]as u32.seq_len- Number of sequence positions.head_dim- Dimension of each head (must be even).
§Errors
Returns MlxError::InvalidArgument if:
- Input dtype is not f32 or f16.
- head_dim is not even.
- Input and output element counts do not match.