1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! Strided indexing kernels — `gather_front`, `scatter`, `masked_scatter`.
//!
//! The contiguous along-an-axis forms (`gather_axis` / `scatter_axis`)
//! ship in their own modules. This file covers the three remaining
//! `indexing/` ops from MLX's `mlx/backend/metal/kernels/indexing/`:
//!
//! - **`gather_front`** — gather whole rows by a first-axis index:
//! `out[r, :] = src[indices[r], :]`. The embedding-table-style
//! row gather where the index selects which source row to copy.
//! MLX reference: `indexing/gather_front.h`.
//! - **`scatter`** — the inverse: write rows into index-selected slots
//! of a pre-initialized output, `out[indices[r], :] = updates[r, :]`.
//! Assignment form (`reduce = None`) — colliding indices race, so the
//! caller must supply distinct indices for a deterministic result,
//! matching MLX `scatter` with no reduction.
//! - **`masked_scatter`** — gather with a per-element mask:
//! `out[i] = mask[i] ? src[scatter_offsets[i]] : out[i]`. The masked
//! elements pull from a compacted `src` via a precomputed offset
//! table; unmasked elements keep `out`'s prior value. MLX reference:
//! `indexing/masked_scatter.h`.
//!
//! All three are one-thread-per-output Grid3D kernels — no cross-thread
//! cooperation, so the reduction-mode dispatch hazards do not apply.
//! Indices / offsets / mask are `u32` tensors (a `0/1` mask rather than
//! a `bool` tensor — `u32` is the dtype the DSL exposes for index
//! buffers, and the caller packs the mask as `0u32` / `1u32`).
//!
//! Codegen-only; correctness pinned by
//! `tests/indexing_gpu_correctness.rs`.
use kernel;
use KernelMode;
use crate::;
/// First-axis row gather — `out[r, i] = src[indices[r], i]`.
///
/// `src` is `[n_src_rows, row_width]`, `indices` is `[n_out_rows]`
/// (u32), `out` is `[n_out_rows, row_width]`. One thread per output
/// element; the output element `idx` decomposes into `(r, i)` and the
/// source row is looked up from `indices[r]`.
///
/// `n_elems = n_out_rows * row_width` is passed as a constexpr so
/// threads past the output (a Grid3D dispatch rounds the thread count
/// up to a multiple of TPG) early-out — they must not read `indices`
/// out of bounds or write a stray `out` slot.
/// First-axis row scatter — `out[indices[r], i] = updates[r, i]`.
///
/// `updates` is `[n_upd_rows, row_width]`, `indices` is `[n_upd_rows]`
/// (u32), `out` is `[n_out_rows, row_width]` and is pre-initialized by
/// the caller (typically a copy of the source). One thread per update
/// element. Assignment (no-reduce) form — distinct `indices` are
/// required for a deterministic result; colliding indices race, the
/// same contract as MLX `scatter` with `reduce = None`.
///
/// `n_elems = n_upd_rows * row_width` is passed as a constexpr so
/// threads past the update count early-out — without the guard a
/// stray thread reads `indices` / `updates` out of bounds and scatters
/// garbage into `out`.
/// Masked gather-scatter — `out[i] = mask[i] ? src[offsets[i]] : out[i]`.
///
/// One thread per output element. `mask` is a `u32` `0/1` buffer the
/// same length as `out`; `offsets` (also `u32`, same length) is the
/// precomputed compacted-`src` index for each masked position. Where
/// the mask is `0` the thread re-reads and re-writes `out`'s prior
/// value (a no-op store rather than a branch — keeps the kernel
/// branch-divergence-free). `out` must be pre-initialized.
///
/// MLX's reference compacts `src` to one batch's worth of rows and
/// derives `batch_idx` from a `mask_batch_size`; this port flattens to
/// the single-batch case (`offsets` already absolute into `src`),
/// which is what the FFAI masked-cache-update path needs.
submit!
submit!
submit!