Skip to main content

embedding_gather

Function embedding_gather 

Source
pub fn embedding_gather(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    weight_packed: &MlxBuffer,
    scales: &MlxBuffer,
    biases: &MlxBuffer,
    token_ids: &MlxBuffer,
    output: &MlxBuffer,
    params: &EmbeddingGatherParams,
) -> Result<()>
Expand description

Encode a quantized embedding gather operation into the command buffer.

Looks up n_tokens rows from a quantized embedding table, dequantizing each row on-the-fly on the GPU.

§Buffer expectations

  • weight_packed — Packed quantized embedding table.
    • 4-bit: [vocab_size, embed_dim / 8] uint32 values (8 values per uint32).
    • 6-bit: [vocab_size, embed_dim * 3 / 4] uint8 bytes (4 values per 3 bytes).
  • scales — bf16 scales, [vocab_size, n_groups_per_row].
  • biases — bf16 biases, [vocab_size, n_groups_per_row].
  • token_ids — uint32 token IDs, [n_tokens].
  • output — f32 output buffer, [n_tokens, embed_dim].

§Errors

Returns MlxError::InvalidArgument if:

  • bits is not 4 or 6
  • embed_dim is zero
  • group_size is zero
  • embed_dim is not divisible by group_size
  • n_tokens is zero
  • Output buffer is too small