gemm_common/
microkernel.rs

1pub type MicroKernelFn<T> = unsafe fn(
2    usize,
3    usize,
4    usize,
5    *mut T,
6    *const T,
7    *const T,
8    isize,
9    isize,
10    isize,
11    isize,
12    isize,
13    T,
14    T,
15    u8,
16    bool,
17    bool,
18    bool,
19    *const T,
20);
21
22pub type HMicroKernelFn<T> = unsafe fn(
23    k: usize,
24    dst: *mut T,
25    lhs: *const T,
26    rhs: *const T,
27    dst_cs: isize,
28    dst_rs: isize,
29    lhs_rs: isize,
30    rhs_cs: isize,
31    alpha: T,
32    beta: T,
33    alpha_status: u8,
34    _conj_dst: bool,
35    _conj_lhs: bool,
36    _conj_rhs: bool,
37);
38
39// microkernel_fn_array!{
40// [ a, b, c, ],
41// [ d, e, f, ],
42// }
43//
44// expands to
45// pub const UKR: [[[MicroKernelFn; 3]; 2]; 3] = [
46// [
47// [ a::<0>, b::<0>, c::<0>, ],
48// [ d::<0>, e::<0>, f::<0>, ],
49// ],
50// [
51// [ a::<1>, b::<1>, c::<1>, ],
52// [ d::<1>, e::<1>, f::<1>, ],
53// ],
54// [
55// [ a::<2>, b::<2>, c::<2>, ],
56// [ d::<2>, e::<2>, f::<2>, ],
57// ],
58// ]
59#[macro_export]
60macro_rules! __one {
61    (
62        $tt: tt
63    ) => {
64        1
65    };
66}
67
68#[macro_export]
69macro_rules! __microkernel_fn_array_helper {
70    (
71        [ $($tt: tt,)* ]
72    ) => {
73        {
74            let mut count = 0_usize;
75            $(count += $crate::__one!($tt);)*
76            count
77        }
78    }
79}
80
81#[macro_export]
82macro_rules! __microkernel_fn_array_helper_nr {
83    ($([
84       $($ukr: ident,)*
85    ],)*) => {
86        {
87            let counts = [$({
88                let mut count = 0_usize;
89                $(count += $crate::__one!($ukr);)*
90                count
91            },)*];
92
93            counts[0]
94        }
95    }
96}
97
98#[macro_export]
99macro_rules! microkernel_fn_array {
100    ($([
101       $($ukr: ident,)*
102    ],)*) => {
103       pub const MR_DIV_N: usize =
104           $crate::__microkernel_fn_array_helper!([$([$($ukr,)*],)*]);
105       pub const NR: usize =
106           $crate::__microkernel_fn_array_helper_nr!($([$($ukr,)*],)*);
107
108        pub const UKR: [[$crate::microkernel::MicroKernelFn<T>; NR]; MR_DIV_N] =
109            [ $([$($ukr,)*]),* ];
110    };
111}
112
113#[macro_export]
114macro_rules! hmicrokernel_fn_array {
115    ($([
116       $($ukr: ident,)*
117    ],)*) => {
118       pub const H_M: usize =
119           $crate::__microkernel_fn_array_helper!([$([$($ukr,)*],)*]);
120       pub const H_N: usize =
121           $crate::__microkernel_fn_array_helper_nr!($([$($ukr,)*],)*);
122
123        pub const H_UKR: [[$crate::microkernel::HMicroKernelFn<T>; H_N]; H_M] =
124            [ $([$($ukr,)*]),* ];
125    };
126}
127
128#[macro_export]
129macro_rules! microkernel_cplx_fn_array {
130    ($([
131       $($ukr: ident,)*
132    ],)*) => {
133       pub const CPLX_MR_DIV_N: usize =
134           $crate::__microkernel_fn_array_helper!([$([$($ukr,)*],)*]);
135       pub const CPLX_NR: usize =
136           $crate::__microkernel_fn_array_helper_nr!($([$($ukr,)*],)*);
137
138        pub const CPLX_UKR: [[$crate::microkernel::MicroKernelFn<num_complex::Complex<T>>; CPLX_NR]; CPLX_MR_DIV_N] =
139            [ $([$($ukr,)*]),* ];
140    };
141}
142
143#[macro_export]
144macro_rules! hmicrokernel_cplx_fn_array {
145    ($([
146       $($ukr: ident,)*
147    ],)*) => {
148       pub const H_CPLX_M: usize =
149           $crate::__microkernel_fn_array_helper!([$([$($ukr,)*],)*]);
150       pub const H_CPLX_N: usize =
151           $crate::__microkernel_fn_array_helper_nr!($([$($ukr,)*],)*);
152
153        pub const H_CPLX_UKR: [[$crate::microkernel::HMicroKernelFn<num_complex::Complex<T>>; H_CPLX_N]; H_CPLX_M] =
154            [ $([$($ukr,)*]),* ];
155    };
156}
157
158#[macro_export]
159macro_rules! amx {
160    ($op: tt, $gpr: expr) => {
161        ::core::arch::asm!(::core::concat!(".word (0x201000 + (", ::core::stringify!($op) ," << 5))") , in("x0")$gpr)
162    };
163}
164
165// Credit to RuQing Xu (https://github.com/xrq-phys/blis_apple) for the reference implementation.
166#[macro_export]
167macro_rules! microkernel_amx {
168    ($ty: tt, $([$target: tt])?, $unroll: tt, $name: ident, $mr_div_n: tt, $nr: tt , $nr_div_n: tt, $n: tt) => {
169        $(#[target_feature(enable = $target)])?
170        // 0, 1, or 2 for generic alpha
171        pub unsafe fn $name(
172            m: usize,
173            n: usize,
174            k: usize,
175            dst: *mut T,
176            mut packed_lhs: *const T,
177            mut packed_rhs: *const T,
178            dst_cs: isize,
179            dst_rs: isize,
180            lhs_cs: isize,
181            rhs_rs: isize,
182            rhs_cs: isize,
183            alpha: T,
184            beta: T,
185            alpha_status: u8,
186            _conj_dst: bool,
187            _conj_lhs: bool,
188            _conj_rhs: bool,
189            mut next_lhs: *const T,
190        ) {
191            assert_eq!(rhs_cs, 1);
192
193            macro_rules! amx_nop {
194                ($imm5: tt) => {
195                    ::core::arch::asm!(
196                        ::core::concat!(
197                            "nop\nnop\nnop\n.word (0x201000 + (17 << 5) + ",
198                            ::core::stringify!($imm5),
199                            ")"
200                        ),
201                        options(nostack)
202                    );
203                };
204            }
205            macro_rules! amx_set { () => { amx_nop!(0) }; }
206            macro_rules! amx_clr { () => { amx_nop!(1) }; }
207
208            macro_rules! __ldx { ($gpr: expr) => { $crate::amx!(0, $gpr) } }
209            macro_rules! __ldy { ($gpr: expr) => { $crate::amx!(1, $gpr) } }
210            macro_rules! __stz { ($gpr: expr) => { $crate::amx!(5, $gpr) } }
211            macro_rules! __extrx { ($gpr: expr) => { $crate::amx!(8, $gpr) } }
212            macro_rules! __fma {
213                (f64, $gpr: expr) => { $crate::amx!(10, $gpr) };
214                (f32, $gpr: expr) => { $crate::amx!(12, $gpr) };
215                (f16, $gpr: expr) => { $crate::amx!(15, $gpr) };
216            }
217
218            macro_rules! ldx { ($addr: expr, $idx: expr) => { __ldx!( ($idx << 56) | (($addr as usize) & ((1 << 56) - 1)) ) } }
219            macro_rules! ldy { ($addr: expr, $idx: expr) => { __ldy!( ($idx << 56) | (($addr as usize) & ((1 << 56) - 1)) ) } }
220            macro_rules! stz { ($addr: expr, $idx: expr) => { __stz!( ($idx << 56) | (($addr as usize) & ((1 << 56) - 1)) ) } }
221            macro_rules! extrx { ($z: expr, $x: expr) => { __extrx!( ($x << 16) | ($z << 20) ) } }
222            macro_rules! fma {
223                ($x: expr, $y: expr, $z: expr) => { __fma!($ty, ($y << 6) | ($x << 16) | ($z << 20)) };
224            }
225            macro_rules! mul_select {
226                ($col: expr, $x: expr, $y: expr, $z: expr) => {
227                    __fma!(
228                        $ty,
229                        (1usize << 27)
230                        | (1usize << 37)
231                        | ($col << 32)
232                        | ($y << 6)
233                        | ($x << 16)
234                        | ($z << 20)
235                    )
236                };
237            }
238            macro_rules! fma_select {
239                ($col: expr, $x: expr, $y: expr, $z: expr) => {
240                    __fma!(
241                        $ty,
242                        (1usize << 37)
243                        | ($col << 32)
244                        | ($y << 6)
245                        | ($x << 16)
246                        | ($z << 20)
247                    )
248                };
249            }
250
251            amx_set!();
252
253            ldx!(packed_lhs, 0);
254
255            let k_unroll = k / $unroll;
256            let k_leftover = k % $unroll ;
257
258            let mut depth = k_unroll;
259            if depth != 0 {
260                loop {
261                    seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
262                        seq_macro::seq!(M_ITER in 0..$mr_div_n {{
263                            ldx!(packed_lhs.offset(lhs_cs * UNROLL_ITER + M_ITER * $n), UNROLL_ITER + $unroll * M_ITER);
264                        }});
265                        seq_macro::seq!(N_ITER in 0..$nr_div_n {{
266                            ldy!(packed_rhs.offset(rhs_rs * UNROLL_ITER + N_ITER * $n), UNROLL_ITER + $unroll * N_ITER);
267                        }});
268                    }});
269                    seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
270                        seq_macro::seq!(M_ITER in 0..$mr_div_n {{
271                            seq_macro::seq!(N_ITER in 0..$nr_div_n {{
272                                fma!(UNROLL_ITER + $unroll * M_ITER, UNROLL_ITER + $unroll * N_ITER, M_ITER + $mr_div_n * N_ITER);
273                            }});
274                        }});
275                    }});
276
277                    packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
278                    packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
279                    next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
280
281                    depth -= 1;
282                    if depth == 0 {
283                        break;
284                    }
285                }
286            }
287            depth = k_leftover;
288            if depth != 0 {
289                loop {
290                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
291                        ldx!(packed_lhs.offset(lhs_cs * 0 + M_ITER * $n), 0 + $unroll * M_ITER);
292                    }});
293                    seq_macro::seq!(N_ITER in 0..$nr_div_n {{
294                        ldy!(packed_rhs.offset(rhs_rs * 0 + N_ITER * $n), 0 + $unroll * N_ITER);
295                    }});
296                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
297                        seq_macro::seq!(N_ITER in 0..$nr_div_n {{
298                            fma!($unroll * M_ITER, $unroll * N_ITER, M_ITER + $mr_div_n * N_ITER);
299                        }});
300                    }});
301
302                    packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
303                    packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
304                    next_lhs = next_lhs.wrapping_offset(lhs_cs);
305
306                    depth -= 1;
307                    if depth == 0 {
308                        break;
309                    }
310                }
311            }
312
313            let alpha_c = [alpha; $n];
314            let beta_c = [beta; $n];
315            ldy!(alpha_c.as_ptr(), 0);
316            ldy!(beta_c.as_ptr(), 1);
317
318            let stride = 64 / N;
319
320            for i in 0..N {
321                seq_macro::seq!(M_ITER in 0..$mr_div_n {{
322                    seq_macro::seq!(N_ITER in 0..$nr_div_n {{
323                        // extrx!((i * stride + M_ITER + $mr_div_n * N_ITER), M_ITER + $mr_div_n * N_ITER);
324                    }});
325                }});
326                seq_macro::seq!(M_ITER in 0..$mr_div_n {{
327                    seq_macro::seq!(N_ITER in 0..$nr_div_n {{
328                        extrx!((i * stride + M_ITER + $mr_div_n * N_ITER), M_ITER + $mr_div_n * N_ITER);
329                        mul_select!(i, M_ITER + $mr_div_n * N_ITER, 1, i * stride + M_ITER + $mr_div_n * N_ITER);
330                    }});
331                }});
332            }
333
334            let mut tmp_dst = [::core::mem::MaybeUninit::<T>::uninit(); { $mr_div_n * $n * $nr }];
335            let mut tmp_dst = tmp_dst.as_mut_ptr() as *mut T;
336            let mut tmp_dst_cs = $mr_div_n * $n;
337
338            if dst_rs == 1 && m == $mr_div_n * N && n == $nr {
339                tmp_dst = dst;
340                tmp_dst_cs = dst_cs;
341            } else {
342                for j in 0..n {
343                    for i in 0..m {
344                        *tmp_dst.offset(i as isize          + j as isize * tmp_dst_cs) =
345                        *dst    .offset(i as isize * dst_rs + j as isize * dst_cs);
346                    }
347                    for i in m..$mr_div_n * N {
348                        *tmp_dst.offset(i as isize          + j as isize * tmp_dst_cs) = ::core::mem::zeroed();
349                    }
350                }
351            }
352
353            if alpha_status == 0 {
354                for i in 0..N {
355                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
356                        seq_macro::seq!(N_ITER in 0..$nr_div_n {{
357                            stz!(
358                                tmp_dst.offset((i as isize + N_ITER * $n) * tmp_dst_cs + M_ITER * $n),
359                                (i * stride + M_ITER + $mr_div_n * N_ITER)
360                            );
361                        }});
362                    }});
363                }
364            } else {
365                for i in 0..N {
366                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
367                        seq_macro::seq!(N_ITER in 0..$nr_div_n {{
368                            ldx!(
369                                tmp_dst.offset((i as isize + N_ITER * $n) * tmp_dst_cs + M_ITER * $n),
370                                M_ITER + $mr_div_n * N_ITER
371                            );
372                            // ldx!(tmp_dst.offset(i as isize * tmp_dst_cs), M_ITER + $mr_div_n * N_ITER);
373                            fma_select!(i, M_ITER + $mr_div_n * N_ITER, 0, i * stride + M_ITER + $mr_div_n * N_ITER);
374                            stz!(
375                                tmp_dst.offset((i as isize + N_ITER * $n) * tmp_dst_cs + M_ITER * $n),
376                                i * stride + M_ITER + $mr_div_n * N_ITER
377                            );
378                        }});
379                    }});
380                }
381            }
382
383            if !(dst_rs == 1 && m == $mr_div_n * N && n == $nr) {
384                for j in 0..n {
385                    for i in 0..m {
386                        *dst    .offset(i as isize * dst_rs + j as isize * dst_cs) =
387                        *tmp_dst.offset(i as isize          + j as isize * tmp_dst_cs);
388                    }
389                }
390            }
391
392            amx_clr!();
393        }
394    };
395}
396
397#[macro_export]
398macro_rules! microkernel {
399    ($([$target: tt])?, $unroll: tt, $name: ident, $mr_div_n: tt, $nr: tt $(, $nr_div_n: tt, $n: tt)?) => {
400        $(#[target_feature(enable = $target)])?
401        // 0, 1, or 2 for generic alpha
402        pub unsafe fn $name(
403            m: usize,
404            n: usize,
405            k: usize,
406            dst: *mut T,
407            mut packed_lhs: *const T,
408            mut packed_rhs: *const T,
409            dst_cs: isize,
410            dst_rs: isize,
411            lhs_cs: isize,
412            rhs_rs: isize,
413            rhs_cs: isize,
414            alpha: T,
415            beta: T,
416            alpha_status: u8,
417            _conj_dst: bool,
418            _conj_lhs: bool,
419            _conj_rhs: bool,
420            mut next_lhs: *const T,
421        ) {
422            let mut accum_storage = [[splat(::core::mem::zeroed()); $mr_div_n]; $nr];
423            let accum = accum_storage.as_mut_ptr() as *mut Pack;
424
425            let mut lhs = [::core::mem::MaybeUninit::<Pack>::uninit(); $mr_div_n];
426            let mut rhs = ::core::mem::MaybeUninit::<Pack>::uninit();
427
428            #[derive(Copy, Clone)]
429            struct KernelIter {
430                packed_lhs: *const T,
431                packed_rhs: *const T,
432                next_lhs: *const T,
433                lhs_cs: isize,
434                rhs_rs: isize,
435                rhs_cs: isize,
436                accum: *mut Pack,
437                lhs: *mut Pack,
438                rhs: *mut Pack,
439            }
440
441            impl KernelIter {
442                #[inline(always)]
443                unsafe fn execute(self, iter: usize) {
444                    let packed_lhs = self.packed_lhs.wrapping_offset(iter as isize * self.lhs_cs);
445                    let packed_rhs = self.packed_rhs.wrapping_offset(iter as isize * self.rhs_rs);
446                    let next_lhs = self.next_lhs.wrapping_offset(iter as isize * self.lhs_cs);
447
448                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
449                        *self.lhs.add(M_ITER) = *(packed_lhs.add(M_ITER * N) as *const Pack);
450                    }});
451
452                    seq_macro::seq!(N_ITER in 0..$nr {{
453                        *self.rhs = splat(*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs));
454                        let accum = self.accum.add(N_ITER * $mr_div_n);
455                        seq_macro::seq!(M_ITER in 0..$mr_div_n {{
456                            let accum = &mut *accum.add(M_ITER);
457                            *accum = mul_add(
458                                *self.lhs.add(M_ITER),
459                                *self.rhs,
460                                *accum,
461                                );
462                        }});
463                    }});
464
465                    let _ = next_lhs;
466                }
467
468                $(
469                    #[inline(always)]
470                    unsafe fn execute_neon(self, iter: usize) {
471                        debug_assert_eq!(self.rhs_cs, 1);
472                        let packed_lhs = self.packed_lhs.wrapping_offset(iter as isize * self.lhs_cs);
473                        let packed_rhs = self.packed_rhs.wrapping_offset(iter as isize * self.rhs_rs);
474
475                        load::<$mr_div_n>(self.lhs, packed_lhs);
476
477                        seq_macro::seq!(N_ITER0 in 0..$nr_div_n {{
478                            *self.rhs = *(packed_rhs.wrapping_offset(N_ITER0 * $n) as *const Pack);
479
480                            seq_macro::seq!(N_ITER1 in 0..$n {{
481                                const N_ITER: usize = N_ITER0 * $n + N_ITER1;
482                                let accum = self.accum.add(N_ITER * $mr_div_n);
483                                seq_macro::seq!(M_ITER in 0..$mr_div_n {{
484                                    let accum = &mut *accum.add(M_ITER);
485                                    *accum = mul_add_lane::<N_ITER1>(
486                                        *self.lhs.add(M_ITER),
487                                        *self.rhs,
488                                        *accum,
489                                        );
490                                }});
491                            }});
492                        }});
493                    }
494                )?
495            }
496
497            let k_unroll = k / $unroll;
498            let k_leftover = k % $unroll;
499
500            let mut main_loop = {
501                #[inline(always)]
502                || {
503                    loop {
504                        $(
505                        let _ = $nr_div_n;
506                        if rhs_cs == 1 {
507                            let mut depth = k_unroll;
508                            if depth != 0 {
509                                loop {
510                                    let iter = KernelIter {
511                                        packed_lhs,
512                                        next_lhs,
513                                        packed_rhs,
514                                        lhs_cs,
515                                        rhs_rs,
516                                        rhs_cs,
517                                        accum,
518                                        lhs: lhs.as_mut_ptr() as _,
519                                        rhs: &mut rhs as *mut _ as _,
520                                    };
521
522                                    seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
523                                        iter.execute_neon(UNROLL_ITER);
524                                    }});
525
526                                    packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
527                                    packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
528                                    next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
529
530                                    depth -= 1;
531                                    if depth == 0 {
532                                        break;
533                                    }
534                                }
535                            }
536                            depth = k_leftover;
537                            if depth != 0 {
538                                loop {
539                                    KernelIter {
540                                        packed_lhs,
541                                        next_lhs,
542                                        packed_rhs,
543                                        lhs_cs,
544                                        rhs_rs,
545                                        rhs_cs,
546                                        accum,
547                                        lhs: lhs.as_mut_ptr() as _,
548                                        rhs: &mut rhs as *mut _ as _,
549                                    }
550                                    .execute_neon(0);
551
552                                    packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
553                                    packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
554                                    next_lhs = next_lhs.wrapping_offset(lhs_cs);
555
556                                    depth -= 1;
557                                    if depth == 0 {
558                                        break;
559                                    }
560                                }
561                            }
562                            break;
563                        }
564                        )?
565
566                        let mut depth = k_unroll;
567                        if depth != 0 {
568                            loop {
569                                let iter = KernelIter {
570                                    packed_lhs,
571                                    next_lhs,
572                                    packed_rhs,
573                                    lhs_cs,
574                                    rhs_rs,
575                                    rhs_cs,
576                                    accum,
577                                    lhs: lhs.as_mut_ptr() as _,
578                                    rhs: &mut rhs as *mut _ as _,
579                                };
580
581                                seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
582                                    iter.execute(UNROLL_ITER);
583                                }});
584
585                                packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
586                                packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
587                                next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
588
589                                depth -= 1;
590                                if depth == 0 {
591                                    break;
592                                }
593                            }
594                        }
595                        depth = k_leftover;
596                        if depth != 0 {
597                            loop {
598                                KernelIter {
599                                    packed_lhs,
600                                    next_lhs,
601                                    packed_rhs,
602                                    lhs_cs,
603                                    rhs_rs,
604                                    rhs_cs,
605                                    accum,
606                                    lhs: lhs.as_mut_ptr() as _,
607                                    rhs: &mut rhs as *mut _ as _,
608                                }
609                                .execute(0);
610
611                                packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
612                                packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
613                                next_lhs = next_lhs.wrapping_offset(lhs_cs);
614
615                                depth -= 1;
616                                if depth == 0 {
617                                    break;
618                                }
619                            }
620                        }
621                        break;
622                    }
623                }
624            };
625
626            if rhs_rs == 1 {
627                main_loop();
628            } else {
629                main_loop();
630            }
631
632            if m == $mr_div_n * N && n == $nr && dst_rs == 1  {
633                let alpha = splat(alpha);
634                let beta = splat(beta);
635                if alpha_status == 2 {
636                    seq_macro::seq!(N_ITER in 0..$nr {{
637                        seq_macro::seq!(M_ITER in 0..$mr_div_n {{
638                            let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
639                            dst.write_unaligned(add(
640                                    mul(alpha, *dst),
641                                    mul(beta, *accum.offset(M_ITER + $mr_div_n * N_ITER)),
642                                    ));
643                        }});
644                    }});
645                } else if alpha_status == 1 {
646                    seq_macro::seq!(N_ITER in 0..$nr {{
647                        seq_macro::seq!(M_ITER in 0..$mr_div_n {{
648                            let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
649                            dst.write_unaligned(mul_add(
650                                    beta,
651                                    *accum.offset(M_ITER + $mr_div_n * N_ITER),
652                                    *dst,
653                                    ));
654                        }});
655                    }});
656                } else {
657                    seq_macro::seq!(N_ITER in 0..$nr {{
658                        seq_macro::seq!(M_ITER in 0..$mr_div_n {{
659                            let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
660                            dst.write_unaligned(mul(beta, *accum.offset(M_ITER + $mr_div_n * N_ITER)));
661                        }});
662                    }});
663                }
664            } else {
665                let src = accum_storage; // write to stack
666                let src = src.as_ptr() as *const T;
667
668                if alpha_status == 2 {
669                    for j in 0..n {
670                        let dst_j = dst.offset(dst_cs * j as isize);
671                        let src_j = src.add(j * $mr_div_n * N);
672
673                        for i in 0..m {
674                            let dst_ij = dst_j.offset(dst_rs * i as isize);
675                            let src_ij = src_j.add(i);
676
677                            *dst_ij = scalar_add(scalar_mul(alpha, *dst_ij), scalar_mul(beta, *src_ij));
678                        }
679                    }
680                } else if alpha_status == 1 {
681                    for j in 0..n {
682                        let dst_j = dst.offset(dst_cs * j as isize);
683                        let src_j = src.add(j * $mr_div_n * N);
684
685                        for i in 0..m {
686                            let dst_ij = dst_j.offset(dst_rs * i as isize);
687                            let src_ij = src_j.add(i);
688
689                            *dst_ij = scalar_mul_add(beta, *src_ij, *dst_ij);
690                        }
691                    }
692                } else {
693                    for j in 0..n {
694                        let dst_j = dst.offset(dst_cs * j as isize);
695                        let src_j = src.add(j * $mr_div_n * N);
696
697                        for i in 0..m {
698                            let dst_ij = dst_j.offset(dst_rs * i as isize);
699                            let src_ij = src_j.add(i);
700
701                            *dst_ij = scalar_mul(beta, *src_ij);
702                        }
703                    }
704                }
705            }
706
707        }
708    };
709}
710
711#[macro_export]
712macro_rules! microkernel_cplx_2step {
713    ($([$target: tt])?, $unroll: tt, $name: ident, $mr_div_n: tt, $nr: tt) => {
714        $(#[target_feature(enable = $target)])?
715        // 0, 1, or 2 for generic alpha
716        pub unsafe fn $name(
717            m: usize,
718            n: usize,
719            k: usize,
720            dst: *mut num_complex::Complex<T>,
721            mut packed_lhs: *const num_complex::Complex<T>,
722            mut packed_rhs: *const num_complex::Complex<T>,
723            dst_cs: isize,
724            dst_rs: isize,
725            lhs_cs: isize,
726            rhs_rs: isize,
727            rhs_cs: isize,
728            alpha: num_complex::Complex<T>,
729            beta: num_complex::Complex<T>,
730            alpha_status: u8,
731            conj_dst: bool,
732            conj_lhs: bool,
733            conj_rhs: bool,
734            mut next_lhs: *const num_complex::Complex<T>,
735        ) {
736            let mut accum_storage = [[splat(0.0); $mr_div_n]; $nr];
737            let accum = accum_storage.as_mut_ptr() as *mut Pack;
738
739            let (neg_conj_rhs, conj_all, neg_all) = match (conj_lhs, conj_rhs) {
740                (true, true) => (true, false, true),
741                (false, true) => (false, true, false),
742                (true, false) => (false, false, false),
743                (false, false) => (true, true, true),
744            };
745
746            let mut lhs_re_im = [::core::mem::MaybeUninit::<Pack>::uninit(); $mr_div_n];
747            let mut lhs_im_re = [::core::mem::MaybeUninit::<Pack>::uninit(); $mr_div_n];
748            let mut rhs_re = ::core::mem::MaybeUninit::<Pack>::uninit();
749            let mut rhs_im = ::core::mem::MaybeUninit::<Pack>::uninit();
750
751            #[derive(Copy, Clone)]
752            struct KernelIter {
753                packed_lhs: *const num_complex::Complex<T>,
754                next_lhs: *const num_complex::Complex<T>,
755                packed_rhs: *const num_complex::Complex<T>,
756                lhs_cs: isize,
757                rhs_rs: isize,
758                rhs_cs: isize,
759                accum: *mut Pack,
760                lhs_re_im: *mut Pack,
761                lhs_im_re: *mut Pack,
762                rhs_re: *mut Pack,
763                rhs_im: *mut Pack,
764            }
765
766            impl KernelIter {
767                #[inline(always)]
768                unsafe fn execute(self, iter: usize, neg_conj_rhs: bool) {
769                    let packed_lhs = self.packed_lhs.wrapping_offset(iter as isize * self.lhs_cs);
770                    let packed_rhs = self.packed_rhs.wrapping_offset(iter as isize * self.rhs_rs);
771                    let next_lhs = self.next_lhs.wrapping_offset(iter as isize * self.lhs_cs);
772
773                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
774                        let tmp = *(packed_lhs.add(M_ITER * CPLX_N) as *const Pack);
775                        *self.lhs_re_im.add(M_ITER) = tmp;
776                        *self.lhs_im_re.add(M_ITER) = swap_re_im(tmp);
777                    }});
778
779                    seq_macro::seq!(N_ITER in 0..$nr {{
780                        *self.rhs_re = splat((*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs)).re);
781
782                        let accum = self.accum.add(N_ITER * $mr_div_n);
783                        seq_macro::seq!(M_ITER in 0..$mr_div_n {{
784                            let accum = &mut *accum.add(M_ITER);
785                            *accum = mul_add_cplx_step0(
786                                *self.lhs_re_im.add(M_ITER),
787                                *self.rhs_re,
788                                *accum,
789                                neg_conj_rhs,
790                                );
791                        }});
792                    }});
793
794                    seq_macro::seq!(N_ITER in 0..$nr {{
795                        *self.rhs_im = splat((*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs)).im);
796
797                        let accum = self.accum.add(N_ITER * $mr_div_n);
798                        seq_macro::seq!(M_ITER in 0..$mr_div_n {{
799                            let accum = &mut *accum.add(M_ITER);
800                            *accum = mul_add_cplx_step1(
801                                *self.lhs_im_re.add(M_ITER),
802                                *self.rhs_im,
803                                *accum,
804                                neg_conj_rhs,
805                                );
806                        }});
807                    }});
808
809                    let _ = next_lhs;
810                }
811            }
812
813            let k_unroll = k / $unroll;
814            let k_leftover = k % $unroll;
815
816            let mut main_loop = {
817                #[inline(always)]
818                || {
819                    loop {
820                        if neg_conj_rhs {
821                            let mut depth = k_unroll;
822                            if depth != 0 {
823                                loop {
824                                    let iter = KernelIter {
825                                        packed_lhs,
826                                        next_lhs,
827                                        packed_rhs,
828                                        lhs_cs,
829                                        rhs_rs,
830                                        rhs_cs,
831                                        accum,
832                                        lhs_re_im: lhs_re_im.as_mut_ptr() as _,
833                                        lhs_im_re: lhs_im_re.as_mut_ptr() as _,
834                                        rhs_re: &mut rhs_re as *mut _ as _,
835                                        rhs_im: &mut rhs_im as *mut _ as _,
836                                    };
837
838                                    seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
839                                        iter.execute(UNROLL_ITER, true);
840                                    }});
841
842                                    packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
843                                    packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
844                                    next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
845
846                                    depth -= 1;
847                                    if depth == 0 {
848                                        break;
849                                    }
850                                }
851                            }
852                            depth = k_leftover;
853                            if depth != 0 {
854                                loop {
855                                    KernelIter {
856                                        packed_lhs,
857                                        next_lhs,
858                                        packed_rhs,
859                                        lhs_cs,
860                                        rhs_rs,
861                                        rhs_cs,
862                                        accum,
863                                        lhs_re_im: lhs_re_im.as_mut_ptr() as _,
864                                        lhs_im_re: lhs_im_re.as_mut_ptr() as _,
865                                        rhs_re: &mut rhs_re as *mut _ as _,
866                                        rhs_im: &mut rhs_im as *mut _ as _,
867                                    }
868                                    .execute(0, true);
869
870                                    packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
871                                    packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
872                                    next_lhs = next_lhs.wrapping_offset(lhs_cs);
873
874                                    depth -= 1;
875                                    if depth == 0 {
876                                        break;
877                                    }
878                                }
879                            }
880                            break;
881                        } else {
882                            let mut depth = k_unroll;
883                            if depth != 0 {
884                                loop {
885                                    let iter = KernelIter {
886                                        next_lhs,
887                                        packed_lhs,
888                                        packed_rhs,
889                                        lhs_cs,
890                                        rhs_rs,
891                                        rhs_cs,
892                                        accum,
893                                        lhs_re_im: lhs_re_im.as_mut_ptr() as _,
894                                        lhs_im_re: lhs_im_re.as_mut_ptr() as _,
895                                        rhs_re: &mut rhs_re as *mut _ as _,
896                                        rhs_im: &mut rhs_im as *mut _ as _,
897                                    };
898
899                                    seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
900                                        iter.execute(UNROLL_ITER, false);
901                                    }});
902
903                                    packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
904                                    packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
905                                    next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
906
907                                    depth -= 1;
908                                    if depth == 0 {
909                                        break;
910                                    }
911                                }
912                            }
913                            depth = k_leftover;
914                            if depth != 0 {
915                                loop {
916                                    KernelIter {
917                                        next_lhs,
918                                        packed_lhs,
919                                        packed_rhs,
920                                        lhs_cs,
921                                        rhs_rs,
922                                        rhs_cs,
923                                        accum,
924                                        lhs_re_im: lhs_re_im.as_mut_ptr() as _,
925                                        lhs_im_re: lhs_im_re.as_mut_ptr() as _,
926                                        rhs_re: &mut rhs_re as *mut _ as _,
927                                        rhs_im: &mut rhs_im as *mut _ as _,
928                                    }
929                                    .execute(0, false);
930
931                                    packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
932                                    packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
933                                    next_lhs = next_lhs.wrapping_offset(lhs_cs);
934
935                                    depth -= 1;
936                                    if depth == 0 {
937                                        break;
938                                    }
939                                }
940                            }
941                            break;
942                        }
943                    }
944                }
945            };
946
947            if rhs_rs == 1 {
948                main_loop();
949            } else {
950                main_loop();
951            }
952
953            if conj_all && neg_all {
954                seq_macro::seq!(N_ITER in 0..$nr {{
955                    let accum = accum.add(N_ITER * $mr_div_n);
956                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
957                        let accum = &mut *accum.add(M_ITER);
958                        *accum = neg_conj(*accum);
959                    }});
960                }});
961            } else if !conj_all && neg_all {
962                seq_macro::seq!(N_ITER in 0..$nr {{
963                    let accum = accum.add(N_ITER * $mr_div_n);
964                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
965                        let accum = &mut *accum.add(M_ITER);
966                        *accum = neg(*accum);
967                    }});
968                }});
969            } else if conj_all && !neg_all {
970                seq_macro::seq!(N_ITER in 0..$nr {{
971                    let accum = accum.add(N_ITER * $mr_div_n);
972                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
973                        let accum = &mut *accum.add(M_ITER);
974                        *accum = conj(*accum);
975                    }});
976                }});
977            }
978
979            if m == $mr_div_n * CPLX_N && n == $nr && dst_rs == 1 {
980                let alpha_re = splat(alpha.re);
981                let alpha_im = splat(alpha.im);
982                let beta_re = splat(beta.re);
983                let beta_im = splat(beta.im);
984
985                if conj_dst {
986                    if alpha_status == 2 {
987                        seq_macro::seq!(N_ITER in 0..$nr {{
988                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
989                                let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
990                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
991                                *dst = add(
992                                    mul_cplx(conj(*dst), swap_re_im(conj(*dst)), alpha_re, alpha_im),
993                                    mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
994                                    );
995                            }});
996                        }});
997                    } else if alpha_status == 1 {
998                        seq_macro::seq!(N_ITER in 0..$nr {{
999                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1000                                let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1001                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1002                                *dst = add(
1003                                    conj(*dst),
1004                                    mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1005                                    );
1006                            }});
1007                        }});
1008                    } else {
1009                        seq_macro::seq!(N_ITER in 0..$nr {{
1010                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1011                                let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1012                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1013                                *dst = mul_cplx(accum, swap_re_im(accum), beta_re, beta_im);
1014                            }});
1015                        }});
1016                    }
1017                } else {
1018                    if alpha_status == 2 {
1019                        seq_macro::seq!(N_ITER in 0..$nr {{
1020                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1021                                let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1022                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1023                                *dst = add(
1024                                    mul_cplx(*dst, swap_re_im(*dst), alpha_re, alpha_im),
1025                                    mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1026                                );
1027                            }});
1028                        }});
1029                    } else if alpha_status == 1 {
1030                        seq_macro::seq!(N_ITER in 0..$nr {{
1031                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1032                                let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1033                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1034                                *dst = add(
1035                                    *dst,
1036                                    mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1037                                );
1038                            }});
1039                        }});
1040                    } else {
1041                        seq_macro::seq!(N_ITER in 0..$nr {{
1042                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1043                                let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1044                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1045                                *dst = mul_cplx(accum, swap_re_im(accum), beta_re, beta_im);
1046                            }});
1047                        }});
1048                    }
1049                }
1050            } else {
1051                let src = accum_storage; // write to stack
1052                let src = src.as_ptr() as *const num_complex::Complex<T>;
1053
1054                if conj_dst {
1055                    if alpha_status == 2 {
1056                        for j in 0..n {
1057                            let dst_j = dst.offset(dst_cs * j as isize);
1058                            let src_j = src.add(j * $mr_div_n * CPLX_N);
1059
1060                            for i in 0..m {
1061                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1062                                let src_ij = src_j.add(i);
1063
1064                                *dst_ij = alpha * (*dst_ij).conj() + beta * *src_ij;
1065                            }
1066                        }
1067                    } else if alpha_status == 1 {
1068                        for j in 0..n {
1069                            let dst_j = dst.offset(dst_cs * j as isize);
1070                            let src_j = src.add(j * $mr_div_n * CPLX_N);
1071
1072                            for i in 0..m {
1073                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1074                                let src_ij = src_j.add(i);
1075
1076                                *dst_ij = (*dst_ij).conj() + beta * *src_ij;
1077                            }
1078                        }
1079                    } else {
1080                        for j in 0..n {
1081                            let dst_j = dst.offset(dst_cs * j as isize);
1082                            let src_j = src.add(j * $mr_div_n * CPLX_N);
1083
1084                            for i in 0..m {
1085                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1086                                let src_ij = src_j.add(i);
1087
1088                                *dst_ij = beta * *src_ij;
1089                            }
1090                        }
1091                    }
1092                } else {
1093                    if alpha_status == 2 {
1094                        for j in 0..n {
1095                            let dst_j = dst.offset(dst_cs * j as isize);
1096                            let src_j = src.add(j * $mr_div_n * CPLX_N);
1097
1098                            for i in 0..m {
1099                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1100                                let src_ij = src_j.add(i);
1101
1102                                *dst_ij = alpha * *dst_ij + beta * *src_ij;
1103                            }
1104                        }
1105                    } else if alpha_status == 1 {
1106                        for j in 0..n {
1107                            let dst_j = dst.offset(dst_cs * j as isize);
1108                            let src_j = src.add(j * $mr_div_n * CPLX_N);
1109
1110                            for i in 0..m {
1111                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1112                                let src_ij = src_j.add(i);
1113
1114                                *dst_ij = *dst_ij + beta * *src_ij;
1115                            }
1116                        }
1117                    } else {
1118                        for j in 0..n {
1119                            let dst_j = dst.offset(dst_cs * j as isize);
1120                            let src_j = src.add(j * $mr_div_n * CPLX_N);
1121
1122                            for i in 0..m {
1123                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1124                                let src_ij = src_j.add(i);
1125
1126                                *dst_ij = beta * *src_ij;
1127                            }
1128                        }
1129                    }
1130                }
1131            }
1132        }
1133    };
1134}
1135
1136#[macro_export]
1137macro_rules! microkernel_cplx {
1138    ($([$target: tt])?, $unroll: tt, $name: ident, $mr_div_n: tt, $nr: tt) => {
1139        $(#[target_feature(enable = $target)])?
1140        // 0, 1, or 2 for generic alpha
1141        pub unsafe fn $name(
1142            m: usize,
1143            n: usize,
1144            k: usize,
1145            dst: *mut num_complex::Complex<T>,
1146            mut packed_lhs: *const num_complex::Complex<T>,
1147            mut packed_rhs: *const num_complex::Complex<T>,
1148            dst_cs: isize,
1149            dst_rs: isize,
1150            lhs_cs: isize,
1151            rhs_rs: isize,
1152            rhs_cs: isize,
1153            alpha: num_complex::Complex<T>,
1154            beta: num_complex::Complex<T>,
1155            alpha_status: u8,
1156            conj_dst: bool,
1157            conj_lhs: bool,
1158            conj_rhs: bool,
1159            mut next_lhs: *const num_complex::Complex<T>,
1160        ) {
1161            let mut accum_storage = [[splat(0.0); $mr_div_n]; $nr];
1162            let accum = accum_storage.as_mut_ptr() as *mut Pack;
1163
1164            let conj_both_lhs_rhs = conj_lhs;
1165            let conj_rhs = conj_lhs != conj_rhs;
1166
1167            let mut lhs_re_im = [::core::mem::MaybeUninit::<Pack>::uninit(); $mr_div_n];
1168            let mut lhs_im_re = [::core::mem::MaybeUninit::<Pack>::uninit(); $mr_div_n];
1169            let mut rhs_re = ::core::mem::MaybeUninit::<Pack>::uninit();
1170            let mut rhs_im = ::core::mem::MaybeUninit::<Pack>::uninit();
1171
1172            #[derive(Copy, Clone)]
1173            struct KernelIter {
1174                packed_lhs: *const num_complex::Complex<T>,
1175                next_lhs: *const num_complex::Complex<T>,
1176                packed_rhs: *const num_complex::Complex<T>,
1177                lhs_cs: isize,
1178                rhs_rs: isize,
1179                rhs_cs: isize,
1180                accum: *mut Pack,
1181                lhs_re_im: *mut Pack,
1182                lhs_im_re: *mut Pack,
1183                rhs_re: *mut Pack,
1184                rhs_im: *mut Pack,
1185            }
1186
1187            impl KernelIter {
1188                #[inline(always)]
1189                unsafe fn execute(self, iter: usize, conj_rhs: bool) {
1190                    let packed_lhs = self.packed_lhs.wrapping_offset(iter as isize * self.lhs_cs);
1191                    let packed_rhs = self.packed_rhs.wrapping_offset(iter as isize * self.rhs_rs);
1192                    let next_lhs = self.next_lhs.wrapping_offset(iter as isize * self.lhs_cs);
1193
1194                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1195                        let tmp = *(packed_lhs.add(M_ITER * CPLX_N) as *const Pack);
1196                        *self.lhs_re_im.add(M_ITER) = tmp;
1197                        *self.lhs_im_re.add(M_ITER) = swap_re_im(tmp);
1198                    }});
1199
1200                    seq_macro::seq!(N_ITER in 0..$nr {{
1201                        *self.rhs_re = splat((*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs)).re);
1202                        *self.rhs_im = splat((*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs)).im);
1203
1204                        let accum = self.accum.add(N_ITER * $mr_div_n);
1205                        seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1206                            let accum = &mut *accum.add(M_ITER);
1207                            *accum = mul_add_cplx(
1208                                *self.lhs_re_im.add(M_ITER),
1209                                *self.lhs_im_re.add(M_ITER),
1210                                *self.rhs_re,
1211                                *self.rhs_im,
1212                                *accum,
1213                                conj_rhs,
1214                                );
1215                        }});
1216                    }});
1217
1218                    let _ = next_lhs;
1219                }
1220            }
1221
1222            let k_unroll = k / $unroll;
1223            let k_leftover = k % $unroll;
1224
1225            loop {
1226                if conj_rhs {
1227                    let mut depth = k_unroll;
1228                    if depth != 0 {
1229                        loop {
1230                            let iter = KernelIter {
1231                                packed_lhs,
1232                                next_lhs,
1233                                packed_rhs,
1234                                lhs_cs,
1235                                rhs_rs,
1236                                rhs_cs,
1237                                accum,
1238                                lhs_re_im: lhs_re_im.as_mut_ptr() as _,
1239                                lhs_im_re: lhs_im_re.as_mut_ptr() as _,
1240                                rhs_re: &mut rhs_re as *mut _ as _,
1241                                rhs_im: &mut rhs_im as *mut _ as _,
1242                            };
1243
1244                            seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
1245                                iter.execute(UNROLL_ITER, true);
1246                            }});
1247
1248                            packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
1249                            packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
1250                            next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
1251
1252                            depth -= 1;
1253                            if depth == 0 {
1254                                break;
1255                            }
1256                        }
1257                    }
1258                    depth = k_leftover;
1259                    if depth != 0 {
1260                        loop {
1261                            KernelIter {
1262                                packed_lhs,
1263                                next_lhs,
1264                                packed_rhs,
1265                                lhs_cs,
1266                                rhs_rs,
1267                                rhs_cs,
1268                                accum,
1269                                lhs_re_im: lhs_re_im.as_mut_ptr() as _,
1270                                lhs_im_re: lhs_im_re.as_mut_ptr() as _,
1271                                rhs_re: &mut rhs_re as *mut _ as _,
1272                                rhs_im: &mut rhs_im as *mut _ as _,
1273                            }
1274                            .execute(0, true);
1275
1276                            packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
1277                            packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
1278                            next_lhs = next_lhs.wrapping_offset(lhs_cs);
1279
1280                            depth -= 1;
1281                            if depth == 0 {
1282                                break;
1283                            }
1284                        }
1285                    }
1286                    break;
1287                } else {
1288                    let mut depth = k_unroll;
1289                    if depth != 0 {
1290                        loop {
1291                            let iter = KernelIter {
1292                                next_lhs,
1293                                packed_lhs,
1294                                packed_rhs,
1295                                lhs_cs,
1296                                rhs_rs,
1297                                rhs_cs,
1298                                accum,
1299                                lhs_re_im: lhs_re_im.as_mut_ptr() as _,
1300                                lhs_im_re: lhs_im_re.as_mut_ptr() as _,
1301                                rhs_re: &mut rhs_re as *mut _ as _,
1302                                rhs_im: &mut rhs_im as *mut _ as _,
1303                            };
1304
1305                            seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
1306                                iter.execute(UNROLL_ITER, false);
1307                            }});
1308
1309                            packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
1310                            packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
1311                            next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
1312
1313                            depth -= 1;
1314                            if depth == 0 {
1315                                break;
1316                            }
1317                        }
1318                    }
1319                    depth = k_leftover;
1320                    if depth != 0 {
1321                        loop {
1322                            KernelIter {
1323                                next_lhs,
1324                                packed_lhs,
1325                                packed_rhs,
1326                                lhs_cs,
1327                                rhs_rs,
1328                                rhs_cs,
1329                                accum,
1330                                lhs_re_im: lhs_re_im.as_mut_ptr() as _,
1331                                lhs_im_re: lhs_im_re.as_mut_ptr() as _,
1332                                rhs_re: &mut rhs_re as *mut _ as _,
1333                                rhs_im: &mut rhs_im as *mut _ as _,
1334                            }
1335                            .execute(0, false);
1336
1337                            packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
1338                            packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
1339                            next_lhs = next_lhs.wrapping_offset(lhs_cs);
1340
1341                            depth -= 1;
1342                            if depth == 0 {
1343                                break;
1344                            }
1345                        }
1346                    }
1347                    break;
1348                }
1349            }
1350
1351            if conj_both_lhs_rhs {
1352                seq_macro::seq!(N_ITER in 0..$nr {{
1353                    let accum = accum.add(N_ITER * $mr_div_n);
1354                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1355                        let accum = &mut *accum.add(M_ITER);
1356                        *accum = conj(*accum);
1357                    }});
1358                }});
1359            }
1360
1361            if m == $mr_div_n * CPLX_N && n == $nr && dst_rs == 1 {
1362                let alpha_re = splat(alpha.re);
1363                let alpha_im = splat(alpha.im);
1364                let beta_re = splat(beta.re);
1365                let beta_im = splat(beta.im);
1366
1367                if conj_dst {
1368                    if alpha_status == 2 {
1369                        seq_macro::seq!(N_ITER in 0..$nr {{
1370                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1371                                let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1372                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1373                                *dst = add(
1374                                    mul_cplx(conj(*dst), swap_re_im(conj(*dst)), alpha_re, alpha_im),
1375                                    mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1376                                    );
1377                            }});
1378                        }});
1379                    } else if alpha_status == 1 {
1380                        seq_macro::seq!(N_ITER in 0..$nr {{
1381                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1382                                let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1383                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1384                                *dst = add(
1385                                    conj(*dst),
1386                                    mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1387                                    );
1388                            }});
1389                        }});
1390                    } else {
1391                        seq_macro::seq!(N_ITER in 0..$nr {{
1392                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1393                                let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1394                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1395                                *dst = mul_cplx(accum, swap_re_im(accum), beta_re, beta_im);
1396                            }});
1397                        }});
1398                    }
1399                } else {
1400                    if alpha_status == 2 {
1401                        seq_macro::seq!(N_ITER in 0..$nr {{
1402                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1403                                let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1404                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1405                                *dst = add(
1406                                    mul_cplx(*dst, swap_re_im(*dst), alpha_re, alpha_im),
1407                                    mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1408                                );
1409                            }});
1410                        }});
1411                    } else if alpha_status == 1 {
1412                        seq_macro::seq!(N_ITER in 0..$nr {{
1413                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1414                                let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1415                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1416                                *dst = add(
1417                                    *dst,
1418                                    mul_cplx(accum, swap_re_im(accum), beta_re, beta_im),
1419                                );
1420                            }});
1421                        }});
1422                    } else {
1423                        seq_macro::seq!(N_ITER in 0..$nr {{
1424                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1425                                let dst = dst.offset(M_ITER * CPLX_N as isize + N_ITER * dst_cs) as *mut Pack;
1426                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1427                                *dst = mul_cplx(accum, swap_re_im(accum), beta_re, beta_im);
1428                            }});
1429                        }});
1430                    }
1431                }
1432            } else {
1433                let src = accum_storage; // write to stack
1434                let src = src.as_ptr() as *const num_complex::Complex<T>;
1435
1436                if conj_dst {
1437                    if alpha_status == 2 {
1438                        for j in 0..n {
1439                            let dst_j = dst.offset(dst_cs * j as isize);
1440                            let src_j = src.add(j * $mr_div_n * CPLX_N);
1441
1442                            for i in 0..m {
1443                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1444                                let src_ij = src_j.add(i);
1445
1446                                *dst_ij = alpha * (*dst_ij).conj() + beta * *src_ij;
1447                            }
1448                        }
1449                    } else if alpha_status == 1 {
1450                        for j in 0..n {
1451                            let dst_j = dst.offset(dst_cs * j as isize);
1452                            let src_j = src.add(j * $mr_div_n * CPLX_N);
1453
1454                            for i in 0..m {
1455                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1456                                let src_ij = src_j.add(i);
1457
1458                                *dst_ij = (*dst_ij).conj() + beta * *src_ij;
1459                            }
1460                        }
1461                    } else {
1462                        for j in 0..n {
1463                            let dst_j = dst.offset(dst_cs * j as isize);
1464                            let src_j = src.add(j * $mr_div_n * CPLX_N);
1465
1466                            for i in 0..m {
1467                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1468                                let src_ij = src_j.add(i);
1469
1470                                *dst_ij = beta * *src_ij;
1471                            }
1472                        }
1473                    }
1474                } else {
1475                    if alpha_status == 2 {
1476                        for j in 0..n {
1477                            let dst_j = dst.offset(dst_cs * j as isize);
1478                            let src_j = src.add(j * $mr_div_n * CPLX_N);
1479
1480                            for i in 0..m {
1481                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1482                                let src_ij = src_j.add(i);
1483
1484                                *dst_ij = alpha * *dst_ij + beta * *src_ij;
1485                            }
1486                        }
1487                    } else if alpha_status == 1 {
1488                        for j in 0..n {
1489                            let dst_j = dst.offset(dst_cs * j as isize);
1490                            let src_j = src.add(j * $mr_div_n * CPLX_N);
1491
1492                            for i in 0..m {
1493                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1494                                let src_ij = src_j.add(i);
1495
1496                                *dst_ij = *dst_ij + beta * *src_ij;
1497                            }
1498                        }
1499                    } else {
1500                        for j in 0..n {
1501                            let dst_j = dst.offset(dst_cs * j as isize);
1502                            let src_j = src.add(j * $mr_div_n * CPLX_N);
1503
1504                            for i in 0..m {
1505                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1506                                let src_ij = src_j.add(i);
1507
1508                                *dst_ij = beta * *src_ij;
1509                            }
1510                        }
1511                    }
1512                }
1513            }
1514        }
1515    };
1516}
1517
1518#[macro_export]
1519macro_rules! microkernel_cplx_packed {
1520    ($([$target: tt])?, $unroll: tt, $name: ident, $mr_div_n: tt, $nr: tt) => {
1521        $(#[target_feature(enable = $target)])?
1522        // 0, 1, or 2 for generic alpha
1523        pub unsafe fn $name(
1524            m: usize,
1525            n: usize,
1526            k: usize,
1527            dst: *mut T,
1528            mut packed_lhs: *const T,
1529            mut packed_rhs: *const T,
1530            dst_cs: isize,
1531            dst_rs: isize,
1532            lhs_cs: isize,
1533            rhs_rs: isize,
1534            rhs_cs: isize,
1535            alpha: T,
1536            beta: T,
1537            alpha_status: u8,
1538            conj_dst: bool,
1539            conj_lhs: bool,
1540            conj_rhs: bool,
1541            mut next_lhs: *const T,
1542        ) {
1543            let mut accum_storage = [[core::mem::zeroed::<Pack>(); $mr_div_n]; $nr];
1544            let accum = accum_storage.as_mut_ptr() as *mut Pack;
1545
1546            let conj_both_lhs_rhs = conj_lhs;
1547            let conj_rhs = conj_lhs != conj_rhs;
1548
1549            let mut lhs = [::core::mem::MaybeUninit::<Pack>::uninit(); $mr_div_n];
1550            let mut rhs = ::core::mem::MaybeUninit::<Pack>::uninit();
1551
1552            #[derive(Copy, Clone)]
1553            struct KernelIter {
1554                packed_lhs: *const T,
1555                next_lhs: *const T,
1556                packed_rhs: *const T,
1557                lhs_cs: isize,
1558                rhs_rs: isize,
1559                rhs_cs: isize,
1560                accum: *mut Pack,
1561                lhs: *mut Pack,
1562                rhs: *mut Pack,
1563            }
1564
1565            impl KernelIter {
1566                #[inline(always)]
1567                unsafe fn execute(self, iter: usize, conj_rhs: bool) {
1568                    let packed_lhs = self.packed_lhs.wrapping_offset(iter as isize * self.lhs_cs);
1569                    let packed_rhs = self.packed_rhs.wrapping_offset(iter as isize * self.rhs_rs);
1570                    let next_lhs = self.next_lhs.wrapping_offset(iter as isize * self.lhs_cs);
1571
1572                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1573                        *self.lhs.add(M_ITER) = *(packed_lhs.add(M_ITER * N) as *const Pack);
1574                    }});
1575
1576                    seq_macro::seq!(N_ITER in 0..$nr {{
1577                        *self.rhs = splat(*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs));
1578
1579                        let accum = self.accum.add(N_ITER * $mr_div_n);
1580                        seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1581                            let accum = &mut *accum.add(M_ITER);
1582                            *accum = mul_add_cplx(
1583                                *self.lhs.add(M_ITER),
1584                                *self.rhs,
1585                                *accum,
1586                                conj_rhs,
1587                            );
1588                        }});
1589                    }});
1590
1591                    let _ = next_lhs;
1592                }
1593            }
1594
1595            let k_unroll = k / $unroll;
1596            let k_leftover = k % $unroll;
1597
1598            loop {
1599                if conj_rhs {
1600                    let mut depth = k_unroll;
1601                    if depth != 0 {
1602                        loop {
1603                            let iter = KernelIter {
1604                                packed_lhs,
1605                                next_lhs,
1606                                packed_rhs,
1607                                lhs_cs,
1608                                rhs_rs,
1609                                rhs_cs,
1610                                accum,
1611                                lhs: lhs.as_mut_ptr() as _,
1612                                rhs: &mut rhs as *mut _ as _,
1613                            };
1614
1615                            seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
1616                                iter.execute(UNROLL_ITER, true);
1617                            }});
1618
1619                            packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
1620                            packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
1621                            next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
1622
1623                            depth -= 1;
1624                            if depth == 0 {
1625                                break;
1626                            }
1627                        }
1628                    }
1629                    depth = k_leftover;
1630                    if depth != 0 {
1631                        loop {
1632                            KernelIter {
1633                                packed_lhs,
1634                                next_lhs,
1635                                packed_rhs,
1636                                lhs_cs,
1637                                rhs_rs,
1638                                rhs_cs,
1639                                accum,
1640                                lhs: lhs.as_mut_ptr() as _,
1641                                rhs: &mut rhs as *mut _ as _,
1642                            }
1643                            .execute(0, true);
1644
1645                            packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
1646                            packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
1647                            next_lhs = next_lhs.wrapping_offset(lhs_cs);
1648
1649                            depth -= 1;
1650                            if depth == 0 {
1651                                break;
1652                            }
1653                        }
1654                    }
1655                    break;
1656                } else {
1657                    let mut depth = k_unroll;
1658                    if depth != 0 {
1659                        loop {
1660                            let iter = KernelIter {
1661                                next_lhs,
1662                                packed_lhs,
1663                                packed_rhs,
1664                                lhs_cs,
1665                                rhs_rs,
1666                                rhs_cs,
1667                                accum,
1668                                lhs: lhs.as_mut_ptr() as _,
1669                                rhs: &mut rhs as *mut _ as _,
1670                            };
1671
1672                            seq_macro::seq!(UNROLL_ITER in 0..$unroll {{
1673                                iter.execute(UNROLL_ITER, false);
1674                            }});
1675
1676                            packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs);
1677                            packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs);
1678                            next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs);
1679
1680                            depth -= 1;
1681                            if depth == 0 {
1682                                break;
1683                            }
1684                        }
1685                    }
1686                    depth = k_leftover;
1687                    if depth != 0 {
1688                        loop {
1689                            KernelIter {
1690                                next_lhs,
1691                                packed_lhs,
1692                                packed_rhs,
1693                                lhs_cs,
1694                                rhs_rs,
1695                                rhs_cs,
1696                                accum,
1697                                lhs: lhs.as_mut_ptr() as _,
1698                                rhs: &mut rhs as *mut _ as _,
1699                            }
1700                            .execute(0, false);
1701
1702                            packed_lhs = packed_lhs.wrapping_offset(lhs_cs);
1703                            packed_rhs = packed_rhs.wrapping_offset(rhs_rs);
1704                            next_lhs = next_lhs.wrapping_offset(lhs_cs);
1705
1706                            depth -= 1;
1707                            if depth == 0 {
1708                                break;
1709                            }
1710                        }
1711                    }
1712                    break;
1713                }
1714            }
1715
1716            if conj_both_lhs_rhs {
1717                seq_macro::seq!(N_ITER in 0..$nr {{
1718                    let accum = accum.add(N_ITER * $mr_div_n);
1719                    seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1720                        let accum = &mut *accum.add(M_ITER);
1721                        *accum = conj(*accum);
1722                    }});
1723                }});
1724            }
1725
1726            if m == $mr_div_n * N && n == $nr && dst_rs == 1 {
1727                let alpha = splat(alpha);
1728                let beta = splat(beta);
1729
1730                if conj_dst {
1731                    if alpha_status == 2 {
1732                        seq_macro::seq!(N_ITER in 0..$nr {{
1733                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1734                                let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
1735                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1736                                *dst = add(
1737                                    mul_cplx(conj(*dst), alpha),
1738                                    mul_cplx(accum, beta),
1739                                );
1740                            }});
1741                        }});
1742                    } else if alpha_status == 1 {
1743                        seq_macro::seq!(N_ITER in 0..$nr {{
1744                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1745                                let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
1746                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1747                                *dst = add(
1748                                    conj(*dst),
1749                                    mul_cplx(accum, beta),
1750                                    );
1751                            }});
1752                        }});
1753                    } else {
1754                        seq_macro::seq!(N_ITER in 0..$nr {{
1755                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1756                                let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
1757                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1758                                *dst = mul_cplx(accum, beta);
1759                            }});
1760                        }});
1761                    }
1762                } else {
1763                    if alpha_status == 2 {
1764                        seq_macro::seq!(N_ITER in 0..$nr {{
1765                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1766                                let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
1767                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1768                                *dst = add(
1769                                    mul_cplx(*dst, alpha),
1770                                    mul_cplx(accum, beta),
1771                                );
1772                            }});
1773                        }});
1774                    } else if alpha_status == 1 {
1775                        seq_macro::seq!(N_ITER in 0..$nr {{
1776                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1777                                let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
1778                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1779                                *dst = add(
1780                                    *dst,
1781                                    mul_cplx(accum, beta),
1782                                );
1783                            }});
1784                        }});
1785                    } else {
1786                        seq_macro::seq!(N_ITER in 0..$nr {{
1787                            seq_macro::seq!(M_ITER in 0..$mr_div_n {{
1788                                let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack;
1789                                let accum = *accum.offset(M_ITER + $mr_div_n * N_ITER);
1790                                *dst = mul_cplx(accum, beta);
1791                            }});
1792                        }});
1793                    }
1794                }
1795            } else {
1796                let src = accum_storage; // write to stack
1797                let src = src.as_ptr() as *const T;
1798
1799                if conj_dst {
1800                    if alpha_status == 2 {
1801                        for j in 0..n {
1802                            let dst_j = dst.offset(dst_cs * j as isize);
1803                            let src_j = src.add(j * $mr_div_n * N);
1804
1805                            for i in 0..m {
1806                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1807                                let src_ij = src_j.add(i);
1808
1809                                *dst_ij = alpha * (*dst_ij).conj() + beta * *src_ij;
1810                            }
1811                        }
1812                    } else if alpha_status == 1 {
1813                        for j in 0..n {
1814                            let dst_j = dst.offset(dst_cs * j as isize);
1815                            let src_j = src.add(j * $mr_div_n * N);
1816
1817                            for i in 0..m {
1818                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1819                                let src_ij = src_j.add(i);
1820
1821                                *dst_ij = (*dst_ij).conj() + beta * *src_ij;
1822                            }
1823                        }
1824                    } else {
1825                        for j in 0..n {
1826                            let dst_j = dst.offset(dst_cs * j as isize);
1827                            let src_j = src.add(j * $mr_div_n * N);
1828
1829                            for i in 0..m {
1830                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1831                                let src_ij = src_j.add(i);
1832
1833                                *dst_ij = beta * *src_ij;
1834                            }
1835                        }
1836                    }
1837                } else {
1838                    if alpha_status == 2 {
1839                        for j in 0..n {
1840                            let dst_j = dst.offset(dst_cs * j as isize);
1841                            let src_j = src.add(j * $mr_div_n * N);
1842
1843                            for i in 0..m {
1844                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1845                                let src_ij = src_j.add(i);
1846
1847                                *dst_ij = alpha * *dst_ij + beta * *src_ij;
1848                            }
1849                        }
1850                    } else if alpha_status == 1 {
1851                        for j in 0..n {
1852                            let dst_j = dst.offset(dst_cs * j as isize);
1853                            let src_j = src.add(j * $mr_div_n * N);
1854
1855                            for i in 0..m {
1856                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1857                                let src_ij = src_j.add(i);
1858
1859                                *dst_ij = *dst_ij + beta * *src_ij;
1860                            }
1861                        }
1862                    } else {
1863                        for j in 0..n {
1864                            let dst_j = dst.offset(dst_cs * j as isize);
1865                            let src_j = src.add(j * $mr_div_n * N);
1866
1867                            for i in 0..m {
1868                                let dst_ij = dst_j.offset(dst_rs * i as isize);
1869                                let src_ij = src_j.add(i);
1870
1871                                *dst_ij = beta * *src_ij;
1872                            }
1873                        }
1874                    }
1875                }
1876            }
1877        }
1878    };
1879}