// TQ KV dequantize kernel — iter-21 Track A fix.
//
// Reads nibble-packed TurboQuant K or V cache at a given position and
// writes a dense F32 buffer of shape [num_kv_heads, head_dim].
//
// This enables the Leg F ablation: encode K/V as TQ, then dequantize back
// to F32 and dispatch the dense flash_attn_vec kernel. The output is in the
// FWHT-rotated domain (same as the TQ SDPA readpath), NOT the original F32.
// This isolates the SDPA kernel math from TQ representation noise.
//
// Dequant formula MUST match flash_attn_vec_tq.metal inline dequant exactly
// (see flash_attn_vec_tq.metal:305-348):
// D=256: scale_norm = norm * inv_sqrt(256) — single-norm convention (unchanged)
// D=512: scale_norm = norm / scale_factor_d512 — per-block norm, NO inv_sqrt factor
// (iter-16 fix: encoder stores raw blk_norm; decoder uses blk_norm directly)
//
// IMPORTANT (iter-21 Track A): iter-20 erroneously applied inv_sqrt_hd to BOTH
// D=256 and D=512. For D=512 this introduces a sqrt(512)≈22.6x scale error vs
// the TQ SDPA kernel, corrupting attention scores for global-attention layers
// even when the prefill shadow cache is correctly populated.
//
// Packed layout: [num_kv_heads, cache_capacity, head_dim/2] u8 (nibble-packed)
// Norms layout:
// D=256: [num_kv_heads, cache_capacity] f32 — 1 norm per position
// D=512: [num_kv_heads, cache_capacity, 2] f32 — 2 per-block norms per position
// Output: [num_kv_heads, head_dim] f32 — dense dequantized values for ONE position
//
// Packed layout: [num_kv_heads, cache_capacity, head_dim/2] u8 (nibble-packed)
// Norms layout:
// D=256: [num_kv_heads, cache_capacity] f32 — 1 norm per position
// D=512: [num_kv_heads, cache_capacity, 2] f32 — 2 per-block norms per position
// Output: [num_kv_heads, head_dim] f32 — dense dequantized values for ONE position
#include <metal_stdlib>
using namespace metal;
constant float CODEBOOK_4BIT_DQ[16] = {
-2.7325896f, -2.0690172f, -1.6180464f, -1.2562312f,
-0.9423405f, -0.6567591f, -0.3880483f, -0.1283950f,
0.1283950f, 0.3880483f, 0.6567591f, 0.9423405f,
1.2562312f, 1.6180464f, 2.0690172f, 2.7325896f,
};
struct TqDequantizeKvParams {
uint head_dim; // 256 or 512
uint num_kv_heads; // number of KV heads
uint read_pos; // cache position to read from (already wrapped for ring buffers)
uint cache_capacity; // KV cache capacity (stride in packed/norms buffers)
uint norms_per_pos; // 1 for D=256, 2 for D=512
float scale_factor_d512; // iter-18 S2B: scale divisor for D=512 norm (1.0=bare)
};
// One threadgroup per KV head; head_dim threads per threadgroup.
// Each thread dequantizes head_dim/num_threads elements.
kernel void tq_dequantize_kv(
device const uint8_t *packed [[buffer(0)]], // [nkv, capacity, hd/2] u8
device const float *norms [[buffer(1)]], // [nkv, capacity, norms_per_pos] f32
device float *dst [[buffer(2)]], // [nkv, hd] f32 — OUTPUT
constant TqDequantizeKvParams ¶ms [[buffer(3)]],
uint3 tgpig [[threadgroup_position_in_grid]], // threadgroup = kv head index
uint tiitg [[thread_index_in_threadgroup]]) // thread index within threadgroup
{
const uint kv_head = tgpig[0];
const uint hd = params.head_dim;
const uint cap = params.cache_capacity;
const uint pos = params.read_pos;
// inv_sqrt_hd is only used for D=256 (single-norm convention).
// D=512 uses per-block raw norm (no inv_sqrt_dk factor) per iter-16 fix.
const float inv_sqrt_hd = rsqrt(float(hd));
const float sf = params.scale_factor_d512;
const bool is_d512 = (hd > 256);
if (kv_head >= params.num_kv_heads) return;
// Packed base: [kv_head, pos, 0..hd/2]
device const uint8_t *packed_pos = packed + kv_head * cap * (hd / 2) + pos * (hd / 2);
// Norms base for this head+pos: at offset kv_head * cap * norms_per_pos + pos * norms_per_pos
const uint npp = params.norms_per_pos;
device const float *norms_pos = norms + kv_head * cap * npp + pos * npp;
// Output base: [kv_head, 0..hd]
device float *dst_head = dst + kv_head * hd;
// Each thread processes one coordinate (tiitg = coord index).
// Head_dim threads per threadgroup → every coord handled by exactly one thread.
// For large head_dim we may need multiple passes; here threadgroups are sized to hd.
for (uint coord = tiitg; coord < hd; coord += /* threads per TG */ hd) {
// Determine which block this coord falls in (for D=512 per-block norms)
uint block_idx = is_d512 ? (coord / 256u) : 0u;
block_idx = min(block_idx, npp - 1);
float norm = norms_pos[block_idx];
// scale_norm: must match flash_attn_vec_tq.metal decode convention exactly.
// D=256: norm * inv_sqrt(256) — single-norm, same as before
// D=512: norm / scale_factor_d512 — per-block raw norm, NO inv_sqrt factor
// (iter-21 Track A: iter-20 incorrectly applied inv_sqrt_hd to D=512 also,
// introducing a sqrt(512)≈22.6x scale error for global-attention layers.)
float scale_norm;
if (is_d512) {
scale_norm = norm / sf;
} else {
scale_norm = norm * inv_sqrt_hd;
}
// Read nibble for this coord
uint byte_idx = coord / 2;
uint8_t packed_byte = packed_pos[byte_idx];
uint nibble = (coord % 2 == 0) ? (packed_byte & 0xFu) : ((packed_byte >> 4u) & 0xFu);
dst_head[coord] = CODEBOOK_4BIT_DQ[nibble] * scale_norm;
}
}
// ============================================================================
// Track B (iter-21): higher-bit dequantize kernel.
// Reads byte-packed 5-bit or 6-bit indices from the higher-bit KV cache and
// writes dense F32 values in the FWHT-rotated domain (same as tq_dequantize_kv).
//
// Packed layout: [num_kv_heads, capacity, head_dim] u8 (byte-packed, 1 byte/elem)
// Norms layout: same as 4-bit (D=256: 1 norm/pos, D=512: 2 per-block norms/pos)
// ============================================================================
constant float CODEBOOK_5BIT_DQ[32] = {
-3.2606790f, -2.6910589f, -2.3176743f, -2.0286608f,
-1.7871646f, -1.5761599f, -1.3862739f, -1.2117410f,
-1.0487242f, -0.8945114f, -0.7470884f, -0.6048936f,
-0.4666676f, -0.3313550f, -0.1980377f, -0.0658849f,
0.0658849f, 0.1980377f, 0.3313550f, 0.4666676f,
0.6048936f, 0.7470884f, 0.8945114f, 1.0487242f,
1.2117410f, 1.3862739f, 1.5761599f, 1.7871646f,
2.0286608f, 2.3176743f, 2.6910589f, 3.2606790f,
};
constant float CODEBOOK_6BIT_DQ[64] = {
-3.6996161f, -3.1907215f, -2.8640626f, -2.6161277f,
-2.4129324f, -2.2388464f, -2.0853192f, -1.9471373f,
-1.8208742f, -1.7041502f, -1.5952401f, -1.4928497f,
-1.3959804f, -1.3038428f, -1.2157998f, -1.1313277f,
-1.0499889f, -0.9714118f, -0.8952766f, -0.8213046f,
-0.7492492f, -0.6788902f, -0.6100285f, -0.5424819f,
-0.4760822f, -0.4106724f, -0.3461048f, -0.2822386f,
-0.2189392f, -0.1560761f, -0.0935225f, -0.0311537f,
0.0311537f, 0.0935225f, 0.1560761f, 0.2189392f,
0.2822386f, 0.3461048f, 0.4106724f, 0.4760822f,
0.5424819f, 0.6100285f, 0.6788902f, 0.7492492f,
0.8213046f, 0.8952766f, 0.9714118f, 1.0499889f,
1.1313277f, 1.2157998f, 1.3038428f, 1.3959804f,
1.4928497f, 1.5952401f, 1.7041502f, 1.8208742f,
1.9471373f, 2.0853192f, 2.2388464f, 2.4129324f,
2.6161277f, 2.8640626f, 3.1907215f, 3.6996161f,
};
// ============================================================================
// 8-bit Lloyd-Max codebook for N(0,1): 256 reconstruction levels.
// Must match CODEBOOK_8BIT in hadamard_quantize_kv_fast.metal exactly.
// ============================================================================
// iter-24 fix: CODEBOOK_8BIT_DQ must match CODEBOOK_8BIT in hadamard_quantize_kv_fast.metal
// (Lloyd-Max N(0,1) 256-centroid, range [-5.065, +5.065])
constant float CODEBOOK_8BIT_DQ[256] = {
-5.0652659f, -4.6836997f, -4.4467193f, -4.2715508f,
-4.1311907f, -4.0132856f, -3.9111092f, -3.8205780f,
-3.7390194f, -3.6645851f, -3.5959415f, -3.5320936f,
-3.4722785f, -3.4158977f, -3.3624729f, -3.3116156f,
-3.2630056f, -3.2163758f, -3.1715011f, -3.1281899f,
-3.0862780f, -3.0456229f, -3.0061011f, -2.9676040f,
-2.9300362f, -2.8933131f, -2.8573596f, -2.8221086f,
-2.7874999f, -2.7534795f, -2.7199985f, -2.6870129f,
-2.6544825f, -2.6223710f, -2.5906452f, -2.5592748f,
-2.5282321f, -2.4974918f, -2.4670306f, -2.4368270f,
-2.4068614f, -2.3771157f, -2.3475732f, -2.3182184f,
-2.2890372f, -2.2600165f, -2.2311440f, -2.2024086f,
-2.1737998f, -2.1453081f, -2.1169245f, -2.0886408f,
-2.0604493f, -2.0323430f, -2.0043154f, -1.9763603f,
-1.9484722f, -1.9206458f, -1.8928763f, -1.8651592f,
-1.8374904f, -1.8098662f, -1.7822828f, -1.7547372f,
-1.7272261f, -1.6997469f, -1.6722970f, -1.6448739f,
-1.6174755f, -1.5900996f, -1.5627445f, -1.5354084f,
-1.5080897f, -1.4807869f, -1.4534986f, -1.4262237f,
-1.3989610f, -1.3717093f, -1.3444678f, -1.3172356f,
-1.2900118f, -1.2627956f, -1.2355865f, -1.2083838f,
-1.1811868f, -1.1539951f, -1.1268081f, -1.0996255f,
-1.0724469f, -1.0452718f, -1.0180999f, -0.9909310f,
-0.9637647f, -0.9366008f, -0.9094390f, -0.8822793f,
-0.8551212f, -0.8279648f, -0.8008098f, -0.7736561f,
-0.7465035f, -0.7193520f, -0.6922014f, -0.6650517f,
-0.6379027f, -0.6107544f, -0.5836067f, -0.5564596f,
-0.5293129f, -0.5021667f, -0.4750208f, -0.4478753f,
-0.4207301f, -0.3935852f, -0.3664405f, -0.3392960f,
-0.3121517f, -0.2850076f, -0.2578636f, -0.2307198f,
-0.2035761f, -0.1764324f, -0.1492888f, -0.1221453f,
-0.0950019f, -0.0678584f, -0.0407151f, -0.0135717f,
0.0135717f, 0.0407151f, 0.0678584f, 0.0950019f,
0.1221453f, 0.1492888f, 0.1764324f, 0.2035761f,
0.2307198f, 0.2578636f, 0.2850076f, 0.3121517f,
0.3392960f, 0.3664405f, 0.3935852f, 0.4207301f,
0.4478753f, 0.4750208f, 0.5021667f, 0.5293129f,
0.5564596f, 0.5836067f, 0.6107544f, 0.6379027f,
0.6650517f, 0.6922014f, 0.7193520f, 0.7465035f,
0.7736561f, 0.8008098f, 0.8279648f, 0.8551212f,
0.8822793f, 0.9094390f, 0.9366008f, 0.9637647f,
0.9909310f, 1.0180999f, 1.0452718f, 1.0724469f,
1.0996255f, 1.1268081f, 1.1539951f, 1.1811868f,
1.2083838f, 1.2355865f, 1.2627956f, 1.2900118f,
1.3172356f, 1.3444678f, 1.3717093f, 1.3989610f,
1.4262237f, 1.4534986f, 1.4807869f, 1.5080897f,
1.5354084f, 1.5627445f, 1.5900996f, 1.6174755f,
1.6448739f, 1.6722970f, 1.6997469f, 1.7272261f,
1.7547372f, 1.7822828f, 1.8098662f, 1.8374904f,
1.8651592f, 1.8928763f, 1.9206458f, 1.9484722f,
1.9763603f, 2.0043154f, 2.0323430f, 2.0604493f,
2.0886408f, 2.1169245f, 2.1453081f, 2.1737998f,
2.2024086f, 2.2311440f, 2.2600165f, 2.2890372f,
2.3182184f, 2.3475732f, 2.3771157f, 2.4068614f,
2.4368270f, 2.4670306f, 2.4974918f, 2.5282321f,
2.5592748f, 2.5906452f, 2.6223710f, 2.6544825f,
2.6870129f, 2.7199985f, 2.7534795f, 2.7874999f,
2.8221086f, 2.8573596f, 2.8933131f, 2.9300362f,
2.9676040f, 3.0061011f, 3.0456229f, 3.0862780f,
3.1281899f, 3.1715011f, 3.2163758f, 3.2630056f,
3.3116156f, 3.3624729f, 3.4158977f, 3.4722785f,
3.5320936f, 3.5959415f, 3.6645851f, 3.7390194f,
3.8205780f, 3.9111092f, 4.0132856f, 4.1311907f,
4.2715508f, 4.4467193f, 4.6836997f, 5.0652659f,
};
struct TqDequantizeHbKvParams {
uint head_dim;
uint num_kv_heads;
uint read_pos;
uint cache_capacity;
uint norms_per_pos;
float scale_factor_d512;
uint codebook_bits; // 5, 6, or 8
};
// ============================================================================
// Track B alternate approach: requantize-to-F32 kernel.
// Takes FWHT-rotated+normalized F32 K/V (from attn_k_normed after FWHT+norm)
// and snaps each value to the nearest 5-bit or 6-bit centroid, then writes
// the centroid value back as F32. This simulates the quantize→dequantize
// round-trip without allocating a separate packed buffer.
//
// Input: [num_kv_heads, head_dim] f32 — post-FWHT normalized values
// Output: same shape — centroid-snapped F32 values
//
// Workflow for Track B:
// 1. Encode K/V to TQ 4-bit (existing hadamard_quantize_kv_fast path)
// 2. Dequantize K/V to F32 (tq_dequantize_kv) into attn_k_normed
// 3. Apply requantize_to_f32_hb to snap values to 5/6-bit centroids
// 4. Copy requantized F32 into leg_hb_kvs shadow cache
// ============================================================================
// Not used in the primary Track B ablation (see tq_dequantize_hb_kv below);
// this is here for future use if a cleaner 1-step path is needed.
kernel void tq_dequantize_hb_kv(
device const uint8_t *packed [[buffer(0)]], // [nkv, capacity, hd] u8 byte-packed
device const float *norms [[buffer(1)]],
device float *dst [[buffer(2)]], // [nkv, hd] f32
constant TqDequantizeHbKvParams ¶ms [[buffer(3)]],
uint3 tgpig [[threadgroup_position_in_grid]],
uint tiitg [[thread_index_in_threadgroup]])
{
const uint kv_head = tgpig[0];
const uint hd = params.head_dim;
const uint cap = params.cache_capacity;
const uint pos = params.read_pos;
const float inv_sqrt_hd = rsqrt(float(hd));
const float sf = params.scale_factor_d512;
const bool is_d512 = (hd > 256);
const bool use_5bit = (params.codebook_bits == 5u);
if (kv_head >= params.num_kv_heads) return;
// Packed base: [kv_head, pos, 0..hd] — byte-packed (1 byte per element)
device const uint8_t *packed_pos = packed + kv_head * cap * hd + pos * hd;
const uint npp = params.norms_per_pos;
device const float *norms_pos = norms + kv_head * cap * npp + pos * npp;
device float *dst_head = dst + kv_head * hd;
for (uint coord = tiitg; coord < hd; coord += hd) {
uint block_idx = is_d512 ? (coord / 256u) : 0u;
block_idx = min(block_idx, npp - 1u);
float norm = norms_pos[block_idx];
// Same scale convention as tq_dequantize_kv (Track A fix)
float scale_norm = is_d512 ? (norm / sf) : (norm * inv_sqrt_hd);
uint idx = packed_pos[coord]; // byte-packed index
float centroid;
if (use_5bit) {
centroid = CODEBOOK_5BIT_DQ[idx & 0x1Fu];
} else if (params.codebook_bits == 6u) {
centroid = CODEBOOK_6BIT_DQ[idx & 0x3Fu];
} else {
centroid = CODEBOOK_8BIT_DQ[idx]; // 8-bit: full byte is the index
}
dst_head[coord] = centroid * scale_norm;
}
}