Skip to main content

embedding_gather_scale_f32

Function embedding_gather_scale_f32 

Source
pub fn embedding_gather_scale_f32(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    embed_table: &MlxBuffer,
    output: &MlxBuffer,
    token_id: u32,
    hidden_size: usize,
    scale: f32,
) -> Result<()>
Expand description

Encode an embedding gather + scale: output[i] = embed[token_id * hs + i] * scale.

ยงArguments

  • encoder - Command encoder.
  • registry - Kernel registry.
  • device - Metal device.
  • embed_table - f32 [vocab_size * hidden_size].
  • output - f32 [hidden_size].
  • token_id - Token index into the embedding table.
  • hidden_size - Embedding dimension.
  • scale - Scale factor (e.g. sqrt(hidden_size)).