1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! Strided copy benchmark — #[kernel] DSL vs MLX metal/copy.metal
use ;
use KernelMode;
use crate::;
// ─── mt_strided_copy_nd ──────────────────────────────────────────────────
//
// General N-D strided copy — the MLX `copy_g` / `copy_g_nd{1,2,3}`
// counterpart. The 2-D `mt_strided_copy` above only handles a
// row-major-padded `[rows, cols]` source; this kernel copies an
// arbitrary-rank logical tensor out of a source buffer whose physical
// layout is described by per-dimension `shape` + `strides` arrays.
//
// The destination is always contiguous row-major: output element `p`
// (a flat index in `[0, n_out)`) maps to the multi-index obtained by
// unravelling `p` against `out_shape` (== logical `shape`), then the
// source byte offset is `Σ_d coord_d · strides[d]`. This is exactly
// MLX's `elem_to_loc` (`mlx/backend/metal/kernels/utils.h`).
//
// Because the source strides are *arbitrary* (not necessarily a
// padded row-major view), this generalises:
// - padded copies (the 2-D `mt_strided_copy` case),
// - transposes (strides permuted vs shape),
// - broadcasts (a stride of 0 on a broadcast axis),
// - any slice / dilation (non-unit innermost stride).
//
// Inputs:
// src — source data buffer (raw, physically strided)
// shape — [rank] u32 logical extent of each dimension
// strides — [rank] u32 element stride of each source dimension
// out — [n_out] contiguous row-major output
//
// Constexpr:
// rank — number of dimensions (logical). Compile-time constant so
// the unravel loop is fully unrolled — no dynamic trip count.
//
// ## DISPATCH INVARIANTS
//
// - **Mode: Grid3D** — one thread per output element, no cross-thread
// cooperation. `program_id::<0>()` is the flat output index.
// - **Grid: `[n_out, 1, 1]`, TPG: `[1, 1, 1]`** (or any
// `grid·tpg == n_out` split). `n_out == Π shape[d]`.
// - **`rank >= 1`.** `shape` and `strides` must each hold exactly
// `rank` u32 entries; a short buffer reads out of bounds.
// - The unravel walks dimensions **last → first**: the running
// remainder is divided by `shape[d]` from `d = rank-1` down to `0`,
// so `strides` is interpreted in the same major-to-minor order as
// `shape` (row-major logical indexing).
submit!