#ifdef USE_SUBGROUP_REDUCTION
enable subgroups;
#endif
enable f16;
requires packed_4x8_integer_dot_product;
#include "common_decls.tmpl"
struct Params {
offset_src1: u32,
stride_12: u32,
stride_13: u32,
ne0: u32,
ne2: u32,
ne3: u32,
};
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
@group(0) @binding(0) var<storage, read_write> src1: array<SRC1_TYPE>;
@group(0) @binding(1) var<storage, read_write> src1q: array<q8_1>;
@group(0) @binding(2) var<uniform> params: Params;
#ifdef USE_SUBGROUP_REDUCTION
fn cluster_max_8(v: f32) -> f32 {
var r = v;
r = max(r, subgroupShuffleXor(r, 1u));
r = max(r, subgroupShuffleXor(r, 2u));
r = max(r, subgroupShuffleXor(r, 4u));
return r;
}
#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K)
fn cluster_add_i4x8(v: i32) -> i32 {
var r= v;
r += subgroupShuffleXor(r, 1u);
r += subgroupShuffleXor(r, 2u);
r += subgroupShuffleXor(r, 4u);
return r;
}
#endif
#endif
#ifdef USE_WORKGROUP_REDUCTION
#define CLUSTER_SIZE 8
var<workgroup> partial_amaxs: array<array<f32, CLUSTER_SIZE>, WG_SIZE / CLUSTER_SIZE>;
var<workgroup> partial_sums: array<array<i32, CLUSTER_SIZE>, WG_SIZE / CLUSTER_SIZE>;
#endif
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let thread_id = local_id.x;
let num_vec4 = params.ne0 / 4u;
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
let total_batches = wg_per_vec * params.ne2 * params.ne3;
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
if (wg_linear >= total_batches) {
return;
}
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
let src11_wg_idx = wg_linear % wg_per_vec;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let src1_idx_vec4_base = src1_idx_base / 4u;
let blocks_per_row = params.ne0 / 32u;
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
let qs_idx = thread_id % 8u;
// reduction
var q4 = vec4<f32>(0.0);
var q4_quants = 0u;
var thread_amax = 0.0;
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
let is_valid = src11_vec4_idx < num_vec4;
#ifdef USE_SUBGROUP_REDUCTION
var d = 0.0;
if (is_valid) {
q4 = src1[src1_idx_vec4_base + src11_vec4_idx];
let abs_q4 = abs(q4);
thread_amax = max(max(abs_q4[0u], abs_q4[1u]), max(abs_q4[2], abs_q4[3]));
}
d = cluster_max_8(thread_amax) / 127.0;
if (is_valid) {
let id = select(0.0, 1.0 / d, d > 0.0);
q4_quants = pack4xI8(vec4<i32>(round(q4 * id)));
if (qs_idx == 0u) {
src1q[src1q_idx].d = f16(d);
}
src1q[src1q_idx].qs[qs_idx] = q4_quants;
}
#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K)
let q4_quants_sum = dot4I8Packed(q4_quants, 0x01010101u);
let s = f16(d * f32(cluster_add_i4x8(q4_quants_sum)));
if (is_valid) {
if (qs_idx == 0u) {
src1q[src1q_idx].s = s;
}
}
#endif
#endif
#ifdef USE_WORKGROUP_REDUCTION
var d = 0.0;
let cluster_id = thread_id / 8u;
if (is_valid) {
q4 = src1[src1_idx_vec4_base + src11_vec4_idx];
let abs_q4 = abs(q4);
thread_amax = max(max(abs_q4[0], abs_q4[1]), max(abs_q4[2], abs_q4[3]));
partial_amaxs[cluster_id][qs_idx] = thread_amax;
}
workgroupBarrier();
if (is_valid) {
let amax = max(
max(
max(partial_amaxs[cluster_id][0], partial_amaxs[cluster_id][1]), max(partial_amaxs[cluster_id][2], partial_amaxs[cluster_id][3])),
max(
max(partial_amaxs[cluster_id][4], partial_amaxs[cluster_id][5]), max(partial_amaxs[cluster_id][6], partial_amaxs[cluster_id][7]))
);
d = amax / 127.0;
let id = select(0.0f, 1.0f / d, d > 0.0f);
q4_quants = pack4xI8(vec4<i32>(round(q4 * id)));
src1q[src1q_idx].qs[qs_idx] = q4_quants;
if (qs_idx == 0u) {
src1q[src1q_idx].d = f16(d);
}
}
#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K)
partial_sums[cluster_id][qs_idx] = dot4I8Packed(q4_quants, 0x01010101u);
workgroupBarrier();
if (is_valid) {
if (qs_idx == 0u) {
let s = d * f32(partial_sums[cluster_id][0] + partial_sums[cluster_id][1] + partial_sums[cluster_id][2] + partial_sums[cluster_id][3]
+ partial_sums[cluster_id][4] + partial_sums[cluster_id][5] + partial_sums[cluster_id][6] + partial_sums[cluster_id][7]);
src1q[src1q_idx].s = f16(s);
}
}
#endif
#endif
}