pub fn embedding_gather_scale_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
embed_table: &MlxBuffer,
output: &MlxBuffer,
token_id: u32,
hidden_size: usize,
scale: f32,
) -> Result<()>Expand description
Encode an embedding gather + scale: output[i] = embed[token_id * hs + i] * scale.
ยงArguments
encoder- Command encoder.registry- Kernel registry.device- Metal device.embed_table- f32[vocab_size * hidden_size].output- f32[hidden_size].token_id- Token index into the embedding table.hidden_size- Embedding dimension.scale- Scale factor (e.g. sqrt(hidden_size)).