diagnostic(off, subgroup_uniformity);
enable f16;
#define KV_TILE 32
#define WG_SIZE 32
struct Params {
offset_mask: u32,
seq_len_q: u32,
seq_len_kv: u32,
stride_mask3: u32,
// Number of KV blocks and Q blocks per batch.
// nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = seq_len_q.
nblk0: u32,
nblk1: u32,
};
@group(0) @binding(0) var<storage, read> mask: array<f16>;
@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
@group(0) @binding(2) var<uniform> params: Params;
const MASK_MIN: f32 = -65504.0;
const MASK_MAX: f32 = 65504.0;
var<workgroup> wg_min: array<f32, WG_SIZE>;
var<workgroup> wg_max: array<f32, WG_SIZE>;
var<workgroup> wg_any: array<u32, WG_SIZE>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>) {
// Dispatch mapping:
// - x indexes KV blocks
// - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk
let kv_blk = wg_id.x;
let y = wg_id.y;
let q_blk = y % params.nblk1;
let batch_idx = y / params.nblk1;
if (kv_blk >= params.nblk0) {
return;
}
let q_start = q_blk;
let k_start = kv_blk * KV_TILE;
let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3;
// We keep min/max to classify:
// - fully masked (max <= MASK_MIN)
// - all-zero mask (min == 0 && max == 0)
// - mixed/general mask
var local_min = MASK_MAX;
var local_max = -MASK_MAX;
var local_any = 0u;
let q_row = q_start;
if (q_row < params.seq_len_q) {
let row_base = mask_batch_base + q_row * params.seq_len_kv;
for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) {
let k_col = k_start + k_rel;
if (k_col >= params.seq_len_kv) {
continue;
}
let mv = f32(mask[row_base + k_col]);
local_min = min(local_min, mv);
local_max = max(local_max, mv);
local_any = 1u;
}
}
wg_min[local_id.x] = local_min;
wg_max[local_id.x] = local_max;
wg_any[local_id.x] = local_any;
workgroupBarrier();
// Thread 0 writes one state per block.
if (local_id.x == 0u) {
var mmin = wg_min[0];
var mmax = wg_max[0];
var many = wg_any[0];
for (var i = 1u; i < WG_SIZE; i += 1u) {
mmin = min(mmin, wg_min[i]);
mmax = max(mmax, wg_max[i]);
many = max(many, wg_any[i]);
}
var state = 0u;
if (many != 0u) {
if (mmax <= MASK_MIN) {
state = 0u;
} else if (mmin == 0.0 && mmax == 0.0) {
state = 2u;
} else {
state = 1u;
}
}
let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk;
blk[blk_idx] = state;
}
}