#version 450
// 2D Convolution compute shader
// Computes: result[y][x] = sum over kernel of input[y+ky-pad][x+kx-pad] * kernel[ky][kx]
// Supports stride, padding, and multiple input/output channels
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer InputBuffer {
float input_data[];
};
layout(set = 0, binding = 1) buffer KernelBuffer {
float kernel_data[];
};
layout(set = 0, binding = 2) buffer BiasBuffer {
float bias_data[];
};
layout(set = 0, binding = 3) buffer OutputBuffer {
float result[];
};
layout(set = 0, binding = 4) uniform UniformBuffer {
uint batch_size; // Number of samples in batch
uint input_channels; // Number of input channels
uint input_height; // Height of input feature maps
uint input_width; // Width of input feature maps
uint output_channels; // Number of output channels (filters)
uint output_height; // Height of output feature maps
uint output_width; // Width of output feature maps
uint kernel_height; // Height of convolution kernel
uint kernel_width; // Width of convolution kernel
uint stride_y; // Vertical stride
uint stride_x; // Horizontal stride
uint pad_y; // Vertical padding
uint pad_x; // Horizontal padding
uint use_bias; // Whether to add bias (1 = yes, 0 = no)
};
void main() {
uint output_x = gl_GlobalInvocationID.x;
uint output_y = gl_GlobalInvocationID.y;
uint output_z = gl_GlobalInvocationID.z;
// Check bounds
if (output_x >= output_width || output_y >= output_height || output_z >= output_channels) {
return;
}
// Process each sample in the batch
for (uint batch = 0; batch < batch_size; batch++) {
float sum = 0.0;
// Convolve over all input channels
for (uint in_ch = 0; in_ch < input_channels; in_ch++) {
// Convolve over kernel
for (uint ky = 0; ky < kernel_height; ky++) {
for (uint kx = 0; kx < kernel_width; kx++) {
// Calculate input coordinates
int input_y = int(output_y * stride_y + ky) - int(pad_y);
int input_x = int(output_x * stride_x + kx) - int(pad_x);
// Check if input coordinates are valid (within bounds)
if (input_y >= 0 && input_y < int(input_height) &&
input_x >= 0 && input_x < int(input_width)) {
// Calculate indices
uint input_idx = batch * input_channels * input_height * input_width +
in_ch * input_height * input_width +
uint(input_y) * input_width +
uint(input_x);
uint kernel_idx = output_z * input_channels * kernel_height * kernel_width +
in_ch * kernel_height * kernel_width +
ky * kernel_width +
kx;
// Accumulate convolution
sum += input_data[input_idx] * kernel_data[kernel_idx];
}
// If out of bounds, treat as zero (zero-padding)
}
}
}
// Add bias if enabled
if (use_bias != 0) {
sum += bias_data[output_z];
}
// Store result
uint output_idx = batch * output_channels * output_height * output_width +
output_z * output_height * output_width +
output_y * output_width +
output_x;
result[output_idx] = sum;
}
}