#version 450
#include "types.glsl"
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // columns: [K_OC, T_in]
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; // output: [T_out, OC]
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (push_constant) uniform parameter {
uint32_t T_out;
uint32_t OC;
uint32_t K_OC;
uint32_t T_in;
uint32_t K;
int32_t stride;
int32_t p0;
} p;
// Load A_TYPE to float
float load_col(uint32_t idx) {
#if defined(DATA_A_BF16)
return bf16_to_fp32(uint32_t(data_a[idx]));
#else
return float(data_a[idx]);
#endif
}
// Store float as D_TYPE
void store_dst(uint32_t idx, float v) {
#if defined(DATA_A_BF16)
data_d[idx] = D_TYPE(fp32_to_bf16(v));
#else
data_d[idx] = D_TYPE(v);
#endif
}
void main() {
const uint32_t t_out = gl_GlobalInvocationID.x;
const uint32_t oc = gl_GlobalInvocationID.y;
if (t_out >= p.T_out || oc >= p.OC) return;
const int32_t t_abs = int32_t(t_out) + p.p0; // absolute position in uncropped signal
// Gather: only the ceil(K/stride) columns that scatter into t_abs, no modulo
int32_t t_in_min = (t_abs - int32_t(p.K) + p.stride) / p.stride;
if (t_in_min < 0) t_in_min = 0;
int32_t t_in_max = t_abs / p.stride;
if (t_in_max >= int32_t(p.T_in)) t_in_max = int32_t(p.T_in) - 1;
float val = 0.0;
for (int32_t t_in = t_in_min; t_in <= t_in_max; t_in++) {
int32_t k = t_abs - t_in * p.stride;
// col layout: [K_OC, T_in], column index = oc * K + k
uint32_t col_idx = (oc * p.K + uint32_t(k)) + uint32_t(t_in) * p.K_OC;
val += load_col(col_idx);
}
// dst layout: [T_out, OC], element (t_out, oc) = t_out + oc * T_out
store_dst(t_out + oc * p.T_out, val);
}