1use metal::MTLSize;
11
12use crate::buffer::MlxBuffer;
13use crate::dtypes::DType;
14use crate::encoder::CommandEncoder;
15use crate::error::{MlxError, Result};
16use crate::kernel_registry::KernelRegistry;
17
18pub static TAKE_ALONG_AXIS_SHADER_SOURCE: &str =
19 include_str!("../shaders/take_along_axis.metal");
20
21pub fn register(registry: &mut KernelRegistry) {
22 registry.register_source("take_along_axis_f32", TAKE_ALONG_AXIS_SHADER_SOURCE);
23 registry.register_source(
24 "take_along_axis_backward_f32",
25 TAKE_ALONG_AXIS_SHADER_SOURCE,
26 );
27}
28
29fn validate(
30 op: &str,
31 rows: u32,
32 cols: u32,
33 k: u32,
34 a: &MlxBuffer,
35 indices: &MlxBuffer,
36 out: &MlxBuffer,
37 params: &MlxBuffer,
38 expected_a: usize,
39 expected_out: usize,
40) -> Result<()> {
41 if rows == 0 || cols == 0 || k == 0 {
42 return Err(MlxError::InvalidArgument(format!(
43 "{op}: rows, cols, k must all be > 0 (got {rows}, {cols}, {k})"
44 )));
45 }
46 if k > cols {
47 return Err(MlxError::InvalidArgument(format!(
48 "{op}: k ({k}) > cols ({cols})"
49 )));
50 }
51 if a.dtype() != DType::F32 || out.dtype() != DType::F32 {
52 return Err(MlxError::InvalidArgument(format!(
53 "{op}: a/out must be f32"
54 )));
55 }
56 if indices.dtype() != DType::U32 {
57 return Err(MlxError::InvalidArgument(format!(
58 "{op}: indices dtype {} not u32",
59 indices.dtype()
60 )));
61 }
62 if a.element_count() != expected_a {
63 return Err(MlxError::InvalidArgument(format!(
64 "{op}: a element_count {} != {expected_a}",
65 a.element_count()
66 )));
67 }
68 if indices.element_count() != (rows as usize) * (k as usize) {
69 return Err(MlxError::InvalidArgument(format!(
70 "{op}: indices element_count {} != rows*k = {}",
71 indices.element_count(),
72 (rows as usize) * (k as usize)
73 )));
74 }
75 if out.element_count() != expected_out {
76 return Err(MlxError::InvalidArgument(format!(
77 "{op}: out element_count {} != {expected_out}",
78 out.element_count()
79 )));
80 }
81 if params.byte_len() < 12 {
82 return Err(MlxError::InvalidArgument(format!(
83 "{op}: params < 12 bytes (need 3 × u32)"
84 )));
85 }
86 Ok(())
87}
88
89#[allow(clippy::too_many_arguments)]
90pub fn dispatch_take_along_axis_f32(
91 encoder: &mut CommandEncoder,
92 registry: &mut KernelRegistry,
93 device: &metal::DeviceRef,
94 x: &MlxBuffer,
95 indices: &MlxBuffer,
96 y: &MlxBuffer,
97 params: &MlxBuffer,
98 rows: u32,
99 cols: u32,
100 k: u32,
101) -> Result<()> {
102 const OP: &str = "take_along_axis_f32";
103 let r = rows as usize;
104 let c = cols as usize;
105 let k_us = k as usize;
106 validate(OP, rows, cols, k, x, indices, y, params, r * c, r * k_us)?;
107
108 let pipeline = registry.get_pipeline(OP, device)?;
109 encoder.encode(
110 pipeline,
111 &[(0, x), (1, indices), (2, y), (3, params)],
112 MTLSize::new(rows as u64, k as u64, 1),
113 MTLSize::new(
114 std::cmp::min(16, rows as u64),
115 std::cmp::min(16, k as u64),
116 1,
117 ),
118 );
119 Ok(())
120}
121
122#[allow(clippy::too_many_arguments)]
123pub fn dispatch_take_along_axis_backward_f32(
124 encoder: &mut CommandEncoder,
125 registry: &mut KernelRegistry,
126 device: &metal::DeviceRef,
127 dy: &MlxBuffer,
128 indices: &MlxBuffer,
129 dx: &MlxBuffer,
130 params: &MlxBuffer,
131 rows: u32,
132 cols: u32,
133 k: u32,
134) -> Result<()> {
135 const OP: &str = "take_along_axis_backward_f32";
136 let r = rows as usize;
137 let c = cols as usize;
138 let k_us = k as usize;
139 validate(OP, rows, cols, k, dx, indices, dy, params, r * c, r * k_us)?;
140
141 let pipeline = registry.get_pipeline(OP, device)?;
142 encoder.encode(
143 pipeline,
144 &[(0, dy), (1, indices), (2, dx), (3, params)],
145 MTLSize::new(rows as u64, k as u64, 1),
146 MTLSize::new(
147 std::cmp::min(16, rows as u64),
148 std::cmp::min(16, k as u64),
149 1,
150 ),
151 );
152 Ok(())
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::device::MlxDevice;
159
160 fn alloc_f32(d: &MlxDevice, n: usize, sh: Vec<usize>) -> MlxBuffer {
161 let mut b = d.alloc_buffer(n * 4, DType::F32, sh).unwrap();
162 b.as_mut_slice::<f32>().unwrap().fill(0.0);
163 b
164 }
165 fn alloc_u32(d: &MlxDevice, n: usize, sh: Vec<usize>) -> MlxBuffer {
166 let mut b = d.alloc_buffer(n * 4, DType::U32, sh).unwrap();
167 b.as_mut_slice::<u32>().unwrap().fill(0);
168 b
169 }
170 fn make_params(d: &MlxDevice, rows: u32, cols: u32, k: u32) -> MlxBuffer {
171 let mut p = d.alloc_buffer(12, DType::U32, vec![3]).unwrap();
172 p.as_mut_slice::<u32>().unwrap().copy_from_slice(&[rows, cols, k]);
173 p
174 }
175
176 #[test]
177 fn forward_matches_cpu_oracle() {
178 let device = MlxDevice::new().unwrap();
179 let mut registry = KernelRegistry::new();
180 let rows = 4;
181 let cols = 8;
182 let k = 3;
183 let x: Vec<f32> = (0..(rows * cols))
184 .map(|i| ((i as f32) * 0.137 - 0.4).sin() * 0.7)
185 .collect();
186 let indices: Vec<u32> = vec![
189 0, 3, 7,
190 1, 4, 6,
191 2, 5, 0,
192 7, 0, 4,
193 ];
194
195 let mut x_buf = alloc_f32(&device, rows * cols, vec![rows, cols]);
196 x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
197 let mut idx_buf = alloc_u32(&device, rows * k, vec![rows, k]);
198 idx_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(&indices);
199 let y_buf = alloc_f32(&device, rows * k, vec![rows, k]);
200 let params = make_params(&device, rows as u32, cols as u32, k as u32);
201
202 let mut encoder = device.command_encoder().unwrap();
203 dispatch_take_along_axis_f32(
204 &mut encoder, &mut registry, device.metal_device(),
205 &x_buf, &idx_buf, &y_buf, ¶ms,
206 rows as u32, cols as u32, k as u32,
207 ).unwrap();
208 encoder.commit_and_wait().unwrap();
209
210 let gpu = y_buf.as_slice::<f32>().unwrap();
211 for r in 0..rows {
212 for j in 0..k {
213 let idx = indices[r * k + j] as usize;
214 let expected = x[r * cols + idx];
215 assert!(
216 (gpu[r * k + j] - expected).abs() < 1e-6 * expected.abs().max(1.0),
217 "y[{r},{j}]: gpu={} expected={} (idx={})",
218 gpu[r * k + j], expected, idx
219 );
220 }
221 }
222 }
223
224 #[test]
225 fn backward_scatter_matches_cpu_oracle() {
226 let device = MlxDevice::new().unwrap();
227 let mut registry = KernelRegistry::new();
228 let rows = 3;
229 let cols = 6;
230 let k = 2;
231 let dy: Vec<f32> = (0..(rows * k))
232 .map(|i| ((i as f32) * 0.231 + 0.1).sin() * 0.6)
233 .collect();
234 let indices: Vec<u32> = vec![
236 0, 4,
237 1, 5,
238 2, 3,
239 ];
240
241 let mut dy_buf = alloc_f32(&device, rows * k, vec![rows, k]);
242 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
243 let mut idx_buf = alloc_u32(&device, rows * k, vec![rows, k]);
244 idx_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(&indices);
245 let dx_buf = alloc_f32(&device, rows * cols, vec![rows, cols]);
246 let params = make_params(&device, rows as u32, cols as u32, k as u32);
247
248 let mut encoder = device.command_encoder().unwrap();
249 dispatch_take_along_axis_backward_f32(
250 &mut encoder, &mut registry, device.metal_device(),
251 &dy_buf, &idx_buf, &dx_buf, ¶ms,
252 rows as u32, cols as u32, k as u32,
253 ).unwrap();
254 encoder.commit_and_wait().unwrap();
255
256 let gpu = dx_buf.as_slice::<f32>().unwrap();
257 let mut expected = vec![0.0f32; rows * cols];
259 for r in 0..rows {
260 for j in 0..k {
261 let idx = indices[r * k + j] as usize;
262 expected[r * cols + idx] = dy[r * k + j];
263 }
264 }
265 for i in 0..(rows * cols) {
266 assert!(
267 (gpu[i] - expected[i]).abs() < 1e-6,
268 "dx[{i}]: gpu={} expected={}",
269 gpu[i], expected[i]
270 );
271 }
272 }
273
274 #[test]
277 fn backward_finite_difference_falsifier() {
278 let device = MlxDevice::new().unwrap();
279 let mut registry = KernelRegistry::new();
280 let rows = 4;
281 let cols = 6;
282 let k = 2;
283 let x: Vec<f32> = (0..(rows * cols))
284 .map(|i| 0.3 + (i as f32) * 0.013)
285 .collect();
286 let indices: Vec<u32> = vec![
287 0, 3,
288 1, 5,
289 2, 4,
290 0, 4,
291 ];
292
293 let mut x_buf = alloc_f32(&device, rows * cols, vec![rows, cols]);
295 x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
296 let mut idx_buf = alloc_u32(&device, rows * k, vec![rows, k]);
297 idx_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(&indices);
298 let y_buf = alloc_f32(&device, rows * k, vec![rows, k]);
299 let params = make_params(&device, rows as u32, cols as u32, k as u32);
300 let mut encoder = device.command_encoder().unwrap();
301 dispatch_take_along_axis_f32(
302 &mut encoder, &mut registry, device.metal_device(),
303 &x_buf, &idx_buf, &y_buf, ¶ms,
304 rows as u32, cols as u32, k as u32,
305 ).unwrap();
306 encoder.commit_and_wait().unwrap();
307
308 let dy_ones = vec![1.0f32; rows * k];
310 let mut dy_buf = alloc_f32(&device, rows * k, vec![rows, k]);
311 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy_ones);
312 let dx_buf = alloc_f32(&device, rows * cols, vec![rows, cols]);
313 let mut encoder = device.command_encoder().unwrap();
314 dispatch_take_along_axis_backward_f32(
315 &mut encoder, &mut registry, device.metal_device(),
316 &dy_buf, &idx_buf, &dx_buf, ¶ms,
317 rows as u32, cols as u32, k as u32,
318 ).unwrap();
319 encoder.commit_and_wait().unwrap();
320 let dx = dx_buf.as_slice::<f32>().unwrap().to_vec();
321
322 let h = 1e-3f64;
324 let loss = |x_in: &[f32]| -> f64 {
325 let mut s = 0.0f64;
326 for r in 0..rows {
327 for j in 0..k {
328 s += x_in[r * cols + indices[r * k + j] as usize] as f64;
329 }
330 }
331 s
332 };
333 for i in 0..(rows * cols) {
334 let mut xp = x.clone(); xp[i] += h as f32;
335 let mut xm = x.clone(); xm[i] -= h as f32;
336 let fd = (loss(&xp) - loss(&xm)) / (2.0 * h);
337 let tol = 1e-3 * fd.abs().max(1.0);
338 assert!(
339 (dx[i] as f64 - fd).abs() < tol,
340 "FD x[{i}]: analytic={} fd={}", dx[i], fd
341 );
342 }
343 }
344
345 #[test]
346 fn rejects_k_greater_than_cols() {
347 let device = MlxDevice::new().unwrap();
348 let mut registry = KernelRegistry::new();
349 let x = alloc_f32(&device, 4, vec![1, 4]);
350 let i = alloc_u32(&device, 5, vec![1, 5]);
351 let y = alloc_f32(&device, 5, vec![1, 5]);
352 let p = make_params(&device, 1, 4, 5);
353 let mut encoder = device.command_encoder().unwrap();
354 let res = dispatch_take_along_axis_f32(
355 &mut encoder, &mut registry, device.metal_device(),
356 &x, &i, &y, &p, 1, 4, 5,
357 );
358 assert!(res.is_err());
359 }
360}