Skip to main content

mlx_native/ops/
embedding_autograd.rs

1//! FP32 embedding-table lookup with reverse-mode autograd backward.
2//!
3//! Used by hf2q's ADR-020 Track 1 multi-layer model on GpuTape (iter-11d).
4//!
5//! Forward: `output[b, h] = embedding[ids[b], h]`
6//! Backward: `d_embedding[id, h] = Σ_{b: ids[b] == id} dy[b, h]`
7//!
8//! The existing `shaders/embedding.metal` covers QUANTIZED 4-bit/6-bit
9//! lookup for inference; this module is the FP32-everywhere variant
10//! needed by the autograd tape.
11//!
12//! The backward kernel is O(vocab × hidden × batch) — fine for the
13//! test fixtures (vocab ≤ a few hundred); production-scale
14//! performance (vocab=150k+) is a follow-up optimization (atomic
15//! float adds or sort-segment-sum).
16
17use metal::MTLSize;
18
19use crate::buffer::MlxBuffer;
20use crate::dtypes::DType;
21use crate::encoder::CommandEncoder;
22use crate::error::{MlxError, Result};
23use crate::kernel_registry::KernelRegistry;
24
25pub static EMBEDDING_AUTOGRAD_SHADER_SOURCE: &str =
26    include_str!("../shaders/embedding_autograd.metal");
27
28pub fn register(registry: &mut KernelRegistry) {
29    registry.register_source("embedding_lookup_f32", EMBEDDING_AUTOGRAD_SHADER_SOURCE);
30    registry.register_source(
31        "embedding_scatter_add_f32",
32        EMBEDDING_AUTOGRAD_SHADER_SOURCE,
33    );
34}
35
36/// Encode `output[b, h] = embedding[ids[b], h]`.
37///
38/// `ids` element type must be u32 (kernel reads as `uint32_t`).  Out-of-range
39/// IDs (≥ vocab) silently produce 0.0 instead of OOB reads.
40///
41/// `params_buf` must be at least 8 bytes (2 × u32: vocab, hidden).
42#[allow(clippy::too_many_arguments)]
43pub fn dispatch_embedding_lookup_f32(
44    encoder: &mut CommandEncoder,
45    registry: &mut KernelRegistry,
46    device: &metal::DeviceRef,
47    embedding: &MlxBuffer,
48    ids: &MlxBuffer,
49    output: &MlxBuffer,
50    params_buf: &MlxBuffer,
51    vocab: u32,
52    hidden: u32,
53    batch: u32,
54) -> Result<()> {
55    if vocab == 0 || hidden == 0 || batch == 0 {
56        return Err(MlxError::InvalidArgument(
57            "embedding_lookup_f32: vocab/hidden/batch must all be > 0".into(),
58        ));
59    }
60    if embedding.element_count() != (vocab as usize) * (hidden as usize) {
61        return Err(MlxError::InvalidArgument(format!(
62            "embedding_lookup_f32: embedding element count {} != vocab({vocab}) * hidden({hidden})",
63            embedding.element_count(),
64        )));
65    }
66    // ids buffer is u32; element_count counts u32 elements.
67    if ids.element_count() != batch as usize {
68        return Err(MlxError::InvalidArgument(format!(
69            "embedding_lookup_f32: ids element count {} != batch ({batch})",
70            ids.element_count()
71        )));
72    }
73    if output.element_count() != (batch as usize) * (hidden as usize) {
74        return Err(MlxError::InvalidArgument(format!(
75            "embedding_lookup_f32: output element count {} != batch({batch}) * hidden({hidden})",
76            output.element_count(),
77        )));
78    }
79    if embedding.dtype() != DType::F32 || output.dtype() != DType::F32 {
80        return Err(MlxError::InvalidArgument(format!(
81            "embedding_lookup_f32: embedding/output dtype must be f32; got {} / {}",
82            embedding.dtype(),
83            output.dtype()
84        )));
85    }
86    if params_buf.byte_len() < 8 {
87        return Err(MlxError::InvalidArgument(format!(
88            "embedding_lookup_f32: params_buf too small (need 8 bytes for 2×u32, got {})",
89            params_buf.byte_len()
90        )));
91    }
92
93    let pipeline = registry.get_pipeline("embedding_lookup_f32", device)?;
94    encoder.encode(
95        pipeline,
96        &[(0, embedding), (1, ids), (2, output), (3, params_buf)],
97        MTLSize::new(hidden as u64, batch as u64, 1),
98        MTLSize::new(
99            std::cmp::min(hidden as u64, 32),
100            std::cmp::min(batch as u64, 8),
101            1,
102        ),
103    );
104    Ok(())
105}
106
107/// Encode the embedding backward (scatter-add).
108///
109/// `d_embedding` MUST be pre-zeroed by caller — the kernel writes
110/// each cell exactly once with the accumulated upstream contribution.
111///
112/// `params_buf` must be at least 12 bytes (3 × u32: vocab, hidden, batch).
113#[allow(clippy::too_many_arguments)]
114pub fn dispatch_embedding_scatter_add_f32(
115    encoder: &mut CommandEncoder,
116    registry: &mut KernelRegistry,
117    device: &metal::DeviceRef,
118    dy: &MlxBuffer,
119    ids: &MlxBuffer,
120    d_embedding: &MlxBuffer,
121    params_buf: &MlxBuffer,
122    vocab: u32,
123    hidden: u32,
124    batch: u32,
125) -> Result<()> {
126    if vocab == 0 || hidden == 0 || batch == 0 {
127        return Err(MlxError::InvalidArgument(
128            "embedding_scatter_add_f32: vocab/hidden/batch must all be > 0".into(),
129        ));
130    }
131    if dy.element_count() != (batch as usize) * (hidden as usize) {
132        return Err(MlxError::InvalidArgument(format!(
133            "embedding_scatter_add_f32: dy element count {} != batch({batch}) * hidden({hidden})",
134            dy.element_count(),
135        )));
136    }
137    if ids.element_count() != batch as usize {
138        return Err(MlxError::InvalidArgument(format!(
139            "embedding_scatter_add_f32: ids element count {} != batch ({batch})",
140            ids.element_count()
141        )));
142    }
143    if d_embedding.element_count() != (vocab as usize) * (hidden as usize) {
144        return Err(MlxError::InvalidArgument(format!(
145            "embedding_scatter_add_f32: d_embedding element count {} != vocab({vocab}) * hidden({hidden})",
146            d_embedding.element_count(),
147        )));
148    }
149    if dy.dtype() != DType::F32 || d_embedding.dtype() != DType::F32 {
150        return Err(MlxError::InvalidArgument(format!(
151            "embedding_scatter_add_f32: dy/d_embedding dtype must be f32; got {} / {}",
152            dy.dtype(),
153            d_embedding.dtype()
154        )));
155    }
156    if params_buf.byte_len() < 12 {
157        return Err(MlxError::InvalidArgument(format!(
158            "embedding_scatter_add_f32: params_buf too small (need 12 bytes for 3×u32, got {})",
159            params_buf.byte_len()
160        )));
161    }
162
163    let pipeline = registry.get_pipeline("embedding_scatter_add_f32", device)?;
164    encoder.encode(
165        pipeline,
166        &[(0, dy), (1, ids), (2, d_embedding), (3, params_buf)],
167        MTLSize::new(hidden as u64, vocab as u64, 1),
168        MTLSize::new(
169            std::cmp::min(hidden as u64, 32),
170            std::cmp::min(vocab as u64, 8),
171            1,
172        ),
173    );
174    Ok(())
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::device::MlxDevice;
181
182    fn cpu_lookup(embedding: &[f32], ids: &[u32], hidden: usize) -> Vec<f32> {
183        let mut out = vec![0f32; ids.len() * hidden];
184        for (b, &id) in ids.iter().enumerate() {
185            let id = id as usize;
186            for h in 0..hidden {
187                out[b * hidden + h] = embedding[id * hidden + h];
188            }
189        }
190        out
191    }
192
193    fn cpu_scatter_add(dy: &[f32], ids: &[u32], vocab: usize, hidden: usize) -> Vec<f32> {
194        let mut d_embed = vec![0f32; vocab * hidden];
195        for (b, &id) in ids.iter().enumerate() {
196            let id = id as usize;
197            for h in 0..hidden {
198                d_embed[id * hidden + h] += dy[b * hidden + h];
199            }
200        }
201        d_embed
202    }
203
204    fn run_lookup(embedding: &[f32], ids: &[u32], vocab: usize, hidden: usize) -> Vec<f32> {
205        let device = MlxDevice::new().expect("device");
206        let batch = ids.len();
207        let mut e_buf = device
208            .alloc_buffer(vocab * hidden * 4, DType::F32, vec![vocab, hidden])
209            .expect("alloc embedding");
210        e_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(embedding);
211        let mut id_buf = device
212            .alloc_buffer(batch * 4, DType::U32, vec![batch])
213            .expect("alloc ids");
214        id_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(ids);
215        let out_buf = device
216            .alloc_buffer(batch * hidden * 4, DType::F32, vec![batch, hidden])
217            .expect("alloc out");
218        let mut params = device
219            .alloc_buffer(8, DType::F32, vec![2])
220            .expect("alloc params");
221        params.as_mut_slice::<u32>().unwrap()[..2]
222            .copy_from_slice(&[vocab as u32, hidden as u32]);
223
224        let mut registry = KernelRegistry::new();
225        register(&mut registry);
226        let mut encoder = device.command_encoder().expect("encoder");
227        dispatch_embedding_lookup_f32(
228            &mut encoder,
229            &mut registry,
230            device.metal_device(),
231            &e_buf,
232            &id_buf,
233            &out_buf,
234            &params,
235            vocab as u32,
236            hidden as u32,
237            batch as u32,
238        )
239        .expect("dispatch lookup");
240        encoder.commit_and_wait().expect("commit");
241        out_buf.as_slice::<f32>().unwrap().to_vec()
242    }
243
244    fn run_scatter_add(
245        dy: &[f32],
246        ids: &[u32],
247        vocab: usize,
248        hidden: usize,
249    ) -> Vec<f32> {
250        let device = MlxDevice::new().expect("device");
251        let batch = ids.len();
252        let mut dy_buf = device
253            .alloc_buffer(batch * hidden * 4, DType::F32, vec![batch, hidden])
254            .expect("alloc dy");
255        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(dy);
256        let mut id_buf = device
257            .alloc_buffer(batch * 4, DType::U32, vec![batch])
258            .expect("alloc ids");
259        id_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(ids);
260        // alloc_buffer is zero-fill (ADR-015 iter61a).
261        let de_buf = device
262            .alloc_buffer(vocab * hidden * 4, DType::F32, vec![vocab, hidden])
263            .expect("alloc d_embedding");
264        let mut params = device
265            .alloc_buffer(12, DType::F32, vec![3])
266            .expect("alloc params");
267        params.as_mut_slice::<u32>().unwrap()[..3]
268            .copy_from_slice(&[vocab as u32, hidden as u32, batch as u32]);
269
270        let mut registry = KernelRegistry::new();
271        register(&mut registry);
272        let mut encoder = device.command_encoder().expect("encoder");
273        dispatch_embedding_scatter_add_f32(
274            &mut encoder,
275            &mut registry,
276            device.metal_device(),
277            &dy_buf,
278            &id_buf,
279            &de_buf,
280            &params,
281            vocab as u32,
282            hidden as u32,
283            batch as u32,
284        )
285        .expect("dispatch scatter_add");
286        encoder.commit_and_wait().expect("commit");
287        de_buf.as_slice::<f32>().unwrap().to_vec()
288    }
289
290    #[test]
291    fn embedding_lookup_byte_identical_to_cpu() {
292        let vocab = 16;
293        let hidden = 8;
294        let embedding: Vec<f32> = (0..vocab * hidden)
295            .map(|i| (i as f32) * 0.13 - 0.5)
296            .collect();
297        let ids: Vec<u32> = vec![3, 7, 0, 15, 5, 5, 12, 1];
298        let gpu = run_lookup(&embedding, &ids, vocab, hidden);
299        let cpu = cpu_lookup(&embedding, &ids, hidden);
300        for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
301            assert_eq!(g.to_bits(), c.to_bits(), "mismatch at {i}");
302        }
303    }
304
305    #[test]
306    fn embedding_lookup_handles_repeated_ids() {
307        // Same ID appearing multiple times in batch.  Output rows must
308        // all match the embedding row exactly.
309        let vocab = 8;
310        let hidden = 4;
311        let embedding: Vec<f32> = (0..vocab * hidden)
312            .map(|i| (i as f32) * 0.7)
313            .collect();
314        let ids: Vec<u32> = vec![5, 5, 5, 5];
315        let gpu = run_lookup(&embedding, &ids, vocab, hidden);
316        let row5 = &embedding[5 * hidden..6 * hidden];
317        for b in 0..ids.len() {
318            for h in 0..hidden {
319                assert_eq!(gpu[b * hidden + h].to_bits(), row5[h].to_bits());
320            }
321        }
322    }
323
324    #[test]
325    fn embedding_scatter_add_byte_identical_to_cpu() {
326        let vocab = 16;
327        let hidden = 8;
328        let batch = 12;
329        let dy: Vec<f32> = (0..batch * hidden)
330            .map(|i| (i as f32) * 0.011 - 0.05)
331            .collect();
332        let ids: Vec<u32> = vec![3, 7, 0, 15, 5, 5, 12, 1, 5, 0, 7, 11];
333        let gpu = run_scatter_add(&dy, &ids, vocab, hidden);
334        let cpu = cpu_scatter_add(&dy, &ids, vocab, hidden);
335        for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
336            assert_eq!(g.to_bits(), c.to_bits(), "scatter-add mismatch at {i}");
337        }
338    }
339
340    #[test]
341    fn embedding_scatter_add_unused_ids_are_zero() {
342        // IDs 0, 4, 9, 13 are NEVER used in the batch — their rows in
343        // d_embedding must remain zero.
344        let vocab = 16;
345        let hidden = 4;
346        let batch = 6;
347        let dy: Vec<f32> = (0..batch * hidden).map(|i| (i as f32) + 1.0).collect();
348        let ids: Vec<u32> = vec![1, 2, 3, 5, 7, 11];
349        let gpu = run_scatter_add(&dy, &ids, vocab, hidden);
350        for &unused_id in &[0u32, 4, 6, 8, 9, 10, 12, 13, 14, 15] {
351            for h in 0..hidden {
352                assert_eq!(
353                    gpu[unused_id as usize * hidden + h], 0.0,
354                    "unused id {unused_id} row should be zero at h={h}"
355                );
356            }
357        }
358    }
359
360    #[test]
361    fn embedding_round_trip_lookup_then_scatter_add() {
362        // Lookup by ids; then scatter-add the lookup output back —
363        // the scatter sums all batch contributions back into the
364        // touched rows.  For each id `i` appearing `k` times,
365        // d_embedding[i] should equal `k * embedding[i]`.
366        let vocab = 8;
367        let hidden = 4;
368        let embedding: Vec<f32> = (0..vocab * hidden).map(|i| (i as f32) * 0.5).collect();
369        let ids: Vec<u32> = vec![2, 5, 2, 7, 5, 5, 2];
370        // counts: 2 appears 3x, 5 appears 3x, 7 appears 1x.
371        let lookup_out = run_lookup(&embedding, &ids, vocab, hidden);
372        let scatter = run_scatter_add(&lookup_out, &ids, vocab, hidden);
373        for id in 0..vocab {
374            let count = ids.iter().filter(|&&i| i as usize == id).count();
375            for h in 0..hidden {
376                let expected = embedding[id * hidden + h] * (count as f32);
377                let actual = scatter[id * hidden + h];
378                assert!(
379                    (actual - expected).abs() < 1e-5,
380                    "id={id} h={h}: expected {expected} (count={count}), got {actual}"
381                );
382            }
383        }
384    }
385}