Skip to main content

embedding_gather_scale_batch_f32

Function embedding_gather_scale_batch_f32 

Source
pub fn embedding_gather_scale_batch_f32(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    embed_table: &MlxBuffer,
    token_ids: &MlxBuffer,
    output: &MlxBuffer,
    hidden_size: usize,
    n_tokens: usize,
    scale: f32,
) -> Result<()>
Expand description

Batched embedding gather + scale for prefill (f32).

Reads token_ids[tok] for each tok in 0..n_tokens, gathers the embedding row from embed_table, multiplies by scale, and writes to output[tok * hidden_size + i].

  • embed_table — f32 [vocab_size * hidden_size]
  • token_ids — u32 [n_tokens]
  • output — f32 [n_tokens * hidden_size]