1use metal::MTLSize;
18
19use crate::buffer::MlxBuffer;
20use crate::dtypes::DType;
21use crate::encoder::CommandEncoder;
22use crate::error::{MlxError, Result};
23use crate::kernel_registry::KernelRegistry;
24
25pub static EMBEDDING_AUTOGRAD_SHADER_SOURCE: &str =
26 include_str!("../shaders/embedding_autograd.metal");
27
28pub fn register(registry: &mut KernelRegistry) {
29 registry.register_source("embedding_lookup_f32", EMBEDDING_AUTOGRAD_SHADER_SOURCE);
30 registry.register_source(
31 "embedding_scatter_add_f32",
32 EMBEDDING_AUTOGRAD_SHADER_SOURCE,
33 );
34}
35
36#[allow(clippy::too_many_arguments)]
43pub fn dispatch_embedding_lookup_f32(
44 encoder: &mut CommandEncoder,
45 registry: &mut KernelRegistry,
46 device: &metal::DeviceRef,
47 embedding: &MlxBuffer,
48 ids: &MlxBuffer,
49 output: &MlxBuffer,
50 params_buf: &MlxBuffer,
51 vocab: u32,
52 hidden: u32,
53 batch: u32,
54) -> Result<()> {
55 if vocab == 0 || hidden == 0 || batch == 0 {
56 return Err(MlxError::InvalidArgument(
57 "embedding_lookup_f32: vocab/hidden/batch must all be > 0".into(),
58 ));
59 }
60 if embedding.element_count() != (vocab as usize) * (hidden as usize) {
61 return Err(MlxError::InvalidArgument(format!(
62 "embedding_lookup_f32: embedding element count {} != vocab({vocab}) * hidden({hidden})",
63 embedding.element_count(),
64 )));
65 }
66 if ids.element_count() != batch as usize {
68 return Err(MlxError::InvalidArgument(format!(
69 "embedding_lookup_f32: ids element count {} != batch ({batch})",
70 ids.element_count()
71 )));
72 }
73 if output.element_count() != (batch as usize) * (hidden as usize) {
74 return Err(MlxError::InvalidArgument(format!(
75 "embedding_lookup_f32: output element count {} != batch({batch}) * hidden({hidden})",
76 output.element_count(),
77 )));
78 }
79 if embedding.dtype() != DType::F32 || output.dtype() != DType::F32 {
80 return Err(MlxError::InvalidArgument(format!(
81 "embedding_lookup_f32: embedding/output dtype must be f32; got {} / {}",
82 embedding.dtype(),
83 output.dtype()
84 )));
85 }
86 if params_buf.byte_len() < 8 {
87 return Err(MlxError::InvalidArgument(format!(
88 "embedding_lookup_f32: params_buf too small (need 8 bytes for 2×u32, got {})",
89 params_buf.byte_len()
90 )));
91 }
92
93 let pipeline = registry.get_pipeline("embedding_lookup_f32", device)?;
94 encoder.encode(
95 pipeline,
96 &[(0, embedding), (1, ids), (2, output), (3, params_buf)],
97 MTLSize::new(hidden as u64, batch as u64, 1),
98 MTLSize::new(
99 std::cmp::min(hidden as u64, 32),
100 std::cmp::min(batch as u64, 8),
101 1,
102 ),
103 );
104 Ok(())
105}
106
107#[allow(clippy::too_many_arguments)]
114pub fn dispatch_embedding_scatter_add_f32(
115 encoder: &mut CommandEncoder,
116 registry: &mut KernelRegistry,
117 device: &metal::DeviceRef,
118 dy: &MlxBuffer,
119 ids: &MlxBuffer,
120 d_embedding: &MlxBuffer,
121 params_buf: &MlxBuffer,
122 vocab: u32,
123 hidden: u32,
124 batch: u32,
125) -> Result<()> {
126 if vocab == 0 || hidden == 0 || batch == 0 {
127 return Err(MlxError::InvalidArgument(
128 "embedding_scatter_add_f32: vocab/hidden/batch must all be > 0".into(),
129 ));
130 }
131 if dy.element_count() != (batch as usize) * (hidden as usize) {
132 return Err(MlxError::InvalidArgument(format!(
133 "embedding_scatter_add_f32: dy element count {} != batch({batch}) * hidden({hidden})",
134 dy.element_count(),
135 )));
136 }
137 if ids.element_count() != batch as usize {
138 return Err(MlxError::InvalidArgument(format!(
139 "embedding_scatter_add_f32: ids element count {} != batch ({batch})",
140 ids.element_count()
141 )));
142 }
143 if d_embedding.element_count() != (vocab as usize) * (hidden as usize) {
144 return Err(MlxError::InvalidArgument(format!(
145 "embedding_scatter_add_f32: d_embedding element count {} != vocab({vocab}) * hidden({hidden})",
146 d_embedding.element_count(),
147 )));
148 }
149 if dy.dtype() != DType::F32 || d_embedding.dtype() != DType::F32 {
150 return Err(MlxError::InvalidArgument(format!(
151 "embedding_scatter_add_f32: dy/d_embedding dtype must be f32; got {} / {}",
152 dy.dtype(),
153 d_embedding.dtype()
154 )));
155 }
156 if params_buf.byte_len() < 12 {
157 return Err(MlxError::InvalidArgument(format!(
158 "embedding_scatter_add_f32: params_buf too small (need 12 bytes for 3×u32, got {})",
159 params_buf.byte_len()
160 )));
161 }
162
163 let pipeline = registry.get_pipeline("embedding_scatter_add_f32", device)?;
164 encoder.encode(
165 pipeline,
166 &[(0, dy), (1, ids), (2, d_embedding), (3, params_buf)],
167 MTLSize::new(hidden as u64, vocab as u64, 1),
168 MTLSize::new(
169 std::cmp::min(hidden as u64, 32),
170 std::cmp::min(vocab as u64, 8),
171 1,
172 ),
173 );
174 Ok(())
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use crate::device::MlxDevice;
181
182 fn cpu_lookup(embedding: &[f32], ids: &[u32], hidden: usize) -> Vec<f32> {
183 let mut out = vec![0f32; ids.len() * hidden];
184 for (b, &id) in ids.iter().enumerate() {
185 let id = id as usize;
186 for h in 0..hidden {
187 out[b * hidden + h] = embedding[id * hidden + h];
188 }
189 }
190 out
191 }
192
193 fn cpu_scatter_add(dy: &[f32], ids: &[u32], vocab: usize, hidden: usize) -> Vec<f32> {
194 let mut d_embed = vec![0f32; vocab * hidden];
195 for (b, &id) in ids.iter().enumerate() {
196 let id = id as usize;
197 for h in 0..hidden {
198 d_embed[id * hidden + h] += dy[b * hidden + h];
199 }
200 }
201 d_embed
202 }
203
204 fn run_lookup(embedding: &[f32], ids: &[u32], vocab: usize, hidden: usize) -> Vec<f32> {
205 let device = MlxDevice::new().expect("device");
206 let batch = ids.len();
207 let mut e_buf = device
208 .alloc_buffer(vocab * hidden * 4, DType::F32, vec![vocab, hidden])
209 .expect("alloc embedding");
210 e_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(embedding);
211 let mut id_buf = device
212 .alloc_buffer(batch * 4, DType::U32, vec![batch])
213 .expect("alloc ids");
214 id_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(ids);
215 let out_buf = device
216 .alloc_buffer(batch * hidden * 4, DType::F32, vec![batch, hidden])
217 .expect("alloc out");
218 let mut params = device
219 .alloc_buffer(8, DType::F32, vec![2])
220 .expect("alloc params");
221 params.as_mut_slice::<u32>().unwrap()[..2]
222 .copy_from_slice(&[vocab as u32, hidden as u32]);
223
224 let mut registry = KernelRegistry::new();
225 register(&mut registry);
226 let mut encoder = device.command_encoder().expect("encoder");
227 dispatch_embedding_lookup_f32(
228 &mut encoder,
229 &mut registry,
230 device.metal_device(),
231 &e_buf,
232 &id_buf,
233 &out_buf,
234 ¶ms,
235 vocab as u32,
236 hidden as u32,
237 batch as u32,
238 )
239 .expect("dispatch lookup");
240 encoder.commit_and_wait().expect("commit");
241 out_buf.as_slice::<f32>().unwrap().to_vec()
242 }
243
244 fn run_scatter_add(
245 dy: &[f32],
246 ids: &[u32],
247 vocab: usize,
248 hidden: usize,
249 ) -> Vec<f32> {
250 let device = MlxDevice::new().expect("device");
251 let batch = ids.len();
252 let mut dy_buf = device
253 .alloc_buffer(batch * hidden * 4, DType::F32, vec![batch, hidden])
254 .expect("alloc dy");
255 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(dy);
256 let mut id_buf = device
257 .alloc_buffer(batch * 4, DType::U32, vec![batch])
258 .expect("alloc ids");
259 id_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(ids);
260 let de_buf = device
262 .alloc_buffer(vocab * hidden * 4, DType::F32, vec![vocab, hidden])
263 .expect("alloc d_embedding");
264 let mut params = device
265 .alloc_buffer(12, DType::F32, vec![3])
266 .expect("alloc params");
267 params.as_mut_slice::<u32>().unwrap()[..3]
268 .copy_from_slice(&[vocab as u32, hidden as u32, batch as u32]);
269
270 let mut registry = KernelRegistry::new();
271 register(&mut registry);
272 let mut encoder = device.command_encoder().expect("encoder");
273 dispatch_embedding_scatter_add_f32(
274 &mut encoder,
275 &mut registry,
276 device.metal_device(),
277 &dy_buf,
278 &id_buf,
279 &de_buf,
280 ¶ms,
281 vocab as u32,
282 hidden as u32,
283 batch as u32,
284 )
285 .expect("dispatch scatter_add");
286 encoder.commit_and_wait().expect("commit");
287 de_buf.as_slice::<f32>().unwrap().to_vec()
288 }
289
290 #[test]
291 fn embedding_lookup_byte_identical_to_cpu() {
292 let vocab = 16;
293 let hidden = 8;
294 let embedding: Vec<f32> = (0..vocab * hidden)
295 .map(|i| (i as f32) * 0.13 - 0.5)
296 .collect();
297 let ids: Vec<u32> = vec![3, 7, 0, 15, 5, 5, 12, 1];
298 let gpu = run_lookup(&embedding, &ids, vocab, hidden);
299 let cpu = cpu_lookup(&embedding, &ids, hidden);
300 for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
301 assert_eq!(g.to_bits(), c.to_bits(), "mismatch at {i}");
302 }
303 }
304
305 #[test]
306 fn embedding_lookup_handles_repeated_ids() {
307 let vocab = 8;
310 let hidden = 4;
311 let embedding: Vec<f32> = (0..vocab * hidden)
312 .map(|i| (i as f32) * 0.7)
313 .collect();
314 let ids: Vec<u32> = vec![5, 5, 5, 5];
315 let gpu = run_lookup(&embedding, &ids, vocab, hidden);
316 let row5 = &embedding[5 * hidden..6 * hidden];
317 for b in 0..ids.len() {
318 for h in 0..hidden {
319 assert_eq!(gpu[b * hidden + h].to_bits(), row5[h].to_bits());
320 }
321 }
322 }
323
324 #[test]
325 fn embedding_scatter_add_byte_identical_to_cpu() {
326 let vocab = 16;
327 let hidden = 8;
328 let batch = 12;
329 let dy: Vec<f32> = (0..batch * hidden)
330 .map(|i| (i as f32) * 0.011 - 0.05)
331 .collect();
332 let ids: Vec<u32> = vec![3, 7, 0, 15, 5, 5, 12, 1, 5, 0, 7, 11];
333 let gpu = run_scatter_add(&dy, &ids, vocab, hidden);
334 let cpu = cpu_scatter_add(&dy, &ids, vocab, hidden);
335 for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
336 assert_eq!(g.to_bits(), c.to_bits(), "scatter-add mismatch at {i}");
337 }
338 }
339
340 #[test]
341 fn embedding_scatter_add_unused_ids_are_zero() {
342 let vocab = 16;
345 let hidden = 4;
346 let batch = 6;
347 let dy: Vec<f32> = (0..batch * hidden).map(|i| (i as f32) + 1.0).collect();
348 let ids: Vec<u32> = vec![1, 2, 3, 5, 7, 11];
349 let gpu = run_scatter_add(&dy, &ids, vocab, hidden);
350 for &unused_id in &[0u32, 4, 6, 8, 9, 10, 12, 13, 14, 15] {
351 for h in 0..hidden {
352 assert_eq!(
353 gpu[unused_id as usize * hidden + h], 0.0,
354 "unused id {unused_id} row should be zero at h={h}"
355 );
356 }
357 }
358 }
359
360 #[test]
361 fn embedding_round_trip_lookup_then_scatter_add() {
362 let vocab = 8;
367 let hidden = 4;
368 let embedding: Vec<f32> = (0..vocab * hidden).map(|i| (i as f32) * 0.5).collect();
369 let ids: Vec<u32> = vec![2, 5, 2, 7, 5, 5, 2];
370 let lookup_out = run_lookup(&embedding, &ids, vocab, hidden);
372 let scatter = run_scatter_add(&lookup_out, &ids, vocab, hidden);
373 for id in 0..vocab {
374 let count = ids.iter().filter(|&&i| i as usize == id).count();
375 for h in 0..hidden {
376 let expected = embedding[id * hidden + h] * (count as f32);
377 let actual = scatter[id * hidden + h];
378 assert!(
379 (actual - expected).abs() < 1e-5,
380 "id={id} h={h}: expected {expected} (count={count}), got {actual}"
381 );
382 }
383 }
384 }
385}