nnl 0.1.6

A high-performance neural network library for Rust with CPU and GPU support
Documentation
#version 450

// Fused matrix multiplication + bias addition + ReLU activation
// Computes: ReLU(A * B + bias)
// This fuses three operations into one for maximum efficiency

layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;

layout(set = 0, binding = 0) buffer InputBufferA {
    float a[];
};

layout(set = 0, binding = 1) buffer InputBufferB {
    float b[];
};

layout(set = 0, binding = 2) buffer BiasBuffer {
    float bias[];
};

layout(set = 0, binding = 3) buffer OutputBuffer {
    float result[];
};

layout(set = 0, binding = 4) uniform UniformBuffer {
    uint M;    // Rows of A, rows of C
    uint N;    // Cols of B, cols of C
    uint K;    // Cols of A, rows of B
    uint bias_size; // Size of bias vector (should equal N)
};

// Shared memory tiles for cooperative loading
shared float tileA[16][16];
shared float tileB[16][16];
shared float tileBias[16]; // Shared bias values

void main() {
    // Global thread coordinates
    uint globalRow = gl_GlobalInvocationID.y;
    uint globalCol = gl_GlobalInvocationID.x;

    // Local thread coordinates within workgroup
    uint localRow = gl_LocalInvocationID.y;
    uint localCol = gl_LocalInvocationID.x;

    // Load bias values cooperatively (only need to do this once per workgroup)
    if (localRow == 0 && globalCol < N) {
        tileBias[localCol] = (globalCol < bias_size) ? bias[globalCol] : 0.0;
    }

    // Accumulated result for matrix multiplication
    float matmul_sum = 0.0;

    // Number of tiles needed to cover K dimension
    uint numTiles = (K + 15) / 16;

    // Process tiles across the K dimension for matrix multiplication
    for (uint tile = 0; tile < numTiles; tile++) {
        // Cooperative loading of tile A
        uint aRow = globalRow;
        uint aCol = tile * 16 + localCol;

        if (aRow < M && aCol < K) {
            tileA[localRow][localCol] = a[aRow * K + aCol];
        } else {
            tileA[localRow][localCol] = 0.0;
        }

        // Cooperative loading of tile B
        uint bRow = tile * 16 + localRow;
        uint bCol = globalCol;

        if (bRow < K && bCol < N) {
            tileB[localRow][localCol] = b[bRow * N + bCol];
        } else {
            tileB[localRow][localCol] = 0.0;
        }

        // Synchronize to ensure all threads have loaded their data
        barrier();

        // Compute partial dot product using shared memory
        for (uint k = 0; k < 16; k++) {
            matmul_sum += tileA[localRow][k] * tileB[k][localCol];
        }

        // Synchronize before loading next tile
        barrier();
    }

    // Fused bias addition and ReLU activation
    if (globalRow < M && globalCol < N) {
        // Add bias (broadcast across rows)
        float biased_result = matmul_sum + tileBias[localCol];

        // Apply ReLU activation: max(0, x)
        float activated_result = max(0.0, biased_result);

        // Store final result
        result[globalRow * N + globalCol] = activated_result;
    }
}