trueno/backends/q4k/gemv/
avx2.rs1use super::super::{parse_q4k_header, SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
7
8#[cfg(target_arch = "x86_64")]
13#[target_feature(enable = "avx2", enable = "fma")]
14pub(crate) unsafe fn matmul_q4k_f32_avx2(
16 q4k_data: &[u8],
17 input: &[f32],
18 out_dim: usize,
19 in_dim: usize,
20) -> Vec<f32> {
21 unsafe {
22 use std::arch::x86_64::*;
23
24 let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
25 let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
26 let low_mask = _mm256_set1_epi32(0x0F);
27
28 let mut output: Vec<f32> = Vec::with_capacity(out_dim);
30 output.set_len(out_dim);
32
33 for out_idx in 0..out_dim {
34 let row_start = out_idx * row_bytes;
35 let mut acc = _mm256_setzero_ps();
36
37 for sb_idx in 0..num_blocks_per_row {
38 let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
39 let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
40 let input_offset = sb_idx * SUPER_BLOCK_SIZE;
41 process_q4k_superblock_avx2(
42 sb_data,
43 input,
44 input_offset,
45 in_dim,
46 low_mask,
47 &mut acc,
48 );
49 }
50
51 output[out_idx] = hsum_avx2(acc);
52 }
53
54 output
55 }
56}
57
58#[cfg(target_arch = "x86_64")]
60#[target_feature(enable = "avx2", enable = "fma")]
61pub(crate) unsafe fn process_q4k_superblock_avx2(
63 sb_data: &[u8],
64 input: &[f32],
65 input_offset: usize,
66 in_dim: usize,
67 low_mask: std::arch::x86_64::__m256i,
68 acc: &mut std::arch::x86_64::__m256,
69) {
70 unsafe {
71 use std::arch::x86_64::*;
72
73 let (d, dmin, scales, mins) = parse_q4k_header(sb_data);
74 let qs = sb_data.get(16..144).expect("Q4_K: need ≥144 bytes for qs");
75
76 for chunk_i in 0..4 {
77 let chunk_start = chunk_i * 64;
78 let q_start = chunk_i * 32;
79
80 let d1 = d * f32::from(scales[chunk_i * 2]);
81 let dm1 = dmin * f32::from(mins[chunk_i * 2]);
82 let d2 = d * f32::from(scales[chunk_i * 2 + 1]);
83 let dm2 = dmin * f32::from(mins[chunk_i * 2 + 1]);
84
85 let d1_vec = _mm256_set1_ps(d1);
86 let dm1_vec = _mm256_set1_ps(dm1);
87 let d2_vec = _mm256_set1_ps(d2);
88 let dm2_vec = _mm256_set1_ps(dm2);
89
90 let mut i = 0;
92 while i + 8 <= 32 {
93 let input_base = input_offset + chunk_start + i;
94 if input_base + 8 <= in_dim {
95 let q_bytes = _mm_loadl_epi64(qs.as_ptr().add(q_start + i) as *const __m128i);
96 let q_i32 = _mm256_cvtepu8_epi32(q_bytes);
97 let q_low = _mm256_and_si256(q_i32, low_mask);
98 let q_f32 = _mm256_cvtepi32_ps(q_low);
99 let x = _mm256_loadu_ps(input.as_ptr().add(input_base));
100 let dequant = _mm256_fmsub_ps(d1_vec, q_f32, dm1_vec);
101 *acc = _mm256_fmadd_ps(dequant, x, *acc);
102 }
103 i += 8;
104 }
105
106 let mut i = 0;
108 while i + 8 <= 32 {
109 let input_base = input_offset + chunk_start + 32 + i;
110 if input_base + 8 <= in_dim {
111 let q_bytes = _mm_loadl_epi64(qs.as_ptr().add(q_start + i) as *const __m128i);
112 let q_i32 = _mm256_cvtepu8_epi32(q_bytes);
113 let q_high = _mm256_srli_epi32(q_i32, 4);
114 let q_f32 = _mm256_cvtepi32_ps(q_high);
115 let x = _mm256_loadu_ps(input.as_ptr().add(input_base));
116 let dequant = _mm256_fmsub_ps(d2_vec, q_f32, dm2_vec);
117 *acc = _mm256_fmadd_ps(dequant, x, *acc);
118 }
119 i += 8;
120 }
121 }
122 }
123}
124
125#[cfg(target_arch = "x86_64")]
127#[target_feature(enable = "avx2")]
128pub(crate) unsafe fn hsum_avx2(acc: std::arch::x86_64::__m256) -> f32 {
130 use std::arch::x86_64::*;
131 let hi128 = _mm256_extractf128_ps(acc, 1);
132 let lo128 = _mm256_castps256_ps128(acc);
133 let sum128 = _mm_add_ps(lo128, hi128);
134 let hi64 = _mm_movehl_ps(sum128, sum128);
135 let sum64 = _mm_add_ps(sum128, hi64);
136 let hi32 = _mm_shuffle_ps(sum64, sum64, 1);
137 let sum32 = _mm_add_ss(sum64, hi32);
138 _mm_cvtss_f32(sum32)
139}
140
141#[cfg(target_arch = "x86_64")]
142#[target_feature(enable = "avx2", enable = "fma")]
143pub(crate) unsafe fn compute_chunk_q4k_avx2(
145 q4k_data: &[u8],
146 input: &[f32],
147 chunk: &mut [f32],
148 start_row: usize,
149 out_dim: usize,
150 in_dim: usize,
151 num_blocks_per_row: usize,
152 row_bytes: usize,
153) {
154 unsafe {
155 use std::arch::x86_64::*;
156
157 let low_mask = _mm256_set1_epi32(0x0F);
158
159 for (local_idx, out_val) in chunk.iter_mut().enumerate() {
160 let out_idx = start_row + local_idx;
161 if out_idx >= out_dim {
162 break;
163 }
164
165 let row_start = out_idx * row_bytes;
166 let mut acc = _mm256_setzero_ps();
167
168 for sb_idx in 0..num_blocks_per_row {
169 let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
170 if sb_start + SUPER_BLOCK_BYTES > q4k_data.len() {
171 break;
172 }
173 let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
174 let input_offset = sb_idx * SUPER_BLOCK_SIZE;
175 process_q4k_superblock_avx2(
176 sb_data,
177 input,
178 input_offset,
179 in_dim,
180 low_mask,
181 &mut acc,
182 );
183 }
184
185 *out_val = hsum_avx2(acc);
186 }
187 }
188}