use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="kv_cache",
subop="update",
class=GenericEmpty,
tol=0.0,
kernel_mode=Grid3D,
)]
#[kernel]
pub fn kv_cache_update<T>(
src: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] max_seq: u32,
#[constexpr] position: u32,
) {
let idx = program_id::<0>();
let h = idx / head_dim;
let d = idx - h * head_dim;
let dst_idx = h * max_seq * head_dim + position * head_dim + d;
store(out[dst_idx], load(src[idx]));
}
macro_rules! quantize_kv_kernel {
($name:ident, $bits:literal, $subop:literal) => {
#[bench_kernel(op="kv_cache", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Grid3D,)]
#[kernel]
pub fn $name<T>(
src: Tensor<T>,
mut out_w: Tensor<u32>,
mut out_s: Tensor<T>,
mut out_b: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] max_seq: u32,
#[constexpr] group_size: u32,
#[constexpr] position: u32,
) {
let vals_per_pack = 32u32 / $bits;
let max_quant_u = (1u32 << $bits) - 1u32;
let max_quant_f = max_quant_u.cast::<f32>();
let g_global = program_id::<0>();
let groups_per_head = head_dim / group_size;
let h = g_global / groups_per_head;
let g_in_h = g_global - h * groups_per_head;
let d_start = g_in_h * group_size;
let src_base = h * head_dim;
let mut mn = load(src[src_base + d_start]).cast::<f32>();
let mut mx = mn;
for i in range(1u32, group_size, 1u32) {
let v = load(src[src_base + d_start + i]).cast::<f32>();
mn = select(v < mn, v, mn);
mx = select(v > mx, v, mx);
}
let range = mx - mn;
let safe_scale = select(range == 0.0f32, 1.0f32, range / max_quant_f);
let inv_scale = 1.0f32 / safe_scale;
let dst_sb_idx = (h * max_seq + position) * groups_per_head + g_in_h;
store(out_s[dst_sb_idx], safe_scale.cast::<T>());
store(out_b[dst_sb_idx], mn.cast::<T>());
let dst_w_base =
(h * max_seq + position) * (head_dim / vals_per_pack) + d_start / vals_per_pack;
for p in range(0u32, group_size / vals_per_pack, 1u32) {
let mut packed = 0u32;
for i in range(0u32, vals_per_pack, 1u32) {
let v = load(src[src_base + d_start + p * vals_per_pack + i]).cast::<f32>();
let q_f = (v - mn) * inv_scale + 0.5f32;
let q_clamped_f =
select(q_f > max_quant_f, max_quant_f, select(q_f < 0.0f32, 0.0f32, q_f));
let q = q_clamped_f.cast::<u32>();
packed = packed | (q << (i * $bits));
}
store(out_w[dst_w_base + p], packed);
}
}
};
}
macro_rules! bulk_dequant_kv_kernel {
($name:ident, $bits:literal, $subop:literal) => {
#[bench_kernel(op="kv_cache", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Grid3D,)]
#[kernel]
pub fn $name<T>(
in_w: Tensor<u32>,
in_s: Tensor<T>,
in_b: Tensor<T>,
mut out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] max_seq: u32,
#[constexpr] group_size: u32,
#[constexpr] n_positions: u32,
) {
let vals_per_pack = 32u32 / $bits;
let mask = (1u32 << $bits) - 1u32;
let idx = program_id::<0>();
let total_per_head = n_positions * head_dim;
let h = idx / total_per_head;
let rest = idx - h * total_per_head;
let pos = rest / head_dim;
let d = rest - pos * head_dim;
let groups_per_head = head_dim / group_size;
let g = d / group_size;
let scale = load(in_s[(h * max_seq + pos) * groups_per_head + g]).cast::<f32>();
let bias = load(in_b[(h * max_seq + pos) * groups_per_head + g]).cast::<f32>();
let pack_idx = (h * max_seq + pos) * (head_dim / vals_per_pack) + d / vals_per_pack;
let lane = d & (vals_per_pack - 1u32);
let packed = load(in_w[pack_idx]);
let q = (packed >> (lane * $bits)) & mask;
let w_real = q.cast::<f32>() * scale + bias;
let dst_idx = h * max_seq * head_dim + pos * head_dim + d;
store(out[dst_idx], w_real.cast::<T>());
}
};
}
quantize_kv_kernel!(quantize_kv_int4, 4u32, "quantize_int4");
quantize_kv_kernel!(quantize_kv_int8, 8u32, "quantize_int8");
bulk_dequant_kv_kernel!(bulk_dequant_kv_int4, 4u32, "bulk_dequant_int4");
bulk_dequant_kv_kernel!(bulk_dequant_kv_int8, 8u32, "bulk_dequant_int8");
macro_rules! quantize_kv_fp8 {
($name:ident, $subop:literal, $mant_f:literal, $mant_i:literal, $emin:literal, $emax:literal, $fp8max:literal) => {
#[bench_kernel(op="kv_cache", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Grid3D,)]
#[kernel]
pub fn $name<T>(
src: Tensor<T>,
mut out_w: Tensor<u32>,
mut out_s: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] max_seq: u32,
#[constexpr] group_size: u32,
#[constexpr] position: u32,
) {
let vals_per_pack = 4u32;
let g_global = program_id::<0>();
let groups_per_head = head_dim / group_size;
let h = g_global / groups_per_head;
let g_in_h = g_global - h * groups_per_head;
let d_start = g_in_h * group_size;
let src_base = h * head_dim;
let mut mx = 0.0f32;
for i in range(0u32, group_size, 1u32) {
let v = abs(load(src[src_base + d_start + i]).cast::<f32>());
mx = select(v > mx, v, mx);
}
let inv_scale = select(mx > 0.0f32, $fp8max / mx, 0.0f32);
let scale = select(mx > 0.0f32, mx / $fp8max, 0.0f32);
let dst_s_idx = (h * max_seq + position) * groups_per_head + g_in_h;
store(out_s[dst_s_idx], scale.cast::<T>());
let dst_w_base =
(h * max_seq + position) * (head_dim / vals_per_pack) + d_start / vals_per_pack;
for p in range(0u32, group_size / vals_per_pack, 1u32) {
let mut packed = 0u32;
for i in range(0u32, vals_per_pack, 1u32) {
let v = load(src[src_base + d_start + p * vals_per_pack + i]).cast::<f32>();
let sign = select(v < 0.0f32, 1u32, 0u32);
let ax = abs(v);
let norm = ax * inv_scale;
let raw_e = floor(log2(norm));
let e_lo = select(raw_e < $emin, $emin, raw_e);
let e = select(e_lo > $emax, $emax, e_lo);
let quantum = exp2(e - $mant_f);
let q_snapped = select(norm > 0.0f32, round(norm / quantum), 0.0f32);
let mant_offset = exp2($mant_f);
let m_unbiased =
select(q_snapped > mant_offset, q_snapped - mant_offset, 0.0f32);
let max_m = mant_offset - 1.0f32;
let q_m = select(m_unbiased > max_m, max_m, m_unbiased);
let e_int = (e + $emax).cast::<u32>();
let m_int = q_m.cast::<u32>();
let code7 = (e_int << $mant_i) | m_int;
let code = (sign << 7u32) | (code7 & 127u32);
packed = packed | (code << (i * 8u32));
}
store(out_w[dst_w_base + p], packed);
}
}
};
}
macro_rules! bulk_dequant_kv_fp8 {
($name:ident, $subop:literal, $mant_f:literal, $mant_i:literal, $emin:literal, $emax:literal, $fp8max:literal) => {
#[bench_kernel(op="kv_cache", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Grid3D,)]
#[kernel]
pub fn $name<T>(
in_w: Tensor<u32>,
in_s: Tensor<T>,
mut out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] max_seq: u32,
#[constexpr] group_size: u32,
#[constexpr] n_positions: u32,
) {
let vals_per_pack = 4u32;
let idx = program_id::<0>();
let total_per_head = n_positions * head_dim;
let h = idx / total_per_head;
let rest = idx - h * total_per_head;
let pos = rest / head_dim;
let d = rest - pos * head_dim;
let groups_per_head = head_dim / group_size;
let g = d / group_size;
let scale = load(in_s[(h * max_seq + pos) * groups_per_head + g]).cast::<f32>();
let pack_idx = (h * max_seq + pos) * (head_dim / vals_per_pack) + d / vals_per_pack;
let lane = d & (vals_per_pack - 1u32);
let packed = load(in_w[pack_idx]);
let code = (packed >> (lane * 8u32)) & 255u32;
let sign = 1.0f32 - 2.0f32 * (code >> 7u32).cast::<f32>();
let code7 = code & 127u32;
let e_raw = code7 >> $mant_i;
let m_mask = (1u32 << $mant_i) - 1u32;
let m_raw = code7 & m_mask;
let is_normal = select(e_raw > 0u32, 1u32, 0u32);
let e_f = e_raw.cast::<f32>() - $emax;
let norm_mag = exp2(e_f) * (1.0f32 + m_raw.cast::<f32>() * exp2(-($mant_f)));
let sub_mag = exp2($emin) * m_raw.cast::<f32>() * exp2(-($mant_f));
let mag = select(is_normal == 1u32, norm_mag, sub_mag);
let w_real = scale * sign * mag;
let dst_idx = h * max_seq * head_dim + pos * head_dim + d;
store(out[dst_idx], w_real.cast::<T>());
}
};
}
quantize_kv_fp8!(
quantize_kv_fp8_e4m3,
"quantize_fp8_e4m3",
3.0f32,
3u32,
-6.0f32,
7.0f32,
240.0f32
);
quantize_kv_fp8!(
quantize_kv_fp8_e5m2,
"quantize_fp8_e5m2",
2.0f32,
2u32,
-14.0f32,
15.0f32,
57344.0f32
);
bulk_dequant_kv_fp8!(
bulk_dequant_kv_fp8_e4m3,
"bulk_dequant_fp8_e4m3",
3.0f32,
3u32,
-6.0f32,
7.0f32,
240.0f32
);
bulk_dequant_kv_fp8!(
bulk_dequant_kv_fp8_e5m2,
"bulk_dequant_fp8_e5m2",
2.0f32,
2u32,
-14.0f32,
15.0f32,
57344.0f32
);