#version 450
#extension GL_EXT_control_flow_attributes : require
#ifndef FWHT_SHMEM
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#endif
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(constant_id = 1) const uint N = 128;
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(push_constant) uniform parameter
{
uint n_rows;
uint src_offset;
uint dst_offset;
float scale;
};
layout(binding = 0, std430) readonly buffer A { float data_a[]; };
layout(binding = 1, std430) writeonly buffer D { float data_d[]; };
const uint EL_W = N / BLOCK_SIZE;
#ifdef FWHT_SHMEM
shared float shmem[4 * N];
#endif
void main() {
#ifdef FWHT_SHMEM
const uint tid = gl_LocalInvocationID.x;
const uint shmem_base = gl_LocalInvocationID.y * N;
const uint row_id = gl_LocalInvocationID.y;
#else
const uint tid = gl_SubgroupInvocationID;
const uint row_id = gl_SubgroupID;
#endif
for (uint base_row = gl_WorkGroupID.x * gl_WorkGroupSize.y;
base_row < n_rows;
base_row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) {
const uint row = base_row + row_id;
const uint row_offset = row * N;
#ifndef FWHT_SHMEM
if (row >= n_rows) {
continue;
}
#endif
float reg[EL_W];
[[unroll]]
for (uint i = 0; i < EL_W; ++i) {
reg[i] = row < n_rows ? data_a[src_offset + row_offset + i * BLOCK_SIZE + tid] * scale : 0.0;
}
#ifdef FWHT_SHMEM
[[unroll]]
for (uint h = 1; h < BLOCK_SIZE; h <<= 1) {
[[unroll]]
for (uint i = 0; i < EL_W; ++i) {
shmem[shmem_base + i * BLOCK_SIZE + tid] = reg[i];
}
barrier();
[[unroll]]
for (uint j = 0; j < EL_W; ++j) {
const float val = reg[j];
const float other = shmem[shmem_base + j * BLOCK_SIZE + (tid ^ h)];
reg[j] = (tid & h) == 0 ? val + other : other - val;
}
barrier();
}
#else
[[unroll]]
for (uint h = 1; h < BLOCK_SIZE; h <<= 1) {
[[unroll]]
for (uint j = 0; j < EL_W; ++j) {
const float val = reg[j];
const float val2 = subgroupShuffleXor(val, h);
reg[j] = (tid & h) == 0 ? val + val2 : val2 - val;
}
}
#endif
[[unroll]]
for (uint h = BLOCK_SIZE; h < N; h <<= 1) {
const uint step = h / BLOCK_SIZE;
[[unroll]]
for (uint j = 0; j < EL_W; j += 2 * step) {
[[unroll]]
for (uint k = 0; k < step; ++k) {
const float x = reg[j + k];
const float y = reg[j + k + step];
reg[j + k] = x + y;
reg[j + k + step] = x - y;
}
}
}
#ifdef FWHT_SHMEM
if (row < n_rows) {
#endif
[[unroll]]
for (uint i = 0; i < EL_W; ++i) {
data_d[dst_offset + row_offset + i * BLOCK_SIZE + tid] = reg[i];
}
#ifdef FWHT_SHMEM
}
barrier();
#endif
}
}