#version 430
// Ultra-optimized vector operations for GPU acceleration
// Includes: vector add, scalar multiply, FMA, dot product reduction
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
// Input buffers
layout(std430, binding = 0) buffer InputA {
float a[];
};
layout(std430, binding = 1) buffer InputB {
float b[];
};
layout(std430, binding = 2) buffer InputC {
float c[];
};
// Output buffer
layout(std430, binding = 3) buffer Output {
float result[];
};
// Shared memory for reductions
shared float shared_data[256];
uniform int operation_type; // 0=add, 1=fma, 2=dot_product, 3=norm
uniform float scalar;
uniform int n;
void main() {
uint tid = gl_LocalInvocationID.x;
uint bid = gl_WorkGroupID.x;
uint index = bid * gl_WorkGroupSize.x + tid;
float value = 0.0;
if (index < n) {
switch (operation_type) {
case 0: // Vector addition: result = a + b
value = a[index] + b[index];
break;
case 1: // Fused multiply-add: result = a + scalar * b
value = a[index] + scalar * b[index];
break;
case 2: // Dot product element: value = a * b
value = a[index] * b[index];
break;
case 3: // L2 norm element: value = a * a
value = a[index] * a[index];
break;
case 4: // Three-way FMA: result = a + b + scalar * c
value = a[index] + b[index] + scalar * c[index];
break;
}
}
// For reduction operations (dot product, norm)
if (operation_type >= 2) {
// Store in shared memory
shared_data[tid] = value;
barrier();
// Reduction in shared memory
for (uint s = gl_WorkGroupSize.x / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_data[tid] += shared_data[tid + s];
}
barrier();
}
// Write result of this work group
if (tid == 0) {
result[bid] = shared_data[0];
}
} else {
// For element-wise operations, write directly
if (index < n) {
result[index] = value;
}
}
}