1use super::{MR, NR};
13#[cfg(target_arch = "x86_64")]
14use super::{MR_512, NR_512};
15
16pub fn pack_a(
28 a: &[f32],
29 lda: usize, mc: usize, kc: usize, packed: &mut [f32],
33) {
34 let mut pack_idx = 0;
35
36 let full_panels = mc / MR;
38 let remainder = mc % MR;
39
40 for panel in 0..full_panels {
41 let row_start = panel * MR;
42
43 for col in 0..kc {
44 for row in 0..MR {
45 packed[pack_idx] = a[(row_start + row) * lda + col];
46 pack_idx += 1;
47 }
48 }
49 }
50
51 if remainder > 0 {
53 let row_start = full_panels * MR;
54
55 for col in 0..kc {
56 for row in 0..MR {
57 if row < remainder {
58 packed[pack_idx] = a[(row_start + row) * lda + col];
59 } else {
60 packed[pack_idx] = 0.0; }
62 pack_idx += 1;
63 }
64 }
65 }
66}
67
68pub fn pack_b(
75 b: &[f32],
76 ldb: usize, kc: usize, nc: usize, packed: &mut [f32],
80) {
81 let mut pack_idx = 0;
82
83 let full_panels = nc / NR;
84 let remainder = nc % NR;
85
86 for panel in 0..full_panels {
87 let col_start = panel * NR;
88
89 for row in 0..kc {
90 for col in 0..NR {
91 packed[pack_idx] = b[row * ldb + col_start + col];
92 pack_idx += 1;
93 }
94 }
95 }
96
97 if remainder > 0 {
99 let col_start = full_panels * NR;
100
101 for row in 0..kc {
102 for col in 0..NR {
103 if col < remainder {
104 packed[pack_idx] = b[row * ldb + col_start + col];
105 } else {
106 packed[pack_idx] = 0.0;
107 }
108 pack_idx += 1;
109 }
110 }
111 }
112}
113
114#[inline]
116pub fn packed_a_size(mc: usize, kc: usize) -> usize {
117 let panels = (mc + MR - 1) / MR;
118 panels * MR * kc
119}
120
121#[inline]
123pub fn packed_b_size(kc: usize, nc: usize) -> usize {
124 let panels = (nc + NR - 1) / NR;
125 panels * NR * kc
126}
127
128pub(super) fn pack_a_block(
131 a: &[f32],
132 lda: usize,
133 row_start: usize,
134 col_start: usize,
135 rows: usize,
136 cols: usize,
137 packed: &mut [f32],
138) {
139 let panels = (rows + MR - 1) / MR;
140
141 #[cfg(target_arch = "x86_64")]
142 if is_x86_feature_detected!("avx2") {
143 unsafe {
145 pack_a_block_avx2(a, lda, row_start, col_start, rows, cols, panels, packed);
146 }
147 return;
148 }
149
150 let mut pack_idx = 0;
152 for panel in 0..panels {
153 let ir = panel * MR;
154 let mr_actual = MR.min(rows - ir);
155
156 for col in 0..cols {
157 for row in 0..MR {
158 if row < mr_actual {
159 packed[pack_idx] = a[(row_start + ir + row) * lda + col_start + col];
160 } else {
161 packed[pack_idx] = 0.0;
162 }
163 pack_idx += 1;
164 }
165 }
166 }
167}
168
169#[cfg(target_arch = "x86_64")]
175#[target_feature(enable = "avx2")]
176unsafe fn pack_a_block_avx2(
177 a: &[f32],
178 lda: usize,
179 row_start: usize,
180 col_start: usize,
181 rows: usize,
182 cols: usize,
183 panels: usize,
184 packed: &mut [f32],
185) {
186 use std::arch::x86_64::*;
187
188 let mut pack_idx = 0;
189
190 for panel in 0..panels {
191 let ir = panel * MR;
192 let mr_actual = MR.min(rows - ir);
193
194 if mr_actual == MR {
195 let k_blocks = cols / 8;
197 let k_rem_start = k_blocks * 8;
198
199 for kb in 0..k_blocks {
200 let p = kb * 8;
201 let base = row_start + ir;
202 let col = col_start + p;
203
204 unsafe {
207 let r0 = _mm256_loadu_ps(a.as_ptr().add(base * lda + col));
208 let r1 = _mm256_loadu_ps(a.as_ptr().add((base + 1) * lda + col));
209 let r2 = _mm256_loadu_ps(a.as_ptr().add((base + 2) * lda + col));
210 let r3 = _mm256_loadu_ps(a.as_ptr().add((base + 3) * lda + col));
211 let r4 = _mm256_loadu_ps(a.as_ptr().add((base + 4) * lda + col));
212 let r5 = _mm256_loadu_ps(a.as_ptr().add((base + 5) * lda + col));
213 let r6 = _mm256_loadu_ps(a.as_ptr().add((base + 6) * lda + col));
214 let r7 = _mm256_loadu_ps(a.as_ptr().add((base + 7) * lda + col));
215
216 let t0 = _mm256_unpacklo_ps(r0, r1);
218 let t1 = _mm256_unpackhi_ps(r0, r1);
219 let t2 = _mm256_unpacklo_ps(r2, r3);
220 let t3 = _mm256_unpackhi_ps(r2, r3);
221 let t4 = _mm256_unpacklo_ps(r4, r5);
222 let t5 = _mm256_unpackhi_ps(r4, r5);
223 let t6 = _mm256_unpacklo_ps(r6, r7);
224 let t7 = _mm256_unpackhi_ps(r6, r7);
225
226 let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
227 let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
228 let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
229 let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
230 let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
231 let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
232 let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
233 let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
234
235 let dst = packed.as_mut_ptr().add(pack_idx);
236 _mm256_storeu_ps(dst, _mm256_permute2f128_ps(u0, u4, 0x20));
237 _mm256_storeu_ps(dst.add(8), _mm256_permute2f128_ps(u1, u5, 0x20));
238 _mm256_storeu_ps(dst.add(16), _mm256_permute2f128_ps(u2, u6, 0x20));
239 _mm256_storeu_ps(dst.add(24), _mm256_permute2f128_ps(u3, u7, 0x20));
240 _mm256_storeu_ps(dst.add(32), _mm256_permute2f128_ps(u0, u4, 0x31));
241 _mm256_storeu_ps(dst.add(40), _mm256_permute2f128_ps(u1, u5, 0x31));
242 _mm256_storeu_ps(dst.add(48), _mm256_permute2f128_ps(u2, u6, 0x31));
243 _mm256_storeu_ps(dst.add(56), _mm256_permute2f128_ps(u3, u7, 0x31));
244 }
245 pack_idx += 64;
246 }
247
248 for col in k_rem_start..cols {
250 for row in 0..MR {
251 packed[pack_idx] = a[(row_start + ir + row) * lda + col_start + col];
252 pack_idx += 1;
253 }
254 }
255 } else {
256 for col in 0..cols {
258 for row in 0..MR {
259 if row < mr_actual {
260 packed[pack_idx] = a[(row_start + ir + row) * lda + col_start + col];
261 } else {
262 packed[pack_idx] = 0.0;
263 }
264 pack_idx += 1;
265 }
266 }
267 }
268 }
269}
270
271pub(super) fn pack_b_block(
273 b: &[f32],
274 ldb: usize,
275 row_start: usize,
276 col_start: usize,
277 rows: usize,
278 cols: usize,
279 packed: &mut [f32],
280) {
281 let mut pack_idx = 0;
282 let panels = (cols + NR - 1) / NR;
283
284 for panel in 0..panels {
285 let jr = panel * NR;
286 let nr_actual = NR.min(cols - jr);
287
288 for row in 0..rows {
289 for col in 0..NR {
290 if col < nr_actual {
291 packed[pack_idx] = b[(row_start + row) * ldb + col_start + jr + col];
292 } else {
293 packed[pack_idx] = 0.0;
294 }
295 pack_idx += 1;
296 }
297 }
298 }
299}
300
301#[cfg(target_arch = "x86_64")]
307#[inline]
308pub fn packed_a_size_512(mc: usize, kc: usize) -> usize {
309 let panels = (mc + MR_512 - 1) / MR_512;
310 panels * MR_512 * kc
311}
312
313#[cfg(target_arch = "x86_64")]
315#[inline]
316pub fn packed_b_size_512(kc: usize, nc: usize) -> usize {
317 let panels = (nc + NR_512 - 1) / NR_512;
318 panels * NR_512 * kc
319}
320
321#[cfg(target_arch = "x86_64")]
326#[allow(dead_code)] pub(super) fn pack_a_block_512(
328 a: &[f32],
329 lda: usize,
330 row_start: usize,
331 col_start: usize,
332 rows: usize,
333 cols: usize,
334 packed: &mut [f32],
335) {
336 let mut pack_idx = 0;
337 let panels = (rows + MR_512 - 1) / MR_512;
338
339 for panel in 0..panels {
340 let ir = panel * MR_512;
341 let mr_actual = MR_512.min(rows - ir);
342
343 if mr_actual == MR_512 {
344 for col in 0..cols {
346 for row in 0..MR_512 {
347 packed[pack_idx + row] = a[(row_start + ir + row) * lda + col_start + col];
348 }
349 pack_idx += MR_512;
350 }
351 } else {
352 for col in 0..cols {
354 for row in 0..mr_actual {
355 packed[pack_idx + row] = a[(row_start + ir + row) * lda + col_start + col];
356 }
357 for row in mr_actual..MR_512 {
358 packed[pack_idx + row] = 0.0;
359 }
360 pack_idx += MR_512;
361 }
362 }
363 }
364}
365
366#[cfg(target_arch = "x86_64")]
370pub(super) fn pack_b_block_512(
371 b: &[f32],
372 ldb: usize,
373 row_start: usize,
374 col_start: usize,
375 rows: usize,
376 cols: usize,
377 packed: &mut [f32],
378) {
379 let panels = (cols + NR_512 - 1) / NR_512;
380 let use_simd = is_x86_feature_detected!("avx2");
381
382 for panel in 0..panels {
383 let jr = panel * NR_512;
384 let nr_actual = NR_512.min(cols - jr);
385 let dst_base = panel * NR_512 * rows;
386
387 if nr_actual == NR_512 && use_simd {
388 unsafe {
391 use std::arch::x86_64::*;
392 for row in 0..rows {
393 let src = b.as_ptr().add((row_start + row) * ldb + col_start + jr);
394 let dst = packed.as_mut_ptr().add(dst_base + row * NR_512);
395 _mm256_storeu_ps(dst, _mm256_loadu_ps(src));
396 }
397 }
398 } else {
399 let mut pack_idx = dst_base;
401 for row in 0..rows {
402 for col in 0..NR_512 {
403 if col < nr_actual {
404 packed[pack_idx] = b[(row_start + row) * ldb + col_start + jr + col];
405 } else {
406 packed[pack_idx] = 0.0;
407 }
408 pack_idx += 1;
409 }
410 }
411 }
412 }
413}
414
415use super::{MR_512V2, NR_512V2};
420
421#[cfg(target_arch = "x86_64")]
423#[inline]
424#[allow(dead_code)] pub fn packed_a_size_v2(mc: usize, kc: usize) -> usize {
426 let panels = (mc + MR_512V2 - 1) / MR_512V2;
427 panels * MR_512V2 * kc
428}
429
430#[cfg(target_arch = "x86_64")]
432#[inline]
433#[allow(dead_code)] pub fn packed_b_size_v2(kc: usize, nc: usize) -> usize {
435 let panels = (nc + NR_512V2 - 1) / NR_512V2;
436 panels * NR_512V2 * kc
437}
438
439#[cfg(target_arch = "x86_64")]
441#[allow(dead_code)] pub(super) fn pack_a_block_v2(
443 a: &[f32],
444 lda: usize,
445 row_start: usize,
446 col_start: usize,
447 rows: usize,
448 cols: usize,
449 packed: &mut [f32],
450) {
451 let mut pack_idx = 0;
452 let panels = (rows + MR_512V2 - 1) / MR_512V2;
453
454 for panel in 0..panels {
455 let ir = panel * MR_512V2;
456 let mr_actual = MR_512V2.min(rows - ir);
457
458 for col in 0..cols {
459 for row in 0..mr_actual {
460 packed[pack_idx + row] = a[(row_start + ir + row) * lda + col_start + col];
461 }
462 for row in mr_actual..MR_512V2 {
463 packed[pack_idx + row] = 0.0;
464 }
465 pack_idx += MR_512V2;
466 }
467 }
468}
469
470#[cfg(target_arch = "x86_64")]
472#[allow(dead_code)] pub(super) fn pack_b_block_v2(
474 b: &[f32],
475 ldb: usize,
476 row_start: usize,
477 col_start: usize,
478 rows: usize,
479 cols: usize,
480 packed: &mut [f32],
481) {
482 let panels = (cols + NR_512V2 - 1) / NR_512V2;
483
484 for panel in 0..panels {
485 let jr = panel * NR_512V2;
486 let nr_actual = NR_512V2.min(cols - jr);
487 let dst_base = panel * NR_512V2 * rows;
488
489 for row in 0..rows {
490 let pack_idx = dst_base + row * NR_512V2;
491 for col in 0..nr_actual {
492 packed[pack_idx + col] = b[(row_start + row) * ldb + col_start + jr + col];
493 }
494 for col in nr_actual..NR_512V2 {
495 packed[pack_idx + col] = 0.0;
496 }
497 }
498 }
499}