web-rwkv 0.10.20

An implementation of the RWKV language model in pure WebGPU.
struct Input {
    @builtin(global_invocation_id) uid: vec3<u32>,
    @builtin(num_workgroups) b: vec3<u32>,
};

@group(0) @binding(0) var<uniform> len: vec4<u32>;                          // [L]
@group(0) @binding(1) var<uniform> shape: vec4<u32>;                        // [C, R]

@group(0) @binding(2) var<storage, read> input: array<vec2<u32>>;           // (R, C)

@group(0) @binding(3) var<storage, read_write> minmax: array<u32>;          // (R, C / S)
@group(0) @binding(4) var<storage, read_write> output: array<u32>;          // (R, C / 2)

const INT8_BLOCK_STEP: u32 = INT8_BLOCK_SIZE / 4u;

fn pack4x16float(x: vec4<f32>) -> vec2<u32> {
    return vec2<u32>(pack2x16float(x.xy), pack2x16float(x.zw));
}

fn unpack4x16float(x: vec2<u32>) -> vec4<f32> {
    return vec4<f32>(unpack2x16float(x.x), unpack2x16float(x.y));
}

@compute @workgroup_size(BLOCK_SIZE, 1, 1)
fn compute_minmax(in: Input) {
    let bti = dot(in.uid, vec3<u32>(1u, BLOCK_SIZE * in.b.x, BLOCK_SIZE * in.b.x * in.b.y));
    if bti >= len.x >> 1u {
        return;
    }

    let stride = shape.w * shape.z * shape.y * shape.x >> 2u;
    var _min = vec4<f32>(65504.0);
    var _max = vec4<f32>(-65504.0);
    for (var i = 0u; i < INT8_BLOCK_STEP; i += 1u) {
        let index = bti * INT8_BLOCK_STEP + i;
        if index < stride {
            let v = unpack4x16float(input[index]);
            _min = min(v, _min);
            _max = max(v, _max);
        }
    }

    _min[0] = min(min(_min[0], _min[1]), min(_min[2], _min[3]));
    _max[0] = max(max(_max[0], _max[1]), max(_max[2], _max[3]));
    minmax[bti] = pack2x16float(vec2<f32>(_min[0], _max[0]));
}

@compute @workgroup_size(BLOCK_SIZE, 1, 1)
fn quantize(in: Input) {
    let bti = dot(in.uid, vec3<u32>(1u, BLOCK_SIZE * in.b.x, BLOCK_SIZE * in.b.x * in.b.y));
    if bti >= len.x >> 2u {
        return;
    }

    let m = unpack2x16float(minmax[bti / INT8_BLOCK_STEP]);
    let v = unpack4x16float(input[bti]);
    let x = saturate((v - m[0]) / (m[1] - m[0]));
    output[bti] = pack4x8unorm(x);
}