use crate::f16_to_f32;
#[must_use]
pub fn dequantize_q4_k_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
const SUPER_BLOCK_SIZE: usize = 256;
const SUPER_BLOCK_BYTES: usize = 144;
let num_blocks = num_elements.div_ceil(SUPER_BLOCK_SIZE);
let mut result = vec![0.0f32; num_blocks * SUPER_BLOCK_SIZE];
for sb_idx in 0..num_blocks {
let sb_start = sb_idx * SUPER_BLOCK_BYTES;
let out_start = sb_idx * SUPER_BLOCK_SIZE;
if sb_start + SUPER_BLOCK_BYTES > data.len() {
break;
}
let d = sanitize_f16_scale(data[sb_start], data[sb_start + 1]);
let dmin = sanitize_f16_scale(data[sb_start + 2], data[sb_start + 3]);
let (scales, mins) = unpack_q4k_scales(&data[sb_start + 4..sb_start + 16]);
let qs = &data[sb_start + 16..sb_start + 144];
dequantize_q4k_block(d, dmin, &scales, &mins, qs, &mut result[out_start..]);
}
result.truncate(num_elements);
result
}
#[inline]
fn sanitize_f16_scale(lo: u8, hi: u8) -> f32 {
let raw = f16_to_f32(u16::from_le_bytes([lo, hi]));
if raw.is_nan() || raw.is_infinite() {
0.0
} else {
raw
}
}
fn unpack_q4k_scales(scales_bytes: &[u8]) -> ([u8; 8], [u8; 8]) {
let mut scales = [0u8; 8];
let mut mins = [0u8; 8];
for i in 0..4 {
scales[i] = scales_bytes[i] & 0x3F;
mins[i] = scales_bytes[i + 4] & 0x3F;
scales[i + 4] = (scales_bytes[i + 8] & 0x0F) | ((scales_bytes[i] >> 6) << 4);
mins[i + 4] = (scales_bytes[i + 8] >> 4) | ((scales_bytes[i + 4] >> 6) << 4);
}
(scales, mins)
}
fn dequantize_q4k_block(
d: f32,
dmin: f32,
scales: &[u8; 8],
mins: &[u8; 8],
qs: &[u8],
output: &mut [f32],
) {
let mut ys_index = 0;
for chunk in 0..4 {
let is = chunk * 2;
let scale_lo = d * f32::from(scales[is]);
let min_lo = dmin * f32::from(mins[is]);
let scale_hi = d * f32::from(scales[is + 1]);
let min_hi = dmin * f32::from(mins[is + 1]);
for l in 0..32 {
let byte = qs[chunk * 32 + l];
output[ys_index] = scale_lo * f32::from(byte & 0x0F) - min_lo;
ys_index += 1;
}
for l in 0..32 {
let byte = qs[chunk * 32 + l];
output[ys_index] = scale_hi * f32::from((byte >> 4) & 0x0F) - min_hi;
ys_index += 1;
}
}
}
#[must_use]
pub fn dequantize_q5_k_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
const SUPER_BLOCK_SIZE: usize = 256;
const SUPER_BLOCK_BYTES: usize = 176;
let num_blocks = num_elements.div_ceil(SUPER_BLOCK_SIZE);
let mut result = vec![0.0f32; num_blocks * SUPER_BLOCK_SIZE];
for sb_idx in 0..num_blocks {
let sb_start = sb_idx * SUPER_BLOCK_BYTES;
let out_start = sb_idx * SUPER_BLOCK_SIZE;
if sb_start + SUPER_BLOCK_BYTES > data.len() {
break;
}
let d = sanitize_f16_scale(data[sb_start], data[sb_start + 1]);
let dmin = sanitize_f16_scale(data[sb_start + 2], data[sb_start + 3]);
let scales_bytes = &data[sb_start + 4..sb_start + 16];
let mut scales = [0u8; 8];
let mut mins = [0u8; 8];
for i in 0..4 {
scales[i] = scales_bytes[i] & 0x3F;
mins[i] = scales_bytes[i + 4] & 0x3F;
scales[i + 4] = (scales_bytes[i + 8] & 0x0F) | ((scales_bytes[i] >> 6) << 4);
mins[i + 4] = (scales_bytes[i + 8] >> 4) | ((scales_bytes[i + 4] >> 6) << 4);
}
let qh = &data[sb_start + 16..sb_start + 48];
let qs = &data[sb_start + 48..sb_start + 176];
for j in 0..8 {
let scale = d * f32::from(scales[j]);
let min_val = dmin * f32::from(mins[j]);
for k in 0..32 {
let idx = j * 32 + k;
let qs_idx = j * 16 + (k % 16);
let q_lo = if k < 16 {
qs[qs_idx] & 0x0F
} else {
(qs[qs_idx] >> 4) & 0x0F
};
let q_hi = (qh[k] >> j) & 1;
let q = q_lo | (q_hi << 4);
result[out_start + idx] = scale * f32::from(q) - min_val;
}
}
}
result.truncate(num_elements);
result
}
#[must_use]
pub fn dequantize_q6_k_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
const SUPER_BLOCK_SIZE: usize = 256;
const SUPER_BLOCK_BYTES: usize = 210;
let num_blocks = num_elements.div_ceil(SUPER_BLOCK_SIZE);
let mut result = vec![0.0f32; num_blocks * SUPER_BLOCK_SIZE];
for sb_idx in 0..num_blocks {
let sb_start = sb_idx * SUPER_BLOCK_BYTES;
let out_start = sb_idx * SUPER_BLOCK_SIZE;
if sb_start + SUPER_BLOCK_BYTES > data.len() {
break;
}
let ql = &data[sb_start..sb_start + 128];
let qh = &data[sb_start + 128..sb_start + 192];
let scales = &data[sb_start + 192..sb_start + 208];
let d = sanitize_f16_scale(data[sb_start + 208], data[sb_start + 209]);
for half in 0..2 {
let ql_base = half * 64;
let qh_base = half * 32;
let out_base = out_start + half * 128;
for l in 0..32 {
let q1_lo = ql[ql_base + l] & 0x0F;
let q2_lo = ql[ql_base + l + 32] & 0x0F;
let q3_lo = (ql[ql_base + l] >> 4) & 0x0F;
let q4_lo = (ql[ql_base + l + 32] >> 4) & 0x0F;
let qh_byte = qh[qh_base + l];
let q1_hi = (qh_byte & 0x03) << 4;
let q2_hi = ((qh_byte >> 2) & 0x03) << 4;
let q3_hi = ((qh_byte >> 4) & 0x03) << 4;
let q4_hi = ((qh_byte >> 6) & 0x03) << 4;
let q1 = (q1_lo | q1_hi) as i8 - 32;
let q2 = (q2_lo | q2_hi) as i8 - 32;
let q3 = (q3_lo | q3_hi) as i8 - 32;
let q4 = (q4_lo | q4_hi) as i8 - 32;
let scale_idx_1 = (half * 8) + (l / 16);
let scale_idx_2 = (half * 8) + (l / 16) + 2;
let scale_idx_3 = (half * 8) + (l / 16) + 4;
let scale_idx_4 = (half * 8) + (l / 16) + 6;
let s1 = scales[scale_idx_1] as i8;
let s2 = scales[scale_idx_2] as i8;
let s3 = scales[scale_idx_3] as i8;
let s4 = scales[scale_idx_4] as i8;
result[out_base + l] = d * f32::from(s1) * f32::from(q1);
result[out_base + l + 32] = d * f32::from(s2) * f32::from(q2);
result[out_base + l + 64] = d * f32::from(s3) * f32::from(q3);
result[out_base + l + 96] = d * f32::from(s4) * f32::from(q4);
}
}
}
result.truncate(num_elements);
result
}