[vk::constant_id(0)]
const int WORKGROUP_SIZE_X = 1;
[vk::constant_id(1)]
const int WORKGROUP_SIZE_Y = 1;
[vk::constant_id(2)]
const int WORKGROUP_SIZE_Z = 1;
struct PushConstants
{
uint n;
uint c;
uint in_h;
uint in_w;
uint out_h;
uint out_w;
uint k_h;
uint k_w;
uint s_h;
uint s_w;
uint d_h;
uint d_w;
uint pad_h;
uint pad_w;
}
[[vk::push_constant]]
PushConstants pc;
[shader("compute")]
[numthreads(WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y, WORKGROUP_SIZE_Z)]
void main<T : IArithmetic>(
StructuredBuffer<T> src,
RWStructuredBuffer<T> dst,
uint3 threadId: SV_DispatchThreadID)
{
uint out_x = threadId.x;
uint out_y = threadId.y;
uint z = threadId.z; // encodes batch * channel
if (out_x >= pc.out_w || out_y >= pc.out_h)
return;
uint total_channels = pc.n * pc.c;
if (z >= total_channels)
return;
uint batch = z / pc.c;
uint ch = z % pc.c;
T acc = T(0);
bool found = false;
for (uint ky = 0u; ky < pc.k_h; ++ky)
{
for (uint kx = 0u; kx < pc.k_w; ++kx)
{
int in_y_i = int(out_y) * int(pc.s_h) - int(pc.pad_h) + int(ky) * int(pc.d_h);
int in_x_i = int(out_x) * int(pc.s_w) - int(pc.pad_w) + int(kx) * int(pc.d_w);
if (in_y_i < 0 || in_x_i < 0)
continue;
uint in_y = uint(in_y_i);
uint in_x = uint(in_x_i);
if (in_y >= pc.in_h || in_x >= pc.in_w)
continue;
uint src_off = (((batch * pc.c + ch) * pc.in_h + in_y) * pc.in_w) + in_x;
T val = src[src_off];
if (!found)
{
acc = val;
found = true;
}
else if (val > acc)
{
acc = val;
}
}
}
uint dst_off = (((batch * pc.c + ch) * pc.out_h + out_y) * pc.out_w) + out_x;
if (found)
dst[dst_off] = acc;
else
dst[dst_off] = T(0);
}