Skip to main content

trueno/backends/q4k/gemv/
scalar.rs

1#![allow(missing_docs)]
2//! Scalar Q4_K GEMV implementations.
3//!
4//! Contains the baseline scalar dot product, 4-way unrolled fused GEMV,
5//! and the scalar chunk processor for parallel dispatch.
6
7use super::super::{parse_q4k_header, SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
8
9/// Scalar dot product for one Q4K super-block row.
10#[inline(always)]
11fn process_q4k_superblock_scalar(
12    sb_data: &[u8],
13    input: &[f32],
14    input_offset: usize,
15    in_dim: usize,
16) -> f32 {
17    let (d, dmin, scales, mins) = parse_q4k_header(sb_data);
18    let qs = sb_data.get(16..144).expect("Q4_K: need ≥144 bytes for qs");
19    let mut sum = 0.0f32;
20
21    for chunk in 0..4 {
22        let chunk_start = chunk * 64;
23        let q_start = chunk * 32;
24
25        let d1 = d * f32::from(scales[chunk * 2]);
26        let dm1 = dmin * f32::from(mins[chunk * 2]);
27        let d2 = d * f32::from(scales[chunk * 2 + 1]);
28        let dm2 = dmin * f32::from(mins[chunk * 2 + 1]);
29
30        for i in 0..32 {
31            let input_idx = input_offset + chunk_start + i;
32            if input_idx < in_dim {
33                let q_val = (qs[q_start + i] & 0x0F) as f32;
34                sum += (d1 * q_val - dm1) * input[input_idx];
35            }
36        }
37        for i in 0..32 {
38            let input_idx = input_offset + chunk_start + 32 + i;
39            if input_idx < in_dim {
40                let q_val = (qs[q_start + i] >> 4) as f32;
41                sum += (d2 * q_val - dm2) * input[input_idx];
42            }
43        }
44    }
45    sum
46}
47
48pub fn matmul_q4k_f32_scalar(
49    q4k_data: &[u8],
50    input: &[f32],
51    out_dim: usize,
52    in_dim: usize,
53) -> Vec<f32> {
54    assert_eq!(input.len(), in_dim, "Input length mismatch");
55    assert!(
56        in_dim % SUPER_BLOCK_SIZE == 0 || in_dim < SUPER_BLOCK_SIZE,
57        "in_dim must be multiple of 256 (or smaller for padding)"
58    );
59
60    let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
61    let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
62    let expected_size = out_dim * row_bytes;
63
64    assert!(
65        q4k_data.len() >= expected_size,
66        "Q4K data too small: {} < {}",
67        q4k_data.len(),
68        expected_size
69    );
70
71    // Uninit: output[out_idx] = sum (SET) for every out_idx.
72    let mut output: Vec<f32> = Vec::with_capacity(out_dim);
73    // SAFETY: Each output[out_idx] is SET to the accumulated sum. No reads before writes.
74    unsafe {
75        output.set_len(out_dim);
76    }
77
78    for out_idx in 0..out_dim {
79        let row_start = out_idx * row_bytes;
80        let mut sum = 0.0f32;
81
82        for sb_idx in 0..num_blocks_per_row {
83            let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
84            let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
85            let input_offset = sb_idx * SUPER_BLOCK_SIZE;
86            sum += process_q4k_superblock_scalar(sb_data, input, input_offset, in_dim);
87        }
88
89        output[out_idx] = sum;
90    }
91
92    output
93}
94
95/// Process one Q4K nibble half (low or high) with 4-way unrolled accumulation.
96#[inline(always)]
97fn process_q4k_nibble_half(
98    qs: &[u8],
99    q_start: usize,
100    input: &[f32],
101    input_base: usize,
102    in_dim: usize,
103    d_val: f32,
104    dm_val: f32,
105    shift: u8,
106    acc: &mut [f32; 4],
107) {
108    let mut i = 0;
109    while i + 3 < 32 {
110        let idx = input_base + i;
111        if idx + 3 < in_dim {
112            let q0 = ((qs[q_start + i] >> shift) & 0x0F) as f32;
113            let q1 = ((qs[q_start + i + 1] >> shift) & 0x0F) as f32;
114            let q2 = ((qs[q_start + i + 2] >> shift) & 0x0F) as f32;
115            let q3 = ((qs[q_start + i + 3] >> shift) & 0x0F) as f32;
116
117            acc[0] = (d_val * q0 - dm_val).mul_add(input[idx], acc[0]);
118            acc[1] = (d_val * q1 - dm_val).mul_add(input[idx + 1], acc[1]);
119            acc[2] = (d_val * q2 - dm_val).mul_add(input[idx + 2], acc[2]);
120            acc[3] = (d_val * q3 - dm_val).mul_add(input[idx + 3], acc[3]);
121        }
122        i += 4;
123    }
124    while i < 32 {
125        let idx = input_base + i;
126        if idx < in_dim {
127            let q_val = ((qs[q_start + i] >> shift) & 0x0F) as f32;
128            acc[0] = (d_val * q_val - dm_val).mul_add(input[idx], acc[0]);
129        }
130        i += 1;
131    }
132}
133
134/// Fused Q4_K matrix-vector multiply (optimized with 4-way unrolling)
135///
136/// This version uses 4 independent accumulators to improve instruction-level
137/// parallelism while maintaining scalar correctness.
138///
139/// # Arguments
140/// Same as `matmul_q4k_f32_scalar`
141pub fn matmul_q4k_f32(q4k_data: &[u8], input: &[f32], out_dim: usize, in_dim: usize) -> Vec<f32> {
142    assert_eq!(input.len(), in_dim, "Input length mismatch");
143
144    let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
145    let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
146
147    // Uninit: output[out_idx] = (acc[0]+acc[1])+(acc[2]+acc[3]) (SET) for every out_idx.
148    let mut output: Vec<f32> = Vec::with_capacity(out_dim);
149    // SAFETY: Each output[out_idx] is SET from local accumulator. No reads before writes.
150    unsafe {
151        output.set_len(out_dim);
152    }
153
154    for out_idx in 0..out_dim {
155        let row_start = out_idx * row_bytes;
156        let mut acc = [0.0f32; 4];
157
158        for sb_idx in 0..num_blocks_per_row {
159            let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
160            let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
161
162            let (d, dmin, scales, mins) = parse_q4k_header(sb_data);
163            let qs = sb_data.get(16..144).expect("Q4_K: need ≥144 bytes for qs");
164            let input_offset = sb_idx * SUPER_BLOCK_SIZE;
165
166            for chunk in 0..4 {
167                let chunk_start = chunk * 64;
168                let q_start = chunk * 32;
169
170                let d1 = d * f32::from(scales[chunk * 2]);
171                let dm1 = dmin * f32::from(mins[chunk * 2]);
172                let d2 = d * f32::from(scales[chunk * 2 + 1]);
173                let dm2 = dmin * f32::from(mins[chunk * 2 + 1]);
174
175                let base_low = input_offset + chunk_start;
176                process_q4k_nibble_half(qs, q_start, input, base_low, in_dim, d1, dm1, 0, &mut acc);
177
178                let base_high = input_offset + chunk_start + 32;
179                process_q4k_nibble_half(
180                    qs, q_start, input, base_high, in_dim, d2, dm2, 4, &mut acc,
181                );
182            }
183        }
184
185        output[out_idx] = (acc[0] + acc[1]) + (acc[2] + acc[3]);
186    }
187
188    output
189}
190
191/// Scalar chunk processor for parallel dispatch.
192pub(crate) fn compute_chunk_q4k_scalar(
193    q4k_data: &[u8],
194    input: &[f32],
195    chunk: &mut [f32],
196    start_row: usize,
197    out_dim: usize,
198    in_dim: usize,
199    num_blocks_per_row: usize,
200    row_bytes: usize,
201) {
202    for (local_idx, out_val) in chunk.iter_mut().enumerate() {
203        let out_idx = start_row + local_idx;
204        if out_idx >= out_dim {
205            break;
206        }
207
208        let row_start = out_idx * row_bytes;
209        let mut sum = 0.0f32;
210
211        for sb_idx in 0..num_blocks_per_row {
212            let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
213            if sb_start + SUPER_BLOCK_BYTES > q4k_data.len() {
214                break;
215            }
216            let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
217            let input_offset = sb_idx * SUPER_BLOCK_SIZE;
218            sum += process_q4k_superblock_scalar(sb_data, input, input_offset, in_dim);
219        }
220
221        *out_val = sum;
222    }
223}