1#![allow(missing_docs)]
2use super::{f16_to_f32, SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
8
9#[inline(always)]
12fn extract_q6k_scalar(ql: &[u8], qh: &[u8], idx: usize) -> i8 {
13 let ql_byte = ql[idx / 2];
14 let low4 = if idx % 2 == 0 { ql_byte & 0x0F } else { ql_byte >> 4 };
15 let qh_byte = qh[idx / 4];
16 let high2 = (qh_byte >> ((idx % 4) * 2)) & 0x03;
17 (low4 | (high2 << 4)) as i8 - 32
18}
19
20#[inline(always)]
22fn process_q6k_superblock_scalar(
23 sb_data: &[u8],
24 input: &[f32],
25 input_offset: usize,
26 in_dim: usize,
27) -> f32 {
28 let ql = sb_data.get(0..128).expect("Q6_K: need ≥128 bytes for ql");
29 let qh = sb_data.get(128..192).expect("Q6_K: need ≥192 bytes for qh");
30 let scales = sb_data.get(192..208).expect("Q6_K: need ≥208 bytes for scales");
31 let d = f16_to_f32(u16::from_le_bytes([sb_data[208], sb_data[209]]));
32 let mut sum = 0.0f32;
33
34 for group in 0..16 {
35 let scale = (scales[group] as i8) as f32;
36 let group_offset = group * 16;
37
38 for j in 0..16 {
39 let idx = group_offset + j;
40 let input_idx = input_offset + idx;
41 if input_idx >= in_dim {
42 continue;
43 }
44 let q6 = extract_q6k_scalar(ql, qh, idx);
45 sum += d * scale * q6 as f32 * input[input_idx];
46 }
47 }
48 sum
49}
50
51pub fn matmul_q6k_f32_scalar(
52 q6k_data: &[u8],
53 input: &[f32],
54 out_dim: usize,
55 in_dim: usize,
56) -> Vec<f32> {
57 assert_eq!(input.len(), in_dim, "Input length mismatch");
58
59 let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
60 let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
61
62 let mut output: Vec<f32> = Vec::with_capacity(out_dim);
64 unsafe {
66 output.set_len(out_dim);
67 }
68
69 for out_idx in 0..out_dim {
70 let row_start = out_idx * row_bytes;
71 let mut sum = 0.0f32;
72
73 for sb_idx in 0..num_blocks_per_row {
74 let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
75 if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
76 break;
77 }
78 let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
79 let input_offset = sb_idx * SUPER_BLOCK_SIZE;
80 sum += process_q6k_superblock_scalar(sb_data, input, input_offset, in_dim);
81 }
82
83 output[out_idx] = sum;
84 }
85
86 output
87}
88
89#[inline(always)]
91fn extract_q6k_values(ql: &[u8], qh: &[u8], idx_base: usize) -> [i32; 8] {
92 let mut q6_vals = [0i32; 8];
93 for i in 0..8 {
94 let idx = idx_base + i;
95 let ql_byte = ql[idx / 2];
96 let low4 = if idx % 2 == 0 { ql_byte & 0x0F } else { ql_byte >> 4 };
97 let qh_byte = qh[idx / 4];
98 let qh_shift = (idx % 4) * 2;
99 let high2 = (qh_byte >> qh_shift) & 0x03;
100 q6_vals[i] = ((low4 | (high2 << 4)) as i32) - 32;
101 }
102 q6_vals
103}
104
105#[cfg(target_arch = "x86_64")]
107#[target_feature(enable = "avx2")]
108unsafe fn hsum_q6k_avx2(acc: std::arch::x86_64::__m256) -> f32 {
110 use std::arch::x86_64::*;
111 let hi128 = _mm256_extractf128_ps(acc, 1);
112 let lo128 = _mm256_castps256_ps128(acc);
113 let sum128 = _mm_add_ps(lo128, hi128);
114 let hi64 = _mm_movehl_ps(sum128, sum128);
115 let sum64 = _mm_add_ps(sum128, hi64);
116 let hi32 = _mm_shuffle_ps(sum64, sum64, 1);
117 let sum32 = _mm_add_ss(sum64, hi32);
118 _mm_cvtss_f32(sum32)
119}
120
121#[cfg(target_arch = "x86_64")]
123#[target_feature(enable = "avx2", enable = "fma")]
124unsafe fn process_q6k_superblock_avx2(
126 sb_data: &[u8],
127 input: &[f32],
128 input_offset: usize,
129 in_dim: usize,
130 acc: &mut std::arch::x86_64::__m256,
131) {
132 unsafe {
133 use std::arch::x86_64::*;
134
135 let ql = sb_data.get(0..128).expect("Q6_K: need ≥128 bytes for ql");
136 let qh = sb_data.get(128..192).expect("Q6_K: need ≥192 bytes for qh");
137 let scales = sb_data.get(192..208).expect("Q6_K: need ≥208 bytes for scales");
138 let d = f16_to_f32(u16::from_le_bytes([sb_data[208], sb_data[209]]));
139 let d_vec = _mm256_set1_ps(d);
140
141 for group in 0..16 {
142 let scale = (scales[group] as i8) as f32;
143 let ds_vec = _mm256_mul_ps(d_vec, _mm256_set1_ps(scale));
144 let group_offset = group * 16;
145 let input_group = input_offset + group_offset;
146
147 for half in 0..2 {
148 let half_offset = half * 8;
149 let input_base = input_group + half_offset;
150 if input_base + 8 > in_dim {
151 continue;
152 }
153
154 let q6_vals = extract_q6k_values(ql, qh, group_offset + half_offset);
155 let q6_i32 = _mm256_loadu_si256(q6_vals.as_ptr() as *const __m256i);
156 let q6_f32 = _mm256_cvtepi32_ps(q6_i32);
157 let x = _mm256_loadu_ps(input.as_ptr().add(input_base));
158 let dequant = _mm256_mul_ps(ds_vec, q6_f32);
159 *acc = _mm256_fmadd_ps(dequant, x, *acc);
160 }
161 }
162 }
163}
164
165#[cfg(target_arch = "x86_64")]
170#[target_feature(enable = "avx2", enable = "fma")]
171unsafe fn matmul_q6k_f32_avx2(
173 q6k_data: &[u8],
174 input: &[f32],
175 out_dim: usize,
176 in_dim: usize,
177) -> Vec<f32> {
178 unsafe {
179 use std::arch::x86_64::*;
180
181 let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
182 let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
183
184 let mut output: Vec<f32> = Vec::with_capacity(out_dim);
186 output.set_len(out_dim);
188
189 for out_idx in 0..out_dim {
190 let row_start = out_idx * row_bytes;
191 let mut acc = _mm256_setzero_ps();
192
193 for sb_idx in 0..num_blocks_per_row {
194 let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
195 if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
196 break;
197 }
198 let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
199 let input_offset = sb_idx * SUPER_BLOCK_SIZE;
200 process_q6k_superblock_avx2(sb_data, input, input_offset, in_dim, &mut acc);
201 }
202
203 output[out_idx] = hsum_q6k_avx2(acc);
204 }
205
206 output
207 }
208}
209
210#[inline]
223pub fn matmul_q6k_f32_dispatch(
224 q6k_data: &[u8],
225 input: &[f32],
226 out_dim: usize,
227 in_dim: usize,
228) -> Vec<f32> {
229 debug_assert_eq!(input.len(), in_dim, "Q6K dispatch: input length mismatch");
233 debug_assert!(
234 q6k_data.len() >= crate::contracts::Q6_K.expected_bytes(out_dim, in_dim),
235 "Q6K dispatch: buffer too small: {} bytes for [{}, {}] (need {})",
236 q6k_data.len(),
237 out_dim,
238 in_dim,
239 crate::contracts::Q6_K.expected_bytes(out_dim, in_dim),
240 );
241
242 let total_work = out_dim * in_dim;
246 if total_work >= 8_000_000 {
247 return matmul_q6k_f32_parallel(q6k_data, input, out_dim, in_dim);
248 }
249
250 #[cfg(target_arch = "x86_64")]
251 {
252 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
253 return unsafe { matmul_q6k_f32_avx2(q6k_data, input, out_dim, in_dim) };
255 }
256 }
257 matmul_q6k_f32_scalar(q6k_data, input, out_dim, in_dim)
258}
259
260#[cfg(target_arch = "x86_64")]
262fn matmul_q6k_f32_parallel(
263 q6k_data: &[u8],
264 input: &[f32],
265 out_dim: usize,
266 in_dim: usize,
267) -> Vec<f32> {
268 use std::thread;
269
270 let num_threads = thread::available_parallelism().map(|p| p.get()).unwrap_or(4).min(12); let chunk_size = (out_dim + num_threads - 1) / num_threads;
274 let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
275 let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
276
277 let mut output: Vec<f32> = Vec::with_capacity(out_dim);
279 unsafe {
281 output.set_len(out_dim);
282 }
283 let has_avx2 = is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma");
284
285 thread::scope(|s| {
286 let input_ref = input;
287 let q6k_ref = q6k_data;
288 for (chunk_idx, chunk) in output.chunks_mut(chunk_size).enumerate() {
290 let start_row = chunk_idx * chunk_size;
291
292 s.spawn(move || {
293 if has_avx2 {
294 unsafe {
297 compute_chunk_avx2(
298 q6k_ref,
299 input_ref,
300 chunk,
301 start_row,
302 out_dim,
303 in_dim,
304 num_blocks_per_row,
305 row_bytes,
306 );
307 }
308 } else {
309 compute_chunk_scalar(
310 q6k_ref,
311 input_ref,
312 chunk,
313 start_row,
314 out_dim,
315 in_dim,
316 num_blocks_per_row,
317 row_bytes,
318 );
319 }
320 });
321 }
322 });
323
324 output
325}
326
327#[cfg(not(target_arch = "x86_64"))]
329fn matmul_q6k_f32_parallel(
330 q6k_data: &[u8],
331 input: &[f32],
332 out_dim: usize,
333 in_dim: usize,
334) -> Vec<f32> {
335 matmul_q6k_f32_scalar(q6k_data, input, out_dim, in_dim)
336}
337
338#[cfg(target_arch = "x86_64")]
339#[target_feature(enable = "avx2", enable = "fma")]
340unsafe fn compute_chunk_avx2(
342 q6k_data: &[u8],
343 input: &[f32],
344 chunk: &mut [f32],
345 start_row: usize,
346 out_dim: usize,
347 in_dim: usize,
348 num_blocks_per_row: usize,
349 row_bytes: usize,
350) {
351 unsafe {
352 use std::arch::x86_64::*;
353
354 for (local_idx, out_val) in chunk.iter_mut().enumerate() {
355 let out_idx = start_row + local_idx;
356 if out_idx >= out_dim {
357 break;
358 }
359
360 let row_start = out_idx * row_bytes;
361 let mut acc = _mm256_setzero_ps();
362
363 for sb_idx in 0..num_blocks_per_row {
364 let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
365 if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
366 break;
367 }
368 let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
369 let input_offset = sb_idx * SUPER_BLOCK_SIZE;
370 process_q6k_superblock_avx2(sb_data, input, input_offset, in_dim, &mut acc);
371 }
372
373 *out_val = hsum_q6k_avx2(acc);
374 }
375 }
376}
377
378pub(crate) fn compute_chunk_scalar(
379 q6k_data: &[u8],
380 input: &[f32],
381 chunk: &mut [f32],
382 start_row: usize,
383 out_dim: usize,
384 in_dim: usize,
385 num_blocks_per_row: usize,
386 row_bytes: usize,
387) {
388 for (local_idx, out_val) in chunk.iter_mut().enumerate() {
389 let out_idx = start_row + local_idx;
390 if out_idx >= out_dim {
391 break;
392 }
393
394 let row_start = out_idx * row_bytes;
395 let mut sum = 0.0f32;
396
397 for sb_idx in 0..num_blocks_per_row {
398 let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
399 if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
400 break;
401 }
402 let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
403 let input_offset = sb_idx * SUPER_BLOCK_SIZE;
404 sum += process_q6k_superblock_scalar(sb_data, input, input_offset, in_dim);
405 }
406
407 *out_val = sum;
408 }
409}
410
411pub fn matmul_q6k_f32(q6k_data: &[u8], input: &[f32], out_dim: usize, in_dim: usize) -> Vec<f32> {
413 matmul_q6k_f32_dispatch(q6k_data, input, out_dim, in_dim)
414}