trueno/backends/q4k/gemv/
scalar.rs1#![allow(missing_docs)]
2use super::super::{parse_q4k_header, SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
8
9#[inline(always)]
11fn process_q4k_superblock_scalar(
12 sb_data: &[u8],
13 input: &[f32],
14 input_offset: usize,
15 in_dim: usize,
16) -> f32 {
17 let (d, dmin, scales, mins) = parse_q4k_header(sb_data);
18 let qs = sb_data.get(16..144).expect("Q4_K: need ≥144 bytes for qs");
19 let mut sum = 0.0f32;
20
21 for chunk in 0..4 {
22 let chunk_start = chunk * 64;
23 let q_start = chunk * 32;
24
25 let d1 = d * f32::from(scales[chunk * 2]);
26 let dm1 = dmin * f32::from(mins[chunk * 2]);
27 let d2 = d * f32::from(scales[chunk * 2 + 1]);
28 let dm2 = dmin * f32::from(mins[chunk * 2 + 1]);
29
30 for i in 0..32 {
31 let input_idx = input_offset + chunk_start + i;
32 if input_idx < in_dim {
33 let q_val = (qs[q_start + i] & 0x0F) as f32;
34 sum += (d1 * q_val - dm1) * input[input_idx];
35 }
36 }
37 for i in 0..32 {
38 let input_idx = input_offset + chunk_start + 32 + i;
39 if input_idx < in_dim {
40 let q_val = (qs[q_start + i] >> 4) as f32;
41 sum += (d2 * q_val - dm2) * input[input_idx];
42 }
43 }
44 }
45 sum
46}
47
48pub fn matmul_q4k_f32_scalar(
49 q4k_data: &[u8],
50 input: &[f32],
51 out_dim: usize,
52 in_dim: usize,
53) -> Vec<f32> {
54 assert_eq!(input.len(), in_dim, "Input length mismatch");
55 assert!(
56 in_dim % SUPER_BLOCK_SIZE == 0 || in_dim < SUPER_BLOCK_SIZE,
57 "in_dim must be multiple of 256 (or smaller for padding)"
58 );
59
60 let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
61 let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
62 let expected_size = out_dim * row_bytes;
63
64 assert!(
65 q4k_data.len() >= expected_size,
66 "Q4K data too small: {} < {}",
67 q4k_data.len(),
68 expected_size
69 );
70
71 let mut output: Vec<f32> = Vec::with_capacity(out_dim);
73 unsafe {
75 output.set_len(out_dim);
76 }
77
78 for out_idx in 0..out_dim {
79 let row_start = out_idx * row_bytes;
80 let mut sum = 0.0f32;
81
82 for sb_idx in 0..num_blocks_per_row {
83 let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
84 let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
85 let input_offset = sb_idx * SUPER_BLOCK_SIZE;
86 sum += process_q4k_superblock_scalar(sb_data, input, input_offset, in_dim);
87 }
88
89 output[out_idx] = sum;
90 }
91
92 output
93}
94
95#[inline(always)]
97fn process_q4k_nibble_half(
98 qs: &[u8],
99 q_start: usize,
100 input: &[f32],
101 input_base: usize,
102 in_dim: usize,
103 d_val: f32,
104 dm_val: f32,
105 shift: u8,
106 acc: &mut [f32; 4],
107) {
108 let mut i = 0;
109 while i + 3 < 32 {
110 let idx = input_base + i;
111 if idx + 3 < in_dim {
112 let q0 = ((qs[q_start + i] >> shift) & 0x0F) as f32;
113 let q1 = ((qs[q_start + i + 1] >> shift) & 0x0F) as f32;
114 let q2 = ((qs[q_start + i + 2] >> shift) & 0x0F) as f32;
115 let q3 = ((qs[q_start + i + 3] >> shift) & 0x0F) as f32;
116
117 acc[0] = (d_val * q0 - dm_val).mul_add(input[idx], acc[0]);
118 acc[1] = (d_val * q1 - dm_val).mul_add(input[idx + 1], acc[1]);
119 acc[2] = (d_val * q2 - dm_val).mul_add(input[idx + 2], acc[2]);
120 acc[3] = (d_val * q3 - dm_val).mul_add(input[idx + 3], acc[3]);
121 }
122 i += 4;
123 }
124 while i < 32 {
125 let idx = input_base + i;
126 if idx < in_dim {
127 let q_val = ((qs[q_start + i] >> shift) & 0x0F) as f32;
128 acc[0] = (d_val * q_val - dm_val).mul_add(input[idx], acc[0]);
129 }
130 i += 1;
131 }
132}
133
134pub fn matmul_q4k_f32(q4k_data: &[u8], input: &[f32], out_dim: usize, in_dim: usize) -> Vec<f32> {
142 assert_eq!(input.len(), in_dim, "Input length mismatch");
143
144 let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
145 let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
146
147 let mut output: Vec<f32> = Vec::with_capacity(out_dim);
149 unsafe {
151 output.set_len(out_dim);
152 }
153
154 for out_idx in 0..out_dim {
155 let row_start = out_idx * row_bytes;
156 let mut acc = [0.0f32; 4];
157
158 for sb_idx in 0..num_blocks_per_row {
159 let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
160 let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
161
162 let (d, dmin, scales, mins) = parse_q4k_header(sb_data);
163 let qs = sb_data.get(16..144).expect("Q4_K: need ≥144 bytes for qs");
164 let input_offset = sb_idx * SUPER_BLOCK_SIZE;
165
166 for chunk in 0..4 {
167 let chunk_start = chunk * 64;
168 let q_start = chunk * 32;
169
170 let d1 = d * f32::from(scales[chunk * 2]);
171 let dm1 = dmin * f32::from(mins[chunk * 2]);
172 let d2 = d * f32::from(scales[chunk * 2 + 1]);
173 let dm2 = dmin * f32::from(mins[chunk * 2 + 1]);
174
175 let base_low = input_offset + chunk_start;
176 process_q4k_nibble_half(qs, q_start, input, base_low, in_dim, d1, dm1, 0, &mut acc);
177
178 let base_high = input_offset + chunk_start + 32;
179 process_q4k_nibble_half(
180 qs, q_start, input, base_high, in_dim, d2, dm2, 4, &mut acc,
181 );
182 }
183 }
184
185 output[out_idx] = (acc[0] + acc[1]) + (acc[2] + acc[3]);
186 }
187
188 output
189}
190
191pub(crate) fn compute_chunk_q4k_scalar(
193 q4k_data: &[u8],
194 input: &[f32],
195 chunk: &mut [f32],
196 start_row: usize,
197 out_dim: usize,
198 in_dim: usize,
199 num_blocks_per_row: usize,
200 row_bytes: usize,
201) {
202 for (local_idx, out_val) in chunk.iter_mut().enumerate() {
203 let out_idx = start_row + local_idx;
204 if out_idx >= out_dim {
205 break;
206 }
207
208 let row_start = out_idx * row_bytes;
209 let mut sum = 0.0f32;
210
211 for sb_idx in 0..num_blocks_per_row {
212 let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
213 if sb_start + SUPER_BLOCK_BYTES > q4k_data.len() {
214 break;
215 }
216 let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
217 let input_offset = sb_idx * SUPER_BLOCK_SIZE;
218 sum += process_q4k_superblock_scalar(sb_data, input, input_offset, in_dim);
219 }
220
221 *out_val = sum;
222 }
223}