Skip to main content

mlx_native/ops/
take_along_axis.rs

1//! ADR-020 iter-11h-e1 — take_along_axis (gather) + scatter-backward
2//! for the GpuTape autograd pipeline.  Forward gathers values along
3//! the last axis using a precomputed (non-differentiable) index
4//! buffer; backward scatters gradients back into a zero-initialised
5//! dx buffer.
6//!
7//! Used by MoE router on GpuTape (iter-11h-e):
8//!   y = take_along_axis(softmax(gate(x)), top_k_indices, axis=-1)
9
10use metal::MTLSize;
11
12use crate::buffer::MlxBuffer;
13use crate::dtypes::DType;
14use crate::encoder::CommandEncoder;
15use crate::error::{MlxError, Result};
16use crate::kernel_registry::KernelRegistry;
17
18pub static TAKE_ALONG_AXIS_SHADER_SOURCE: &str =
19    include_str!("../shaders/take_along_axis.metal");
20
21pub fn register(registry: &mut KernelRegistry) {
22    registry.register_source("take_along_axis_f32", TAKE_ALONG_AXIS_SHADER_SOURCE);
23    registry.register_source(
24        "take_along_axis_backward_f32",
25        TAKE_ALONG_AXIS_SHADER_SOURCE,
26    );
27}
28
29fn validate(
30    op: &str,
31    rows: u32,
32    cols: u32,
33    k: u32,
34    a: &MlxBuffer,
35    indices: &MlxBuffer,
36    out: &MlxBuffer,
37    params: &MlxBuffer,
38    expected_a: usize,
39    expected_out: usize,
40) -> Result<()> {
41    if rows == 0 || cols == 0 || k == 0 {
42        return Err(MlxError::InvalidArgument(format!(
43            "{op}: rows, cols, k must all be > 0 (got {rows}, {cols}, {k})"
44        )));
45    }
46    if k > cols {
47        return Err(MlxError::InvalidArgument(format!(
48            "{op}: k ({k}) > cols ({cols})"
49        )));
50    }
51    if a.dtype() != DType::F32 || out.dtype() != DType::F32 {
52        return Err(MlxError::InvalidArgument(format!(
53            "{op}: a/out must be f32"
54        )));
55    }
56    if indices.dtype() != DType::U32 {
57        return Err(MlxError::InvalidArgument(format!(
58            "{op}: indices dtype {} not u32",
59            indices.dtype()
60        )));
61    }
62    if a.element_count() != expected_a {
63        return Err(MlxError::InvalidArgument(format!(
64            "{op}: a element_count {} != {expected_a}",
65            a.element_count()
66        )));
67    }
68    if indices.element_count() != (rows as usize) * (k as usize) {
69        return Err(MlxError::InvalidArgument(format!(
70            "{op}: indices element_count {} != rows*k = {}",
71            indices.element_count(),
72            (rows as usize) * (k as usize)
73        )));
74    }
75    if out.element_count() != expected_out {
76        return Err(MlxError::InvalidArgument(format!(
77            "{op}: out element_count {} != {expected_out}",
78            out.element_count()
79        )));
80    }
81    if params.byte_len() < 12 {
82        return Err(MlxError::InvalidArgument(format!(
83            "{op}: params < 12 bytes (need 3 × u32)"
84        )));
85    }
86    Ok(())
87}
88
89#[allow(clippy::too_many_arguments)]
90pub fn dispatch_take_along_axis_f32(
91    encoder: &mut CommandEncoder,
92    registry: &mut KernelRegistry,
93    device: &metal::DeviceRef,
94    x: &MlxBuffer,
95    indices: &MlxBuffer,
96    y: &MlxBuffer,
97    params: &MlxBuffer,
98    rows: u32,
99    cols: u32,
100    k: u32,
101) -> Result<()> {
102    const OP: &str = "take_along_axis_f32";
103    let r = rows as usize;
104    let c = cols as usize;
105    let k_us = k as usize;
106    validate(OP, rows, cols, k, x, indices, y, params, r * c, r * k_us)?;
107
108    let pipeline = registry.get_pipeline(OP, device)?;
109    encoder.encode(
110        pipeline,
111        &[(0, x), (1, indices), (2, y), (3, params)],
112        MTLSize::new(rows as u64, k as u64, 1),
113        MTLSize::new(
114            std::cmp::min(16, rows as u64),
115            std::cmp::min(16, k as u64),
116            1,
117        ),
118    );
119    Ok(())
120}
121
122#[allow(clippy::too_many_arguments)]
123pub fn dispatch_take_along_axis_backward_f32(
124    encoder: &mut CommandEncoder,
125    registry: &mut KernelRegistry,
126    device: &metal::DeviceRef,
127    dy: &MlxBuffer,
128    indices: &MlxBuffer,
129    dx: &MlxBuffer,
130    params: &MlxBuffer,
131    rows: u32,
132    cols: u32,
133    k: u32,
134) -> Result<()> {
135    const OP: &str = "take_along_axis_backward_f32";
136    let r = rows as usize;
137    let c = cols as usize;
138    let k_us = k as usize;
139    validate(OP, rows, cols, k, dx, indices, dy, params, r * c, r * k_us)?;
140
141    let pipeline = registry.get_pipeline(OP, device)?;
142    encoder.encode(
143        pipeline,
144        &[(0, dy), (1, indices), (2, dx), (3, params)],
145        MTLSize::new(rows as u64, k as u64, 1),
146        MTLSize::new(
147            std::cmp::min(16, rows as u64),
148            std::cmp::min(16, k as u64),
149            1,
150        ),
151    );
152    Ok(())
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::device::MlxDevice;
159
160    fn alloc_f32(d: &MlxDevice, n: usize, sh: Vec<usize>) -> MlxBuffer {
161        let mut b = d.alloc_buffer(n * 4, DType::F32, sh).unwrap();
162        b.as_mut_slice::<f32>().unwrap().fill(0.0);
163        b
164    }
165    fn alloc_u32(d: &MlxDevice, n: usize, sh: Vec<usize>) -> MlxBuffer {
166        let mut b = d.alloc_buffer(n * 4, DType::U32, sh).unwrap();
167        b.as_mut_slice::<u32>().unwrap().fill(0);
168        b
169    }
170    fn make_params(d: &MlxDevice, rows: u32, cols: u32, k: u32) -> MlxBuffer {
171        let mut p = d.alloc_buffer(12, DType::U32, vec![3]).unwrap();
172        p.as_mut_slice::<u32>().unwrap().copy_from_slice(&[rows, cols, k]);
173        p
174    }
175
176    #[test]
177    fn forward_matches_cpu_oracle() {
178        let device = MlxDevice::new().unwrap();
179        let mut registry = KernelRegistry::new();
180        let rows = 4;
181        let cols = 8;
182        let k = 3;
183        let x: Vec<f32> = (0..(rows * cols))
184            .map(|i| ((i as f32) * 0.137 - 0.4).sin() * 0.7)
185            .collect();
186        // Per-row top-K indices (must be distinct within a row, and
187        // < cols).  Hand-pick non-trivial values.
188        let indices: Vec<u32> = vec![
189            0, 3, 7,
190            1, 4, 6,
191            2, 5, 0,
192            7, 0, 4,
193        ];
194
195        let mut x_buf = alloc_f32(&device, rows * cols, vec![rows, cols]);
196        x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
197        let mut idx_buf = alloc_u32(&device, rows * k, vec![rows, k]);
198        idx_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(&indices);
199        let y_buf = alloc_f32(&device, rows * k, vec![rows, k]);
200        let params = make_params(&device, rows as u32, cols as u32, k as u32);
201
202        let mut encoder = device.command_encoder().unwrap();
203        dispatch_take_along_axis_f32(
204            &mut encoder, &mut registry, device.metal_device(),
205            &x_buf, &idx_buf, &y_buf, &params,
206            rows as u32, cols as u32, k as u32,
207        ).unwrap();
208        encoder.commit_and_wait().unwrap();
209
210        let gpu = y_buf.as_slice::<f32>().unwrap();
211        for r in 0..rows {
212            for j in 0..k {
213                let idx = indices[r * k + j] as usize;
214                let expected = x[r * cols + idx];
215                assert!(
216                    (gpu[r * k + j] - expected).abs() < 1e-6 * expected.abs().max(1.0),
217                    "y[{r},{j}]: gpu={} expected={} (idx={})",
218                    gpu[r * k + j], expected, idx
219                );
220            }
221        }
222    }
223
224    #[test]
225    fn backward_scatter_matches_cpu_oracle() {
226        let device = MlxDevice::new().unwrap();
227        let mut registry = KernelRegistry::new();
228        let rows = 3;
229        let cols = 6;
230        let k = 2;
231        let dy: Vec<f32> = (0..(rows * k))
232            .map(|i| ((i as f32) * 0.231 + 0.1).sin() * 0.6)
233            .collect();
234        // Distinct indices per row.
235        let indices: Vec<u32> = vec![
236            0, 4,
237            1, 5,
238            2, 3,
239        ];
240
241        let mut dy_buf = alloc_f32(&device, rows * k, vec![rows, k]);
242        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
243        let mut idx_buf = alloc_u32(&device, rows * k, vec![rows, k]);
244        idx_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(&indices);
245        let dx_buf = alloc_f32(&device, rows * cols, vec![rows, cols]);
246        let params = make_params(&device, rows as u32, cols as u32, k as u32);
247
248        let mut encoder = device.command_encoder().unwrap();
249        dispatch_take_along_axis_backward_f32(
250            &mut encoder, &mut registry, device.metal_device(),
251            &dy_buf, &idx_buf, &dx_buf, &params,
252            rows as u32, cols as u32, k as u32,
253        ).unwrap();
254        encoder.commit_and_wait().unwrap();
255
256        let gpu = dx_buf.as_slice::<f32>().unwrap();
257        // Build CPU oracle.
258        let mut expected = vec![0.0f32; rows * cols];
259        for r in 0..rows {
260            for j in 0..k {
261                let idx = indices[r * k + j] as usize;
262                expected[r * cols + idx] = dy[r * k + j];
263            }
264        }
265        for i in 0..(rows * cols) {
266            assert!(
267                (gpu[i] - expected[i]).abs() < 1e-6,
268                "dx[{i}]: gpu={} expected={}",
269                gpu[i], expected[i]
270            );
271        }
272    }
273
274    /// FD falsifier: loss = sum(take_along_axis(x, indices)).  Analytic
275    /// dx[r, c] = 1 if c is in row r's top-K else 0.  FD must match.
276    #[test]
277    fn backward_finite_difference_falsifier() {
278        let device = MlxDevice::new().unwrap();
279        let mut registry = KernelRegistry::new();
280        let rows = 4;
281        let cols = 6;
282        let k = 2;
283        let x: Vec<f32> = (0..(rows * cols))
284            .map(|i| 0.3 + (i as f32) * 0.013)
285            .collect();
286        let indices: Vec<u32> = vec![
287            0, 3,
288            1, 5,
289            2, 4,
290            0, 4,
291        ];
292
293        // Forward to get y.
294        let mut x_buf = alloc_f32(&device, rows * cols, vec![rows, cols]);
295        x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
296        let mut idx_buf = alloc_u32(&device, rows * k, vec![rows, k]);
297        idx_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(&indices);
298        let y_buf = alloc_f32(&device, rows * k, vec![rows, k]);
299        let params = make_params(&device, rows as u32, cols as u32, k as u32);
300        let mut encoder = device.command_encoder().unwrap();
301        dispatch_take_along_axis_f32(
302            &mut encoder, &mut registry, device.metal_device(),
303            &x_buf, &idx_buf, &y_buf, &params,
304            rows as u32, cols as u32, k as u32,
305        ).unwrap();
306        encoder.commit_and_wait().unwrap();
307
308        // Analytic dx via dy=ones.
309        let dy_ones = vec![1.0f32; rows * k];
310        let mut dy_buf = alloc_f32(&device, rows * k, vec![rows, k]);
311        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy_ones);
312        let dx_buf = alloc_f32(&device, rows * cols, vec![rows, cols]);
313        let mut encoder = device.command_encoder().unwrap();
314        dispatch_take_along_axis_backward_f32(
315            &mut encoder, &mut registry, device.metal_device(),
316            &dy_buf, &idx_buf, &dx_buf, &params,
317            rows as u32, cols as u32, k as u32,
318        ).unwrap();
319        encoder.commit_and_wait().unwrap();
320        let dx = dx_buf.as_slice::<f32>().unwrap().to_vec();
321
322        // FD on every x[i].
323        let h = 1e-3f64;
324        let loss = |x_in: &[f32]| -> f64 {
325            let mut s = 0.0f64;
326            for r in 0..rows {
327                for j in 0..k {
328                    s += x_in[r * cols + indices[r * k + j] as usize] as f64;
329                }
330            }
331            s
332        };
333        for i in 0..(rows * cols) {
334            let mut xp = x.clone(); xp[i] += h as f32;
335            let mut xm = x.clone(); xm[i] -= h as f32;
336            let fd = (loss(&xp) - loss(&xm)) / (2.0 * h);
337            let tol = 1e-3 * fd.abs().max(1.0);
338            assert!(
339                (dx[i] as f64 - fd).abs() < tol,
340                "FD x[{i}]: analytic={} fd={}", dx[i], fd
341            );
342        }
343    }
344
345    #[test]
346    fn rejects_k_greater_than_cols() {
347        let device = MlxDevice::new().unwrap();
348        let mut registry = KernelRegistry::new();
349        let x = alloc_f32(&device, 4, vec![1, 4]);
350        let i = alloc_u32(&device, 5, vec![1, 5]);
351        let y = alloc_f32(&device, 5, vec![1, 5]);
352        let p = make_params(&device, 1, 4, 5);
353        let mut encoder = device.command_encoder().unwrap();
354        let res = dispatch_take_along_axis_f32(
355            &mut encoder, &mut registry, device.metal_device(),
356            &x, &i, &y, &p, 1, 4, 5,
357        );
358        assert!(res.is_err());
359    }
360}