oxillama-gpu 0.1.3

Optional wgpu GPU compute backend for OxiLLaMa
Documentation
// gemm_f32.wgsl — tiled f32 GEMM with workgroup shared memory
//
// Computes: C[M × N] = A[M × K] × B[K × N]  (row-major)
//
// Tile parameters:
//   TILE_M = 32  (rows handled per workgroup in Y dimension)
//   TILE_N = 32  (cols handled per workgroup in X dimension)
//   TILE_K = 16  (K-dimension tile depth)
//
// Workgroup: 16 × 16 = 256 threads
//   local_id.x ∈ [0, 16)   → selects column within tile
//   local_id.y ∈ [0, 16)   → selects row within tile
//
// Shared memory:
//   A_tile[TILE_M × TILE_K] = 32 × 16 = 512 f32  (2 KiB)
//   B_tile[TILE_K × TILE_N] = 16 × 32 = 512 f32  (2 KiB)
//   Total = 4 KiB — well within the 16 KiB WebGPU minimum.
//
// Cooperative loading:
//   256 threads, 512 elements per tile → each thread loads 2 elements.
//   Thread tid = local_id.y * 16 + local_id.x  (∈ [0, 256))
//   For A_tile: element at flat index tid and tid + 256.
//   Same for B_tile.

struct Params {
    M: u32,
    N: u32,
    K: u32,
}

@group(0) @binding(0) var<storage, read>       A:      array<f32>;  // [M × K]
@group(0) @binding(1) var<storage, read>       B:      array<f32>;  // [K × N]
@group(0) @binding(2) var<storage, read_write> C:      array<f32>;  // [M × N]
@group(0) @binding(3) var<uniform>             params: Params;

// Shared tiles — 512 elements each = 2 KiB each = 4 KiB total.
var<workgroup> A_tile: array<f32, 512>; // TILE_M × TILE_K = 32 × 16
var<workgroup> B_tile: array<f32, 512>; // TILE_K × TILE_N = 16 × 32

@compute @workgroup_size(16, 16)
fn main(
    @builtin(global_invocation_id) global_id: vec3<u32>,
    @builtin(local_invocation_id)  local_id:  vec3<u32>,
    @builtin(workgroup_id)         wg_id:     vec3<u32>,
) {
    // Output element this thread is responsible for.
    let row = wg_id.y * 32u + local_id.y * 2u;
    let col = wg_id.x * 32u + local_id.x * 2u;

    // Thread index within the workgroup (0..255).
    let tid = local_id.y * 16u + local_id.x;

    // Each thread accumulates a 2×2 output sub-tile.
    var acc00: f32 = 0.0;
    var acc01: f32 = 0.0;
    var acc10: f32 = 0.0;
    var acc11: f32 = 0.0;

    // Number of K-dimension tiles (ceiling division).
    let k_tiles = (params.K + 15u) / 16u;

    for (var kt: u32 = 0u; kt < k_tiles; kt++) {
        // ── Cooperative load A_tile[TILE_M × TILE_K] ──────────────────────
        // A_tile is stored row-major: A_tile[r * 16 + k_local].
        // 256 threads load 512 elements: thread tid loads indices tid and tid+256.
        {
            // Element 0: flat index = tid → row = tid / 16, k_local = tid % 16
            let r0 = tid / 16u;
            let k0 = tid % 16u;
            let a_row0 = wg_id.y * 32u + r0;
            let a_col0 = kt * 16u + k0;
            if a_row0 < params.M && a_col0 < params.K {
                A_tile[r0 * 16u + k0] = A[a_row0 * params.K + a_col0];
            } else {
                A_tile[r0 * 16u + k0] = 0.0;
            }

            // Element 1: flat index = tid + 256
            let r1 = (tid + 256u) / 16u;
            let k1 = (tid + 256u) % 16u;
            let a_row1 = wg_id.y * 32u + r1;
            let a_col1 = kt * 16u + k1;
            if a_row1 < params.M && a_col1 < params.K {
                A_tile[r1 * 16u + k1] = A[a_row1 * params.K + a_col1];
            } else {
                A_tile[r1 * 16u + k1] = 0.0;
            }
        }

        // ── Cooperative load B_tile[TILE_K × TILE_N] ──────────────────────
        // B_tile is stored row-major: B_tile[k_local * 32 + c].
        // Thread tid loads elements tid and tid+256.
        {
            // Element 0
            let k0 = tid / 32u;
            let c0 = tid % 32u;
            let b_row0 = kt * 16u + k0;
            let b_col0 = wg_id.x * 32u + c0;
            if b_row0 < params.K && b_col0 < params.N {
                B_tile[k0 * 32u + c0] = B[b_row0 * params.N + b_col0];
            } else {
                B_tile[k0 * 32u + c0] = 0.0;
            }

            // Element 1: flat index = tid + 256
            let k1 = (tid + 256u) / 32u;
            let c1 = (tid + 256u) % 32u;
            let b_row1 = kt * 16u + k1;
            let b_col1 = wg_id.x * 32u + c1;
            if b_row1 < params.K && b_col1 < params.N {
                B_tile[k1 * 32u + c1] = B[b_row1 * params.N + b_col1];
            } else {
                B_tile[k1 * 32u + c1] = 0.0;
            }
        }

        workgroupBarrier();

        // ── Accumulate 2×2 sub-tile ────────────────────────────────────────
        // Thread (local_id.x, local_id.y) owns output rows [2*ly, 2*ly+1]
        // and output cols [2*lx, 2*lx+1].
        let ly = local_id.y;
        let lx = local_id.x;

        for (var k: u32 = 0u; k < 16u; k++) {
            let a0 = A_tile[(2u * ly + 0u) * 16u + k];
            let a1 = A_tile[(2u * ly + 1u) * 16u + k];
            let b0 = B_tile[k * 32u + 2u * lx + 0u];
            let b1 = B_tile[k * 32u + 2u * lx + 1u];
            acc00 += a0 * b0;
            acc01 += a0 * b1;
            acc10 += a1 * b0;
            acc11 += a1 * b1;
        }

        workgroupBarrier();
    }

    // ── Write 2×2 result to C ──────────────────────────────────────────────
    if row + 0u < params.M && col + 0u < params.N {
        C[(row + 0u) * params.N + col + 0u] = acc00;
    }
    if row + 0u < params.M && col + 1u < params.N {
        C[(row + 0u) * params.N + col + 1u] = acc01;
    }
    if row + 1u < params.M && col + 0u < params.N {
        C[(row + 1u) * params.N + col + 0u] = acc10;
    }
    if row + 1u < params.M && col + 1u < params.N {
        C[(row + 1u) * params.N + col + 1u] = acc11;
    }
}