briny_ai 0.2.2

Minimal, secure autodiff + tensor engine with serialization
Documentation
struct MatDims {
    m: u32,
    k: u32,
    n: u32,
    flags: u32, // bit 0 = transpose A, bit 1 = transpose B
};

@group(0) @binding(0) var<uniform> dims: MatDims;
@group(0) @binding(1) var<storage, read> A: array<f32>;
@group(0) @binding(2) var<storage, read> B: array<f32>;
@group(0) @binding(3) var<storage, read_write> C: array<f32>;

@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let row = gid.y;
    let col = gid.x;

    let m = dims.m;
    let k = dims.k;
    let n = dims.n;

    if row >= m || col >= n {
        return;
    }

    let ta = (dims.flags & 1u) != 0u;
    let tb = (dims.flags & 2u) != 0u;

    let acc_base = 0.0;
    var acc = acc_base;

    let k4 = k / 4u;
    let rem = k % 4u;

    for (var i = 0u; i < k4; i = i + 1u) {
        let ki = i * 4u;

        // Indexing for A: handle transpose with select
        let a0 = select(row * k + ki + 0u, (ki + 0u) * m + row, ta);
        let a1 = select(row * k + ki + 1u, (ki + 1u) * m + row, ta);
        let a2 = select(row * k + ki + 2u, (ki + 2u) * m + row, ta);
        let a3 = select(row * k + ki + 3u, (ki + 3u) * m + row, ta);

        let va = vec4<f32>(
            A[a0],
            A[a1],
            A[a2],
            A[a3]
        );

        // Indexing for B: handle transpose with select
        let b0 = select((ki + 0u) * n + col, col * k + ki + 0u, tb);
        let b1 = select((ki + 1u) * n + col, col * k + ki + 1u, tb);
        let b2 = select((ki + 2u) * n + col, col * k + ki + 2u, tb);
        let b3 = select((ki + 3u) * n + col, col * k + ki + 3u, tb);

        let vb = vec4<f32>(
            B[b0],
            B[b1],
            B[b2],
            B[b3]
        );

        acc = acc + dot(va, vb);
    }

    // Tail handling for remainder
    if rem != 0u {
        for (var i = k - rem; i < k; i = i + 1u) {
            let a_idx = select(row * k + i, i * m + row, ta);
            let b_idx = select(i * n + col, col * k + i, tb);
            acc = acc + A[a_idx] * B[b_idx];
        }
    }

    C[row * n + col] = acc;
}