1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! Scatter along an axis — contiguous form of MLX's `scatter_axis`.
//!
//! `out[o, indices[o, a, i], i] = updates[o, a, i]` — each update
//! element is written to a row-`indices`-selected slot of `out`. One
//! thread per update element. `out` is pre-initialized by the caller
//! (typically a copy of the source) and the kernel overwrites the
//! scattered slots.
//!
//! Layout (row-contiguous):
//! updates: [outer, axis_upd, inner] T
//! indices: [outer, axis_upd, inner] u32
//! out: [outer, axis_size, inner] T (pre-initialized)
//!
//! Assignment (no-reduce) form: distinct `indices` are required for a
//! deterministic result — colliding indices race, matching MLX
//! `scatter_axis` with `reduce = None`. The general strided + reducing
//! kernel is a follow-up.
//!
//! ## DISPATCH INVARIANTS
//!
//! - **Grid3D**, one thread per update element over `outer*axis_upd*inner`.
//!
//! Codegen-only; correctness pinned by
//! `tests/scatter_axis_gpu_correctness.rs`.
use metaltile::kernel;
use metaltile_core::ir::KernelMode;
use crate::{
bench_types::DType,
spec::{BenchDispatch, BenchSpec},
};
#[kernel]
pub fn mt_scatter_axis<T>(
updates: Tensor<T>,
indices: Tensor<u32>,
mut out: Tensor<T>,
#[constexpr] axis_upd: u32,
#[constexpr] axis_size: u32,
#[constexpr] inner: u32,
) {
let idx = program_id::<0>();
let i = idx - (idx / inner) * inner;
let o = idx / (axis_upd * inner);
let scattered = load(indices[idx]);
let out_off = (o * axis_size + scattered) * inner + i;
store(out[out_off], load(updates[idx]));
}
inventory::submit! {
BenchSpec {
op: "indexing",
subop: "scatter_axis",
kernel_name: "mt_scatter_axis",
kernel_ir: mt_scatter_axis::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 0.0,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Grid3D),
}
}