mlx_native/ops/
embedding.rs1use metal::MTLSize;
8
9use crate::buffer::MlxBuffer;
10use crate::encoder::CommandEncoder;
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16pub struct EmbeddingGatherParams {
18 pub embed_dim: usize,
20 pub group_size: usize,
22 pub bits: u8,
24 pub n_tokens: usize,
26}
27
28#[repr(C)]
32#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
33struct GpuEmbeddingParams {
34 embed_dim: u32,
35 group_size: u32,
36 packed_row_stride: u32,
37 n_groups_per_row: u32,
38}
39
40#[allow(clippy::too_many_arguments)]
65pub fn embedding_gather(
66 encoder: &mut CommandEncoder,
67 registry: &mut KernelRegistry,
68 device: &metal::DeviceRef,
69 weight_packed: &MlxBuffer,
70 scales: &MlxBuffer,
71 biases: &MlxBuffer,
72 token_ids: &MlxBuffer,
73 output: &MlxBuffer,
74 params: &EmbeddingGatherParams,
75) -> Result<()> {
76 if params.bits != 4 && params.bits != 6 {
78 return Err(MlxError::InvalidArgument(format!(
79 "embedding_gather: bits must be 4 or 6, got {}",
80 params.bits
81 )));
82 }
83 if params.embed_dim == 0 {
84 return Err(MlxError::InvalidArgument(
85 "embedding_gather: embed_dim must be > 0".into(),
86 ));
87 }
88 if params.group_size == 0 {
89 return Err(MlxError::InvalidArgument(
90 "embedding_gather: group_size must be > 0".into(),
91 ));
92 }
93 if params.embed_dim % params.group_size != 0 {
94 return Err(MlxError::InvalidArgument(format!(
95 "embedding_gather: embed_dim ({}) must be divisible by group_size ({})",
96 params.embed_dim, params.group_size
97 )));
98 }
99 if params.n_tokens == 0 {
100 return Err(MlxError::InvalidArgument(
101 "embedding_gather: n_tokens must be > 0".into(),
102 ));
103 }
104
105 let expected_output_bytes = params.n_tokens * params.embed_dim * std::mem::size_of::<f32>();
106 if output.byte_len() < expected_output_bytes {
107 return Err(MlxError::InvalidArgument(format!(
108 "embedding_gather: output buffer too small: need {} bytes, have {}",
109 expected_output_bytes,
110 output.byte_len()
111 )));
112 }
113
114 let n_groups_per_row = params.embed_dim / params.group_size;
116
117 let packed_row_stride: u32 = match params.bits {
118 4 => {
119 (params.embed_dim / 8) as u32
121 }
122 6 => {
123 (params.embed_dim * 3 / 4) as u32
125 }
126 _ => unreachable!(), };
128
129 let gpu_params = GpuEmbeddingParams {
130 embed_dim: params.embed_dim as u32,
131 group_size: params.group_size as u32,
132 packed_row_stride,
133 n_groups_per_row: n_groups_per_row as u32,
134 };
135
136 let kernel_name = match params.bits {
138 4 => "embedding_gather_4bit",
139 6 => "embedding_gather_6bit",
140 _ => unreachable!(),
141 };
142
143 let pipeline = registry.get_pipeline(kernel_name, device)?;
144
145 let grid = MTLSize::new(params.embed_dim as u64, params.n_tokens as u64, 1);
147 let tg_size = MTLSize::new(
148 std::cmp::min(256, params.embed_dim as u64),
149 1,
150 1,
151 );
152
153 let params_bytes = as_bytes(&gpu_params);
154
155 encode_with_args(
156 encoder,
157 pipeline,
158 &[
159 (0, KernelArg::Buffer(weight_packed)),
160 (1, KernelArg::Buffer(scales)),
161 (2, KernelArg::Buffer(biases)),
162 (3, KernelArg::Buffer(token_ids)),
163 (4, KernelArg::Buffer(output)),
164 (5, KernelArg::Bytes(params_bytes)),
165 ],
166 grid,
167 tg_size,
168 );
169
170 Ok(())
171}