trueno/backends/q4k/gemv/
avx512.rs1use super::super::{parse_q4k_header, SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
8
9#[cfg(target_arch = "x86_64")]
13#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "fma")]
14pub(crate) unsafe fn matmul_q4k_f32_avx512(
15 q4k_data: &[u8],
16 input: &[f32],
17 out_dim: usize,
18 in_dim: usize,
19) -> Vec<f32> {
20 unsafe {
21 use std::arch::x86_64::*;
22
23 let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
24 let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
25 let low_mask = _mm512_set1_epi32(0x0F);
26
27 let mut output: Vec<f32> = Vec::with_capacity(out_dim);
29 output.set_len(out_dim);
31
32 for out_idx in 0..out_dim {
33 let row_start = out_idx * row_bytes;
34 let mut acc = _mm512_setzero_ps();
35
36 for sb_idx in 0..num_blocks_per_row {
37 let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
38 let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
39 let input_offset = sb_idx * SUPER_BLOCK_SIZE;
40 process_q4k_superblock_avx512(
41 sb_data,
42 input,
43 input_offset,
44 in_dim,
45 low_mask,
46 &mut acc,
47 );
48 }
49
50 output[out_idx] = hsum_avx512(acc);
51 }
52
53 contract_post_dequant!(output);
54 output
55 }
56}
57
58#[cfg(target_arch = "x86_64")]
74#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "fma")]
75unsafe fn process_q4k_superblock_avx512(
76 sb_data: &[u8],
77 input: &[f32],
78 input_offset: usize,
79 in_dim: usize,
80 low_mask: std::arch::x86_64::__m512i,
81 acc: &mut std::arch::x86_64::__m512,
82) {
83 unsafe {
84 use std::arch::x86_64::*;
85
86 let (d, dmin, scales, mins) = parse_q4k_header(sb_data);
87 let qs = sb_data.get(16..144).expect("Q4_K: need ≥144 bytes for qs");
88 let qs_ptr = qs.as_ptr();
89 let input_ptr = input.as_ptr();
90
91 _mm_prefetch(sb_data.as_ptr().add(SUPER_BLOCK_BYTES) as *const i8, _MM_HINT_T0);
94 _mm_prefetch(sb_data.as_ptr().add(SUPER_BLOCK_BYTES + 64) as *const i8, _MM_HINT_T0);
95
96 for chunk_i in 0..4 {
97 let chunk_start = chunk_i * 64;
98 let q_start = chunk_i * 32;
99
100 let d1 = d * f32::from(scales[chunk_i * 2]);
101 let dm1 = dmin * f32::from(mins[chunk_i * 2]);
102 let d2 = d * f32::from(scales[chunk_i * 2 + 1]);
103 let dm2 = dmin * f32::from(mins[chunk_i * 2 + 1]);
104
105 let d1_vec = _mm512_set1_ps(d1);
106 let dm1_vec = _mm512_set1_ps(dm1);
107 let d2_vec = _mm512_set1_ps(d2);
108 let dm2_vec = _mm512_set1_ps(dm2);
109
110 let input_base_lo0 = input_offset + chunk_start;
112 if input_base_lo0 + 32 <= in_dim {
113 let q0 = _mm_loadu_si128(qs_ptr.add(q_start) as *const __m128i);
115 let q0_i32 = _mm512_cvtepu8_epi32(q0);
116 let q0_low = _mm512_and_si512(q0_i32, low_mask);
117 let q0_f32 = _mm512_cvtepi32_ps(q0_low);
118 let x0 = _mm512_loadu_ps(input_ptr.add(input_base_lo0));
119 let dq0 = _mm512_fmsub_ps(d1_vec, q0_f32, dm1_vec);
120 *acc = _mm512_fmadd_ps(dq0, x0, *acc);
121
122 let q1 = _mm_loadu_si128(qs_ptr.add(q_start + 16) as *const __m128i);
124 let q1_i32 = _mm512_cvtepu8_epi32(q1);
125 let q1_low = _mm512_and_si512(q1_i32, low_mask);
126 let q1_f32 = _mm512_cvtepi32_ps(q1_low);
127 let x1 = _mm512_loadu_ps(input_ptr.add(input_base_lo0 + 16));
128 let dq1 = _mm512_fmsub_ps(d1_vec, q1_f32, dm1_vec);
129 *acc = _mm512_fmadd_ps(dq1, x1, *acc);
130
131 let input_base_hi0 = input_offset + chunk_start + 32;
133
134 let q0_high = _mm512_srli_epi32(q0_i32, 4);
136 let q0h_f32 = _mm512_cvtepi32_ps(q0_high);
137 let xh0 = _mm512_loadu_ps(input_ptr.add(input_base_hi0));
138 let dqh0 = _mm512_fmsub_ps(d2_vec, q0h_f32, dm2_vec);
139 *acc = _mm512_fmadd_ps(dqh0, xh0, *acc);
140
141 let q1_high = _mm512_srli_epi32(q1_i32, 4);
143 let q1h_f32 = _mm512_cvtepi32_ps(q1_high);
144 let xh1 = _mm512_loadu_ps(input_ptr.add(input_base_hi0 + 16));
145 let dqh1 = _mm512_fmsub_ps(d2_vec, q1h_f32, dm2_vec);
146 *acc = _mm512_fmadd_ps(dqh1, xh1, *acc);
147 }
148 }
149 }
150}
151
152#[cfg(target_arch = "x86_64")]
155#[target_feature(enable = "avx512f")]
156unsafe fn hsum_avx512(v: std::arch::x86_64::__m512) -> f32 {
157 use std::arch::x86_64::*;
158 let lo256 = _mm512_castps512_ps256(v);
160 let hi_shifted = _mm512_shuffle_f32x4(v, v, 0b_01_00_11_10); let hi256 = _mm512_castps512_ps256(hi_shifted);
163 let sum256 = _mm256_add_ps(lo256, hi256);
164 let hi128 = _mm256_extractf128_ps(sum256, 1);
166 let lo128 = _mm256_castps256_ps128(sum256);
167 let sum128 = _mm_add_ps(lo128, hi128);
168 let hi64 = _mm_movehl_ps(sum128, sum128);
169 let sum64 = _mm_add_ps(sum128, hi64);
170 let hi32 = _mm_shuffle_ps(sum64, sum64, 1);
171 let sum32 = _mm_add_ss(sum64, hi32);
172 _mm_cvtss_f32(sum32)
173}
174
175#[cfg(target_arch = "x86_64")]
178#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "fma")]
179pub(crate) unsafe fn compute_chunk_q4k_avx512(
180 q4k_data: &[u8],
181 input: &[f32],
182 chunk: &mut [f32],
183 start_row: usize,
184 out_dim: usize,
185 in_dim: usize,
186 num_blocks_per_row: usize,
187 row_bytes: usize,
188) {
189 unsafe {
190 use std::arch::x86_64::*;
191
192 let low_mask = _mm512_set1_epi32(0x0F);
193
194 for (local_idx, out_val) in chunk.iter_mut().enumerate() {
195 let out_idx = start_row + local_idx;
196 if out_idx >= out_dim {
197 break;
198 }
199 let row_start = out_idx * row_bytes;
200 let mut acc = _mm512_setzero_ps();
201
202 for sb_idx in 0..num_blocks_per_row {
203 let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
204 let sb_data = &q4k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
205 let input_offset = sb_idx * SUPER_BLOCK_SIZE;
206 process_q4k_superblock_avx512(
207 sb_data,
208 input,
209 input_offset,
210 in_dim,
211 low_mask,
212 &mut acc,
213 );
214 }
215
216 *out_val = hsum_avx512(acc);
217 }
218 }
219 contract_post_dequant!(chunk);
220}