Skip to main content

trueno/backends/q4k/gemv/
avx2.rs

1//! AVX2 SIMD Q4_K GEMV implementations.
2//!
3//! Contains 8-wide AVX2+FMA optimized GEMV, super-block processor,
4//! horizontal sum helper, and the AVX2 chunk processor for parallel dispatch.
5
6use super::super::{parse_q4k_header, SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
7
8/// Fused Q4_K matrix-vector multiply with AVX2 SIMD (8-wide)
9///
10/// Processes 8 elements at a time using AVX2 intrinsics.
11/// Delegates per-super-block work to `process_q4k_superblock_avx2`.
12#[cfg(target_arch = "x86_64")]
13#[target_feature(enable = "avx2", enable = "fma")]
14// SAFETY: Caller ensures AVX2+FMA are available and q4k_data is valid Q4_K layout
15pub(crate) unsafe fn matmul_q4k_f32_avx2(
16    q4k_data: &[u8],
17    input: &[f32],
18    out_dim: usize,
19    in_dim: usize,
20) -> Vec<f32> {
21    unsafe {
22        use std::arch::x86_64::*;
23
24        let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
25        let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
26        let low_mask = _mm256_set1_epi32(0x0F);
27
28        // Uninit: output[out_idx] = hsum_avx2(acc) (SET) for every out_idx.
29        let mut output: Vec<f32> = Vec::with_capacity(out_dim);
30        // SAFETY: Each output[out_idx] is SET from local SIMD accumulator.
31        output.set_len(out_dim);
32
33        for out_idx in 0..out_dim {
34            let row_start = out_idx * row_bytes;
35            let mut acc = _mm256_setzero_ps();
36
37            for sb_idx in 0..num_blocks_per_row {
38                let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
39                let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
40                let input_offset = sb_idx * SUPER_BLOCK_SIZE;
41                process_q4k_superblock_avx2(
42                    sb_data,
43                    input,
44                    input_offset,
45                    in_dim,
46                    low_mask,
47                    &mut acc,
48                );
49            }
50
51            output[out_idx] = hsum_avx2(acc);
52        }
53
54        output
55    }
56}
57
58/// Process one Q4K super-block row with AVX2 and accumulate into `acc`.
59#[cfg(target_arch = "x86_64")]
60#[target_feature(enable = "avx2", enable = "fma")]
61// SAFETY: Caller ensures AVX2+FMA are available, sb_data is a valid super-block
62pub(crate) unsafe fn process_q4k_superblock_avx2(
63    sb_data: &[u8],
64    input: &[f32],
65    input_offset: usize,
66    in_dim: usize,
67    low_mask: std::arch::x86_64::__m256i,
68    acc: &mut std::arch::x86_64::__m256,
69) {
70    unsafe {
71        use std::arch::x86_64::*;
72
73        let (d, dmin, scales, mins) = parse_q4k_header(sb_data);
74        let qs = sb_data.get(16..144).expect("Q4_K: need ≥144 bytes for qs");
75
76        for chunk_i in 0..4 {
77            let chunk_start = chunk_i * 64;
78            let q_start = chunk_i * 32;
79
80            let d1 = d * f32::from(scales[chunk_i * 2]);
81            let dm1 = dmin * f32::from(mins[chunk_i * 2]);
82            let d2 = d * f32::from(scales[chunk_i * 2 + 1]);
83            let dm2 = dmin * f32::from(mins[chunk_i * 2 + 1]);
84
85            let d1_vec = _mm256_set1_ps(d1);
86            let dm1_vec = _mm256_set1_ps(dm1);
87            let d2_vec = _mm256_set1_ps(d2);
88            let dm2_vec = _mm256_set1_ps(dm2);
89
90            // Process low nibbles (32 values) in groups of 8
91            let mut i = 0;
92            while i + 8 <= 32 {
93                let input_base = input_offset + chunk_start + i;
94                if input_base + 8 <= in_dim {
95                    let q_bytes = _mm_loadl_epi64(qs.as_ptr().add(q_start + i) as *const __m128i);
96                    let q_i32 = _mm256_cvtepu8_epi32(q_bytes);
97                    let q_low = _mm256_and_si256(q_i32, low_mask);
98                    let q_f32 = _mm256_cvtepi32_ps(q_low);
99                    let x = _mm256_loadu_ps(input.as_ptr().add(input_base));
100                    let dequant = _mm256_fmsub_ps(d1_vec, q_f32, dm1_vec);
101                    *acc = _mm256_fmadd_ps(dequant, x, *acc);
102                }
103                i += 8;
104            }
105
106            // Process high nibbles (32 values) in groups of 8
107            let mut i = 0;
108            while i + 8 <= 32 {
109                let input_base = input_offset + chunk_start + 32 + i;
110                if input_base + 8 <= in_dim {
111                    let q_bytes = _mm_loadl_epi64(qs.as_ptr().add(q_start + i) as *const __m128i);
112                    let q_i32 = _mm256_cvtepu8_epi32(q_bytes);
113                    let q_high = _mm256_srli_epi32(q_i32, 4);
114                    let q_f32 = _mm256_cvtepi32_ps(q_high);
115                    let x = _mm256_loadu_ps(input.as_ptr().add(input_base));
116                    let dequant = _mm256_fmsub_ps(d2_vec, q_f32, dm2_vec);
117                    *acc = _mm256_fmadd_ps(dequant, x, *acc);
118                }
119                i += 8;
120            }
121        }
122    }
123}
124
125/// AVX2 horizontal sum of 8 f32 lanes to a single f32.
126#[cfg(target_arch = "x86_64")]
127#[target_feature(enable = "avx2")]
128// SAFETY: caller verifies AVX2 support, input slices meet alignment/length requirements
129pub(crate) unsafe fn hsum_avx2(acc: std::arch::x86_64::__m256) -> f32 {
130    use std::arch::x86_64::*;
131    let hi128 = _mm256_extractf128_ps(acc, 1);
132    let lo128 = _mm256_castps256_ps128(acc);
133    let sum128 = _mm_add_ps(lo128, hi128);
134    let hi64 = _mm_movehl_ps(sum128, sum128);
135    let sum64 = _mm_add_ps(sum128, hi64);
136    let hi32 = _mm_shuffle_ps(sum64, sum64, 1);
137    let sum32 = _mm_add_ss(sum64, hi32);
138    _mm_cvtss_f32(sum32)
139}
140
141#[cfg(target_arch = "x86_64")]
142#[target_feature(enable = "avx2", enable = "fma")]
143// SAFETY: Caller ensures AVX2+FMA are available and chunk bounds are valid
144pub(crate) unsafe fn compute_chunk_q4k_avx2(
145    q4k_data: &[u8],
146    input: &[f32],
147    chunk: &mut [f32],
148    start_row: usize,
149    out_dim: usize,
150    in_dim: usize,
151    num_blocks_per_row: usize,
152    row_bytes: usize,
153) {
154    unsafe {
155        use std::arch::x86_64::*;
156
157        let low_mask = _mm256_set1_epi32(0x0F);
158
159        for (local_idx, out_val) in chunk.iter_mut().enumerate() {
160            let out_idx = start_row + local_idx;
161            if out_idx >= out_dim {
162                break;
163            }
164
165            let row_start = out_idx * row_bytes;
166            let mut acc = _mm256_setzero_ps();
167
168            for sb_idx in 0..num_blocks_per_row {
169                let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
170                if sb_start + SUPER_BLOCK_BYTES > q4k_data.len() {
171                    break;
172                }
173                let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
174                let input_offset = sb_idx * SUPER_BLOCK_SIZE;
175                process_q4k_superblock_avx2(
176                    sb_data,
177                    input,
178                    input_offset,
179                    in_dim,
180                    low_mask,
181                    &mut acc,
182                );
183            }
184
185            *out_val = hsum_avx2(acc);
186        }
187    }
188}