pub fn embedding_gather_scale_batch_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
embed_table: &MlxBuffer,
token_ids: &MlxBuffer,
output: &MlxBuffer,
hidden_size: usize,
n_tokens: usize,
scale: f32,
) -> Result<()>Expand description
Batched embedding gather + scale for prefill (f32).
Reads token_ids[tok] for each tok in 0..n_tokens, gathers the
embedding row from embed_table, multiplies by scale, and writes to
output[tok * hidden_size + i].
embed_table— f32[vocab_size * hidden_size]token_ids— u32[n_tokens]output— f32[n_tokens * hidden_size]