Skip to main content

trueno/backends/q4k/
dequant.rs

1//! Q4_K dequantization to F32.
2//!
3//! Provides full dequantization of Q4_K compressed data for golden test comparison.
4
5use super::{parse_q4k_header, SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
6
7/// Dequantize Q4_K data to F32 (for golden test comparison)
8///
9/// This function fully dequantizes Q4K data to F32, matching the
10/// `dequantize_q4_k_to_f32` in aprender/src/format/converter.rs.
11pub fn dequantize_q4k_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
12    contract_pre_dequant!();
13    let num_blocks = (num_elements + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
14    let mut result = vec![0.0f32; num_blocks * SUPER_BLOCK_SIZE];
15
16    for sb_idx in 0..num_blocks {
17        let sb_start = sb_idx * SUPER_BLOCK_BYTES;
18        let out_start = sb_idx * SUPER_BLOCK_SIZE;
19
20        if sb_start + SUPER_BLOCK_BYTES > data.len() {
21            break;
22        }
23
24        let sb_data = &data[sb_start..sb_start + SUPER_BLOCK_BYTES];
25        let (d, dmin, scales, mins) = parse_q4k_header(sb_data);
26        let qs = sb_data.get(16..144).expect("Q4_K: need ≥144 bytes for qs");
27
28        let mut ys_index = out_start;
29
30        for chunk in 0..4 {
31            let q = &qs[chunk * 32..(chunk + 1) * 32];
32
33            let scale_idx_low = chunk * 2;
34            let scale_idx_high = chunk * 2 + 1;
35
36            let d1 = d * f32::from(scales[scale_idx_low]);
37            let dm1 = dmin * f32::from(mins[scale_idx_low]);
38            let d2 = d * f32::from(scales[scale_idx_high]);
39            let dm2 = dmin * f32::from(mins[scale_idx_high]);
40
41            // First pass: 32 low nibbles
42            for &byte in q {
43                result[ys_index] = d1 * (byte & 0xF) as f32 - dm1;
44                ys_index += 1;
45            }
46
47            // Second pass: 32 high nibbles
48            for &byte in q {
49                result[ys_index] = d2 * (byte >> 4) as f32 - dm2;
50                ys_index += 1;
51            }
52        }
53    }
54
55    result.truncate(num_elements);
56    contract_post_dequant!(result);
57    result
58}