use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static TAKE_ALONG_AXIS_SHADER_SOURCE: &str =
include_str!("../shaders/take_along_axis.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("take_along_axis_f32", TAKE_ALONG_AXIS_SHADER_SOURCE);
registry.register_source(
"take_along_axis_backward_f32",
TAKE_ALONG_AXIS_SHADER_SOURCE,
);
}
fn validate(
op: &str,
rows: u32,
cols: u32,
k: u32,
a: &MlxBuffer,
indices: &MlxBuffer,
out: &MlxBuffer,
params: &MlxBuffer,
expected_a: usize,
expected_out: usize,
) -> Result<()> {
if rows == 0 || cols == 0 || k == 0 {
return Err(MlxError::InvalidArgument(format!(
"{op}: rows, cols, k must all be > 0 (got {rows}, {cols}, {k})"
)));
}
if k > cols {
return Err(MlxError::InvalidArgument(format!(
"{op}: k ({k}) > cols ({cols})"
)));
}
if a.dtype() != DType::F32 || out.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"{op}: a/out must be f32"
)));
}
if indices.dtype() != DType::U32 {
return Err(MlxError::InvalidArgument(format!(
"{op}: indices dtype {} not u32",
indices.dtype()
)));
}
if a.element_count() != expected_a {
return Err(MlxError::InvalidArgument(format!(
"{op}: a element_count {} != {expected_a}",
a.element_count()
)));
}
if indices.element_count() != (rows as usize) * (k as usize) {
return Err(MlxError::InvalidArgument(format!(
"{op}: indices element_count {} != rows*k = {}",
indices.element_count(),
(rows as usize) * (k as usize)
)));
}
if out.element_count() != expected_out {
return Err(MlxError::InvalidArgument(format!(
"{op}: out element_count {} != {expected_out}",
out.element_count()
)));
}
if params.byte_len() < 12 {
return Err(MlxError::InvalidArgument(format!(
"{op}: params < 12 bytes (need 3 × u32)"
)));
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_take_along_axis_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
x: &MlxBuffer,
indices: &MlxBuffer,
y: &MlxBuffer,
params: &MlxBuffer,
rows: u32,
cols: u32,
k: u32,
) -> Result<()> {
const OP: &str = "take_along_axis_f32";
let r = rows as usize;
let c = cols as usize;
let k_us = k as usize;
validate(OP, rows, cols, k, x, indices, y, params, r * c, r * k_us)?;
let pipeline = registry.get_pipeline(OP, device)?;
encoder.encode(
pipeline,
&[(0, x), (1, indices), (2, y), (3, params)],
MTLSize::new(rows as u64, k as u64, 1),
MTLSize::new(
std::cmp::min(16, rows as u64),
std::cmp::min(16, k as u64),
1,
),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_take_along_axis_backward_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
dy: &MlxBuffer,
indices: &MlxBuffer,
dx: &MlxBuffer,
params: &MlxBuffer,
rows: u32,
cols: u32,
k: u32,
) -> Result<()> {
const OP: &str = "take_along_axis_backward_f32";
let r = rows as usize;
let c = cols as usize;
let k_us = k as usize;
validate(OP, rows, cols, k, dx, indices, dy, params, r * c, r * k_us)?;
let pipeline = registry.get_pipeline(OP, device)?;
encoder.encode(
pipeline,
&[(0, dy), (1, indices), (2, dx), (3, params)],
MTLSize::new(rows as u64, k as u64, 1),
MTLSize::new(
std::cmp::min(16, rows as u64),
std::cmp::min(16, k as u64),
1,
),
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::MlxDevice;
fn alloc_f32(d: &MlxDevice, n: usize, sh: Vec<usize>) -> MlxBuffer {
let mut b = d.alloc_buffer(n * 4, DType::F32, sh).unwrap();
b.as_mut_slice::<f32>().unwrap().fill(0.0);
b
}
fn alloc_u32(d: &MlxDevice, n: usize, sh: Vec<usize>) -> MlxBuffer {
let mut b = d.alloc_buffer(n * 4, DType::U32, sh).unwrap();
b.as_mut_slice::<u32>().unwrap().fill(0);
b
}
fn make_params(d: &MlxDevice, rows: u32, cols: u32, k: u32) -> MlxBuffer {
let mut p = d.alloc_buffer(12, DType::U32, vec![3]).unwrap();
p.as_mut_slice::<u32>().unwrap().copy_from_slice(&[rows, cols, k]);
p
}
#[test]
fn forward_matches_cpu_oracle() {
let device = MlxDevice::new().unwrap();
let mut registry = KernelRegistry::new();
let rows = 4;
let cols = 8;
let k = 3;
let x: Vec<f32> = (0..(rows * cols))
.map(|i| ((i as f32) * 0.137 - 0.4).sin() * 0.7)
.collect();
let indices: Vec<u32> = vec![
0, 3, 7,
1, 4, 6,
2, 5, 0,
7, 0, 4,
];
let mut x_buf = alloc_f32(&device, rows * cols, vec![rows, cols]);
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut idx_buf = alloc_u32(&device, rows * k, vec![rows, k]);
idx_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(&indices);
let y_buf = alloc_f32(&device, rows * k, vec![rows, k]);
let params = make_params(&device, rows as u32, cols as u32, k as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_take_along_axis_f32(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &idx_buf, &y_buf, ¶ms,
rows as u32, cols as u32, k as u32,
).unwrap();
encoder.commit_and_wait().unwrap();
let gpu = y_buf.as_slice::<f32>().unwrap();
for r in 0..rows {
for j in 0..k {
let idx = indices[r * k + j] as usize;
let expected = x[r * cols + idx];
assert!(
(gpu[r * k + j] - expected).abs() < 1e-6 * expected.abs().max(1.0),
"y[{r},{j}]: gpu={} expected={} (idx={})",
gpu[r * k + j], expected, idx
);
}
}
}
#[test]
fn backward_scatter_matches_cpu_oracle() {
let device = MlxDevice::new().unwrap();
let mut registry = KernelRegistry::new();
let rows = 3;
let cols = 6;
let k = 2;
let dy: Vec<f32> = (0..(rows * k))
.map(|i| ((i as f32) * 0.231 + 0.1).sin() * 0.6)
.collect();
let indices: Vec<u32> = vec![
0, 4,
1, 5,
2, 3,
];
let mut dy_buf = alloc_f32(&device, rows * k, vec![rows, k]);
dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
let mut idx_buf = alloc_u32(&device, rows * k, vec![rows, k]);
idx_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(&indices);
let dx_buf = alloc_f32(&device, rows * cols, vec![rows, cols]);
let params = make_params(&device, rows as u32, cols as u32, k as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_take_along_axis_backward_f32(
&mut encoder, &mut registry, device.metal_device(),
&dy_buf, &idx_buf, &dx_buf, ¶ms,
rows as u32, cols as u32, k as u32,
).unwrap();
encoder.commit_and_wait().unwrap();
let gpu = dx_buf.as_slice::<f32>().unwrap();
let mut expected = vec![0.0f32; rows * cols];
for r in 0..rows {
for j in 0..k {
let idx = indices[r * k + j] as usize;
expected[r * cols + idx] = dy[r * k + j];
}
}
for i in 0..(rows * cols) {
assert!(
(gpu[i] - expected[i]).abs() < 1e-6,
"dx[{i}]: gpu={} expected={}",
gpu[i], expected[i]
);
}
}
#[test]
fn backward_finite_difference_falsifier() {
let device = MlxDevice::new().unwrap();
let mut registry = KernelRegistry::new();
let rows = 4;
let cols = 6;
let k = 2;
let x: Vec<f32> = (0..(rows * cols))
.map(|i| 0.3 + (i as f32) * 0.013)
.collect();
let indices: Vec<u32> = vec![
0, 3,
1, 5,
2, 4,
0, 4,
];
let mut x_buf = alloc_f32(&device, rows * cols, vec![rows, cols]);
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut idx_buf = alloc_u32(&device, rows * k, vec![rows, k]);
idx_buf.as_mut_slice::<u32>().unwrap().copy_from_slice(&indices);
let y_buf = alloc_f32(&device, rows * k, vec![rows, k]);
let params = make_params(&device, rows as u32, cols as u32, k as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_take_along_axis_f32(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &idx_buf, &y_buf, ¶ms,
rows as u32, cols as u32, k as u32,
).unwrap();
encoder.commit_and_wait().unwrap();
let dy_ones = vec![1.0f32; rows * k];
let mut dy_buf = alloc_f32(&device, rows * k, vec![rows, k]);
dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy_ones);
let dx_buf = alloc_f32(&device, rows * cols, vec![rows, cols]);
let mut encoder = device.command_encoder().unwrap();
dispatch_take_along_axis_backward_f32(
&mut encoder, &mut registry, device.metal_device(),
&dy_buf, &idx_buf, &dx_buf, ¶ms,
rows as u32, cols as u32, k as u32,
).unwrap();
encoder.commit_and_wait().unwrap();
let dx = dx_buf.as_slice::<f32>().unwrap().to_vec();
let h = 1e-3f64;
let loss = |x_in: &[f32]| -> f64 {
let mut s = 0.0f64;
for r in 0..rows {
for j in 0..k {
s += x_in[r * cols + indices[r * k + j] as usize] as f64;
}
}
s
};
for i in 0..(rows * cols) {
let mut xp = x.clone(); xp[i] += h as f32;
let mut xm = x.clone(); xm[i] -= h as f32;
let fd = (loss(&xp) - loss(&xm)) / (2.0 * h);
let tol = 1e-3 * fd.abs().max(1.0);
assert!(
(dx[i] as f64 - fd).abs() < tol,
"FD x[{i}]: analytic={} fd={}", dx[i], fd
);
}
}
#[test]
fn rejects_k_greater_than_cols() {
let device = MlxDevice::new().unwrap();
let mut registry = KernelRegistry::new();
let x = alloc_f32(&device, 4, vec![1, 4]);
let i = alloc_u32(&device, 5, vec![1, 5]);
let y = alloc_f32(&device, 5, vec![1, 5]);
let p = make_params(&device, 1, 4, 5);
let mut encoder = device.command_encoder().unwrap();
let res = dispatch_take_along_axis_f32(
&mut encoder, &mut registry, device.metal_device(),
&x, &i, &y, &p, 1, 4, 5,
);
assert!(res.is_err());
}
}