rumus 0.3.1

A native-Rust deep learning framework with explicit memory safety and hardware acceleration
Documentation
// Mixed-precision matmul: C = A @ B where A is scalar, B is Q8.
//
// A: [M, K] in scalar (F32 or F16) — activations
// B: [K, N] packed as Q8 in column-major block order
// C: [M, N] in scalar — output
//
// B layout: for each column n, blocks of K elements are stored contiguously:
//   col n: [header_0][i8 × BS][header_1][i8 × BS]...
//   col n+1: [header_0][i8 × BS]...
//
// Each workgroup computes a 16×16 tile of C.
// The K dimension is traversed in steps of block_size.

struct MatmulQ8Params {
    m: u32,
    k: u32,
    n: u32,
    block_size: u32,
    block_stride_u32: u32,  // (4 + block_size) / 4
    blocks_per_col: u32,    // ceil(K / block_size)
    _pad0: u32,
    _pad1: u32,
}
// 32 bytes ✓

@group(0) @binding(0) var<storage, read>       mq_a:      array<scalar>; // [M, K]
@group(0) @binding(1) var<storage, read>       mq_b:      array<u32>;    // Q8 packed [K, N]
@group(0) @binding(2) var<storage, read_write> mq_c:      array<scalar>; // [M, N]
@group(0) @binding(3) var<uniform>             mq_params: MatmulQ8Params;

/// Extract a signed i8 from a u32 word at the given byte position (0..3).
fn extract_i8(word: u32, byte_pos: u32) -> i32 {
    let raw = (word >> (byte_pos * 8u)) & 0xFFu;
    return select(i32(raw), i32(raw) - 256, raw > 127u);
}

@compute @workgroup_size(16, 16)
fn matmul_q8_kernel(@builtin(global_invocation_id) gid: vec3<u32>) {
    let row = gid.y;  // M dimension
    let col = gid.x;  // N dimension
    if (row >= mq_params.m || col >= mq_params.n) { return; }

    var sum: scalar = scalar(0.0);

    // Column `col` of B: its Q8 blocks start at
    // col * blocks_per_col * block_stride_u32 (in u32 words).
    let col_base = col * mq_params.blocks_per_col * mq_params.block_stride_u32;

    // Traverse K in block-sized steps.
    var k_offset: u32 = 0u;
    for (var blk: u32 = 0u; blk < mq_params.blocks_per_col; blk++) {
        let block_word_base = col_base + blk * mq_params.block_stride_u32;

        // Read f16 scale from the 4-byte header.
        let header = mq_b[block_word_base];
        let scale_f32 = unpack2x16float(header).x;
        let blk_scale = scalar(scale_f32);

        // Data words start at block_word_base + 1.
        let data_base = block_word_base + 1u;

        // Iterate over elements in this block.
        let bs = min(mq_params.block_size, mq_params.k - k_offset);

        // Process 4 elements at a time for vectorized i8 unpacking.
        let full_quads = bs / 4u;
        for (var q: u32 = 0u; q < full_quads; q++) {
            let packed = mq_b[data_base + q];
            let k0 = k_offset + q * 4u;

            // Unpack 4 i8 values and accumulate dot product.
            let q0 = extract_i8(packed, 0u);
            let q1 = extract_i8(packed, 1u);
            let q2 = extract_i8(packed, 2u);
            let q3 = extract_i8(packed, 3u);

            sum += mq_a[row * mq_params.k + k0 + 0u] * scalar(f32(q0)) * blk_scale;
            sum += mq_a[row * mq_params.k + k0 + 1u] * scalar(f32(q1)) * blk_scale;
            sum += mq_a[row * mq_params.k + k0 + 2u] * scalar(f32(q2)) * blk_scale;
            sum += mq_a[row * mq_params.k + k0 + 3u] * scalar(f32(q3)) * blk_scale;
        }

        // Handle remaining elements (< 4).
        let remainder_start = full_quads * 4u;
        if (remainder_start < bs) {
            let packed = mq_b[data_base + full_quads];
            for (var r: u32 = remainder_start; r < bs; r++) {
                let byte_pos = r - remainder_start;
                let q_val = extract_i8(packed, byte_pos);
                sum += mq_a[row * mq_params.k + k_offset + r] * scalar(f32(q_val)) * blk_scale;
            }
        }

        k_offset += mq_params.block_size;
    }

    mq_c[row * mq_params.n + col] = sum;
}