// Stride-aware buffer copy. Used by the Vulkan backend to pad / strip
// innermost rows for R2C / C2R without issuing N_outer * N_middle
// vkCmdCopyBuffer regions.
//
// The shader is type-agnostic: buffers are viewed as arrays of uint
// (4 bytes each). Callers express row sizes and strides in uint-words,
// so one dispatch works for every scalar precision.
//
// R2C pad : src_stride = innermost (tight reals)
// dst_stride = 2*(innermost/2+1) (padded reals)
// row_elems = innermost
// C2R strip: src_stride = 2*(innermost/2+1) (padded reals)
// dst_stride = innermost (tight reals)
// row_elems = innermost
//
// "elems" here means "uint-words"; for f32 there is one uint per
// scalar, for f64 there are two.
#version 450
layout(local_size_x = 64) in;
layout(std430, binding = 0) readonly buffer Src { uint src[]; };
layout(std430, binding = 1) writeonly buffer Dst { uint dst[]; };
layout(push_constant) uniform Push {
uint row_uints; // uint-words to copy per row
uint src_stride_uints; // uint-words from start of row i to start of row i+1 in src
uint dst_stride_uints; // uint-words from start of row i to start of row i+1 in dst
uint n_rows; // number of rows (= outer-axes product * batch)
} pc;
void main() {
uint tid = gl_GlobalInvocationID.x;
uint total = pc.row_uints * pc.n_rows;
if (tid >= total) return;
uint row = tid / pc.row_uints;
uint col = tid - row * pc.row_uints;
dst[row * pc.dst_stride_uints + col] = src[row * pc.src_stride_uints + col];
}