use super::super::{Q4K_SUPER_BLOCK_BYTES, Q4K_SUPER_BLOCK_SIZE};
pub const W4A16_TILE_BYTES: usize = 2560;
pub const W4A16_TILE_COLS: usize = 16;
pub const W4A16_SCALE_OFFSET: usize = 0;
pub const W4A16_MIN_OFFSET: usize = 256;
pub const W4A16_QS_OFFSET: usize = 512;
const NUM_SUB_BLOCKS: usize = 8;
const SUB_BLOCK_SIZE: usize = 32;
const SB_D_OFFSET: usize = 0;
const SB_DMIN_OFFSET: usize = 2;
const SB_SCALES_OFFSET: usize = 4;
const SB_QS_OFFSET: usize = 16;
const SB_QS_SIZE: usize = 128;
fn f16_to_f32(bytes: [u8; 2]) -> f32 {
let h = u16::from_le_bytes(bytes) as u32;
let sign = (h >> 15) & 1;
let exp = (h >> 10) & 0x1F;
let frac = h & 0x3FF;
if exp == 0 {
if frac == 0 {
f32::from_bits(sign << 31)
} else {
let val = (frac as f32) / 1024.0 * (2.0f32).powi(-14);
if sign == 1 {
-val
} else {
val
}
}
} else if exp == 31 {
if frac == 0 {
if sign == 1 {
f32::NEG_INFINITY
} else {
f32::INFINITY
}
} else {
f32::NAN
}
} else {
let f_exp = (exp as i32 - 15 + 127) as u32;
let f_frac = frac << 13;
f32::from_bits((sign << 31) | (f_exp << 23) | f_frac)
}
}
fn f32_to_f16(val: f32) -> [u8; 2] {
let bits = val.to_bits();
let sign = (bits >> 31) & 1;
let exp = ((bits >> 23) & 0xFF) as i32 - 127;
let frac = bits & 0x7FFFFF;
let h = if exp > 15 {
(sign << 15) | 0x7C00
} else if exp < -14 {
sign << 15
} else {
let h_exp = ((exp + 15) as u32) & 0x1F;
let h_frac = frac >> 13;
(sign << 15) | (h_exp << 10) | h_frac
};
(h as u16).to_le_bytes()
}
fn extract_scale_min(scales: &[u8], sub_block: usize) -> (u8, u8) {
debug_assert!(scales.len() == 12);
debug_assert!(sub_block < 8);
if sub_block < 4 {
let scale = scales[sub_block] & 0x3F;
let min = scales[4 + sub_block] & 0x3F;
(scale, min)
} else {
let i = sub_block - 4;
let combo = scales[8 + i];
let sc_low4 = combo & 0x0F;
let sc_high2 = (scales[i] >> 6) & 0x03;
let scale = sc_low4 | (sc_high2 << 4);
let mn_low4 = (combo >> 4) & 0x0F;
let mn_high2 = (scales[4 + i] >> 6) & 0x03;
let min = mn_low4 | (mn_high2 << 4);
(scale, min)
}
}
pub fn repack_q4k_w4a16(src: &[u8], n: usize, k: usize) -> Vec<u8> {
assert!(
k % Q4K_SUPER_BLOCK_SIZE as usize == 0,
"K must be multiple of 256"
);
let num_sb = k / Q4K_SUPER_BLOCK_SIZE as usize;
let sb_bytes = Q4K_SUPER_BLOCK_BYTES as usize;
assert_eq!(
src.len(),
n * num_sb * sb_bytes,
"src length {} != N({}) × num_sb({}) × {}",
src.len(),
n,
num_sb,
n * num_sb * sb_bytes
);
let n_tiles = (n + W4A16_TILE_COLS - 1) / W4A16_TILE_COLS;
let mut dst = vec![0u8; n_tiles * num_sb * W4A16_TILE_BYTES];
for tile_idx in 0..n_tiles {
let col_base = tile_idx * W4A16_TILE_COLS;
for sb_idx in 0..num_sb {
let tile_offset = (tile_idx * num_sb + sb_idx) * W4A16_TILE_BYTES;
for col_in_tile in 0..W4A16_TILE_COLS {
let global_col = col_base + col_in_tile;
let clamped_col = global_col.min(n - 1);
let sb_src_offset = (clamped_col * num_sb + sb_idx) * sb_bytes;
let d = f16_to_f32([
src[sb_src_offset + SB_D_OFFSET],
src[sb_src_offset + SB_D_OFFSET + 1],
]);
let dmin = f16_to_f32([
src[sb_src_offset + SB_DMIN_OFFSET],
src[sb_src_offset + SB_DMIN_OFFSET + 1],
]);
let scales =
&src[sb_src_offset + SB_SCALES_OFFSET..sb_src_offset + SB_SCALES_OFFSET + 12];
for sb_sub in 0..NUM_SUB_BLOCKS {
let (scale_int, min_int) = extract_scale_min(scales, sb_sub);
let eff_scale = d * scale_int as f32;
let eff_min = dmin * min_int as f32;
let scale_dst = tile_offset
+ W4A16_SCALE_OFFSET
+ (sb_sub * W4A16_TILE_COLS + col_in_tile) * 2;
let scale_bytes = f32_to_f16(eff_scale);
dst[scale_dst] = scale_bytes[0];
dst[scale_dst + 1] = scale_bytes[1];
let min_dst = tile_offset
+ W4A16_MIN_OFFSET
+ (sb_sub * W4A16_TILE_COLS + col_in_tile) * 2;
let min_bytes = f32_to_f16(eff_min);
dst[min_dst] = min_bytes[0];
dst[min_dst + 1] = min_bytes[1];
}
for byte_i in 0..SB_QS_SIZE {
let qs_dst =
tile_offset + W4A16_QS_OFFSET + byte_i * W4A16_TILE_COLS + col_in_tile;
dst[qs_dst] = src[sb_src_offset + SB_QS_OFFSET + byte_i];
}
}
}
}
dst
}
#[must_use]
pub fn w4a16_size(n: usize, k: usize) -> usize {
let num_sb = k / Q4K_SUPER_BLOCK_SIZE as usize;
let n_tiles = (n + W4A16_TILE_COLS - 1) / W4A16_TILE_COLS;
n_tiles * num_sb * W4A16_TILE_BYTES
}
#[cfg(test)]
#[allow(clippy::identity_op, clippy::erasing_op)]
mod tests {
use super::*;
#[test]
fn test_w4a16_size() {
assert_eq!(w4a16_size(1536, 1536), 96 * 6 * W4A16_TILE_BYTES);
assert_eq!(w4a16_size(8960, 1536), 560 * 6 * W4A16_TILE_BYTES);
}
#[test]
fn test_w4a16_size_non_aligned() {
assert_eq!(w4a16_size(17, 256), 2 * 1 * W4A16_TILE_BYTES);
}
#[test]
fn test_repack_preserves_total_size() {
let n = 32;
let k = 256;
let src = vec![0u8; n * 1 * Q4K_SUPER_BLOCK_BYTES as usize];
let dst = repack_q4k_w4a16(&src, n, k);
assert_eq!(dst.len(), 2 * 1 * W4A16_TILE_BYTES);
}
#[test]
fn test_extract_scale_min_low() {
let mut scales = [0u8; 12];
scales[0] = 0x15; scales[4] = 0x0A; let (s, m) = extract_scale_min(&scales, 0);
assert_eq!(s, 21);
assert_eq!(m, 10);
}
#[test]
fn test_extract_scale_min_high() {
let mut scales = [0u8; 12];
scales[8] = 0x37; scales[0] = 0x80; scales[4] = 0x40; let (s, m) = extract_scale_min(&scales, 4);
assert_eq!(s, 7 | (2 << 4)); assert_eq!(m, 3 | (1 << 4)); }
#[test]
fn test_repack_roundtrip_dequant() {
let n = 16;
let k = 256;
let sb_bytes = Q4K_SUPER_BLOCK_BYTES as usize;
let mut src = vec![0u8; n * sb_bytes];
for col in 0..n {
let offset = col * sb_bytes;
let d_bytes = f32_to_f16(0.5);
src[offset] = d_bytes[0];
src[offset + 1] = d_bytes[1];
let dmin_bytes = f32_to_f16(0.1);
src[offset + 2] = dmin_bytes[0];
src[offset + 3] = dmin_bytes[1];
for i in 0..12 {
src[offset + 4 + i] = 1;
}
for i in 0..128 {
src[offset + 16 + i] = ((i % 16) | ((i % 16) << 4)) as u8;
}
}
let dst = repack_q4k_w4a16(&src, n, k);
let col = 5;
let sb_offset = col * sb_bytes;
let d = f16_to_f32([src[sb_offset], src[sb_offset + 1]]);
let dmin = f16_to_f32([src[sb_offset + 2], src[sb_offset + 3]]);
let (scale_int, min_int) = extract_scale_min(&src[sb_offset + 4..sb_offset + 16], 0);
let quant = (src[sb_offset + SB_QS_OFFSET] & 0x0F) as f32;
let original_val = d * scale_int as f32 * quant - dmin * min_int as f32;
let eff_scale_off = W4A16_SCALE_OFFSET + (0 * W4A16_TILE_COLS + col) * 2;
let eff_scale = f16_to_f32([dst[eff_scale_off], dst[eff_scale_off + 1]]);
let eff_min_off = W4A16_MIN_OFFSET + (0 * W4A16_TILE_COLS + col) * 2;
let eff_min = f16_to_f32([dst[eff_min_off], dst[eff_min_off + 1]]);
let qs_byte = dst[W4A16_QS_OFFSET + 0 * W4A16_TILE_COLS + col];
let quant_w4 = (qs_byte & 0x0F) as f32;
let w4a16_val = eff_scale * quant_w4 - eff_min;
assert!(
(original_val - w4a16_val).abs() < 0.01,
"Roundtrip mismatch: original={} w4a16={}",
original_val,
w4a16_val
);
}
#[test]
fn test_repack_qs_interleaving() {
let n = 16;
let k = 256;
let sb_bytes = Q4K_SUPER_BLOCK_BYTES as usize;
let mut src = vec![0u8; n * sb_bytes];
for col in 0..16u8 {
let offset = col as usize * sb_bytes + SB_QS_OFFSET;
src[offset] = col + 0x10; }
let dst = repack_q4k_w4a16(&src, n, k);
for col in 0..16u8 {
let qs_offset = W4A16_QS_OFFSET + 0 * W4A16_TILE_COLS + col as usize;
assert_eq!(
dst[qs_offset],
col + 0x10,
"qs interleave failed col {}",
col
);
}
}
#[test]
fn test_repack_effective_scales_coalesced() {
let n = 16;
let k = 256;
let sb_bytes = Q4K_SUPER_BLOCK_BYTES as usize;
let mut src = vec![0u8; n * sb_bytes];
for col in 0..16 {
let offset = col * sb_bytes;
let d_bytes = f32_to_f16(2.0);
src[offset] = d_bytes[0];
src[offset + 1] = d_bytes[1];
src[offset + SB_SCALES_OFFSET] = 3;
}
let dst = repack_q4k_w4a16(&src, n, k);
for col in 0..16 {
let off = W4A16_SCALE_OFFSET + (0 * W4A16_TILE_COLS + col) * 2;
let val = f16_to_f32([dst[off], dst[off + 1]]);
assert!(
(val - 6.0).abs() < 0.1,
"col {} eff_scale = {} expected 6.0",
col,
val
);
}
}
#[test]
fn test_repack_padding_columns() {
let n = 17;
let k = 256;
let sb_bytes = Q4K_SUPER_BLOCK_BYTES as usize;
let mut src = vec![0u8; n * sb_bytes];
let offset = 16 * sb_bytes;
let d_bytes = f32_to_f16(1.0);
src[offset] = d_bytes[0];
src[offset + 1] = d_bytes[1];
src[offset + SB_SCALES_OFFSET] = 5;
let dst = repack_q4k_w4a16(&src, n, k);
let tile1_offset = W4A16_TILE_BYTES;
let off = tile1_offset + W4A16_SCALE_OFFSET + (0 * W4A16_TILE_COLS + 0) * 2;
let val = f16_to_f32([dst[off], dst[off + 1]]);
assert!(
(val - 5.0).abs() < 0.1,
"Padded tile col 0 eff_scale={}",
val
);
let off_pad = tile1_offset + W4A16_SCALE_OFFSET + (0 * W4A16_TILE_COLS + 1) * 2;
let val_pad = f16_to_f32([dst[off_pad], dst[off_pad + 1]]);
assert!(
(val_pad - 5.0).abs() < 0.1,
"Padded col should clone: eff_scale={}",
val_pad
);
}
#[test]
fn test_repack_multiple_sbs() {
let n = 16;
let k = 512;
let num_sb = 2;
let sb_bytes = Q4K_SUPER_BLOCK_BYTES as usize;
let mut src = vec![0u8; n * num_sb * sb_bytes];
let sb_offset = (3 * num_sb + 1) * sb_bytes;
src[sb_offset + SB_QS_OFFSET] = 0xAB;
let dst = repack_q4k_w4a16(&src, n, k);
let tile_offset = 1 * W4A16_TILE_BYTES; let qs_byte = dst[tile_offset + W4A16_QS_OFFSET + 0 * W4A16_TILE_COLS + 3];
assert_eq!(qs_byte, 0xAB, "Multi-SB repack failed");
}
}