Skip to main content

mlx_native/ops/
embedding.rs

1//! GPU-accelerated quantized embedding table lookup.
2//!
3//! Supports 4-bit and 6-bit quantized embedding tables, performing
4//! on-the-fly dequantization during gather.  The dequantization formula
5//! is `float_val = uint_val * scale + bias` with bf16 scales and biases.
6
7use metal::MTLSize;
8
9use crate::buffer::MlxBuffer;
10use crate::encoder::CommandEncoder;
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16/// Parameters for quantized embedding gather.
17pub struct EmbeddingGatherParams {
18    /// Embedding dimension (number of float values per token).
19    pub embed_dim: usize,
20    /// Number of elements per quantization group (typically 64).
21    pub group_size: usize,
22    /// Quantization bit width: 4 or 6.
23    pub bits: u8,
24    /// Number of tokens to gather.
25    pub n_tokens: usize,
26}
27
28/// MSL-compatible parameter struct for the embedding kernel.
29///
30/// Must match the `EmbeddingParams` struct in `embedding.metal`.
31#[repr(C)]
32#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
33struct GpuEmbeddingParams {
34    embed_dim: u32,
35    group_size: u32,
36    packed_row_stride: u32,
37    n_groups_per_row: u32,
38}
39
40/// Encode a quantized embedding gather operation into the command buffer.
41///
42/// Looks up `n_tokens` rows from a quantized embedding table, dequantizing
43/// each row on-the-fly on the GPU.
44///
45/// # Buffer expectations
46///
47/// * `weight_packed` — Packed quantized embedding table.
48///   - 4-bit: `[vocab_size, embed_dim / 8]` uint32 values (8 values per uint32).
49///   - 6-bit: `[vocab_size, embed_dim * 3 / 4]` uint8 bytes (4 values per 3 bytes).
50/// * `scales` — bf16 scales, `[vocab_size, n_groups_per_row]`.
51/// * `biases` — bf16 biases, `[vocab_size, n_groups_per_row]`.
52/// * `token_ids` — uint32 token IDs, `[n_tokens]`.
53/// * `output` — f32 output buffer, `[n_tokens, embed_dim]`.
54///
55/// # Errors
56///
57/// Returns `MlxError::InvalidArgument` if:
58/// * `bits` is not 4 or 6
59/// * `embed_dim` is zero
60/// * `group_size` is zero
61/// * `embed_dim` is not divisible by `group_size`
62/// * `n_tokens` is zero
63/// * Output buffer is too small
64#[allow(clippy::too_many_arguments)]
65pub fn embedding_gather(
66    encoder: &mut CommandEncoder,
67    registry: &mut KernelRegistry,
68    device: &metal::DeviceRef,
69    weight_packed: &MlxBuffer,
70    scales: &MlxBuffer,
71    biases: &MlxBuffer,
72    token_ids: &MlxBuffer,
73    output: &MlxBuffer,
74    params: &EmbeddingGatherParams,
75) -> Result<()> {
76    // --- Validation ---
77    if params.bits != 4 && params.bits != 6 {
78        return Err(MlxError::InvalidArgument(format!(
79            "embedding_gather: bits must be 4 or 6, got {}",
80            params.bits
81        )));
82    }
83    if params.embed_dim == 0 {
84        return Err(MlxError::InvalidArgument(
85            "embedding_gather: embed_dim must be > 0".into(),
86        ));
87    }
88    if params.group_size == 0 {
89        return Err(MlxError::InvalidArgument(
90            "embedding_gather: group_size must be > 0".into(),
91        ));
92    }
93    if params.embed_dim % params.group_size != 0 {
94        return Err(MlxError::InvalidArgument(format!(
95            "embedding_gather: embed_dim ({}) must be divisible by group_size ({})",
96            params.embed_dim, params.group_size
97        )));
98    }
99    if params.n_tokens == 0 {
100        return Err(MlxError::InvalidArgument(
101            "embedding_gather: n_tokens must be > 0".into(),
102        ));
103    }
104
105    let expected_output_bytes = params.n_tokens * params.embed_dim * std::mem::size_of::<f32>();
106    if output.byte_len() < expected_output_bytes {
107        return Err(MlxError::InvalidArgument(format!(
108            "embedding_gather: output buffer too small: need {} bytes, have {}",
109            expected_output_bytes,
110            output.byte_len()
111        )));
112    }
113
114    // --- Compute layout parameters ---
115    let n_groups_per_row = params.embed_dim / params.group_size;
116
117    let packed_row_stride: u32 = match params.bits {
118        4 => {
119            // 8 values per uint32; stride in uint32 count
120            (params.embed_dim / 8) as u32
121        }
122        6 => {
123            // 4 values per 3 bytes; stride in bytes
124            (params.embed_dim * 3 / 4) as u32
125        }
126        _ => unreachable!(), // validated above
127    };
128
129    let gpu_params = GpuEmbeddingParams {
130        embed_dim: params.embed_dim as u32,
131        group_size: params.group_size as u32,
132        packed_row_stride,
133        n_groups_per_row: n_groups_per_row as u32,
134    };
135
136    // --- Select kernel ---
137    let kernel_name = match params.bits {
138        4 => "embedding_gather_4bit",
139        6 => "embedding_gather_6bit",
140        _ => unreachable!(),
141    };
142
143    let pipeline = registry.get_pipeline(kernel_name, device)?;
144
145    // --- Encode dispatch ---
146    let grid = MTLSize::new(params.embed_dim as u64, params.n_tokens as u64, 1);
147    let tg_size = MTLSize::new(
148        std::cmp::min(256, params.embed_dim as u64),
149        1,
150        1,
151    );
152
153    let params_bytes = as_bytes(&gpu_params);
154
155    encode_with_args(
156        encoder,
157        pipeline,
158        &[
159            (0, KernelArg::Buffer(weight_packed)),
160            (1, KernelArg::Buffer(scales)),
161            (2, KernelArg::Buffer(biases)),
162            (3, KernelArg::Buffer(token_ids)),
163            (4, KernelArg::Buffer(output)),
164            (5, KernelArg::Bytes(params_bytes)),
165        ],
166        grid,
167        tg_size,
168    );
169
170    Ok(())
171}