#version 450
// 2D strided block copy: copy `d1` rows of `d2` contiguous f32 elements from `inp`
// to `outp`, with independent per-row source/dest strides and base offsets. This is
// candle's copy2d primitive, which powers `Tensor::cat` along the inner dims and
// `slice_set` (the KV-cache append + GQA repeat_kv path in a transformer forward).
// Replaces the previous GPU->CPU->GPU round-trip with one on-GPU dispatch.
//
// One invocation per copied element (total = d1*d2): decode the linear id `gid` into
// (row, col), then outp[dst_offset + row*dst_stride1 + col] = inp[src_offset +
// row*src_stride1 + col]. Out-of-range writes (gid >= d1*d2) are skipped. The output
// buffer keeps any elements this copy does not address (we never zero it), matching
// the CPU implementation, so successive cat/slice_set writes into one buffer compose.
//
// Buffers are typed `uint`, not `float`: this kernel only moves 4-byte words and must be
// bit-exact for BOTH f32 and u32 storage (Vulkan represents u8/u32 ids and the dtype-as-
// f32 reps in 4-byte buffers). A `uint` load/store copies the raw bits verbatim, whereas a
// `float` copy can flush denormals / canonicalize NaN payloads on load — which would corrupt
// u32 values that happen to encode denormal-looking floats. Same reasoning as `contiguous_u32`.
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer In { uint inp[]; };
layout(set = 0, binding = 1) buffer Out { uint outp[]; };
layout(push_constant) uniform Pc {
uint d1; // number of rows
uint d2; // contiguous elements per row
uint src_stride1; // elements between consecutive source rows
uint dst_stride1; // elements between consecutive dest rows
uint src_offset; // base element offset into inp
uint dst_offset; // base element offset into outp
};
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = d1 * d2;
if (gid < total) {
uint row = gid / d2;
uint col = gid - row * d2;
uint s = src_offset + row * src_stride1 + col;
uint d = dst_offset + row * dst_stride1 + col;
outp[d] = inp[s];
}
}