#version 450
// Dequantize Q8_0: 34-byte blocks, 32 elements each
// Block layout: [f16 d (2 bytes), i8 qs[32] (32 bytes)]
// Formula: out[i] = float(d) * float(qs[i])
layout(local_size_x = 32) in;
layout(set = 0, binding = 0) readonly buffer RawData { uint raw[]; };
layout(set = 0, binding = 1) writeonly buffer Output { float result[]; };
layout(push_constant) uniform Params {
int num_blocks;
};
float half_to_float(uint h) {
uint sign = (h & 0x8000u) << 16;
uint expo = (h >> 10) & 0x1Fu;
uint mant = h & 0x3FFu;
if (expo == 0u) {
if (mant == 0u) return uintBitsToFloat(sign);
while ((mant & 0x400u) == 0u) { mant <<= 1; expo--; }
expo++; mant &= 0x3FFu;
} else if (expo == 31u) {
return uintBitsToFloat(sign | 0x7F800000u | (mant << 13));
}
return uintBitsToFloat(sign | ((expo + 112u) << 23) | (mant << 13));
}
void main() {
uint block_idx = gl_WorkGroupID.x;
uint elem_idx = gl_LocalInvocationID.x;
if (block_idx >= uint(num_blocks)) return;
// Block starts at byte offset block_idx * 34
uint byte_base = block_idx * 34u;
// Read f16 d from first 2 bytes
uint d_word = raw[byte_base / 4u];
uint d_shift = (byte_base % 4u) * 8u;
uint d_bits = (d_word >> d_shift) & 0xFFFFu;
if (d_shift > 16u) {
d_bits = (d_word >> d_shift) | ((raw[byte_base / 4u + 1u] << (32u - d_shift)) & 0xFFFFu);
}
float d = half_to_float(d_bits);
// Read i8 qs[elem_idx] from byte offset 2 + elem_idx
uint qs_byte = byte_base + 2u + elem_idx;
uint qs_word = raw[qs_byte / 4u];
uint qs_shift = (qs_byte % 4u) * 8u;
int qs_val = int((qs_word >> qs_shift) & 0xFFu);
if (qs_val >= 128) qs_val -= 256;
result[block_idx * 32u + elem_idx] = d * float(qs_val);
}