nnl 0.1.6

A high-performance neural network library for Rust with CPU and GPU support
Documentation
#version 450

// Optimized tiled matrix multiplication compute shader
// Computes: C = A * B where A is MxK, B is KxN, C is MxN
// Uses shared memory tiling for maximum performance

layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;

layout(set = 0, binding = 0) buffer InputBufferA {
    float a[];
};

layout(set = 0, binding = 1) buffer InputBufferB {
    float b[];
};

layout(set = 0, binding = 2) buffer OutputBuffer {
    float result[];
};

layout(set = 0, binding = 3) uniform UniformBuffer {
    uint M;    // Rows of A, rows of C
    uint N;    // Cols of B, cols of C
    uint K;    // Cols of A, rows of B
};

// Shared memory tiles for cooperative loading
shared float tileA[16][16];
shared float tileB[16][16];

void main() {
    // Global thread coordinates
    uint globalRow = gl_GlobalInvocationID.y;
    uint globalCol = gl_GlobalInvocationID.x;

    // Local thread coordinates within workgroup
    uint localRow = gl_LocalInvocationID.y;
    uint localCol = gl_LocalInvocationID.x;

    // Accumulated result for this thread
    float sum = 0.0;

    // Number of tiles needed to cover K dimension
    uint numTiles = (K + 15) / 16;

    // Process tiles across the K dimension
    for (uint tile = 0; tile < numTiles; tile++) {
        // Cooperative loading of tile A
        uint aRow = globalRow;
        uint aCol = tile * 16 + localCol;

        // Load with bounds checking
        if (aRow < M && aCol < K) {
            tileA[localRow][localCol] = a[aRow * K + aCol];
        } else {
            tileA[localRow][localCol] = 0.0;
        }

        // Cooperative loading of tile B
        uint bRow = tile * 16 + localRow;
        uint bCol = globalCol;

        // Load with bounds checking
        if (bRow < K && bCol < N) {
            tileB[localRow][localCol] = b[bRow * N + bCol];
        } else {
            tileB[localRow][localCol] = 0.0;
        }

        // Synchronize to ensure all threads have loaded their data
        barrier();

        // Compute partial dot product using shared memory
        for (uint k = 0; k < 16; k++) {
            sum += tileA[localRow][k] * tileB[k][localCol];
        }

        // Synchronize before loading next tile
        barrier();
    }

    // Write final result with bounds checking
    if (globalRow < M && globalCol < N) {
        result[globalRow * N + globalCol] = sum;
    }
}