nnl 0.1.6

A high-performance neural network library for Rust with CPU and GPU support
Documentation
#version 450

// Softmax activation compute shader using shared memory
// Computes: result[i] = exp(input[i] - max) / sum(exp(input[j] - max))
// Uses traditional shared memory approach for maximum compatibility

layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;

layout(set = 0, binding = 0) buffer InputBuffer {
    float input_data[];
};

layout(set = 0, binding = 1) buffer OutputBuffer {
    float result[];
};

layout(set = 0, binding = 2) uniform UniformBuffer {
    uint size;
};

// Shared memory for reduction operations
shared float shared_data[64];

void main() {
    uint index = gl_GlobalInvocationID.x;
    uint local_index = gl_LocalInvocationID.x;
    uint group_size = gl_WorkGroupSize.x;

    // Initialize shared memory
    shared_data[local_index] = (index < size) ? input_data[index] : -1e30; // Very negative for padding

    barrier();

    // First pass: Find maximum value for numerical stability
    // Parallel reduction to find max
    for (uint stride = group_size / 2; stride > 0; stride >>= 1) {
        if (local_index < stride) {
            shared_data[local_index] = max(shared_data[local_index], shared_data[local_index + stride]);
        }
        barrier();
    }

    float max_val = shared_data[0];

    // Broadcast max to all threads in workgroup
    barrier();

    // Second pass: Compute exp(x - max) and sum
    float exp_val = (index < size) ? exp(input_data[index] - max_val) : 0.0;
    shared_data[local_index] = exp_val;

    barrier();

    // Parallel reduction to compute sum
    for (uint stride = group_size / 2; stride > 0; stride >>= 1) {
        if (local_index < stride) {
            shared_data[local_index] += shared_data[local_index + stride];
        }
        barrier();
    }

    float sum = shared_data[0];

    // Final pass: Compute softmax
    if (index < size && sum > 0.0) {
        result[index] = exp_val / sum;
    } else if (index < size) {
        // Handle edge case where sum is zero
        result[index] = 1.0 / float(size);
    }
}