#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#include "types.glsl"
layout (push_constant) uniform parameter
{
BDA_STORAGE_T dst_addr;
uint batch_offset; uint offset_delta;
uint IC;
uint IW; uint IH;
uint OW; uint OH;
uint KW; uint KH;
uint OH_batch;
uint CHW;
int s0; int s1;
int p0; int p1;
int d0; int d1;
uint batch_IC;
} p;
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
const uint NUM_ITER = 512 / BLOCK_SIZE;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
#if BDA
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void im2col(const uint ow, const uint z_idx) {
const uint oh = z_idx % p.OH;
const uint batch_idx = z_idx / p.OH;
const uint gidx = gl_LocalInvocationID.x;
const uint src_batch = batch_idx * p.batch_offset;
const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW;
const uint KHKW = p.KH * p.KW;
// Precompute base input coordinates
const int base_iw = int(ow * p.s0) - p.p0;
const int base_ih = int(oh * p.s1) - p.p1;
// Precompute step deltas
const uint delta_ic = BLOCK_SIZE / KHKW;
const uint delta_rem = BLOCK_SIZE % KHKW;
const uint delta_ky = delta_rem / p.KW;
const uint delta_kx = delta_rem % p.KW;
const uint delta_ic_offset = delta_ic * p.offset_delta;
// If using BDA mode, precompute the base pointer and step size
#if BDA
const BDA_STORAGE_T base_dst_addr = p.dst_addr + D_SIZE * dst_row;
const uint bda_step = D_SIZE * BLOCK_SIZE;
#endif
uint wg_x = gl_WorkGroupID.x;
do {
const uint wg_offset = wg_x * 512;
uint chw_idx = wg_offset + gidx;
uint ic = chw_idx / KHKW;
uint rem = chw_idx % KHKW;
uint ky = rem / p.KW;
uint kx = rem % p.KW;
uint ic_offset = src_batch + ic * p.offset_delta;
// Initialize running pointer/index for the destination buffer
#if BDA
BDA_STORAGE_T current_dst_addr = base_dst_addr + D_SIZE * chw_idx;
#else
uint current_dst_idx = dst_row + chw_idx;
#endif
[[unroll]] for (uint i = 0; i < NUM_ITER; ++i) {
if (chw_idx >= p.CHW) {
return;
}
const int iiw = base_iw + int(kx * p.d0);
const int iih = base_ih + int(ky * p.d1);
A_TYPE val = A_TYPE(0);
if (uint(iih) < p.IH && uint(iiw) < p.IW) {
val = data_a[ic_offset + uint(iih) * p.IW + uint(iiw)];
}
#if BDA
D_ptr(current_dst_addr).d = D_TYPE(val);
current_dst_addr += bda_step;
#else
data_d[current_dst_idx] = D_TYPE(val);
current_dst_idx += BLOCK_SIZE;
#endif
chw_idx += BLOCK_SIZE;
ic_offset += delta_ic_offset;
kx += delta_kx;
ky += delta_ky;
// Handle X axis wrap
uint kx_wrap = uint(kx >= p.KW);
kx -= kx_wrap * p.KW;
ky += kx_wrap;
// Handle Y axis wrap
uint ky_wrap = uint(ky >= p.KH);
ky -= ky_wrap * p.KH;
ic_offset += ky_wrap * p.offset_delta;
}
wg_x += gl_NumWorkGroups.x;
} while (wg_x * 512 < p.CHW);
}
void main() {
uint ow = gl_GlobalInvocationID.y;
while (ow < p.OW) {
uint z = gl_GlobalInvocationID.z;
while (z < p.OH_batch) {
im2col(ow, z);
z += gl_NumWorkGroups.z;
}
ow += gl_NumWorkGroups.y;
}
}