1pub type MicroKernelFn<T> = unsafe fn(
2 usize,
3 usize,
4 usize,
5 *mut T,
6 *const T,
7 *const T,
8 isize,
9 isize,
10 isize,
11 isize,
12 isize,
13 T,
14 T,
15 u8,
16 bool,
17 bool,
18 bool,
19 *const T,
20);
21
22pub type HMicroKernelFn<T> = unsafe fn(
23 k: usize,
24 dst: *mut T,
25 lhs: *const T,
26 rhs: *const T,
27 dst_cs: isize,
28 dst_rs: isize,
29 lhs_rs: isize,
30 rhs_cs: isize,
31 alpha: T,
32 beta: T,
33 alpha_status: u8,
34 _conj_dst: bool,
35 _conj_lhs: bool,
36 _conj_rhs: bool,
37);
38
39#[macro_export]
60macro_rules! __one {
61 (
62 $tt: tt
63 ) => {
64 1
65 };
66}
67
68#[macro_export]
69macro_rules! __microkernel_fn_array_helper {
70 (
71 [ $($tt: tt,)* ]
72 ) => {
73 {
74 let mut count = 0_usize;
75 $(count += $crate::__one!($tt);)*
76 count
77 }
78 }
79}
80
81#[macro_export]
82macro_rules! __microkernel_fn_array_helper_nr {
83 ($([
84 $($ukr: ident,)*
85 ],)*) => {
86 {
87 let counts = [$({
88 let mut count = 0_usize;
89 $(count += $crate::__one!($ukr);)*
90 count
91 },)*];
92
93 counts[0]
94 }
95 }
96}
97
98#[macro_export]
99macro_rules! microkernel_fn_array {
100 ($([
101 $($ukr: ident,)*
102 ],)*) => {
103 pub const MR_DIV_N: usize =
104 $crate::__microkernel_fn_array_helper!([$([$($ukr,)*],)*]);
105 pub const NR: usize =
106 $crate::__microkernel_fn_array_helper_nr!($([$($ukr,)*],)*);
107
108 pub const UKR: [[$crate::microkernel::MicroKernelFn<T>; NR]; MR_DIV_N] =
109 [ $([$($ukr,)*]),* ];
110 };
111}
112
113#[macro_export]
114macro_rules! hmicrokernel_fn_array {
115 ($([
116 $($ukr: ident,)*
117 ],)*) => {
118 pub const H_M: usize =
119 $crate::__microkernel_fn_array_helper!([$([$($ukr,)*],)*]);
120 pub const H_N: usize =
121 $crate::__microkernel_fn_array_helper_nr!($([$($ukr,)*],)*);
122
123 pub const H_UKR: [[$crate::microkernel::HMicroKernelFn<T>; H_N]; H_M] =
124 [ $([$($ukr,)*]),* ];
125 };
126}
127
128#[macro_export]
129macro_rules! microkernel_cplx_fn_array {
130 ($([
131 $($ukr: ident,)*
132 ],)*) => {
133 pub const CPLX_MR_DIV_N: usize =
134 $crate::__microkernel_fn_array_helper!([$([$($ukr,)*],)*]);
135 pub const CPLX_NR: usize =
136 $crate::__microkernel_fn_array_helper_nr!($([$($ukr,)*],)*);
137
138 pub const CPLX_UKR: [[$crate::microkernel::MicroKernelFn<num_complex::Complex<T>>; CPLX_NR]; CPLX_MR_DIV_N] =
139 [ $([$($ukr,)*]),* ];
140 };
141}
142
143#[macro_export]
144macro_rules! hmicrokernel_cplx_fn_array {
145 ($([
146 $($ukr: ident,)*
147 ],)*) => {
148 pub const H_CPLX_M: usize =
149 $crate::__microkernel_fn_array_helper!([$([$($ukr,)*],)*]);
150 pub const H_CPLX_N: usize =
151 $crate::__microkernel_fn_array_helper_nr!($([$($ukr,)*],)*);
152
153 pub const H_CPLX_UKR: [[$crate::microkernel::HMicroKernelFn<num_complex::Complex<T>>; H_CPLX_N]; H_CPLX_M] =
154 [ $([$($ukr,)*]),* ];
155 };
156}
157
158#[macro_export]
159macro_rules! amx {
160 ($op: tt, $gpr: expr) => {
161 ::core::arch::asm!(::core::concat!(".word (0x201000 + (", ::core::stringify!($op) ," << 5))") , in("x0")$gpr)
162 };
163}
164
165#[macro_export]
167macro_rules! microkernel_amx {
168 ($ty: tt, $([$target: tt])?, $unroll: tt, $name: ident, $mr_div_n: tt, $nr: tt , $nr_div_n: tt, $n: tt) => {
169 $(#[target_feature(enable = $target)])?
170 pub unsafe fn $name(
172 m: usize,
173 n: usize,
174 k: usize,
175 dst: *mut T,
176 mut packed_lhs: *const T,
177 mut packed_rhs: *const T,
178 dst_cs: isize,
179 dst_rs: isize,
180 lhs_cs: isize,
181 rhs_rs: isize,
182 rhs_cs: isize,
183 alpha: T,
184 beta: T,
185 alpha_status: u8,
186 _conj_dst: bool,
187 _conj_lhs: bool,
188 _conj_rhs: bool,
189 mut next_lhs: *const T,
190 ) {
191 assert_eq!(rhs_cs, 1);
192
193 macro_rules! amx_nop {
194 ($imm5: tt) => {
195 ::core::arch::asm!(
196 ::core::concat!(
197 "nop\nnop\nnop\n.word (0x201000 + (17 << 5) + ",
198 ::core::stringify!($imm5),
199 ")"
200 ),
201 options(nostack)
202 );
203 };
204 }
205 macro_rules! amx_set { () => { amx_nop!(0) }; }
206 macro_rules! amx_clr { () => { amx_nop!(1) }; }
207
208 macro_rules! __ldx { ($gpr: expr) => { $crate::amx!(0, $gpr) } }
209 macro_rules! __ldy { ($gpr: expr) => { $crate::amx!(1, $gpr) } }
210 macro_rules! __stz { ($gpr: expr) => { $crate::amx!(5, $gpr) } }
211 macro_rules! __extrx { ($gpr: expr) => { $crate::amx!(8, $gpr) } }
212 macro_rules! __fma {
213 (f64, $gpr: expr) => { $crate::amx!(10, $gpr) };
214 (f32, $gpr: expr) => { $crate::amx!(12, $gpr) };
215 (f16, $gpr: expr) => { $crate::amx!(15, $gpr) };
216 }
217
218 macro_rules! ldx { ($addr: expr, $idx: expr) => { __ldx!( ($idx << 56) | (($addr as usize) & ((1 << 56) - 1)) ) } }
219 macro_rules! ldy { ($addr: expr, $idx: expr) => { __ldy!( ($idx << 56) | (($addr as usize) & ((1 << 56) - 1)) ) } }
220 macro_rules! stz { ($addr: expr, $idx: expr) => { __stz!( ($idx << 56) | (($addr as usize) & ((1 << 56) - 1)) ) } }
221 macro_rules! extrx { ($z: expr, $x: expr) => { __extrx!( ($x << 16) | ($z << 20) ) } }
222 macro_rules! fma {
223 ($x: expr, $y: expr, $z: expr) => { __fma!($ty, ($y << 6) | ($x << 16) | ($z << 20)) };
224 }
225 macro_rules! mul_select {
226 ($col: expr, $x: expr, $y: expr, $z: expr) => {
227 __fma!(
228 $ty,
229 (1usize << 27)
230 | (1usize << 37)
231 | ($col << 32)
232 | ($y << 6)
233 | ($x << 16)
234 | ($z << 20)
235 )
236 };
237 }
238 macro_rules! fma_select {
239 ($col: expr, $x: expr, $y: expr, $z: expr) => {
240 __fma!(
241 $ty,
242 (1usize << 37)
243 | ($col << 32)
244 | ($y << 6)
245 | ($x << 16)
246 | ($z << 20)
247 )
248 };
249 }
250
251 amx_set!();
252
253 ldx!(packed_lhs, 0);
254
255 let k_unroll = k / $unroll;
256 let k_leftover = k % $unroll ;
257
258 let mut depth = k_unroll;
259 if depth != 0 {
260 loop {
261 seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
262 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
263 ldx!(packed_lhs.offset(lhs_cs * UNROLL_ITER + M_ITER * $n), UNROLL_ITER + $unroll * M_ITER);
264 }});
265 seq_macro::seq!(N_ITER in 0..$nr_div_n {{
266 ldy!(packed_rhs.offset(rhs_rs * UNROLL_ITER + N_ITER * $n), UNROLL_ITER + $unroll * N_ITER);
267 }});
268 }});
269 seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
270 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
271 seq_macro::seq!(N_ITER in 0..$nr_div_n {{
272 fma!(UNROLL_ITER + $unroll * M_ITER, UNROLL_ITER + $unroll * N_ITER, M_ITER + $mr_div_n * N_ITER);
273 }});
274 }});
275 }});
276
277 packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
278 packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
279 next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
280
281 depth -= 1;
282 if depth == 0 {
283 break;
284 }
285 }
286 }
287 depth = k_leftover;
288 if depth != 0 {
289 loop {
290 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
291 ldx!(packed_lhs.offset(lhs_cs * 0 + M_ITER * $n), 0 + $unroll * M_ITER);
292 }});
293 seq_macro::seq!(N_ITER in 0..$nr_div_n {{
294 ldy!(packed_rhs.offset(rhs_rs * 0 + N_ITER * $n), 0 + $unroll * N_ITER);
295 }});
296 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
297 seq_macro::seq!(N_ITER in 0..$nr_div_n {{
298 fma!($unroll * M_ITER, $unroll * N_ITER, M_ITER + $mr_div_n * N_ITER);
299 }});
300 }});
301
302 packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
303 packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
304 next_lhs = next_lhs.wrapping_offset(lhs_cs);
305
306 depth -= 1;
307 if depth == 0 {
308 break;
309 }
310 }
311 }
312
313 let alpha_c = [alpha; $n];
314 let beta_c = [beta; $n];
315 ldy!(alpha_c.as_ptr(), 0);
316 ldy!(beta_c.as_ptr(), 1);
317
318 let stride = 64 / N;
319
320 for i in 0..N {
321 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
322 seq_macro::seq!(N_ITER in 0..$nr_div_n {{
323 }});
325 }});
326 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
327 seq_macro::seq!(N_ITER in 0..$nr_div_n {{
328 extrx!((i * stride + M_ITER + $mr_div_n * N_ITER), M_ITER + $mr_div_n * N_ITER);
329 mul_select!(i, M_ITER + $mr_div_n * N_ITER, 1, i * stride + M_ITER + $mr_div_n * N_ITER);
330 }});
331 }});
332 }
333
334 let mut tmp_dst = [::core::mem::MaybeUninit::<T>::uninit(); { $mr_div_n * $n * $nr }];
335 let mut tmp_dst = tmp_dst.as_mut_ptr() as *mut T;
336 let mut tmp_dst_cs = $mr_div_n * $n;
337
338 if dst_rs == 1 && m == $mr_div_n * N && n == $nr {
339 tmp_dst = dst;
340 tmp_dst_cs = dst_cs;
341 } else {
342 for j in 0..n {
343 for i in 0..m {
344 *tmp_dst.offset(i as isize + j as isize * tmp_dst_cs) =
345 *dst .offset(i as isize * dst_rs + j as isize * dst_cs);
346 }
347 for i in m..$mr_div_n * N {
348 *tmp_dst.offset(i as isize + j as isize * tmp_dst_cs) = ::core::mem::zeroed();
349 }
350 }
351 }
352
353 if alpha_status == 0 {
354 for i in 0..N {
355 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
356 seq_macro::seq!(N_ITER in 0..$nr_div_n {{
357 stz!(
358 tmp_dst.offset((i as isize + N_ITER * $n) * tmp_dst_cs + M_ITER * $n),
359 (i * stride + M_ITER + $mr_div_n * N_ITER)
360 );
361 }});
362 }});
363 }
364 } else {
365 for i in 0..N {
366 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
367 seq_macro::seq!(N_ITER in 0..$nr_div_n {{
368 ldx!(
369 tmp_dst.offset((i as isize + N_ITER * $n) * tmp_dst_cs + M_ITER * $n),
370 M_ITER + $mr_div_n * N_ITER
371 );
372 fma_select!(i, M_ITER + $mr_div_n * N_ITER, 0, i * stride + M_ITER + $mr_div_n * N_ITER);
374 stz!(
375 tmp_dst.offset((i as isize + N_ITER * $n) * tmp_dst_cs + M_ITER * $n),
376 i * stride + M_ITER + $mr_div_n * N_ITER
377 );
378 }});
379 }});
380 }
381 }
382
383 if !(dst_rs == 1 && m == $mr_div_n * N && n == $nr) {
384 for j in 0..n {
385 for i in 0..m {
386 *dst .offset(i as isize * dst_rs + j as isize * dst_cs) =
387 *tmp_dst.offset(i as isize + j as isize * tmp_dst_cs);
388 }
389 }
390 }
391
392 amx_clr!();
393 }
394 };
395}
396
397#[macro_export]
398macro_rules! microkernel {
399 ($([$target: tt])?, $unroll: tt, $name: ident, $mr_div_n: tt, $nr: tt $(, $nr_div_n: tt, $n: tt)?) => {
400 $(#[target_feature(enable = $target)])?
401 pub unsafe fn $name(
403 m: usize,
404 n: usize,
405 k: usize,
406 dst: *mut T,
407 mut packed_lhs: *const T,
408 mut packed_rhs: *const T,
409 dst_cs: isize,
410 dst_rs: isize,
411 lhs_cs: isize,
412 rhs_rs: isize,
413 rhs_cs: isize,
414 alpha: T,
415 beta: T,
416 alpha_status: u8,
417 _conj_dst: bool,
418 _conj_lhs: bool,
419 _conj_rhs: bool,
420 mut next_lhs: *const T,
421 ) {
422 let mut accum_storage = [[splat(::core::mem::zeroed()); $mr_div_n]; $nr];
423 let accum = accum_storage.as_mut_ptr() as *mut Pack;
424
425 let mut lhs = [::core::mem::MaybeUninit::<Pack>::uninit(); $mr_div_n];
426 let mut rhs = ::core::mem::MaybeUninit::<Pack>::uninit();
427
428 #[derive(Copy, Clone)]
429 struct KernelIter {
430 packed_lhs: *const T,
431 packed_rhs: *const T,
432 next_lhs: *const T,
433 lhs_cs: isize,
434 rhs_rs: isize,
435 rhs_cs: isize,
436 accum: *mut Pack,
437 lhs: *mut Pack,
438 rhs: *mut Pack,
439 }
440
441 impl KernelIter {
442 #[inline(always)]
443 unsafe fn execute(self, iter: usize) {
444 let packed_lhs = self.packed_lhs.wrapping_offset(iter as isize * self.lhs_cs);
445 let packed_rhs = self.packed_rhs.wrapping_offset(iter as isize * self.rhs_rs);
446 let next_lhs = self.next_lhs.wrapping_offset(iter as isize * self.lhs_cs);
447
448 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
449 *self.lhs.add(M_ITER) = *(packed_lhs.add(M_ITER * N) as *const Pack);
450 }});
451
452 seq_macro::seq!(N_ITER in 0..$nr {{
453 *self.rhs = splat(*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs));
454 let accum = self.accum.add(N_ITER * $mr_div_n);
455 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
456 let accum = &mut *accum.add(M_ITER);
457 *accum = mul_add(
458 *self.lhs.add(M_ITER),
459 *self.rhs,
460 *accum,
461 );
462 }});
463 }});
464
465 let _ = next_lhs;
466 }
467
468 $(
469 #[inline(always)]
470 unsafe fn execute_neon(self, iter: usize) {
471 debug_assert_eq!(self.rhs_cs, 1);
472 let packed_lhs = self.packed_lhs.wrapping_offset(iter as isize * self.lhs_cs);
473 let packed_rhs = self.packed_rhs.wrapping_offset(iter as isize * self.rhs_rs);
474
475 load::<$mr_div_n>(self.lhs, packed_lhs);
476
477 seq_macro::seq!(N_ITER0 in 0..$nr_div_n {{
478 *self.rhs = *(packed_rhs.wrapping_offset(N_ITER0 * $n) as *const Pack);
479
480 seq_macro::seq!(N_ITER1 in 0..$n {{
481 const N_ITER: usize = N_ITER0 * $n + N_ITER1;
482 let accum = self.accum.add(N_ITER * $mr_div_n);
483 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
484 let accum = &mut *accum.add(M_ITER);
485 *accum = mul_add_lane::<N_ITER1>(
486 *self.lhs.add(M_ITER),
487 *self.rhs,
488 *accum,
489 );
490 }});
491 }});
492 }});
493 }
494 )?
495 }
496
497 let k_unroll = k / $unroll;
498 let k_leftover = k % $unroll;
499
500 let mut main_loop = {
501 #[inline(always)]
502 || {
503 loop {
504 $(
505 let _ = $nr_div_n;
506 if rhs_cs == 1 {
507 let mut depth = k_unroll;
508 if depth != 0 {
509 loop {
510 let iter = KernelIter {
511 packed_lhs,
512 next_lhs,
513 packed_rhs,
514 lhs_cs,
515 rhs_rs,
516 rhs_cs,
517 accum,
518 lhs: lhs.as_mut_ptr() as _,
519 rhs: &mut rhs as *mut _ as _,
520 };
521
522 seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
523 iter.execute_neon(UNROLL_ITER);
524 }});
525
526 packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
527 packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
528 next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
529
530 depth -= 1;
531 if depth == 0 {
532 break;
533 }
534 }
535 }
536 depth = k_leftover;
537 if depth != 0 {
538 loop {
539 KernelIter {
540 packed_lhs,
541 next_lhs,
542 packed_rhs,
543 lhs_cs,
544 rhs_rs,
545 rhs_cs,
546 accum,
547 lhs: lhs.as_mut_ptr() as _,
548 rhs: &mut rhs as *mut _ as _,
549 }
550 .execute_neon(0);
551
552 packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
553 packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
554 next_lhs = next_lhs.wrapping_offset(lhs_cs);
555
556 depth -= 1;
557 if depth == 0 {
558 break;
559 }
560 }
561 }
562 break;
563 }
564 )?
565
566 let mut depth = k_unroll;
567 if depth != 0 {
568 loop {
569 let iter = KernelIter {
570 packed_lhs,
571 next_lhs,
572 packed_rhs,
573 lhs_cs,
574 rhs_rs,
575 rhs_cs,
576 accum,
577 lhs: lhs.as_mut_ptr() as _,
578 rhs: &mut rhs as *mut _ as _,
579 };
580
581 seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
582 iter.execute(UNROLL_ITER);
583 }});
584
585 packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
586 packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
587 next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
588
589 depth -= 1;
590 if depth == 0 {
591 break;
592 }
593 }
594 }
595 depth = k_leftover;
596 if depth != 0 {
597 loop {
598 KernelIter {
599 packed_lhs,
600 next_lhs,
601 packed_rhs,
602 lhs_cs,
603 rhs_rs,
604 rhs_cs,
605 accum,
606 lhs: lhs.as_mut_ptr() as _,
607 rhs: &mut rhs as *mut _ as _,
608 }
609 .execute(0);
610
611 packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
612 packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
613 next_lhs = next_lhs.wrapping_offset(lhs_cs);
614
615 depth -= 1;
616 if depth == 0 {
617 break;
618 }
619 }
620 }
621 break;
622 }
623 }
624 };
625
626 if rhs_rs == 1 {
627 main_loop();
628 } else {
629 main_loop();
630 }
631
632 if m == $mr_div_n * N && n == $nr && dst_rs == 1 {
633 let alpha = splat(alpha);
634 let beta = splat(beta);
635 if alpha_status == 2 {
636 seq_macro::seq!(N_ITER in 0..$nr {{
637 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
638 let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
639 dst.write_unaligned(add(
640 mul(alpha, *dst),
641 mul(beta, *accum.offset(M_ITER + $mr_div_n * N_ITER)),
642 ));
643 }});
644 }});
645 } else if alpha_status == 1 {
646 seq_macro::seq!(N_ITER in 0..$nr {{
647 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
648 let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
649 dst.write_unaligned(mul_add(
650 beta,
651 *accum.offset(M_ITER + $mr_div_n * N_ITER),
652 *dst,
653 ));
654 }});
655 }});
656 } else {
657 seq_macro::seq!(N_ITER in 0..$nr {{
658 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
659 let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
660 dst.write_unaligned(mul(beta, *accum.offset(M_ITER + $mr_div_n * N_ITER)));
661 }});
662 }});
663 }
664 } else {
665 let src = accum_storage; let src = src.as_ptr() as *const T;
667
668 if alpha_status == 2 {
669 for j in 0..n {
670 let dst_j = dst.offset(dst_cs * j as isize);
671 let src_j = src.add(j * $mr_div_n * N);
672
673 for i in 0..m {
674 let dst_ij = dst_j.offset(dst_rs * i as isize);
675 let src_ij = src_j.add(i);
676
677 *dst_ij = scalar_add(scalar_mul(alpha, *dst_ij), scalar_mul(beta, *src_ij));
678 }
679 }
680 } else if alpha_status == 1 {
681 for j in 0..n {
682 let dst_j = dst.offset(dst_cs * j as isize);
683 let src_j = src.add(j * $mr_div_n * N);
684
685 for i in 0..m {
686 let dst_ij = dst_j.offset(dst_rs * i as isize);
687 let src_ij = src_j.add(i);
688
689 *dst_ij = scalar_mul_add(beta, *src_ij, *dst_ij);
690 }
691 }
692 } else {
693 for j in 0..n {
694 let dst_j = dst.offset(dst_cs * j as isize);
695 let src_j = src.add(j * $mr_div_n * N);
696
697 for i in 0..m {
698 let dst_ij = dst_j.offset(dst_rs * i as isize);
699 let src_ij = src_j.add(i);
700
701 *dst_ij = scalar_mul(beta, *src_ij);
702 }
703 }
704 }
705 }
706
707 }
708 };
709}
710
711#[macro_export]
712macro_rules! microkernel_cplx_2step {
713 ($([$target: tt])?, $unroll: tt, $name: ident, $mr_div_n: tt, $nr: tt) => {
714 $(#[target_feature(enable = $target)])?
715 pub unsafe fn $name(
717 m: usize,
718 n: usize,
719 k: usize,
720 dst: *mut num_complex::Complex<T>,
721 mut packed_lhs: *const num_complex::Complex<T>,
722 mut packed_rhs: *const num_complex::Complex<T>,
723 dst_cs: isize,
724 dst_rs: isize,
725 lhs_cs: isize,
726 rhs_rs: isize,
727 rhs_cs: isize,
728 alpha: num_complex::Complex<T>,
729 beta: num_complex::Complex<T>,
730 alpha_status: u8,
731 conj_dst: bool,
732 conj_lhs: bool,
733 conj_rhs: bool,
734 mut next_lhs: *const num_complex::Complex<T>,
735 ) {
736 let mut accum_storage = [[splat(0.0); $mr_div_n]; $nr];
737 let accum = accum_storage.as_mut_ptr() as *mut Pack;
738
739 let (neg_conj_rhs, conj_all, neg_all) = match (conj_lhs, conj_rhs) {
740 (true, true) => (true, false, true),
741 (false, true) => (false, true, false),
742 (true, false) => (false, false, false),
743 (false, false) => (true, true, true),
744 };
745
746 let mut lhs_re_im = [::core::mem::MaybeUninit::<Pack>::uninit(); $mr_div_n];
747 let mut lhs_im_re = [::core::mem::MaybeUninit::<Pack>::uninit(); $mr_div_n];
748 let mut rhs_re = ::core::mem::MaybeUninit::<Pack>::uninit();
749 let mut rhs_im = ::core::mem::MaybeUninit::<Pack>::uninit();
750
751 #[derive(Copy, Clone)]
752 struct KernelIter {
753 packed_lhs: *const num_complex::Complex<T>,
754 next_lhs: *const num_complex::Complex<T>,
755 packed_rhs: *const num_complex::Complex<T>,
756 lhs_cs: isize,
757 rhs_rs: isize,
758 rhs_cs: isize,
759 accum: *mut Pack,
760 lhs_re_im: *mut Pack,
761 lhs_im_re: *mut Pack,
762 rhs_re: *mut Pack,
763 rhs_im: *mut Pack,
764 }
765
766 impl KernelIter {
767 #[inline(always)]
768 unsafe fn execute(self, iter: usize, neg_conj_rhs: bool) {
769 let packed_lhs = self.packed_lhs.wrapping_offset(iter as isize * self.lhs_cs);
770 let packed_rhs = self.packed_rhs.wrapping_offset(iter as isize * self.rhs_rs);
771 let next_lhs = self.next_lhs.wrapping_offset(iter as isize * self.lhs_cs);
772
773 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
774 let tmp = *(packed_lhs.add(M_ITER * CPLX_N) as *const Pack);
775 *self.lhs_re_im.add(M_ITER) = tmp;
776 *self.lhs_im_re.add(M_ITER) = swap_re_im(tmp);
777 }});
778
779 seq_macro::seq!(N_ITER in 0..$nr {{
780 *self.rhs_re = splat((*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs)).re);
781
782 let accum = self.accum.add(N_ITER * $mr_div_n);
783 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
784 let accum = &mut *accum.add(M_ITER);
785 *accum = mul_add_cplx_step0(
786 *self.lhs_re_im.add(M_ITER),
787 *self.rhs_re,
788 *accum,
789 neg_conj_rhs,
790 );
791 }});
792 }});
793
794 seq_macro::seq!(N_ITER in 0..$nr {{
795 *self.rhs_im = splat((*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs)).im);
796
797 let accum = self.accum.add(N_ITER * $mr_div_n);
798 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
799 let accum = &mut *accum.add(M_ITER);
800 *accum = mul_add_cplx_step1(
801 *self.lhs_im_re.add(M_ITER),
802 *self.rhs_im,
803 *accum,
804 neg_conj_rhs,
805 );
806 }});
807 }});
808
809 let _ = next_lhs;
810 }
811 }
812
813 let k_unroll = k / $unroll;
814 let k_leftover = k % $unroll;
815
816 let mut main_loop = {
817 #[inline(always)]
818 || {
819 loop {
820 if neg_conj_rhs {
821 let mut depth = k_unroll;
822 if depth != 0 {
823 loop {
824 let iter = KernelIter {
825 packed_lhs,
826 next_lhs,
827 packed_rhs,
828 lhs_cs,
829 rhs_rs,
830 rhs_cs,
831 accum,
832 lhs_re_im: lhs_re_im.as_mut_ptr() as _,
833 lhs_im_re: lhs_im_re.as_mut_ptr() as _,
834 rhs_re: &mut rhs_re as *mut _ as _,
835 rhs_im: &mut rhs_im as *mut _ as _,
836 };
837
838 seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
839 iter.execute(UNROLL_ITER, true);
840 }});
841
842 packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
843 packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
844 next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
845
846 depth -= 1;
847 if depth == 0 {
848 break;
849 }
850 }
851 }
852 depth = k_leftover;
853 if depth != 0 {
854 loop {
855 KernelIter {
856 packed_lhs,
857 next_lhs,
858 packed_rhs,
859 lhs_cs,
860 rhs_rs,
861 rhs_cs,
862 accum,
863 lhs_re_im: lhs_re_im.as_mut_ptr() as _,
864 lhs_im_re: lhs_im_re.as_mut_ptr() as _,
865 rhs_re: &mut rhs_re as *mut _ as _,
866 rhs_im: &mut rhs_im as *mut _ as _,
867 }
868 .execute(0, true);
869
870 packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
871 packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
872 next_lhs = next_lhs.wrapping_offset(lhs_cs);
873
874 depth -= 1;
875 if depth == 0 {
876 break;
877 }
878 }
879 }
880 break;
881 } else {
882 let mut depth = k_unroll;
883 if depth != 0 {
884 loop {
885 let iter = KernelIter {
886 next_lhs,
887 packed_lhs,
888 packed_rhs,
889 lhs_cs,
890 rhs_rs,
891 rhs_cs,
892 accum,
893 lhs_re_im: lhs_re_im.as_mut_ptr() as _,
894 lhs_im_re: lhs_im_re.as_mut_ptr() as _,
895 rhs_re: &mut rhs_re as *mut _ as _,
896 rhs_im: &mut rhs_im as *mut _ as _,
897 };
898
899 seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
900 iter.execute(UNROLL_ITER, false);
901 }});
902
903 packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
904 packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
905 next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
906
907 depth -= 1;
908 if depth == 0 {
909 break;
910 }
911 }
912 }
913 depth = k_leftover;
914 if depth != 0 {
915 loop {
916 KernelIter {
917 next_lhs,
918 packed_lhs,
919 packed_rhs,
920 lhs_cs,
921 rhs_rs,
922 rhs_cs,
923 accum,
924 lhs_re_im: lhs_re_im.as_mut_ptr() as _,
925 lhs_im_re: lhs_im_re.as_mut_ptr() as _,
926 rhs_re: &mut rhs_re as *mut _ as _,
927 rhs_im: &mut rhs_im as *mut _ as _,
928 }
929 .execute(0, false);
930
931 packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
932 packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
933 next_lhs = next_lhs.wrapping_offset(lhs_cs);
934
935 depth -= 1;
936 if depth == 0 {
937 break;
938 }
939 }
940 }
941 break;
942 }
943 }
944 }
945 };
946
947 if rhs_rs == 1 {
948 main_loop();
949 } else {
950 main_loop();
951 }
952
953 if conj_all && neg_all {
954 seq_macro::seq!(N_ITER in 0..$nr {{
955 let accum = accum.add(N_ITER * $mr_div_n);
956 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
957 let accum = &mut *accum.add(M_ITER);
958 *accum = neg_conj(*accum);
959 }});
960 }});
961 } else if !conj_all && neg_all {
962 seq_macro::seq!(N_ITER in 0..$nr {{
963 let accum = accum.add(N_ITER * $mr_div_n);
964 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
965 let accum = &mut *accum.add(M_ITER);
966 *accum = neg(*accum);
967 }});
968 }});
969 } else if conj_all && !neg_all {
970 seq_macro::seq!(N_ITER in 0..$nr {{
971 let accum = accum.add(N_ITER * $mr_div_n);
972 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
973 let accum = &mut *accum.add(M_ITER);
974 *accum = conj(*accum);
975 }});
976 }});
977 }
978
979 if m == $mr_div_n * CPLX_N && n == $nr && dst_rs == 1 {
980 let alpha_re = splat(alpha.re);
981 let alpha_im = splat(alpha.im);
982 let beta_re = splat(beta.re);
983 let beta_im = splat(beta.im);
984
985 if conj_dst {
986 if alpha_status == 2 {
987 seq_macro::seq!(N_ITER in 0..$nr {{
988 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
989 let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
990 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
991 *dst = add(
992 mul_cplx(conj(*dst), swap_re_im(conj(*dst)), alpha_re, alpha_im),
993 mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
994 );
995 }});
996 }});
997 } else if alpha_status == 1 {
998 seq_macro::seq!(N_ITER in 0..$nr {{
999 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1000 let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1001 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1002 *dst = add(
1003 conj(*dst),
1004 mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1005 );
1006 }});
1007 }});
1008 } else {
1009 seq_macro::seq!(N_ITER in 0..$nr {{
1010 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1011 let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1012 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1013 *dst = mul_cplx(accum, swap_re_im(accum), beta_re, beta_im);
1014 }});
1015 }});
1016 }
1017 } else {
1018 if alpha_status == 2 {
1019 seq_macro::seq!(N_ITER in 0..$nr {{
1020 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1021 let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1022 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1023 *dst = add(
1024 mul_cplx(*dst, swap_re_im(*dst), alpha_re, alpha_im),
1025 mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1026 );
1027 }});
1028 }});
1029 } else if alpha_status == 1 {
1030 seq_macro::seq!(N_ITER in 0..$nr {{
1031 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1032 let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1033 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1034 *dst = add(
1035 *dst,
1036 mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1037 );
1038 }});
1039 }});
1040 } else {
1041 seq_macro::seq!(N_ITER in 0..$nr {{
1042 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1043 let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1044 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1045 *dst = mul_cplx(accum, swap_re_im(accum), beta_re, beta_im);
1046 }});
1047 }});
1048 }
1049 }
1050 } else {
1051 let src = accum_storage; let src = src.as_ptr() as *const num_complex::Complex<T>;
1053
1054 if conj_dst {
1055 if alpha_status == 2 {
1056 for j in 0..n {
1057 let dst_j = dst.offset(dst_cs * j as isize);
1058 let src_j = src.add(j * $mr_div_n * CPLX_N);
1059
1060 for i in 0..m {
1061 let dst_ij = dst_j.offset(dst_rs * i as isize);
1062 let src_ij = src_j.add(i);
1063
1064 *dst_ij = alpha * (*dst_ij).conj() + beta * *src_ij;
1065 }
1066 }
1067 } else if alpha_status == 1 {
1068 for j in 0..n {
1069 let dst_j = dst.offset(dst_cs * j as isize);
1070 let src_j = src.add(j * $mr_div_n * CPLX_N);
1071
1072 for i in 0..m {
1073 let dst_ij = dst_j.offset(dst_rs * i as isize);
1074 let src_ij = src_j.add(i);
1075
1076 *dst_ij = (*dst_ij).conj() + beta * *src_ij;
1077 }
1078 }
1079 } else {
1080 for j in 0..n {
1081 let dst_j = dst.offset(dst_cs * j as isize);
1082 let src_j = src.add(j * $mr_div_n * CPLX_N);
1083
1084 for i in 0..m {
1085 let dst_ij = dst_j.offset(dst_rs * i as isize);
1086 let src_ij = src_j.add(i);
1087
1088 *dst_ij = beta * *src_ij;
1089 }
1090 }
1091 }
1092 } else {
1093 if alpha_status == 2 {
1094 for j in 0..n {
1095 let dst_j = dst.offset(dst_cs * j as isize);
1096 let src_j = src.add(j * $mr_div_n * CPLX_N);
1097
1098 for i in 0..m {
1099 let dst_ij = dst_j.offset(dst_rs * i as isize);
1100 let src_ij = src_j.add(i);
1101
1102 *dst_ij = alpha * *dst_ij + beta * *src_ij;
1103 }
1104 }
1105 } else if alpha_status == 1 {
1106 for j in 0..n {
1107 let dst_j = dst.offset(dst_cs * j as isize);
1108 let src_j = src.add(j * $mr_div_n * CPLX_N);
1109
1110 for i in 0..m {
1111 let dst_ij = dst_j.offset(dst_rs * i as isize);
1112 let src_ij = src_j.add(i);
1113
1114 *dst_ij = *dst_ij + beta * *src_ij;
1115 }
1116 }
1117 } else {
1118 for j in 0..n {
1119 let dst_j = dst.offset(dst_cs * j as isize);
1120 let src_j = src.add(j * $mr_div_n * CPLX_N);
1121
1122 for i in 0..m {
1123 let dst_ij = dst_j.offset(dst_rs * i as isize);
1124 let src_ij = src_j.add(i);
1125
1126 *dst_ij = beta * *src_ij;
1127 }
1128 }
1129 }
1130 }
1131 }
1132 }
1133 };
1134}
1135
1136#[macro_export]
1137macro_rules! microkernel_cplx {
1138 ($([$target: tt])?, $unroll: tt, $name: ident, $mr_div_n: tt, $nr: tt) => {
1139 $(#[target_feature(enable = $target)])?
1140 pub unsafe fn $name(
1142 m: usize,
1143 n: usize,
1144 k: usize,
1145 dst: *mut num_complex::Complex<T>,
1146 mut packed_lhs: *const num_complex::Complex<T>,
1147 mut packed_rhs: *const num_complex::Complex<T>,
1148 dst_cs: isize,
1149 dst_rs: isize,
1150 lhs_cs: isize,
1151 rhs_rs: isize,
1152 rhs_cs: isize,
1153 alpha: num_complex::Complex<T>,
1154 beta: num_complex::Complex<T>,
1155 alpha_status: u8,
1156 conj_dst: bool,
1157 conj_lhs: bool,
1158 conj_rhs: bool,
1159 mut next_lhs: *const num_complex::Complex<T>,
1160 ) {
1161 let mut accum_storage = [[splat(0.0); $mr_div_n]; $nr];
1162 let accum = accum_storage.as_mut_ptr() as *mut Pack;
1163
1164 let conj_both_lhs_rhs = conj_lhs;
1165 let conj_rhs = conj_lhs != conj_rhs;
1166
1167 let mut lhs_re_im = [::core::mem::MaybeUninit::<Pack>::uninit(); $mr_div_n];
1168 let mut lhs_im_re = [::core::mem::MaybeUninit::<Pack>::uninit(); $mr_div_n];
1169 let mut rhs_re = ::core::mem::MaybeUninit::<Pack>::uninit();
1170 let mut rhs_im = ::core::mem::MaybeUninit::<Pack>::uninit();
1171
1172 #[derive(Copy, Clone)]
1173 struct KernelIter {
1174 packed_lhs: *const num_complex::Complex<T>,
1175 next_lhs: *const num_complex::Complex<T>,
1176 packed_rhs: *const num_complex::Complex<T>,
1177 lhs_cs: isize,
1178 rhs_rs: isize,
1179 rhs_cs: isize,
1180 accum: *mut Pack,
1181 lhs_re_im: *mut Pack,
1182 lhs_im_re: *mut Pack,
1183 rhs_re: *mut Pack,
1184 rhs_im: *mut Pack,
1185 }
1186
1187 impl KernelIter {
1188 #[inline(always)]
1189 unsafe fn execute(self, iter: usize, conj_rhs: bool) {
1190 let packed_lhs = self.packed_lhs.wrapping_offset(iter as isize * self.lhs_cs);
1191 let packed_rhs = self.packed_rhs.wrapping_offset(iter as isize * self.rhs_rs);
1192 let next_lhs = self.next_lhs.wrapping_offset(iter as isize * self.lhs_cs);
1193
1194 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1195 let tmp = *(packed_lhs.add(M_ITER * CPLX_N) as *const Pack);
1196 *self.lhs_re_im.add(M_ITER) = tmp;
1197 *self.lhs_im_re.add(M_ITER) = swap_re_im(tmp);
1198 }});
1199
1200 seq_macro::seq!(N_ITER in 0..$nr {{
1201 *self.rhs_re = splat((*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs)).re);
1202 *self.rhs_im = splat((*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs)).im);
1203
1204 let accum = self.accum.add(N_ITER * $mr_div_n);
1205 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1206 let accum = &mut *accum.add(M_ITER);
1207 *accum = mul_add_cplx(
1208 *self.lhs_re_im.add(M_ITER),
1209 *self.lhs_im_re.add(M_ITER),
1210 *self.rhs_re,
1211 *self.rhs_im,
1212 *accum,
1213 conj_rhs,
1214 );
1215 }});
1216 }});
1217
1218 let _ = next_lhs;
1219 }
1220 }
1221
1222 let k_unroll = k / $unroll;
1223 let k_leftover = k % $unroll;
1224
1225 loop {
1226 if conj_rhs {
1227 let mut depth = k_unroll;
1228 if depth != 0 {
1229 loop {
1230 let iter = KernelIter {
1231 packed_lhs,
1232 next_lhs,
1233 packed_rhs,
1234 lhs_cs,
1235 rhs_rs,
1236 rhs_cs,
1237 accum,
1238 lhs_re_im: lhs_re_im.as_mut_ptr() as _,
1239 lhs_im_re: lhs_im_re.as_mut_ptr() as _,
1240 rhs_re: &mut rhs_re as *mut _ as _,
1241 rhs_im: &mut rhs_im as *mut _ as _,
1242 };
1243
1244 seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
1245 iter.execute(UNROLL_ITER, true);
1246 }});
1247
1248 packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
1249 packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
1250 next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
1251
1252 depth -= 1;
1253 if depth == 0 {
1254 break;
1255 }
1256 }
1257 }
1258 depth = k_leftover;
1259 if depth != 0 {
1260 loop {
1261 KernelIter {
1262 packed_lhs,
1263 next_lhs,
1264 packed_rhs,
1265 lhs_cs,
1266 rhs_rs,
1267 rhs_cs,
1268 accum,
1269 lhs_re_im: lhs_re_im.as_mut_ptr() as _,
1270 lhs_im_re: lhs_im_re.as_mut_ptr() as _,
1271 rhs_re: &mut rhs_re as *mut _ as _,
1272 rhs_im: &mut rhs_im as *mut _ as _,
1273 }
1274 .execute(0, true);
1275
1276 packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
1277 packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
1278 next_lhs = next_lhs.wrapping_offset(lhs_cs);
1279
1280 depth -= 1;
1281 if depth == 0 {
1282 break;
1283 }
1284 }
1285 }
1286 break;
1287 } else {
1288 let mut depth = k_unroll;
1289 if depth != 0 {
1290 loop {
1291 let iter = KernelIter {
1292 next_lhs,
1293 packed_lhs,
1294 packed_rhs,
1295 lhs_cs,
1296 rhs_rs,
1297 rhs_cs,
1298 accum,
1299 lhs_re_im: lhs_re_im.as_mut_ptr() as _,
1300 lhs_im_re: lhs_im_re.as_mut_ptr() as _,
1301 rhs_re: &mut rhs_re as *mut _ as _,
1302 rhs_im: &mut rhs_im as *mut _ as _,
1303 };
1304
1305 seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
1306 iter.execute(UNROLL_ITER, false);
1307 }});
1308
1309 packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
1310 packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
1311 next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
1312
1313 depth -= 1;
1314 if depth == 0 {
1315 break;
1316 }
1317 }
1318 }
1319 depth = k_leftover;
1320 if depth != 0 {
1321 loop {
1322 KernelIter {
1323 next_lhs,
1324 packed_lhs,
1325 packed_rhs,
1326 lhs_cs,
1327 rhs_rs,
1328 rhs_cs,
1329 accum,
1330 lhs_re_im: lhs_re_im.as_mut_ptr() as _,
1331 lhs_im_re: lhs_im_re.as_mut_ptr() as _,
1332 rhs_re: &mut rhs_re as *mut _ as _,
1333 rhs_im: &mut rhs_im as *mut _ as _,
1334 }
1335 .execute(0, false);
1336
1337 packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
1338 packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
1339 next_lhs = next_lhs.wrapping_offset(lhs_cs);
1340
1341 depth -= 1;
1342 if depth == 0 {
1343 break;
1344 }
1345 }
1346 }
1347 break;
1348 }
1349 }
1350
1351 if conj_both_lhs_rhs {
1352 seq_macro::seq!(N_ITER in 0..$nr {{
1353 let accum = accum.add(N_ITER * $mr_div_n);
1354 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1355 let accum = &mut *accum.add(M_ITER);
1356 *accum = conj(*accum);
1357 }});
1358 }});
1359 }
1360
1361 if m == $mr_div_n * CPLX_N && n == $nr && dst_rs == 1 {
1362 let alpha_re = splat(alpha.re);
1363 let alpha_im = splat(alpha.im);
1364 let beta_re = splat(beta.re);
1365 let beta_im = splat(beta.im);
1366
1367 if conj_dst {
1368 if alpha_status == 2 {
1369 seq_macro::seq!(N_ITER in 0..$nr {{
1370 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1371 let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1372 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1373 *dst = add(
1374 mul_cplx(conj(*dst), swap_re_im(conj(*dst)), alpha_re, alpha_im),
1375 mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1376 );
1377 }});
1378 }});
1379 } else if alpha_status == 1 {
1380 seq_macro::seq!(N_ITER in 0..$nr {{
1381 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1382 let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1383 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1384 *dst = add(
1385 conj(*dst),
1386 mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1387 );
1388 }});
1389 }});
1390 } else {
1391 seq_macro::seq!(N_ITER in 0..$nr {{
1392 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1393 let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1394 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1395 *dst = mul_cplx(accum, swap_re_im(accum), beta_re, beta_im);
1396 }});
1397 }});
1398 }
1399 } else {
1400 if alpha_status == 2 {
1401 seq_macro::seq!(N_ITER in 0..$nr {{
1402 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1403 let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1404 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1405 *dst = add(
1406 mul_cplx(*dst, swap_re_im(*dst), alpha_re, alpha_im),
1407 mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1408 );
1409 }});
1410 }});
1411 } else if alpha_status == 1 {
1412 seq_macro::seq!(N_ITER in 0..$nr {{
1413 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1414 let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1415 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1416 *dst = add(
1417 *dst,
1418 mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1419 );
1420 }});
1421 }});
1422 } else {
1423 seq_macro::seq!(N_ITER in 0..$nr {{
1424 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1425 let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1426 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1427 *dst = mul_cplx(accum, swap_re_im(accum), beta_re, beta_im);
1428 }});
1429 }});
1430 }
1431 }
1432 } else {
1433 let src = accum_storage; let src = src.as_ptr() as *const num_complex::Complex<T>;
1435
1436 if conj_dst {
1437 if alpha_status == 2 {
1438 for j in 0..n {
1439 let dst_j = dst.offset(dst_cs * j as isize);
1440 let src_j = src.add(j * $mr_div_n * CPLX_N);
1441
1442 for i in 0..m {
1443 let dst_ij = dst_j.offset(dst_rs * i as isize);
1444 let src_ij = src_j.add(i);
1445
1446 *dst_ij = alpha * (*dst_ij).conj() + beta * *src_ij;
1447 }
1448 }
1449 } else if alpha_status == 1 {
1450 for j in 0..n {
1451 let dst_j = dst.offset(dst_cs * j as isize);
1452 let src_j = src.add(j * $mr_div_n * CPLX_N);
1453
1454 for i in 0..m {
1455 let dst_ij = dst_j.offset(dst_rs * i as isize);
1456 let src_ij = src_j.add(i);
1457
1458 *dst_ij = (*dst_ij).conj() + beta * *src_ij;
1459 }
1460 }
1461 } else {
1462 for j in 0..n {
1463 let dst_j = dst.offset(dst_cs * j as isize);
1464 let src_j = src.add(j * $mr_div_n * CPLX_N);
1465
1466 for i in 0..m {
1467 let dst_ij = dst_j.offset(dst_rs * i as isize);
1468 let src_ij = src_j.add(i);
1469
1470 *dst_ij = beta * *src_ij;
1471 }
1472 }
1473 }
1474 } else {
1475 if alpha_status == 2 {
1476 for j in 0..n {
1477 let dst_j = dst.offset(dst_cs * j as isize);
1478 let src_j = src.add(j * $mr_div_n * CPLX_N);
1479
1480 for i in 0..m {
1481 let dst_ij = dst_j.offset(dst_rs * i as isize);
1482 let src_ij = src_j.add(i);
1483
1484 *dst_ij = alpha * *dst_ij + beta * *src_ij;
1485 }
1486 }
1487 } else if alpha_status == 1 {
1488 for j in 0..n {
1489 let dst_j = dst.offset(dst_cs * j as isize);
1490 let src_j = src.add(j * $mr_div_n * CPLX_N);
1491
1492 for i in 0..m {
1493 let dst_ij = dst_j.offset(dst_rs * i as isize);
1494 let src_ij = src_j.add(i);
1495
1496 *dst_ij = *dst_ij + beta * *src_ij;
1497 }
1498 }
1499 } else {
1500 for j in 0..n {
1501 let dst_j = dst.offset(dst_cs * j as isize);
1502 let src_j = src.add(j * $mr_div_n * CPLX_N);
1503
1504 for i in 0..m {
1505 let dst_ij = dst_j.offset(dst_rs * i as isize);
1506 let src_ij = src_j.add(i);
1507
1508 *dst_ij = beta * *src_ij;
1509 }
1510 }
1511 }
1512 }
1513 }
1514 }
1515 };
1516}
1517
1518#[macro_export]
1519macro_rules! microkernel_cplx_packed {
1520 ($([$target: tt])?, $unroll: tt, $name: ident, $mr_div_n: tt, $nr: tt) => {
1521 $(#[target_feature(enable = $target)])?
1522 pub unsafe fn $name(
1524 m: usize,
1525 n: usize,
1526 k: usize,
1527 dst: *mut T,
1528 mut packed_lhs: *const T,
1529 mut packed_rhs: *const T,
1530 dst_cs: isize,
1531 dst_rs: isize,
1532 lhs_cs: isize,
1533 rhs_rs: isize,
1534 rhs_cs: isize,
1535 alpha: T,
1536 beta: T,
1537 alpha_status: u8,
1538 conj_dst: bool,
1539 conj_lhs: bool,
1540 conj_rhs: bool,
1541 mut next_lhs: *const T,
1542 ) {
1543 let mut accum_storage = [[core::mem::zeroed::<Pack>(); $mr_div_n]; $nr];
1544 let accum = accum_storage.as_mut_ptr() as *mut Pack;
1545
1546 let conj_both_lhs_rhs = conj_lhs;
1547 let conj_rhs = conj_lhs != conj_rhs;
1548
1549 let mut lhs = [::core::mem::MaybeUninit::<Pack>::uninit(); $mr_div_n];
1550 let mut rhs = ::core::mem::MaybeUninit::<Pack>::uninit();
1551
1552 #[derive(Copy, Clone)]
1553 struct KernelIter {
1554 packed_lhs: *const T,
1555 next_lhs: *const T,
1556 packed_rhs: *const T,
1557 lhs_cs: isize,
1558 rhs_rs: isize,
1559 rhs_cs: isize,
1560 accum: *mut Pack,
1561 lhs: *mut Pack,
1562 rhs: *mut Pack,
1563 }
1564
1565 impl KernelIter {
1566 #[inline(always)]
1567 unsafe fn execute(self, iter: usize, conj_rhs: bool) {
1568 let packed_lhs = self.packed_lhs.wrapping_offset(iter as isize * self.lhs_cs);
1569 let packed_rhs = self.packed_rhs.wrapping_offset(iter as isize * self.rhs_rs);
1570 let next_lhs = self.next_lhs.wrapping_offset(iter as isize * self.lhs_cs);
1571
1572 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1573 *self.lhs.add(M_ITER) = *(packed_lhs.add(M_ITER * N) as *const Pack);
1574 }});
1575
1576 seq_macro::seq!(N_ITER in 0..$nr {{
1577 *self.rhs = splat(*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs));
1578
1579 let accum = self.accum.add(N_ITER * $mr_div_n);
1580 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1581 let accum = &mut *accum.add(M_ITER);
1582 *accum = mul_add_cplx(
1583 *self.lhs.add(M_ITER),
1584 *self.rhs,
1585 *accum,
1586 conj_rhs,
1587 );
1588 }});
1589 }});
1590
1591 let _ = next_lhs;
1592 }
1593 }
1594
1595 let k_unroll = k / $unroll;
1596 let k_leftover = k % $unroll;
1597
1598 loop {
1599 if conj_rhs {
1600 let mut depth = k_unroll;
1601 if depth != 0 {
1602 loop {
1603 let iter = KernelIter {
1604 packed_lhs,
1605 next_lhs,
1606 packed_rhs,
1607 lhs_cs,
1608 rhs_rs,
1609 rhs_cs,
1610 accum,
1611 lhs: lhs.as_mut_ptr() as _,
1612 rhs: &mut rhs as *mut _ as _,
1613 };
1614
1615 seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
1616 iter.execute(UNROLL_ITER, true);
1617 }});
1618
1619 packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
1620 packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
1621 next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
1622
1623 depth -= 1;
1624 if depth == 0 {
1625 break;
1626 }
1627 }
1628 }
1629 depth = k_leftover;
1630 if depth != 0 {
1631 loop {
1632 KernelIter {
1633 packed_lhs,
1634 next_lhs,
1635 packed_rhs,
1636 lhs_cs,
1637 rhs_rs,
1638 rhs_cs,
1639 accum,
1640 lhs: lhs.as_mut_ptr() as _,
1641 rhs: &mut rhs as *mut _ as _,
1642 }
1643 .execute(0, true);
1644
1645 packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
1646 packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
1647 next_lhs = next_lhs.wrapping_offset(lhs_cs);
1648
1649 depth -= 1;
1650 if depth == 0 {
1651 break;
1652 }
1653 }
1654 }
1655 break;
1656 } else {
1657 let mut depth = k_unroll;
1658 if depth != 0 {
1659 loop {
1660 let iter = KernelIter {
1661 next_lhs,
1662 packed_lhs,
1663 packed_rhs,
1664 lhs_cs,
1665 rhs_rs,
1666 rhs_cs,
1667 accum,
1668 lhs: lhs.as_mut_ptr() as _,
1669 rhs: &mut rhs as *mut _ as _,
1670 };
1671
1672 seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
1673 iter.execute(UNROLL_ITER, false);
1674 }});
1675
1676 packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
1677 packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
1678 next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
1679
1680 depth -= 1;
1681 if depth == 0 {
1682 break;
1683 }
1684 }
1685 }
1686 depth = k_leftover;
1687 if depth != 0 {
1688 loop {
1689 KernelIter {
1690 next_lhs,
1691 packed_lhs,
1692 packed_rhs,
1693 lhs_cs,
1694 rhs_rs,
1695 rhs_cs,
1696 accum,
1697 lhs: lhs.as_mut_ptr() as _,
1698 rhs: &mut rhs as *mut _ as _,
1699 }
1700 .execute(0, false);
1701
1702 packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
1703 packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
1704 next_lhs = next_lhs.wrapping_offset(lhs_cs);
1705
1706 depth -= 1;
1707 if depth == 0 {
1708 break;
1709 }
1710 }
1711 }
1712 break;
1713 }
1714 }
1715
1716 if conj_both_lhs_rhs {
1717 seq_macro::seq!(N_ITER in 0..$nr {{
1718 let accum = accum.add(N_ITER * $mr_div_n);
1719 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1720 let accum = &mut *accum.add(M_ITER);
1721 *accum = conj(*accum);
1722 }});
1723 }});
1724 }
1725
1726 if m == $mr_div_n * N && n == $nr && dst_rs == 1 {
1727 let alpha = splat(alpha);
1728 let beta = splat(beta);
1729
1730 if conj_dst {
1731 if alpha_status == 2 {
1732 seq_macro::seq!(N_ITER in 0..$nr {{
1733 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1734 let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
1735 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1736 *dst = add(
1737 mul_cplx(conj(*dst), alpha),
1738 mul_cplx(accum, beta),
1739 );
1740 }});
1741 }});
1742 } else if alpha_status == 1 {
1743 seq_macro::seq!(N_ITER in 0..$nr {{
1744 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1745 let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
1746 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1747 *dst = add(
1748 conj(*dst),
1749 mul_cplx(accum, beta),
1750 );
1751 }});
1752 }});
1753 } else {
1754 seq_macro::seq!(N_ITER in 0..$nr {{
1755 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1756 let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
1757 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1758 *dst = mul_cplx(accum, beta);
1759 }});
1760 }});
1761 }
1762 } else {
1763 if alpha_status == 2 {
1764 seq_macro::seq!(N_ITER in 0..$nr {{
1765 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1766 let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
1767 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1768 *dst = add(
1769 mul_cplx(*dst, alpha),
1770 mul_cplx(accum, beta),
1771 );
1772 }});
1773 }});
1774 } else if alpha_status == 1 {
1775 seq_macro::seq!(N_ITER in 0..$nr {{
1776 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1777 let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
1778 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1779 *dst = add(
1780 *dst,
1781 mul_cplx(accum, beta),
1782 );
1783 }});
1784 }});
1785 } else {
1786 seq_macro::seq!(N_ITER in 0..$nr {{
1787 seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1788 let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
1789 let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1790 *dst = mul_cplx(accum, beta);
1791 }});
1792 }});
1793 }
1794 }
1795 } else {
1796 let src = accum_storage; let src = src.as_ptr() as *const T;
1798
1799 if conj_dst {
1800 if alpha_status == 2 {
1801 for j in 0..n {
1802 let dst_j = dst.offset(dst_cs * j as isize);
1803 let src_j = src.add(j * $mr_div_n * N);
1804
1805 for i in 0..m {
1806 let dst_ij = dst_j.offset(dst_rs * i as isize);
1807 let src_ij = src_j.add(i);
1808
1809 *dst_ij = alpha * (*dst_ij).conj() + beta * *src_ij;
1810 }
1811 }
1812 } else if alpha_status == 1 {
1813 for j in 0..n {
1814 let dst_j = dst.offset(dst_cs * j as isize);
1815 let src_j = src.add(j * $mr_div_n * N);
1816
1817 for i in 0..m {
1818 let dst_ij = dst_j.offset(dst_rs * i as isize);
1819 let src_ij = src_j.add(i);
1820
1821 *dst_ij = (*dst_ij).conj() + beta * *src_ij;
1822 }
1823 }
1824 } else {
1825 for j in 0..n {
1826 let dst_j = dst.offset(dst_cs * j as isize);
1827 let src_j = src.add(j * $mr_div_n * N);
1828
1829 for i in 0..m {
1830 let dst_ij = dst_j.offset(dst_rs * i as isize);
1831 let src_ij = src_j.add(i);
1832
1833 *dst_ij = beta * *src_ij;
1834 }
1835 }
1836 }
1837 } else {
1838 if alpha_status == 2 {
1839 for j in 0..n {
1840 let dst_j = dst.offset(dst_cs * j as isize);
1841 let src_j = src.add(j * $mr_div_n * N);
1842
1843 for i in 0..m {
1844 let dst_ij = dst_j.offset(dst_rs * i as isize);
1845 let src_ij = src_j.add(i);
1846
1847 *dst_ij = alpha * *dst_ij + beta * *src_ij;
1848 }
1849 }
1850 } else if alpha_status == 1 {
1851 for j in 0..n {
1852 let dst_j = dst.offset(dst_cs * j as isize);
1853 let src_j = src.add(j * $mr_div_n * N);
1854
1855 for i in 0..m {
1856 let dst_ij = dst_j.offset(dst_rs * i as isize);
1857 let src_ij = src_j.add(i);
1858
1859 *dst_ij = *dst_ij + beta * *src_ij;
1860 }
1861 }
1862 } else {
1863 for j in 0..n {
1864 let dst_j = dst.offset(dst_cs * j as isize);
1865 let src_j = src.add(j * $mr_div_n * N);
1866
1867 for i in 0..m {
1868 let dst_ij = dst_j.offset(dst_rs * i as isize);
1869 let src_ij = src_j.add(i);
1870
1871 *dst_ij = beta * *src_ij;
1872 }
1873 }
1874 }
1875 }
1876 }
1877 }
1878 };
1879}