nnl 0.1.6

A high-performance neural network library for Rust with CPU and GPU support
Documentation
#version 450

// 2D Convolution compute shader
// Computes: result[y][x] = sum over kernel of input[y+ky-pad][x+kx-pad] * kernel[ky][kx]
// Supports stride, padding, and multiple input/output channels

layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;

layout(set = 0, binding = 0) buffer InputBuffer {
    float input_data[];
};

layout(set = 0, binding = 1) buffer KernelBuffer {
    float kernel_data[];
};

layout(set = 0, binding = 2) buffer BiasBuffer {
    float bias_data[];
};

layout(set = 0, binding = 3) buffer OutputBuffer {
    float result[];
};

layout(set = 0, binding = 4) uniform UniformBuffer {
    uint batch_size;        // Number of samples in batch
    uint input_channels;    // Number of input channels
    uint input_height;      // Height of input feature maps
    uint input_width;       // Width of input feature maps
    uint output_channels;   // Number of output channels (filters)
    uint output_height;     // Height of output feature maps
    uint output_width;      // Width of output feature maps
    uint kernel_height;     // Height of convolution kernel
    uint kernel_width;      // Width of convolution kernel
    uint stride_y;          // Vertical stride
    uint stride_x;          // Horizontal stride
    uint pad_y;             // Vertical padding
    uint pad_x;             // Horizontal padding
    uint use_bias;          // Whether to add bias (1 = yes, 0 = no)
};

void main() {
    uint output_x = gl_GlobalInvocationID.x;
    uint output_y = gl_GlobalInvocationID.y;
    uint output_z = gl_GlobalInvocationID.z;

    // Check bounds
    if (output_x >= output_width || output_y >= output_height || output_z >= output_channels) {
        return;
    }

    // Process each sample in the batch
    for (uint batch = 0; batch < batch_size; batch++) {
        float sum = 0.0;

        // Convolve over all input channels
        for (uint in_ch = 0; in_ch < input_channels; in_ch++) {
            // Convolve over kernel
            for (uint ky = 0; ky < kernel_height; ky++) {
                for (uint kx = 0; kx < kernel_width; kx++) {
                    // Calculate input coordinates
                    int input_y = int(output_y * stride_y + ky) - int(pad_y);
                    int input_x = int(output_x * stride_x + kx) - int(pad_x);

                    // Check if input coordinates are valid (within bounds)
                    if (input_y >= 0 && input_y < int(input_height) &&
                        input_x >= 0 && input_x < int(input_width)) {

                        // Calculate indices
                        uint input_idx = batch * input_channels * input_height * input_width +
                                       in_ch * input_height * input_width +
                                       uint(input_y) * input_width +
                                       uint(input_x);

                        uint kernel_idx = output_z * input_channels * kernel_height * kernel_width +
                                        in_ch * kernel_height * kernel_width +
                                        ky * kernel_width +
                                        kx;

                        // Accumulate convolution
                        sum += input_data[input_idx] * kernel_data[kernel_idx];
                    }
                    // If out of bounds, treat as zero (zero-padding)
                }
            }
        }

        // Add bias if enabled
        if (use_bias != 0) {
            sum += bias_data[output_z];
        }

        // Store result
        uint output_idx = batch * output_channels * output_height * output_width +
                         output_z * output_height * output_width +
                         output_y * output_width +
                         output_x;

        result[output_idx] = sum;
    }
}