#version 450
// Fused element-wise addition + ReLU activation
// Computes: ReLU(a + b) = max(0, a + b)
// This fuses two operations into one for better efficiency
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer InputBufferA {
float a[];
};
layout(set = 0, binding = 1) buffer InputBufferB {
float b[];
};
layout(set = 0, binding = 2) buffer OutputBuffer {
float result[];
};
layout(set = 0, binding = 3) uniform UniformBuffer {
uint size;
};
void main() {
uint index = gl_GlobalInvocationID.x * 4;
uint length = size;
// Vectorized processing - handle 4 elements at once
if (index + 3 < length) {
// Load 4 elements at once using vec4
vec4 a_vec = vec4(a[index], a[index + 1], a[index + 2], a[index + 3]);
vec4 b_vec = vec4(b[index], b[index + 1], b[index + 2], b[index + 3]);
// Fused addition + ReLU: max(0, a + b)
vec4 result_vec = max(vec4(0.0), a_vec + b_vec);
// Store 4 results at once
result[index] = result_vec.x;
result[index + 1] = result_vec.y;
result[index + 2] = result_vec.z;
result[index + 3] = result_vec.w;
} else {
// Handle remaining elements individually
for (uint i = 0; i < 4 && index + i < length; i++) {
uint idx = index + i;
if (idx < length) {
float sum = a[idx] + b[idx];
result[idx] = max(0.0, sum);
}
}
}
}