pub fn embedding_gather(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
weight_packed: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
token_ids: &MlxBuffer,
output: &MlxBuffer,
params: &EmbeddingGatherParams,
) -> Result<()>Expand description
Encode a quantized embedding gather operation into the command buffer.
Looks up n_tokens rows from a quantized embedding table, dequantizing
each row on-the-fly on the GPU.
§Buffer expectations
weight_packed— Packed quantized embedding table.- 4-bit:
[vocab_size, embed_dim / 8]uint32 values (8 values per uint32). - 6-bit:
[vocab_size, embed_dim * 3 / 4]uint8 bytes (4 values per 3 bytes).
- 4-bit:
scales— bf16 scales,[vocab_size, n_groups_per_row].biases— bf16 biases,[vocab_size, n_groups_per_row].token_ids— uint32 token IDs,[n_tokens].output— f32 output buffer,[n_tokens, embed_dim].
§Errors
Returns MlxError::InvalidArgument if:
bitsis not 4 or 6embed_dimis zerogroup_sizeis zeroembed_dimis not divisible bygroup_sizen_tokensis zero- Output buffer is too small