Skip to main content

trueno/backends/q6k/
colmajor.rs

1//! Column-major Q6_K matrix-vector multiplication.
2//!
3//! This module implements column-major GEMV for GGML/GGUF format weights,
4//! where weights are stored column-first for cache-efficient streaming.
5
6use super::{f16_to_f32, SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
7
8/// Extract a single Q6K quantized value from packed ql/qh arrays.
9#[inline(always)]
10fn extract_q6k_value(ql: &[u8], qh: &[u8], idx: usize) -> i8 {
11    let ql_byte = ql[idx / 2];
12    let low4 = if idx % 2 == 0 { ql_byte & 0x0F } else { ql_byte >> 4 };
13    let qh_byte = qh[idx / 4];
14    let high2 = (qh_byte >> ((idx % 4) * 2)) & 0x03;
15    (low4 | (high2 << 4)) as i8 - 32
16}
17
18/// Accumulate one Q6_K superblock into output (column-major layout).
19#[inline]
20fn accumulate_q6k_superblock_colmajor(
21    sb_data: &[u8],
22    x_j: f32,
23    output: &mut [f32],
24    output_offset: usize,
25    ne0: usize,
26) {
27    let ql = sb_data.get(0..128).expect("Q6_K: need ≥128 bytes for ql");
28    let qh = sb_data.get(128..192).expect("Q6_K: need ≥192 bytes for qh");
29    let scales = sb_data.get(192..208).expect("Q6_K: need ≥208 bytes for scales");
30    let d = f16_to_f32(u16::from_le_bytes([sb_data[208], sb_data[209]]));
31
32    for group in 0..16 {
33        let scale = (scales[group] as i8) as f32;
34        let group_offset = group * 16;
35
36        for j in 0..16 {
37            let idx = group_offset + j;
38            let output_idx = output_offset + idx;
39            if output_idx >= ne0 {
40                continue;
41            }
42            let q6 = extract_q6k_value(ql, qh, idx);
43            output[output_idx] += x_j * d * scale * q6 as f32;
44        }
45    }
46}
47
48/// Fused Q6_K matrix-vector multiply for GGML column-major layout
49///
50/// Computes: output = input @ Q6K_weight (GGML convention: y = x @ W)
51/// where weight is stored in Q6_K format with GGML column-major super-block organization.
52///
53/// # Arguments
54/// * `q6k_data` - Raw Q6K bytes in GGML column-major layout [ne0, ne1]
55/// * `input` - F32 input vector [ne1] (input/reduction dimension)
56/// * `ne0` - Size of output dimension (rows in GGML, output size)
57/// * `ne1` - Size of input/reduction dimension (columns in GGML, input size)
58///
59/// # Returns
60/// F32 output vector [ne0]
61#[deprecated(
62    since = "0.15.0",
63    note = "LAYOUT-001: Use row-major kernels. APR/GGUF data is transposed at import boundary."
64)]
65pub fn matmul_q6k_f32_colmajor(
66    q6k_data: &[u8],
67    input: &[f32],
68    ne0: usize, // output dimension (rows)
69    ne1: usize, // input/reduction dimension (columns)
70) -> Vec<f32> {
71    assert_eq!(input.len(), ne1, "Input length must match ne1 (input dimension)");
72
73    let blocks_per_col = (ne0 + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
74    let col_bytes = blocks_per_col * SUPER_BLOCK_BYTES;
75
76    let mut output = vec![0.0f32; ne0];
77
78    for col_idx in 0..ne1 {
79        let col_start = col_idx * col_bytes;
80        let x_j = input[col_idx];
81
82        if x_j == 0.0 {
83            continue;
84        }
85
86        for sb_idx in 0..blocks_per_col {
87            let sb_start = col_start + sb_idx * SUPER_BLOCK_BYTES;
88            if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
89                break;
90            }
91            let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
92            let output_offset = sb_idx * SUPER_BLOCK_SIZE;
93            accumulate_q6k_superblock_colmajor(sb_data, x_j, &mut output, output_offset, ne0);
94        }
95    }
96
97    output
98}
99
100/// Runtime dispatch for column-major Q6K matmul
101///
102/// Uses scalar implementation for correctness.
103/// Critical for lm_head which is typically 151936 x 1536 (233M elements).
104#[deprecated(
105    since = "0.15.0",
106    note = "LAYOUT-001: Use row-major kernels. APR/GGUF data is transposed at import boundary."
107)]
108#[inline]
109pub fn matmul_q6k_f32_colmajor_dispatch(
110    q6k_data: &[u8],
111    input: &[f32],
112    ne0: usize,
113    ne1: usize,
114) -> Vec<f32> {
115    #[allow(deprecated)]
116    matmul_q6k_f32_colmajor(q6k_data, input, ne0, ne1)
117}