faer_core/
mul.rs

1//! Matrix multiplication.
2
3use crate::{
4    assert, c32, c64, group_helpers::*, transmute_unchecked, unzipped, zipped, ComplexField, Conj,
5    Conjugate, DivCeil, MatMut, MatRef, Parallelism, SimdGroupFor,
6};
7use core::{iter::zip, marker::PhantomData, mem::MaybeUninit};
8use faer_entity::{SimdCtx, *};
9use pulp::Simd;
10use reborrow::*;
11
12#[doc(hidden)]
13pub mod inner_prod {
14    use super::*;
15    use crate::assert;
16
17    #[inline(always)]
18    fn a_x_b_accumulate1<C: ConjTy, E: ComplexField, S: Simd>(
19        simd: SimdFor<E, S>,
20        conj: C,
21        a: SliceGroup<E>,
22        b: SliceGroup<E>,
23        offset: pulp::Offset<E::SimdMask<S>>,
24    ) -> SimdGroupFor<E, S> {
25        let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
26        let (b_head, b_body, b_tail) = simd.as_aligned_simd(b, offset);
27        let zero = simd.splat(E::faer_zero());
28        let mut acc0 = simd.conditional_conj_mul(conj, a_head.read_or(zero), b_head.read_or(zero));
29
30        let a_body1 = a_body;
31        let b_body1 = b_body;
32        for (a, b) in zip(a_body1.into_ref_iter(), b_body1.into_ref_iter()) {
33            acc0 = simd.conditional_conj_mul_add_e(conj, a.read_or(zero), b.read_or(zero), acc0);
34        }
35        simd.conditional_conj_mul_add_e(conj, a_tail.read_or(zero), b_tail.read_or(zero), acc0)
36    }
37
38    #[inline(always)]
39    fn a_x_b_accumulate2<C: ConjTy, E: ComplexField, S: Simd>(
40        simd: SimdFor<E, S>,
41        conj: C,
42        a: SliceGroup<E>,
43        b: SliceGroup<E>,
44        offset: pulp::Offset<E::SimdMask<S>>,
45    ) -> SimdGroupFor<E, S> {
46        let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
47        let (b_head, b_body, b_tail) = simd.as_aligned_simd(b, offset);
48        let zero = simd.splat(E::faer_zero());
49        let mut acc0 = simd.conditional_conj_mul(conj, a_head.read_or(zero), b_head.read_or(zero));
50        let mut acc1 = zero;
51
52        let (a_body2, a_body1) = a_body.as_arrays::<2>();
53        let (b_body2, b_body1) = b_body.as_arrays::<2>();
54        for ([a0, a1], [b0, b1]) in zip(
55            a_body2.into_ref_iter().map(RefGroup::unzip),
56            b_body2.into_ref_iter().map(RefGroup::unzip),
57        ) {
58            acc0 = simd.conditional_conj_mul_add_e(conj, a0.read_or(zero), b0.read_or(zero), acc0);
59            acc1 = simd.conditional_conj_mul_add_e(conj, a1.read_or(zero), b1.read_or(zero), acc1);
60        }
61        for (a, b) in zip(a_body1.into_ref_iter(), b_body1.into_ref_iter()) {
62            acc0 = simd.conditional_conj_mul_add_e(conj, a.read_or(zero), b.read_or(zero), acc0);
63        }
64        acc0 =
65            simd.conditional_conj_mul_add_e(conj, a_tail.read_or(zero), b_tail.read_or(zero), acc0);
66        simd.add(acc0, acc1)
67    }
68
69    #[inline(always)]
70    fn a_x_b_accumulate4<C: ConjTy, E: ComplexField, S: Simd>(
71        simd: SimdFor<E, S>,
72        conj: C,
73        a: SliceGroup<E>,
74        b: SliceGroup<E>,
75        offset: pulp::Offset<E::SimdMask<S>>,
76    ) -> SimdGroupFor<E, S> {
77        let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
78        let (b_head, b_body, b_tail) = simd.as_aligned_simd(b, offset);
79        let zero = simd.splat(E::faer_zero());
80        let mut acc0 = simd.conditional_conj_mul(conj, a_head.read_or(zero), b_head.read_or(zero));
81        let mut acc1 = zero;
82        let mut acc2 = zero;
83        let mut acc3 = zero;
84
85        let (a_body4, a_body1) = a_body.as_arrays::<4>();
86        let (b_body4, b_body1) = b_body.as_arrays::<4>();
87        for ([a0, a1, a2, a3], [b0, b1, b2, b3]) in zip(
88            a_body4.into_ref_iter().map(RefGroup::unzip),
89            b_body4.into_ref_iter().map(RefGroup::unzip),
90        ) {
91            acc0 = simd.conditional_conj_mul_add_e(conj, a0.read_or(zero), b0.read_or(zero), acc0);
92            acc1 = simd.conditional_conj_mul_add_e(conj, a1.read_or(zero), b1.read_or(zero), acc1);
93            acc2 = simd.conditional_conj_mul_add_e(conj, a2.read_or(zero), b2.read_or(zero), acc2);
94            acc3 = simd.conditional_conj_mul_add_e(conj, a3.read_or(zero), b3.read_or(zero), acc3);
95        }
96        for (a, b) in zip(a_body1.into_ref_iter(), b_body1.into_ref_iter()) {
97            acc0 = simd.conditional_conj_mul_add_e(conj, a.read_or(zero), b.read_or(zero), acc0);
98        }
99        acc0 =
100            simd.conditional_conj_mul_add_e(conj, a_tail.read_or(zero), b_tail.read_or(zero), acc0);
101        simd.add(simd.add(acc0, acc1), simd.add(acc2, acc3))
102    }
103
104    #[inline(always)]
105    fn a_x_b_accumulate8<C: ConjTy, E: ComplexField, S: Simd>(
106        simd: SimdFor<E, S>,
107        conj: C,
108        a: SliceGroup<E>,
109        b: SliceGroup<E>,
110        offset: pulp::Offset<E::SimdMask<S>>,
111    ) -> SimdGroupFor<E, S> {
112        let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
113        let (b_head, b_body, b_tail) = simd.as_aligned_simd(b, offset);
114        let zero = simd.splat(E::faer_zero());
115        let mut acc0 = simd.conditional_conj_mul(conj, a_head.read_or(zero), b_head.read_or(zero));
116        let mut acc1 = zero;
117        let mut acc2 = zero;
118        let mut acc3 = zero;
119        let mut acc4 = zero;
120        let mut acc5 = zero;
121        let mut acc6 = zero;
122        let mut acc7 = zero;
123
124        let (a_body8, a_body1) = a_body.as_arrays::<8>();
125        let (b_body8, b_body1) = b_body.as_arrays::<8>();
126        for ([a0, a1, a2, a3, a4, a5, a6, a7], [b0, b1, b2, b3, b4, b5, b6, b7]) in zip(
127            a_body8.into_ref_iter().map(RefGroup::unzip),
128            b_body8.into_ref_iter().map(RefGroup::unzip),
129        ) {
130            acc0 = simd.conditional_conj_mul_add_e(conj, a0.read_or(zero), b0.read_or(zero), acc0);
131            acc1 = simd.conditional_conj_mul_add_e(conj, a1.read_or(zero), b1.read_or(zero), acc1);
132            acc2 = simd.conditional_conj_mul_add_e(conj, a2.read_or(zero), b2.read_or(zero), acc2);
133            acc3 = simd.conditional_conj_mul_add_e(conj, a3.read_or(zero), b3.read_or(zero), acc3);
134            acc4 = simd.conditional_conj_mul_add_e(conj, a4.read_or(zero), b4.read_or(zero), acc4);
135            acc5 = simd.conditional_conj_mul_add_e(conj, a5.read_or(zero), b5.read_or(zero), acc5);
136            acc6 = simd.conditional_conj_mul_add_e(conj, a6.read_or(zero), b6.read_or(zero), acc6);
137            acc7 = simd.conditional_conj_mul_add_e(conj, a7.read_or(zero), b7.read_or(zero), acc7);
138        }
139        for (a, b) in zip(a_body1.into_ref_iter(), b_body1.into_ref_iter()) {
140            acc0 = simd.conditional_conj_mul_add_e(conj, a.read_or(zero), b.read_or(zero), acc0);
141        }
142        acc0 =
143            simd.conditional_conj_mul_add_e(conj, a_tail.read_or(zero), b_tail.read_or(zero), acc0);
144        simd.add(
145            simd.add(simd.add(acc0, acc1), simd.add(acc2, acc3)),
146            simd.add(simd.add(acc4, acc5), simd.add(acc6, acc7)),
147        )
148    }
149
150    #[inline(always)]
151    pub fn with_simd_and_offset<C: ConjTy, E: ComplexField, S: Simd>(
152        simd: SimdFor<E, S>,
153        conj: C,
154        a: SliceGroup<E>,
155        b: SliceGroup<E>,
156        offset: pulp::Offset<E::SimdMask<S>>,
157    ) -> E {
158        {
159            let prologue = if E::N_COMPONENTS == 1 {
160                a_x_b_accumulate8(simd, conj, a, b, offset)
161            } else if E::N_COMPONENTS == 2 {
162                a_x_b_accumulate4(simd, conj, a, b, offset)
163            } else if E::N_COMPONENTS == 4 {
164                a_x_b_accumulate2(simd, conj, a, b, offset)
165            } else {
166                a_x_b_accumulate1(simd, conj, a, b, offset)
167            };
168
169            simd.reduce_add(simd.rotate_left(prologue, offset.rotate_left_amount()))
170        }
171    }
172
173    pub struct Impl<'a, C: ConjTy, E: ComplexField> {
174        pub a: SliceGroup<'a, E>,
175        pub b: SliceGroup<'a, E>,
176        pub conj: C,
177    }
178
179    impl<C: ConjTy, E: ComplexField> pulp::WithSimd for Impl<'_, C, E> {
180        type Output = E;
181
182        #[inline(always)]
183        fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
184            let simd = SimdFor::new(simd);
185            with_simd_and_offset(simd, self.conj, self.a, self.b, simd.align_offset(self.a))
186        }
187    }
188
189    #[inline(always)]
190    #[track_caller]
191    pub fn inner_prod_with_conj_arch<E: ComplexField>(
192        arch: E::Simd,
193        lhs: MatRef<'_, E>,
194        conj_lhs: Conj,
195        rhs: MatRef<'_, E>,
196        conj_rhs: Conj,
197    ) -> E {
198        assert!(all(
199            lhs.nrows() == rhs.nrows(),
200            lhs.ncols() == 1,
201            rhs.ncols() == 1,
202        ));
203        let nrows = lhs.nrows();
204        let mut a = lhs;
205        let mut b = rhs;
206        if a.row_stride() < 0 {
207            a = a.reverse_rows();
208            b = b.reverse_rows();
209        }
210
211        let res = if a.row_stride() == 1 && b.row_stride() == 1 {
212            let a = SliceGroup::<'_, E>::new(a.try_get_contiguous_col(0));
213            let b = SliceGroup::<'_, E>::new(b.try_get_contiguous_col(0));
214            if conj_lhs == conj_rhs {
215                arch.dispatch(Impl { a, b, conj: NoConj })
216            } else {
217                arch.dispatch(Impl {
218                    a,
219                    b,
220                    conj: YesConj,
221                })
222            }
223        } else {
224            crate::constrained::Size::with2(
225                nrows,
226                1,
227                #[inline(always)]
228                |nrows, ncols| {
229                    let zero_idx = ncols.check(0);
230
231                    let a = crate::constrained::MatRef::new(a, nrows, ncols);
232                    let b = crate::constrained::MatRef::new(b, nrows, ncols);
233                    let mut acc = E::faer_zero();
234                    if conj_lhs == conj_rhs {
235                        for i in nrows.indices() {
236                            acc =
237                                acc.faer_add(E::faer_mul(a.read(i, zero_idx), b.read(i, zero_idx)));
238                        }
239                    } else {
240                        for i in nrows.indices() {
241                            acc = acc.faer_add(E::faer_mul(
242                                a.read(i, zero_idx).faer_conj(),
243                                b.read(i, zero_idx),
244                            ));
245                        }
246                    }
247                    acc
248                },
249            )
250        };
251
252        match conj_rhs {
253            Conj::Yes => res.faer_conj(),
254            Conj::No => res,
255        }
256    }
257
258    #[inline]
259    #[track_caller]
260    pub fn inner_prod_with_conj<E: ComplexField>(
261        lhs: MatRef<'_, E>,
262        conj_lhs: Conj,
263        rhs: MatRef<'_, E>,
264        conj_rhs: Conj,
265    ) -> E {
266        inner_prod_with_conj_arch(E::Simd::default(), lhs, conj_lhs, rhs, conj_rhs)
267    }
268}
269
270#[doc(hidden)]
271pub mod matvec_rowmajor {
272    use super::*;
273    use crate::assert;
274
275    fn matvec_with_conj_impl<E: ComplexField>(
276        acc: MatMut<'_, E>,
277        a: MatRef<'_, E>,
278        conj_a: Conj,
279        b: MatRef<'_, E>,
280        conj_b: Conj,
281        alpha: Option<E>,
282        beta: E,
283    ) {
284        let m = a.nrows();
285        let n = a.ncols();
286
287        assert!(all(
288            b.nrows() == n,
289            b.ncols() == 1,
290            acc.nrows() == m,
291            acc.ncols() == 1,
292            a.col_stride() == 1,
293            b.row_stride() == 1,
294        ));
295
296        let mut acc = acc;
297
298        for i in 0..m {
299            let a = a.submatrix(i, 0, 1, n);
300            let res = inner_prod::inner_prod_with_conj(a.transpose(), conj_a, b, conj_b);
301            match alpha {
302                Some(alpha) => acc.write(
303                    i,
304                    0,
305                    E::faer_add(alpha.faer_mul(acc.read(i, 0)), beta.faer_mul(res)),
306                ),
307                None => acc.write(i, 0, beta.faer_mul(res)),
308            }
309        }
310    }
311
312    pub fn matvec_with_conj<E: ComplexField>(
313        acc: MatMut<'_, E>,
314        lhs: MatRef<'_, E>,
315        conj_lhs: Conj,
316        rhs: MatRef<'_, E>,
317        conj_rhs: Conj,
318        alpha: Option<E>,
319        beta: E,
320    ) {
321        if rhs.row_stride() == 1 {
322            matvec_with_conj_impl(acc, lhs, conj_lhs, rhs, conj_rhs, alpha, beta);
323        } else {
324            matvec_with_conj_impl(
325                acc,
326                lhs,
327                conj_lhs,
328                rhs.to_owned().as_ref(),
329                conj_rhs,
330                alpha,
331                beta,
332            );
333        }
334    }
335}
336
337#[doc(hidden)]
338pub mod matvec_colmajor {
339    use super::*;
340    use crate::assert;
341
342    pub struct Impl<'a, C: ConjTy, E: ComplexField> {
343        pub conj: C,
344        pub acc: SliceGroupMut<'a, E>,
345        pub a: SliceGroup<'a, E>,
346        pub b: E,
347    }
348
349    #[inline(always)]
350    pub fn with_simd_and_offset<C: ConjTy, E: ComplexField, S: Simd>(
351        simd: SimdFor<E, S>,
352        conj: C,
353        acc: SliceGroupMut<'_, E>,
354        a: SliceGroup<'_, E>,
355        b: E,
356        offset: pulp::Offset<SimdMaskFor<E, S>>,
357    ) {
358        let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
359        let (acc_head, acc_body, acc_tail) = simd.as_aligned_simd_mut(acc, offset);
360        let b = simd.splat(b);
361
362        #[inline(always)]
363        pub fn process<C: ConjTy, E: ComplexField, S: Simd>(
364            simd: SimdFor<E, S>,
365            conj: C,
366            mut acc: impl Write<Output = SimdGroupFor<E, S>>,
367            a: impl Read<Output = SimdGroupFor<E, S>>,
368            b: SimdGroupFor<E, S>,
369        ) {
370            acc.write(simd.conditional_conj_mul_add_e(
371                conj,
372                a.read_or(simd.splat(E::faer_zero())),
373                b,
374                acc.read_or(simd.splat(E::faer_zero())),
375            ))
376        }
377
378        process(simd, conj, acc_head, a_head, b);
379        for (acc, a) in acc_body.into_mut_iter().zip(a_body.into_ref_iter()) {
380            process(simd, conj, acc, a, b);
381        }
382        process(simd, conj, acc_tail, a_tail, b);
383    }
384
385    impl<C: ConjTy, E: ComplexField> pulp::WithSimd for Impl<'_, C, E> {
386        type Output = ();
387
388        #[inline(always)]
389        fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
390            let simd = SimdFor::new(simd);
391            with_simd_and_offset(
392                simd,
393                self.conj,
394                self.acc,
395                self.a,
396                self.b,
397                simd.align_offset(self.a),
398            )
399        }
400    }
401
402    fn matvec_with_conj_impl<E: ComplexField>(
403        acc: MatMut<'_, E>,
404        a: MatRef<'_, E>,
405        conj_a: Conj,
406        b: MatRef<'_, E>,
407        conj_b: Conj,
408        beta: E,
409    ) {
410        let m = a.nrows();
411        let n = a.ncols();
412
413        assert!(all(
414            b.nrows() == n,
415            b.ncols() == 1,
416            acc.nrows() == m,
417            acc.ncols() == 1,
418            a.row_stride() == 1,
419            acc.row_stride() == 1,
420        ));
421
422        let mut acc = SliceGroupMut::<'_, E>::new(acc.try_get_contiguous_col_mut(0));
423
424        let arch = E::Simd::default();
425        for j in 0..n {
426            let acc = acc.rb_mut();
427            let a = SliceGroup::<'_, E>::new(a.try_get_contiguous_col(j));
428            let b = b.read(j, 0);
429            let b = match conj_b {
430                Conj::Yes => b.faer_conj(),
431                Conj::No => b,
432            };
433            let b = b.faer_mul(beta);
434
435            match conj_a {
436                Conj::Yes => arch.dispatch(Impl {
437                    conj: YesConj,
438                    acc,
439                    a,
440                    b,
441                }),
442                Conj::No => arch.dispatch(Impl {
443                    conj: NoConj,
444                    acc,
445                    a,
446                    b,
447                }),
448            }
449        }
450    }
451
452    pub fn matvec_with_conj<E: ComplexField>(
453        acc: MatMut<'_, E>,
454        lhs: MatRef<'_, E>,
455        conj_lhs: Conj,
456        rhs: MatRef<'_, E>,
457        conj_rhs: Conj,
458        alpha: Option<E>,
459        beta: E,
460    ) {
461        let m = acc.nrows();
462        let mut acc = acc;
463        if acc.row_stride() == 1 {
464            match alpha {
465                Some(alpha) if alpha == E::faer_one() => {}
466                Some(alpha) => {
467                    for i in 0..m {
468                        acc.write(i, 0, acc.read(i, 0).faer_mul(alpha));
469                    }
470                }
471                None => {
472                    for i in 0..m {
473                        acc.write(i, 0, E::faer_zero());
474                    }
475                }
476            }
477
478            matvec_with_conj_impl(acc, lhs, conj_lhs, rhs, conj_rhs, beta);
479        } else {
480            let mut tmp = crate::Mat::<E>::zeros(m, 1);
481            matvec_with_conj_impl(tmp.as_mut(), lhs, conj_lhs, rhs, conj_rhs, beta);
482            match alpha {
483                Some(alpha) => {
484                    for i in 0..m {
485                        acc.write(
486                            i,
487                            0,
488                            (acc.read(i, 0).faer_mul(alpha)).faer_add(tmp.read(i, 0)),
489                        )
490                    }
491                }
492                None => {
493                    for i in 0..m {
494                        acc.write(i, 0, tmp.read(i, 0))
495                    }
496                }
497            }
498        }
499    }
500}
501
502#[doc(hidden)]
503pub mod matvec {
504    use super::*;
505
506    pub fn matvec_with_conj<E: ComplexField>(
507        acc: MatMut<'_, E>,
508        lhs: MatRef<'_, E>,
509        conj_lhs: Conj,
510        rhs: MatRef<'_, E>,
511        conj_rhs: Conj,
512        alpha: Option<E>,
513        beta: E,
514    ) {
515        let mut acc = acc;
516        let mut a = lhs;
517        let mut b = rhs;
518
519        if a.row_stride() < 0 {
520            a = a.reverse_rows();
521            acc = acc.reverse_rows_mut();
522        }
523        if a.col_stride() < 0 {
524            a = a.reverse_cols();
525            b = b.reverse_rows();
526        }
527
528        if a.row_stride() == 1 {
529            return matvec_colmajor::matvec_with_conj(acc, a, conj_lhs, b, conj_rhs, alpha, beta);
530        }
531        if a.col_stride() == 1 {
532            return matvec_rowmajor::matvec_with_conj(acc, a, conj_lhs, b, conj_rhs, alpha, beta);
533        }
534
535        let m = a.nrows();
536        let n = a.ncols();
537
538        match alpha {
539            Some(alpha) => {
540                for i in 0..m {
541                    acc.write(i, 0, acc.read(i, 0).faer_mul(alpha));
542                }
543            }
544            None => {
545                for i in 0..m {
546                    acc.write(i, 0, E::faer_zero());
547                }
548            }
549        }
550
551        for j in 0..n {
552            let b = b.read(j, 0);
553            let b = match conj_rhs {
554                Conj::Yes => b.faer_conj(),
555                Conj::No => b,
556            };
557            let b = b.faer_mul(beta);
558            for i in 0..m {
559                let mul = a.read(i, j).faer_mul(b);
560                acc.write(i, 0, acc.read(i, 0).faer_add(mul));
561            }
562        }
563    }
564}
565
566#[doc(hidden)]
567pub mod outer_prod {
568    use super::*;
569    use crate::assert;
570
571    pub struct Impl<'a, C: ConjTy, E: ComplexField> {
572        pub conj: C,
573        pub acc: SliceGroupMut<'a, E>,
574        pub a: SliceGroup<'a, E>,
575        pub b: E,
576        pub alpha: Option<E>,
577    }
578
579    #[inline(always)]
580    pub fn with_simd_and_offset<C: ConjTy, E: ComplexField, S: Simd>(
581        simd: SimdFor<E, S>,
582        conj: C,
583        acc: SliceGroupMut<'_, E>,
584        a: SliceGroup<'_, E>,
585        b: E,
586        alpha: Option<E>,
587        offset: pulp::Offset<SimdMaskFor<E, S>>,
588    ) {
589        match alpha {
590            Some(alpha) => {
591                if alpha == E::faer_one() {
592                    return matvec_colmajor::with_simd_and_offset(simd, conj, acc, a, b, offset);
593                }
594
595                let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
596                let (acc_head, acc_body, acc_tail) = simd.as_aligned_simd_mut(acc, offset);
597                let b = simd.splat(b);
598                let alpha = simd.splat(alpha);
599
600                #[inline(always)]
601                pub fn process<C: ConjTy, E: ComplexField, S: Simd>(
602                    simd: SimdFor<E, S>,
603                    conj: C,
604                    mut acc: impl Write<Output = SimdGroupFor<E, S>>,
605                    a: impl Read<Output = SimdGroupFor<E, S>>,
606                    b: SimdGroupFor<E, S>,
607                    alpha: SimdGroupFor<E, S>,
608                ) {
609                    acc.write(simd.conditional_conj_mul_add_e(
610                        conj,
611                        a.read_or(simd.splat(E::faer_zero())),
612                        b,
613                        simd.mul(alpha, acc.read_or(simd.splat(E::faer_zero()))),
614                    ))
615                }
616
617                process(simd, conj, acc_head, a_head, b, alpha);
618                for (acc, a) in acc_body.into_mut_iter().zip(a_body.into_ref_iter()) {
619                    process(simd, conj, acc, a, b, alpha);
620                }
621                process(simd, conj, acc_tail, a_tail, b, alpha);
622            }
623            None => {
624                let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
625                let (acc_head, acc_body, acc_tail) = simd.as_aligned_simd_mut(acc, offset);
626                let b = simd.splat(b);
627
628                #[inline(always)]
629                pub fn process<C: ConjTy, E: ComplexField, S: Simd>(
630                    simd: SimdFor<E, S>,
631                    conj: C,
632                    mut acc: impl Write<Output = SimdGroupFor<E, S>>,
633                    a: impl Read<Output = SimdGroupFor<E, S>>,
634                    b: SimdGroupFor<E, S>,
635                ) {
636                    acc.write(simd.conditional_conj_mul(
637                        conj,
638                        a.read_or(simd.splat(E::faer_zero())),
639                        b,
640                    ))
641                }
642
643                process(simd, conj, acc_head, a_head, b);
644                for (acc, a) in acc_body.into_mut_iter().zip(a_body.into_ref_iter()) {
645                    process(simd, conj, acc, a, b);
646                }
647                process(simd, conj, acc_tail, a_tail, b);
648            }
649        }
650    }
651
652    impl<C: ConjTy, E: ComplexField> pulp::WithSimd for Impl<'_, C, E> {
653        type Output = ();
654
655        #[inline(always)]
656        fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
657            let simd = SimdFor::new(simd);
658            with_simd_and_offset(
659                simd,
660                self.conj,
661                self.acc,
662                self.a,
663                self.b,
664                self.alpha,
665                simd.align_offset(self.a),
666            )
667        }
668    }
669
670    fn outer_prod_with_conj_impl<E: ComplexField>(
671        acc: MatMut<'_, E>,
672        a: MatRef<'_, E>,
673        conj_a: Conj,
674        b: MatRef<'_, E>,
675        conj_b: Conj,
676        alpha: Option<E>,
677        beta: E,
678    ) {
679        let m = acc.nrows();
680        let n = acc.ncols();
681
682        assert!(all(
683            a.nrows() == m,
684            a.ncols() == 1,
685            b.nrows() == n,
686            b.ncols() == 1,
687            acc.row_stride() == 1,
688            a.row_stride() == 1,
689        ));
690
691        let mut acc = acc;
692
693        let arch = E::Simd::default();
694
695        let a = SliceGroup::new(a.try_get_contiguous_col(0));
696
697        for j in 0..n {
698            let acc = SliceGroupMut::new(acc.rb_mut().try_get_contiguous_col_mut(j));
699            let b = b.read(j, 0);
700            let b = match conj_b {
701                Conj::Yes => b.faer_conj(),
702                Conj::No => b,
703            };
704            let b = b.faer_mul(beta);
705            match conj_a {
706                Conj::Yes => arch.dispatch(Impl {
707                    conj: YesConj,
708                    acc,
709                    a,
710                    b,
711                    alpha,
712                }),
713                Conj::No => arch.dispatch(Impl {
714                    conj: NoConj,
715                    acc,
716                    a,
717                    b,
718                    alpha,
719                }),
720            }
721        }
722    }
723
724    pub fn outer_prod_with_conj<E: ComplexField>(
725        acc: MatMut<'_, E>,
726        lhs: MatRef<'_, E>,
727        conj_lhs: Conj,
728        rhs: MatRef<'_, E>,
729        conj_rhs: Conj,
730        alpha: Option<E>,
731        beta: E,
732    ) {
733        let mut acc = acc;
734        let mut a = lhs;
735        let mut b = rhs;
736        let mut conj_a = conj_lhs;
737        let mut conj_b = conj_rhs;
738
739        if acc.row_stride() < 0 {
740            acc = acc.reverse_rows_mut();
741            a = a.reverse_rows();
742        }
743        if acc.col_stride() < 0 {
744            acc = acc.reverse_cols_mut();
745            b = b.reverse_rows();
746        }
747
748        if acc.row_stride() > a.col_stride() {
749            acc = acc.transpose_mut();
750            core::mem::swap(&mut a, &mut b);
751            core::mem::swap(&mut conj_a, &mut conj_b);
752        }
753
754        if acc.row_stride() == 1 {
755            if a.row_stride() == 1 {
756                outer_prod_with_conj_impl(acc, a, conj_a, b, conj_b, alpha, beta);
757            } else {
758                outer_prod_with_conj_impl(
759                    acc,
760                    a.to_owned().as_ref(),
761                    conj_a,
762                    b,
763                    conj_b,
764                    alpha,
765                    beta,
766                );
767            }
768        } else {
769            let m = acc.nrows();
770            let n = acc.ncols();
771            match alpha {
772                Some(alpha) => {
773                    for j in 0..n {
774                        let b = b.read(j, 0);
775                        let b = match conj_b {
776                            Conj::Yes => b.faer_conj(),
777                            Conj::No => b,
778                        };
779                        let b = b.faer_mul(beta);
780                        match conj_a {
781                            Conj::Yes => {
782                                for i in 0..m {
783                                    let ab = a.read(i, 0).faer_conj().faer_mul(b);
784                                    acc.write(
785                                        i,
786                                        j,
787                                        E::faer_add(acc.read(i, j).faer_mul(alpha), ab),
788                                    );
789                                }
790                            }
791                            Conj::No => {
792                                for i in 0..m {
793                                    let ab = a.read(i, 0).faer_mul(b);
794                                    acc.write(
795                                        i,
796                                        j,
797                                        E::faer_add(acc.read(i, j).faer_mul(alpha), ab),
798                                    );
799                                }
800                            }
801                        }
802                    }
803                }
804                None => {
805                    for j in 0..n {
806                        let b = b.read(j, 0);
807                        let b = match conj_b {
808                            Conj::Yes => b.faer_conj(),
809                            Conj::No => b,
810                        };
811                        let b = b.faer_mul(beta);
812                        match conj_a {
813                            Conj::Yes => {
814                                for i in 0..m {
815                                    acc.write(i, j, a.read(i, 0).faer_conj().faer_mul(b));
816                                }
817                            }
818                            Conj::No => {
819                                for i in 0..m {
820                                    acc.write(i, j, a.read(i, 0).faer_mul(b));
821                                }
822                            }
823                        }
824                    }
825                }
826            }
827        }
828    }
829}
830
831const NC: usize = 2048;
832const MC: usize = 48;
833const KC: usize = 64;
834
835struct SimdLaneCount<E: ComplexField> {
836    __marker: PhantomData<fn() -> E>,
837}
838impl<E: ComplexField> pulp::WithSimd for SimdLaneCount<E> {
839    type Output = usize;
840
841    fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
842        let _ = simd;
843        core::mem::size_of::<SimdUnitFor<E, S>>() / core::mem::size_of::<UnitFor<E>>()
844    }
845}
846
847struct Ukr<'a, const MR_DIV_N: usize, const NR: usize, CB: ConjTy, E: ComplexField> {
848    conj_b: CB,
849    acc: MatMut<'a, E>,
850    a: MatRef<'a, E>,
851    b: MatRef<'a, E>,
852}
853
854impl<const MR_DIV_N: usize, const NR: usize, CB: ConjTy, E: ComplexField> pulp::WithSimd
855    for Ukr<'_, MR_DIV_N, NR, CB, E>
856{
857    type Output = ();
858
859    #[inline(always)]
860    fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
861        let Self {
862            mut acc,
863            a,
864            b,
865            conj_b,
866        } = self;
867        let lane_count =
868            core::mem::size_of::<SimdUnitFor<E, S>>() / core::mem::size_of::<UnitFor<E>>();
869
870        let mr = MR_DIV_N * lane_count;
871        let nr = NR;
872
873        assert!(all(
874            acc.nrows() == mr,
875            acc.ncols() == nr,
876            a.nrows() == mr,
877            b.ncols() == nr,
878            a.ncols() == b.nrows(),
879            a.row_stride() == 1,
880            b.row_stride() == 1,
881            acc.row_stride() == 1
882        ));
883
884        let k = a.ncols();
885        let mut local_acc = [[E::faer_simd_splat(simd, E::faer_zero()); MR_DIV_N]; NR];
886        let simd = SimdFor::<E, S>::new(simd);
887
888        unsafe {
889            let mut one_iter = {
890                #[inline(always)]
891                |depth| {
892                    let a = a.ptr_inbounds_at(0, depth);
893
894                    let mut a_uninit = [MaybeUninit::<SimdGroupFor<E, S>>::uninit(); MR_DIV_N];
895
896                    let mut i = 0usize;
897                    loop {
898                        if i == MR_DIV_N {
899                            break;
900                        }
901                        a_uninit[i] = MaybeUninit::new(into_copy::<E, _>(E::faer_map(
902                            E::faer_copy(&a),
903                            #[inline(always)]
904                            |ptr| *(ptr.add(i * lane_count) as *const SimdUnitFor<E, S>),
905                        )));
906                        i += 1;
907                    }
908                    let a: [SimdGroupFor<E, S>; MR_DIV_N] = transmute_unchecked::<
909                        [MaybeUninit<SimdGroupFor<E, S>>; MR_DIV_N],
910                        [SimdGroupFor<E, S>; MR_DIV_N],
911                    >(a_uninit);
912
913                    let mut j = 0usize;
914                    loop {
915                        if j == NR {
916                            break;
917                        }
918                        let b = simd.splat(E::faer_from_units(E::faer_map(
919                            b.ptr_at(depth, j),
920                            #[inline(always)]
921                            |ptr| *ptr,
922                        )));
923                        let mut i = 0;
924                        loop {
925                            if i == MR_DIV_N {
926                                break;
927                            }
928                            let local_acc = &mut local_acc[j][i];
929                            *local_acc =
930                                simd.conditional_conj_mul_add_e(conj_b, b, a[i], *local_acc);
931                            i += 1;
932                        }
933                        j += 1;
934                    }
935                }
936            };
937
938            let mut depth = 0;
939            while depth < k / 4 * 4 {
940                one_iter(depth);
941                one_iter(depth + 1);
942                one_iter(depth + 2);
943                one_iter(depth + 3);
944                depth += 4;
945            }
946            while depth < k {
947                one_iter(depth);
948                depth += 1;
949            }
950
951            let mut j = 0usize;
952            loop {
953                if j == NR {
954                    break;
955                }
956                let mut i = 0usize;
957                loop {
958                    if i == MR_DIV_N {
959                        break;
960                    }
961                    let acc = acc.rb_mut().ptr_inbounds_at_mut(i * lane_count, j);
962                    let mut acc_value = into_copy::<E, _>(E::faer_map(E::faer_copy(&acc), |acc| {
963                        *(acc as *const SimdUnitFor<E, S>)
964                    }));
965                    acc_value = simd.add(acc_value, local_acc[j][i]);
966                    E::faer_map(
967                        E::faer_zip(acc, from_copy::<E, _>(acc_value)),
968                        #[inline(always)]
969                        |(acc, new_acc)| *(acc as *mut SimdUnitFor<E, S>) = new_acc,
970                    );
971                    i += 1;
972                }
973                j += 1;
974            }
975        }
976    }
977}
978
979#[inline]
980fn min(a: usize, b: usize) -> usize {
981    a.min(b)
982}
983
984struct MicroKernelShape<E: ComplexField> {
985    __marker: PhantomData<fn() -> E>,
986}
987
988impl<E: ComplexField> MicroKernelShape<E> {
989    const SHAPE: (usize, usize) = {
990        if E::N_COMPONENTS <= 2 {
991            (2, 2)
992        } else if E::N_COMPONENTS == 4 {
993            (2, 1)
994        } else {
995            (1, 1)
996        }
997    };
998
999    const MAX_MR_DIV_N: usize = Self::SHAPE.0;
1000    const MAX_NR: usize = Self::SHAPE.1;
1001
1002    const IS_2X2: bool = Self::MAX_MR_DIV_N == 2 && Self::MAX_NR == 2;
1003    const IS_2X1: bool = Self::MAX_MR_DIV_N == 2 && Self::MAX_NR == 1;
1004    const IS_1X1: bool = Self::MAX_MR_DIV_N == 2 && Self::MAX_NR == 1;
1005}
1006
1007/// acc += a * maybe_conj(b)
1008///
1009/// acc, a, b are colmajor
1010/// m is a multiple of simd lane count
1011fn matmul_with_conj_impl<E: ComplexField>(
1012    acc: MatMut<'_, E>,
1013    a: MatRef<'_, E>,
1014    b: MatRef<'_, E>,
1015    conj_b: Conj,
1016    parallelism: Parallelism,
1017) {
1018    use coe::Coerce;
1019    use num_complex::Complex;
1020    if coe::is_same::<E, Complex<E::Real>>() {
1021        let acc: MatMut<'_, Complex<E::Real>> = acc.coerce();
1022        let a: MatRef<'_, Complex<E::Real>> = a.coerce();
1023        let b: MatRef<'_, Complex<E::Real>> = b.coerce();
1024
1025        let Complex {
1026            re: mut acc_re,
1027            im: mut acc_im,
1028        } = acc.real_imag_mut();
1029        let Complex { re: a_re, im: a_im } = a.real_imag();
1030        let Complex { re: b_re, im: b_im } = b.real_imag();
1031
1032        let real_matmul = |acc: MatMut<'_, E::Real>,
1033                           a: MatRef<'_, E::Real>,
1034                           b: MatRef<'_, E::Real>,
1035                           beta: E::Real| {
1036            matmul_with_conj(
1037                acc,
1038                a,
1039                Conj::No,
1040                b,
1041                Conj::No,
1042                Some(E::Real::faer_one()),
1043                beta,
1044                parallelism,
1045            )
1046        };
1047
1048        match conj_b {
1049            Conj::Yes => {
1050                real_matmul(acc_re.rb_mut(), a_re, b_re, E::Real::faer_one());
1051                real_matmul(acc_re.rb_mut(), a_im, b_im, E::Real::faer_one());
1052                real_matmul(acc_im.rb_mut(), a_re, b_im, E::Real::faer_one().faer_neg());
1053                real_matmul(acc_im.rb_mut(), a_im, b_re, E::Real::faer_one());
1054            }
1055            Conj::No => {
1056                real_matmul(acc_re.rb_mut(), a_re, b_re, E::Real::faer_one());
1057                real_matmul(acc_re.rb_mut(), a_im, b_im, E::Real::faer_one().faer_neg());
1058                real_matmul(acc_im.rb_mut(), a_re, b_im, E::Real::faer_one());
1059                real_matmul(acc_im.rb_mut(), a_im, b_re, E::Real::faer_one());
1060            }
1061        }
1062
1063        return;
1064    }
1065
1066    let m = acc.nrows();
1067    let n = acc.ncols();
1068    let k = a.ncols();
1069
1070    let arch = E::Simd::default();
1071    let lane_count = arch.dispatch(SimdLaneCount::<E> {
1072        __marker: PhantomData,
1073    });
1074
1075    let nr = MicroKernelShape::<E>::MAX_NR;
1076    let mr_div_n = MicroKernelShape::<E>::MAX_MR_DIV_N;
1077    let mr = mr_div_n * lane_count;
1078
1079    assert!(all(
1080        acc.row_stride() == 1,
1081        a.row_stride() == 1,
1082        b.row_stride() == 1,
1083        m % lane_count == 0,
1084    ));
1085
1086    let mut acc = acc;
1087
1088    let mut col_outer = 0usize;
1089    while col_outer < n {
1090        let n_chunk = min(NC, n - col_outer);
1091
1092        let b_panel = b.submatrix(0, col_outer, k, n_chunk);
1093        let acc = acc.rb_mut().submatrix_mut(0, col_outer, m, n_chunk);
1094
1095        let mut depth_outer = 0usize;
1096        while depth_outer < k {
1097            let k_chunk = min(KC, k - depth_outer);
1098
1099            let a_panel = a.submatrix(0, depth_outer, m, k_chunk);
1100            let b_block = b_panel.submatrix(depth_outer, 0, k_chunk, n_chunk);
1101
1102            let n_job_count = n_chunk.msrv_div_ceil(nr);
1103            let chunk_count = m.msrv_div_ceil(MC);
1104
1105            let job_count = n_job_count * chunk_count;
1106
1107            let job = |idx: usize| {
1108                assert!(all(
1109                    acc.row_stride() == 1,
1110                    a.row_stride() == 1,
1111                    b.row_stride() == 1,
1112                ));
1113
1114                let col_inner = (idx % n_job_count) * nr;
1115                let row_outer = (idx / n_job_count) * MC;
1116                let m_chunk = min(MC, m - row_outer);
1117
1118                let mut row_inner = 0;
1119                let ncols = min(nr, n_chunk - col_inner);
1120                let ukr_j = ncols;
1121
1122                while row_inner < m_chunk {
1123                    let nrows = min(mr, m_chunk - row_inner);
1124
1125                    let ukr_i = nrows / lane_count;
1126
1127                    let a = a_panel.submatrix(row_outer + row_inner, 0, nrows, k_chunk);
1128                    let b = b_block.submatrix(0, col_inner, k_chunk, ncols);
1129                    let acc = acc
1130                        .rb()
1131                        .submatrix(row_outer + row_inner, col_inner, nrows, ncols);
1132                    let acc = unsafe { acc.const_cast() };
1133
1134                    match conj_b {
1135                        Conj::Yes => {
1136                            let conj_b = YesConj;
1137                            if MicroKernelShape::<E>::IS_2X2 {
1138                                match (ukr_i, ukr_j) {
1139                                    (2, 2) => {
1140                                        arch.dispatch(Ukr::<2, 2, _, E> { conj_b, acc, a, b })
1141                                    }
1142                                    (2, 1) => {
1143                                        arch.dispatch(Ukr::<2, 1, _, E> { conj_b, acc, a, b })
1144                                    }
1145                                    (1, 2) => {
1146                                        arch.dispatch(Ukr::<1, 2, _, E> { conj_b, acc, a, b })
1147                                    }
1148                                    (1, 1) => {
1149                                        arch.dispatch(Ukr::<1, 1, _, E> { conj_b, acc, a, b })
1150                                    }
1151                                    _ => unreachable!(),
1152                                }
1153                            } else if MicroKernelShape::<E>::IS_2X1 {
1154                                match (ukr_i, ukr_j) {
1155                                    (2, 1) => {
1156                                        arch.dispatch(Ukr::<2, 1, _, E> { conj_b, acc, a, b })
1157                                    }
1158                                    (1, 1) => {
1159                                        arch.dispatch(Ukr::<1, 1, _, E> { conj_b, acc, a, b })
1160                                    }
1161                                    _ => unreachable!(),
1162                                }
1163                            } else if MicroKernelShape::<E>::IS_1X1 {
1164                                match (ukr_i, ukr_j) {
1165                                    (1, 1) => {
1166                                        arch.dispatch(Ukr::<1, 1, _, E> { conj_b, acc, a, b })
1167                                    }
1168                                    _ => unreachable!(),
1169                                }
1170                            } else {
1171                                unreachable!()
1172                            }
1173                        }
1174                        Conj::No => {
1175                            let conj_b = NoConj;
1176                            if MicroKernelShape::<E>::IS_2X2 {
1177                                match (ukr_i, ukr_j) {
1178                                    (2, 2) => {
1179                                        arch.dispatch(Ukr::<2, 2, _, E> { conj_b, acc, a, b })
1180                                    }
1181                                    (2, 1) => {
1182                                        arch.dispatch(Ukr::<2, 1, _, E> { conj_b, acc, a, b })
1183                                    }
1184                                    (1, 2) => {
1185                                        arch.dispatch(Ukr::<1, 2, _, E> { conj_b, acc, a, b })
1186                                    }
1187                                    (1, 1) => {
1188                                        arch.dispatch(Ukr::<1, 1, _, E> { conj_b, acc, a, b })
1189                                    }
1190                                    _ => unreachable!(),
1191                                }
1192                            } else if MicroKernelShape::<E>::IS_2X1 {
1193                                match (ukr_i, ukr_j) {
1194                                    (2, 1) => {
1195                                        arch.dispatch(Ukr::<2, 1, _, E> { conj_b, acc, a, b })
1196                                    }
1197                                    (1, 1) => {
1198                                        arch.dispatch(Ukr::<1, 1, _, E> { conj_b, acc, a, b })
1199                                    }
1200                                    _ => unreachable!(),
1201                                }
1202                            } else if MicroKernelShape::<E>::IS_1X1 {
1203                                match (ukr_i, ukr_j) {
1204                                    (1, 1) => {
1205                                        arch.dispatch(Ukr::<1, 1, _, E> { conj_b, acc, a, b })
1206                                    }
1207                                    _ => unreachable!(),
1208                                }
1209                            } else {
1210                                unreachable!()
1211                            }
1212                        }
1213                    }
1214                    row_inner += nrows;
1215                }
1216            };
1217
1218            crate::for_each_raw(job_count, job, parallelism);
1219
1220            depth_outer += k_chunk;
1221        }
1222
1223        col_outer += n_chunk;
1224    }
1225}
1226
1227#[doc(hidden)]
1228pub fn matmul_with_conj_gemm_dispatch<E: ComplexField>(
1229    mut acc: MatMut<'_, E>,
1230    lhs: MatRef<'_, E>,
1231    conj_lhs: Conj,
1232    rhs: MatRef<'_, E>,
1233    conj_rhs: Conj,
1234    alpha: Option<E>,
1235    beta: E,
1236    parallelism: Parallelism,
1237    _use_gemm: bool,
1238) {
1239    assert!(all(
1240        acc.nrows() == lhs.nrows(),
1241        acc.ncols() == rhs.ncols(),
1242        lhs.ncols() == rhs.nrows(),
1243    ));
1244
1245    let m = acc.nrows();
1246    let n = acc.ncols();
1247    let k = lhs.ncols();
1248
1249    if m == 0 || n == 0 {
1250        return;
1251    }
1252
1253    if m == 1 && n == 1 {
1254        let mut acc = acc;
1255        let ab = inner_prod::inner_prod_with_conj(lhs.transpose(), conj_lhs, rhs, conj_rhs);
1256        match alpha {
1257            Some(alpha) => {
1258                acc.write(
1259                    0,
1260                    0,
1261                    E::faer_add(acc.read(0, 0).faer_mul(alpha), ab.faer_mul(beta)),
1262                );
1263            }
1264            None => {
1265                acc.write(0, 0, ab.faer_mul(beta));
1266            }
1267        }
1268        return;
1269    }
1270
1271    if k == 1 {
1272        outer_prod::outer_prod_with_conj(
1273            acc,
1274            lhs,
1275            conj_lhs,
1276            rhs.transpose(),
1277            conj_rhs,
1278            alpha,
1279            beta,
1280        );
1281        return;
1282    }
1283    if n == 1 {
1284        matvec::matvec_with_conj(acc, lhs, conj_lhs, rhs, conj_rhs, alpha, beta);
1285        return;
1286    }
1287    if m == 1 {
1288        matvec::matvec_with_conj(
1289            acc.transpose_mut(),
1290            rhs.transpose(),
1291            conj_rhs,
1292            lhs.transpose(),
1293            conj_lhs,
1294            alpha,
1295            beta,
1296        );
1297        return;
1298    }
1299
1300    unsafe {
1301        if m + n < 32 && k <= 6 {
1302            macro_rules! small_gemm {
1303                ($term: expr) => {
1304                    let term = $term;
1305                    match k {
1306                        0 => match alpha {
1307                            Some(alpha) => {
1308                                for i in 0..m {
1309                                    for j in 0..n {
1310                                        acc.write_unchecked(
1311                                            i,
1312                                            j,
1313                                            acc.read_unchecked(i, j).faer_mul(alpha),
1314                                        )
1315                                    }
1316                                }
1317                            }
1318                            None => {
1319                                for i in 0..m {
1320                                    for j in 0..n {
1321                                        acc.write_unchecked(i, j, E::faer_zero())
1322                                    }
1323                                }
1324                            }
1325                        },
1326                        1 => match alpha {
1327                            Some(alpha) => {
1328                                for i in 0..m {
1329                                    for j in 0..n {
1330                                        let dot = term(i, j, 0);
1331                                        acc.write_unchecked(
1332                                            i,
1333                                            j,
1334                                            E::faer_add(
1335                                                acc.read_unchecked(i, j).faer_mul(alpha),
1336                                                dot.faer_mul(beta),
1337                                            ),
1338                                        )
1339                                    }
1340                                }
1341                            }
1342                            None => {
1343                                for i in 0..m {
1344                                    for j in 0..n {
1345                                        let dot = term(i, j, 0);
1346                                        acc.write_unchecked(i, j, dot.faer_mul(beta))
1347                                    }
1348                                }
1349                            }
1350                        },
1351                        2 => match alpha {
1352                            Some(alpha) => {
1353                                for i in 0..m {
1354                                    for j in 0..n {
1355                                        let dot = term(i, j, 0).faer_add(term(i, j, 1));
1356                                        acc.write_unchecked(
1357                                            i,
1358                                            j,
1359                                            E::faer_add(
1360                                                acc.read_unchecked(i, j).faer_mul(alpha),
1361                                                dot.faer_mul(beta),
1362                                            ),
1363                                        )
1364                                    }
1365                                }
1366                            }
1367                            None => {
1368                                for i in 0..m {
1369                                    for j in 0..n {
1370                                        let dot = term(i, j, 0).faer_add(term(i, j, 1));
1371                                        acc.write_unchecked(i, j, dot.faer_mul(beta))
1372                                    }
1373                                }
1374                            }
1375                        },
1376                        3 => match alpha {
1377                            Some(alpha) => {
1378                                for i in 0..m {
1379                                    for j in 0..n {
1380                                        let dot = term(i, j, 0)
1381                                            .faer_add(term(i, j, 1))
1382                                            .faer_add(term(i, j, 2));
1383                                        acc.write_unchecked(
1384                                            i,
1385                                            j,
1386                                            E::faer_add(
1387                                                acc.read_unchecked(i, j).faer_mul(alpha),
1388                                                dot.faer_mul(beta),
1389                                            ),
1390                                        )
1391                                    }
1392                                }
1393                            }
1394                            None => {
1395                                for i in 0..m {
1396                                    for j in 0..n {
1397                                        let dot = term(i, j, 0)
1398                                            .faer_add(term(i, j, 1))
1399                                            .faer_add(term(i, j, 2));
1400                                        acc.write_unchecked(i, j, dot.faer_mul(beta))
1401                                    }
1402                                }
1403                            }
1404                        },
1405                        4 => match alpha {
1406                            Some(alpha) => {
1407                                for i in 0..m {
1408                                    for j in 0..n {
1409                                        let dot = E::faer_add(
1410                                            E::faer_add(term(i, j, 0), term(i, j, 1)),
1411                                            E::faer_add(term(i, j, 2), term(i, j, 3)),
1412                                        );
1413
1414                                        acc.write_unchecked(
1415                                            i,
1416                                            j,
1417                                            E::faer_add(
1418                                                acc.read_unchecked(i, j).faer_mul(alpha),
1419                                                dot.faer_mul(beta),
1420                                            ),
1421                                        )
1422                                    }
1423                                }
1424                            }
1425                            None => {
1426                                for i in 0..m {
1427                                    for j in 0..n {
1428                                        let dot = E::faer_add(
1429                                            E::faer_add(term(i, j, 0), term(i, j, 1)),
1430                                            E::faer_add(term(i, j, 2), term(i, j, 3)),
1431                                        );
1432                                        acc.write_unchecked(i, j, dot.faer_mul(beta))
1433                                    }
1434                                }
1435                            }
1436                        },
1437                        5 => match alpha {
1438                            Some(alpha) => {
1439                                for i in 0..m {
1440                                    for j in 0..n {
1441                                        let dot = E::faer_add(
1442                                            E::faer_add(term(i, j, 0), term(i, j, 1))
1443                                                .faer_add(term(i, j, 2)),
1444                                            E::faer_add(term(i, j, 3), term(i, j, 4)),
1445                                        );
1446
1447                                        acc.write_unchecked(
1448                                            i,
1449                                            j,
1450                                            E::faer_add(
1451                                                acc.read_unchecked(i, j).faer_mul(alpha),
1452                                                dot.faer_mul(beta),
1453                                            ),
1454                                        )
1455                                    }
1456                                }
1457                            }
1458                            None => {
1459                                for i in 0..m {
1460                                    for j in 0..n {
1461                                        let dot = E::faer_add(
1462                                            E::faer_add(term(i, j, 0), term(i, j, 1))
1463                                                .faer_add(term(i, j, 2)),
1464                                            E::faer_add(term(i, j, 3), term(i, j, 4)),
1465                                        );
1466                                        acc.write_unchecked(i, j, dot.faer_mul(beta))
1467                                    }
1468                                }
1469                            }
1470                        },
1471                        6 => match alpha {
1472                            Some(alpha) => {
1473                                for i in 0..m {
1474                                    for j in 0..n {
1475                                        let dot = E::faer_add(
1476                                            E::faer_add(term(i, j, 0), term(i, j, 1))
1477                                                .faer_add(term(i, j, 2)),
1478                                            E::faer_add(term(i, j, 3), term(i, j, 4))
1479                                                .faer_add(term(i, j, 5)),
1480                                        );
1481
1482                                        acc.write_unchecked(
1483                                            i,
1484                                            j,
1485                                            E::faer_add(
1486                                                acc.read_unchecked(i, j).faer_mul(alpha),
1487                                                dot.faer_mul(beta),
1488                                            ),
1489                                        )
1490                                    }
1491                                }
1492                            }
1493                            None => {
1494                                for i in 0..m {
1495                                    for j in 0..n {
1496                                        let dot = E::faer_add(
1497                                            E::faer_add(term(i, j, 0), term(i, j, 1))
1498                                                .faer_add(term(i, j, 2)),
1499                                            E::faer_add(term(i, j, 3), term(i, j, 4))
1500                                                .faer_add(term(i, j, 5)),
1501                                        );
1502                                        acc.write_unchecked(i, j, dot.faer_mul(beta))
1503                                    }
1504                                }
1505                            }
1506                        },
1507                        _ => unreachable!(),
1508                    }
1509                };
1510            }
1511
1512            match (conj_lhs, conj_rhs) {
1513                (Conj::Yes, Conj::Yes) => {
1514                    let term = {
1515                        #[inline(always)]
1516                        |i, j, depth| {
1517                            (lhs.read_unchecked(i, depth)
1518                                .faer_mul(rhs.read_unchecked(depth, j)))
1519                            .faer_conj()
1520                        }
1521                    };
1522                    small_gemm!(term);
1523                }
1524                (Conj::Yes, Conj::No) => {
1525                    let term = {
1526                        #[inline(always)]
1527                        |i, j, depth| {
1528                            lhs.read_unchecked(i, depth)
1529                                .faer_conj()
1530                                .faer_mul(rhs.read_unchecked(depth, j))
1531                        }
1532                    };
1533                    small_gemm!(term);
1534                }
1535                (Conj::No, Conj::Yes) => {
1536                    let term = {
1537                        #[inline(always)]
1538                        |i, j, depth| {
1539                            lhs.read_unchecked(i, depth)
1540                                .faer_mul(rhs.read_unchecked(depth, j).faer_conj())
1541                        }
1542                    };
1543                    small_gemm!(term);
1544                }
1545                (Conj::No, Conj::No) => {
1546                    let term = {
1547                        #[inline(always)]
1548                        |i, j, depth| {
1549                            lhs.read_unchecked(i, depth)
1550                                .faer_mul(rhs.read_unchecked(depth, j))
1551                        }
1552                    };
1553                    small_gemm!(term);
1554                }
1555            }
1556            return;
1557        }
1558    }
1559
1560    #[cfg(not(test))]
1561    let _use_gemm = true;
1562
1563    if _use_gemm {
1564        let gemm_parallelism = match parallelism {
1565            Parallelism::None => gemm::Parallelism::None,
1566            #[cfg(feature = "rayon")]
1567            Parallelism::Rayon(0) => gemm::Parallelism::Rayon(rayon::current_num_threads()),
1568            #[cfg(feature = "rayon")]
1569            Parallelism::Rayon(n_threads) => gemm::Parallelism::Rayon(n_threads),
1570        };
1571        if coe::is_same::<f32, E>() {
1572            let mut acc: MatMut<'_, f32> = coe::coerce(acc);
1573            let a: MatRef<'_, f32> = coe::coerce(lhs);
1574            let b: MatRef<'_, f32> = coe::coerce(rhs);
1575            let alpha: Option<f32> = coe::coerce_static(alpha);
1576            let beta: f32 = coe::coerce_static(beta);
1577            unsafe {
1578                gemm::gemm(
1579                    m,
1580                    n,
1581                    k,
1582                    acc.rb_mut().as_ptr_mut(),
1583                    acc.col_stride(),
1584                    acc.row_stride(),
1585                    alpha.is_some(),
1586                    a.as_ptr(),
1587                    a.col_stride(),
1588                    a.row_stride(),
1589                    b.as_ptr(),
1590                    b.col_stride(),
1591                    b.row_stride(),
1592                    alpha.unwrap_or(0.0),
1593                    beta,
1594                    false,
1595                    conj_lhs == Conj::Yes,
1596                    conj_rhs == Conj::Yes,
1597                    gemm_parallelism,
1598                )
1599            };
1600            return;
1601        }
1602        if coe::is_same::<f64, E>() {
1603            let mut acc: MatMut<'_, f64> = coe::coerce(acc);
1604            let a: MatRef<'_, f64> = coe::coerce(lhs);
1605            let b: MatRef<'_, f64> = coe::coerce(rhs);
1606            let alpha: Option<f64> = coe::coerce_static(alpha);
1607            let beta: f64 = coe::coerce_static(beta);
1608            unsafe {
1609                gemm::gemm(
1610                    m,
1611                    n,
1612                    k,
1613                    acc.rb_mut().as_ptr_mut(),
1614                    acc.col_stride(),
1615                    acc.row_stride(),
1616                    alpha.is_some(),
1617                    a.as_ptr(),
1618                    a.col_stride(),
1619                    a.row_stride(),
1620                    b.as_ptr(),
1621                    b.col_stride(),
1622                    b.row_stride(),
1623                    alpha.unwrap_or(0.0),
1624                    beta,
1625                    false,
1626                    conj_lhs == Conj::Yes,
1627                    conj_rhs == Conj::Yes,
1628                    gemm_parallelism,
1629                )
1630            };
1631            return;
1632        }
1633        if coe::is_same::<c32, E>() {
1634            let mut acc: MatMut<'_, c32> = coe::coerce(acc);
1635            let a: MatRef<'_, c32> = coe::coerce(lhs);
1636            let b: MatRef<'_, c32> = coe::coerce(rhs);
1637            let alpha: Option<c32> = coe::coerce_static(alpha);
1638            let beta: c32 = coe::coerce_static(beta);
1639            unsafe {
1640                gemm::gemm(
1641                    m,
1642                    n,
1643                    k,
1644                    acc.rb_mut().as_ptr_mut() as *mut gemm::c32,
1645                    acc.col_stride(),
1646                    acc.row_stride(),
1647                    alpha.is_some(),
1648                    a.as_ptr() as *const gemm::c32,
1649                    a.col_stride(),
1650                    a.row_stride(),
1651                    b.as_ptr() as *const gemm::c32,
1652                    b.col_stride(),
1653                    b.row_stride(),
1654                    alpha.unwrap_or(c32 { re: 0.0, im: 0.0 }).into(),
1655                    beta.into(),
1656                    false,
1657                    conj_lhs == Conj::Yes,
1658                    conj_rhs == Conj::Yes,
1659                    gemm_parallelism,
1660                )
1661            };
1662            return;
1663        }
1664        if coe::is_same::<c64, E>() {
1665            let mut acc: MatMut<'_, c64> = coe::coerce(acc);
1666            let a: MatRef<'_, c64> = coe::coerce(lhs);
1667            let b: MatRef<'_, c64> = coe::coerce(rhs);
1668            let alpha: Option<c64> = coe::coerce_static(alpha);
1669            let beta: c64 = coe::coerce_static(beta);
1670            unsafe {
1671                gemm::gemm(
1672                    m,
1673                    n,
1674                    k,
1675                    acc.rb_mut().as_ptr_mut() as *mut gemm::c64,
1676                    acc.col_stride(),
1677                    acc.row_stride(),
1678                    alpha.is_some(),
1679                    a.as_ptr() as *const gemm::c64,
1680                    a.col_stride(),
1681                    a.row_stride(),
1682                    b.as_ptr() as *const gemm::c64,
1683                    b.col_stride(),
1684                    b.row_stride(),
1685                    alpha.unwrap_or(c64 { re: 0.0, im: 0.0 }).into(),
1686                    beta.into(),
1687                    false,
1688                    conj_lhs == Conj::Yes,
1689                    conj_rhs == Conj::Yes,
1690                    gemm_parallelism,
1691                )
1692            };
1693            return;
1694        }
1695    }
1696
1697    let arch = E::Simd::default();
1698    let lane_count = arch.dispatch(SimdLaneCount::<E> {
1699        __marker: PhantomData,
1700    });
1701
1702    let mut a = lhs;
1703    let mut b = rhs;
1704    let mut conj_a = conj_lhs;
1705    let mut conj_b = conj_rhs;
1706
1707    if n < m {
1708        (a, b) = (b.transpose(), a.transpose());
1709        core::mem::swap(&mut conj_a, &mut conj_b);
1710        acc = acc.transpose_mut();
1711    }
1712
1713    if b.row_stride() < 0 {
1714        a = a.reverse_cols();
1715        b = b.reverse_rows();
1716    }
1717
1718    let m = acc.nrows();
1719    let n = acc.ncols();
1720
1721    let padded_m = m.msrv_checked_next_multiple_of(lane_count).unwrap();
1722
1723    let mut a_copy = a.to_owned();
1724    a_copy.resize_with(padded_m, k, |_, _| E::faer_zero());
1725    let a_copy = a_copy.as_ref();
1726    let mut tmp = crate::Mat::<E>::zeros(padded_m, n);
1727    let tmp_conj_b = match (conj_a, conj_b) {
1728        (Conj::Yes, Conj::Yes) | (Conj::No, Conj::No) => Conj::No,
1729        (Conj::Yes, Conj::No) | (Conj::No, Conj::Yes) => Conj::Yes,
1730    };
1731    if b.row_stride() == 1 {
1732        matmul_with_conj_impl(tmp.as_mut(), a_copy, b, tmp_conj_b, parallelism);
1733    } else {
1734        let b = b.to_owned();
1735        matmul_with_conj_impl(tmp.as_mut(), a_copy, b.as_ref(), tmp_conj_b, parallelism);
1736    }
1737
1738    let tmp = tmp.as_ref().subrows(0, m);
1739
1740    match alpha {
1741        Some(alpha) => match conj_a {
1742            Conj::Yes => zipped!(acc, tmp).for_each(|unzipped!(mut acc, tmp)| {
1743                acc.write(E::faer_add(
1744                    acc.read().faer_mul(alpha),
1745                    tmp.read().faer_conj().faer_mul(beta),
1746                ))
1747            }),
1748            Conj::No => zipped!(acc, tmp).for_each(|unzipped!(mut acc, tmp)| {
1749                acc.write(E::faer_add(
1750                    acc.read().faer_mul(alpha),
1751                    tmp.read().faer_mul(beta),
1752                ))
1753            }),
1754        },
1755        None => match conj_a {
1756            Conj::Yes => {
1757                zipped!(acc, tmp).for_each(|unzipped!(mut acc, tmp)| {
1758                    acc.write(tmp.read().faer_conj().faer_mul(beta))
1759                });
1760            }
1761            Conj::No => {
1762                zipped!(acc, tmp)
1763                    .for_each(|unzipped!(mut acc, tmp)| acc.write(tmp.read().faer_mul(beta)));
1764            }
1765        },
1766    }
1767}
1768
1769/// Computes the matrix product `[alpha * acc] + beta * lhs * rhs` (while optionally conjugating
1770/// either or both of the input matrices) and stores the result in `acc`.
1771///
1772/// Performs the operation:
1773/// - `acc = beta * Op_lhs(lhs) * Op_rhs(rhs)` if `alpha` is `None` (in this case, the preexisting
1774/// values in `acc` are not read, so it is allowed to be a view over uninitialized values if `E:
1775/// Copy`),
1776/// - `acc = alpha * acc + beta * Op_lhs(lhs) * Op_rhs(rhs)` if `alpha` is `Some(_)`,
1777///
1778/// `Op_lhs` is the identity if `conj_lhs` is `Conj::No`, and the conjugation operation if it is
1779/// `Conj::Yes`.  
1780/// `Op_rhs` is the identity if `conj_rhs` is `Conj::No`, and the conjugation operation if it is
1781/// `Conj::Yes`.  
1782///
1783/// # Panics
1784///
1785/// Panics if the matrix dimensions are not compatible for matrix multiplication.  
1786/// i.e.  
1787///  - `acc.nrows() == lhs.nrows()`
1788///  - `acc.ncols() == rhs.ncols()`
1789///  - `lhs.ncols() == rhs.nrows()`
1790///
1791/// # Example
1792///
1793/// ```
1794/// use faer_core::{mat, mul::matmul_with_conj, unzipped, zipped, Conj, Mat, Parallelism};
1795///
1796/// let lhs = mat![[0.0, 2.0], [1.0, 3.0]];
1797/// let rhs = mat![[4.0, 6.0], [5.0, 7.0]];
1798///
1799/// let mut acc = Mat::<f64>::zeros(2, 2);
1800/// let target = mat![
1801///     [
1802///         2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)),
1803///         2.5 * (lhs.read(0, 0) * rhs.read(0, 1) + lhs.read(0, 1) * rhs.read(1, 1)),
1804///     ],
1805///     [
1806///         2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)),
1807///         2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)),
1808///     ],
1809/// ];
1810///
1811/// matmul_with_conj(
1812///     acc.as_mut(),
1813///     lhs.as_ref(),
1814///     Conj::No,
1815///     rhs.as_ref(),
1816///     Conj::No,
1817///     None,
1818///     2.5,
1819///     Parallelism::None,
1820/// );
1821///
1822/// zipped!(acc.as_ref(), target.as_ref())
1823///     .for_each(|unzipped!(acc, target)| assert!((acc.read() - target.read()).abs() < 1e-10));
1824/// ```
1825#[inline]
1826#[track_caller]
1827pub fn matmul_with_conj<E: ComplexField>(
1828    acc: MatMut<'_, E>,
1829    lhs: MatRef<'_, E>,
1830    conj_lhs: Conj,
1831    rhs: MatRef<'_, E>,
1832    conj_rhs: Conj,
1833    alpha: Option<E>,
1834    beta: E,
1835    parallelism: Parallelism,
1836) {
1837    assert!(all(
1838        acc.nrows() == lhs.nrows(),
1839        acc.ncols() == rhs.ncols(),
1840        lhs.ncols() == rhs.nrows(),
1841    ));
1842    matmul_with_conj_gemm_dispatch(
1843        acc,
1844        lhs,
1845        conj_lhs,
1846        rhs,
1847        conj_rhs,
1848        alpha,
1849        beta,
1850        parallelism,
1851        true,
1852    );
1853}
1854
1855/// Computes the matrix product `[alpha * acc] + beta * lhs * rhs` and
1856/// stores the result in `acc`.
1857///
1858/// Performs the operation:
1859/// - `acc = beta * lhs * rhs` if `alpha` is `None` (in this case, the preexisting values in `acc`
1860///   are not read, so it is allowed to be a view over uninitialized values if `E: Copy`),
1861/// - `acc = alpha * acc + beta * lhs * rhs` if `alpha` is `Some(_)`,
1862///
1863/// # Panics
1864///
1865/// Panics if the matrix dimensions are not compatible for matrix multiplication.  
1866/// i.e.  
1867///  - `acc.nrows() == lhs.nrows()`
1868///  - `acc.ncols() == rhs.ncols()`
1869///  - `lhs.ncols() == rhs.nrows()`
1870///
1871/// # Example
1872///
1873/// ```
1874/// use faer_core::{mat, mul::matmul, unzipped, zipped, Mat, Parallelism};
1875///
1876/// let lhs = mat![[0.0, 2.0], [1.0, 3.0]];
1877/// let rhs = mat![[4.0, 6.0], [5.0, 7.0]];
1878///
1879/// let mut acc = Mat::<f64>::zeros(2, 2);
1880/// let target = mat![
1881///     [
1882///         2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)),
1883///         2.5 * (lhs.read(0, 0) * rhs.read(0, 1) + lhs.read(0, 1) * rhs.read(1, 1)),
1884///     ],
1885///     [
1886///         2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)),
1887///         2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)),
1888///     ],
1889/// ];
1890///
1891/// matmul(
1892///     acc.as_mut(),
1893///     lhs.as_ref(),
1894///     rhs.as_ref(),
1895///     None,
1896///     2.5,
1897///     Parallelism::None,
1898/// );
1899///
1900/// zipped!(acc.as_ref(), target.as_ref())
1901///     .for_each(|unzipped!(acc, target)| assert!((acc.read() - target.read()).abs() < 1e-10));
1902/// ```
1903#[track_caller]
1904pub fn matmul<E: ComplexField, LhsE: Conjugate<Canonical = E>, RhsE: Conjugate<Canonical = E>>(
1905    acc: MatMut<'_, E>,
1906    lhs: MatRef<'_, LhsE>,
1907    rhs: MatRef<'_, RhsE>,
1908    alpha: Option<E>,
1909    beta: E,
1910    parallelism: Parallelism,
1911) {
1912    let (lhs, conj_lhs) = lhs.canonicalize();
1913    let (rhs, conj_rhs) = rhs.canonicalize();
1914    matmul_with_conj::<E>(acc, lhs, conj_lhs, rhs, conj_rhs, alpha, beta, parallelism);
1915}
1916
1917macro_rules! stack_mat_16x16_begin {
1918    ($name: ident, $nrows: expr, $ncols: expr, $rs: expr, $cs: expr, $ty: ty) => {
1919        let __nrows: usize = $nrows;
1920        let __ncols: usize = $ncols;
1921        let __rs: isize = $rs;
1922        let __cs: isize = $cs;
1923        let mut __data = <$ty as $crate::Entity>::faer_map(
1924            <$ty as $crate::Entity>::UNIT,
1925            #[inline(always)]
1926            |()| unsafe {
1927                $crate::transmute_unchecked::<
1928                    ::core::mem::MaybeUninit<[<$ty as $crate::Entity>::Unit; 16 * 16]>,
1929                    [::core::mem::MaybeUninit<<$ty as $crate::Entity>::Unit>; 16 * 16],
1930                >(::core::mem::MaybeUninit::<
1931                    [<$ty as $crate::Entity>::Unit; 16 * 16],
1932                >::uninit())
1933            },
1934        );
1935
1936        <$ty as $crate::Entity>::faer_map(
1937            <$ty as $crate::Entity>::faer_zip(
1938                <$ty as $crate::Entity>::faer_as_mut(&mut __data),
1939                <$ty as $crate::Entity>::faer_into_units(<$ty as $crate::ComplexField>::faer_zero()),
1940            ),
1941            #[inline(always)]
1942            |(__data, zero)| {
1943                let __data: &mut _ = __data;
1944                for __data in __data {
1945                    let __data : &mut _ = __data;
1946                    *__data = ::core::mem::MaybeUninit::new(::core::clone::Clone::clone(&zero));
1947                }
1948            },
1949        );
1950        let mut __data =
1951            <$ty as $crate::Entity>::faer_map(<$ty as $crate::Entity>::faer_as_mut(&mut __data), |__data: &mut _| {
1952                (__data as *mut [::core::mem::MaybeUninit<<$ty as $crate::Entity>::Unit>; 16 * 16]
1953                    as *mut <$ty as $crate::Entity>::Unit)
1954            });
1955
1956        let mut $name = unsafe {
1957            $crate::mat::from_raw_parts_mut::<'_, $ty>(__data, __nrows, __ncols, 1isize, 16isize)
1958        };
1959
1960        if __cs.unsigned_abs() < __rs.unsigned_abs() {
1961            $name = $name.transpose_mut();
1962        }
1963        if __rs == -1 {
1964            $name = $name.reverse_rows_mut();
1965        }
1966        if __cs == -1 {
1967            $name = $name.reverse_cols_mut();
1968        }
1969    };
1970}
1971
1972/// Triangular matrix multiplication module, where some of the operands are treated as triangular
1973/// matrices.
1974pub mod triangular {
1975    use super::*;
1976    use crate::{assert, debug_assert, join_raw, zip::Diag};
1977
1978    #[repr(u8)]
1979    #[derive(Copy, Clone, Debug)]
1980    pub(crate) enum DiagonalKind {
1981        Zero,
1982        Unit,
1983        Generic,
1984    }
1985
1986    unsafe fn copy_lower<E: ComplexField>(
1987        mut dst: MatMut<'_, E>,
1988        src: MatRef<'_, E>,
1989        src_diag: DiagonalKind,
1990    ) {
1991        let n = dst.nrows();
1992        debug_assert!(n == dst.nrows());
1993        debug_assert!(n == dst.ncols());
1994        debug_assert!(n == src.nrows());
1995        debug_assert!(n == src.ncols());
1996
1997        let strict = match src_diag {
1998            DiagonalKind::Zero => {
1999                for j in 0..n {
2000                    dst.write_unchecked(j, j, E::faer_zero());
2001                }
2002                true
2003            }
2004            DiagonalKind::Unit => {
2005                for j in 0..n {
2006                    dst.write_unchecked(j, j, E::faer_one());
2007                }
2008                true
2009            }
2010            DiagonalKind::Generic => false,
2011        };
2012
2013        zipped!(dst.rb_mut())
2014            .for_each_triangular_upper(Diag::Skip, |unzipped!(mut dst)| dst.write(E::faer_zero()));
2015        zipped!(dst, src).for_each_triangular_lower(
2016            if strict { Diag::Skip } else { Diag::Include },
2017            |unzipped!(mut dst, src)| dst.write(src.read()),
2018        );
2019    }
2020
2021    unsafe fn accum_lower<E: ComplexField>(
2022        dst: MatMut<'_, E>,
2023        src: MatRef<'_, E>,
2024        skip_diag: bool,
2025        alpha: Option<E>,
2026    ) {
2027        let n = dst.nrows();
2028        debug_assert!(n == dst.nrows());
2029        debug_assert!(n == dst.ncols());
2030        debug_assert!(n == src.nrows());
2031        debug_assert!(n == src.ncols());
2032
2033        match alpha {
2034            Some(alpha) => {
2035                zipped!(dst, src).for_each_triangular_lower(
2036                    if skip_diag { Diag::Skip } else { Diag::Include },
2037                    |unzipped!(mut dst, src)| {
2038                        dst.write(alpha.faer_mul(dst.read().faer_add(src.read())))
2039                    },
2040                );
2041            }
2042            None => {
2043                zipped!(dst, src).for_each_triangular_lower(
2044                    if skip_diag { Diag::Skip } else { Diag::Include },
2045                    |unzipped!(mut dst, src)| dst.write(src.read()),
2046                );
2047            }
2048        }
2049    }
2050
2051    #[inline]
2052    unsafe fn copy_upper<E: ComplexField>(
2053        dst: MatMut<'_, E>,
2054        src: MatRef<'_, E>,
2055        src_diag: DiagonalKind,
2056    ) {
2057        copy_lower(dst.transpose_mut(), src.transpose(), src_diag)
2058    }
2059
2060    #[inline]
2061    unsafe fn mul<E: ComplexField>(
2062        dst: MatMut<'_, E>,
2063        lhs: MatRef<'_, E>,
2064        rhs: MatRef<'_, E>,
2065        alpha: Option<E>,
2066        beta: E,
2067        conj_lhs: Conj,
2068        conj_rhs: Conj,
2069        parallelism: Parallelism,
2070    ) {
2071        super::matmul_with_conj(dst, lhs, conj_lhs, rhs, conj_rhs, alpha, beta, parallelism);
2072    }
2073
2074    unsafe fn mat_x_lower_into_lower_impl_unchecked<E: ComplexField>(
2075        dst: MatMut<'_, E>,
2076        skip_diag: bool,
2077        lhs: MatRef<'_, E>,
2078        rhs: MatRef<'_, E>,
2079        rhs_diag: DiagonalKind,
2080        alpha: Option<E>,
2081        beta: E,
2082        conj_lhs: Conj,
2083        conj_rhs: Conj,
2084        parallelism: Parallelism,
2085    ) {
2086        let n = dst.nrows();
2087        debug_assert!(n == dst.nrows());
2088        debug_assert!(n == dst.ncols());
2089        debug_assert!(n == lhs.nrows());
2090        debug_assert!(n == lhs.ncols());
2091        debug_assert!(n == rhs.nrows());
2092        debug_assert!(n == rhs.ncols());
2093
2094        if n <= 16 {
2095            let op = {
2096                #[inline(never)]
2097                || {
2098                    stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E);
2099                    stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E);
2100
2101                    copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
2102                    mul(
2103                        temp_dst.rb_mut(),
2104                        lhs,
2105                        temp_rhs.rb(),
2106                        None,
2107                        beta,
2108                        conj_lhs,
2109                        conj_rhs,
2110                        parallelism,
2111                    );
2112                    accum_lower(dst, temp_dst.rb(), skip_diag, alpha);
2113                }
2114            };
2115            op();
2116        } else {
2117            let bs = n / 2;
2118
2119            let (mut dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs);
2120            let (lhs_top_left, lhs_top_right, lhs_bot_left, lhs_bot_right) = lhs.split_at(bs, bs);
2121            let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
2122
2123            // lhs_bot_right × rhs_bot_left  => dst_bot_left  | mat × mat => mat |   1
2124            // lhs_bot_right × rhs_bot_right => dst_bot_right | mat × low => low |   X
2125            //
2126            // lhs_top_left  × rhs_top_left  => dst_top_left  | mat × low => low |   X
2127            // lhs_top_right × rhs_bot_left  => dst_top_left  | mat × mat => low | 1/2
2128            // lhs_bot_left  × rhs_top_left  => dst_bot_left  | mat × low => mat | 1/2
2129
2130            mul(
2131                dst_bot_left.rb_mut(),
2132                lhs_bot_right,
2133                rhs_bot_left,
2134                alpha,
2135                beta,
2136                conj_lhs,
2137                conj_rhs,
2138                parallelism,
2139            );
2140            mat_x_lower_into_lower_impl_unchecked(
2141                dst_bot_right,
2142                skip_diag,
2143                lhs_bot_right,
2144                rhs_bot_right,
2145                rhs_diag,
2146                alpha,
2147                beta,
2148                conj_lhs,
2149                conj_rhs,
2150                parallelism,
2151            );
2152
2153            mat_x_lower_into_lower_impl_unchecked(
2154                dst_top_left.rb_mut(),
2155                skip_diag,
2156                lhs_top_left,
2157                rhs_top_left,
2158                rhs_diag,
2159                alpha,
2160                beta,
2161                conj_lhs,
2162                conj_rhs,
2163                parallelism,
2164            );
2165            mat_x_mat_into_lower_impl_unchecked(
2166                dst_top_left,
2167                skip_diag,
2168                lhs_top_right,
2169                rhs_bot_left,
2170                Some(E::faer_one()),
2171                beta,
2172                conj_lhs,
2173                conj_rhs,
2174                parallelism,
2175            );
2176            mat_x_lower_impl_unchecked(
2177                dst_bot_left,
2178                lhs_bot_left,
2179                rhs_top_left,
2180                rhs_diag,
2181                Some(E::faer_one()),
2182                beta,
2183                conj_lhs,
2184                conj_rhs,
2185                parallelism,
2186            );
2187        }
2188    }
2189
2190    unsafe fn mat_x_lower_impl_unchecked<E: ComplexField>(
2191        dst: MatMut<'_, E>,
2192        lhs: MatRef<'_, E>,
2193        rhs: MatRef<'_, E>,
2194        rhs_diag: DiagonalKind,
2195        alpha: Option<E>,
2196        beta: E,
2197        conj_lhs: Conj,
2198        conj_rhs: Conj,
2199        parallelism: Parallelism,
2200    ) {
2201        let n = rhs.nrows();
2202        let m = lhs.nrows();
2203        debug_assert!(m == lhs.nrows());
2204        debug_assert!(n == lhs.ncols());
2205        debug_assert!(n == rhs.nrows());
2206        debug_assert!(n == rhs.ncols());
2207        debug_assert!(m == dst.nrows());
2208        debug_assert!(n == dst.ncols());
2209
2210        let join_parallelism = if n * n * m < 128 * 128 * 64 {
2211            Parallelism::None
2212        } else {
2213            parallelism
2214        };
2215
2216        if n <= 16 {
2217            let op = {
2218                #[inline(never)]
2219                || {
2220                    stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E);
2221
2222                    copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
2223
2224                    mul(
2225                        dst,
2226                        lhs,
2227                        temp_rhs.rb(),
2228                        alpha,
2229                        beta,
2230                        conj_lhs,
2231                        conj_rhs,
2232                        parallelism,
2233                    );
2234                }
2235            };
2236            op();
2237        } else {
2238            // split rhs into 3 sections
2239            // split lhs and dst into 2 sections
2240
2241            let bs = n / 2;
2242
2243            let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
2244            let (lhs_left, lhs_right) = lhs.split_at_col(bs);
2245            let (mut dst_left, mut dst_right) = dst.split_at_col_mut(bs);
2246
2247            join_raw(
2248                |parallelism| {
2249                    mat_x_lower_impl_unchecked(
2250                        dst_left.rb_mut(),
2251                        lhs_left,
2252                        rhs_top_left,
2253                        rhs_diag,
2254                        alpha,
2255                        beta,
2256                        conj_lhs,
2257                        conj_rhs,
2258                        parallelism,
2259                    )
2260                },
2261                |parallelism| {
2262                    mat_x_lower_impl_unchecked(
2263                        dst_right.rb_mut(),
2264                        lhs_right,
2265                        rhs_bot_right,
2266                        rhs_diag,
2267                        alpha,
2268                        beta,
2269                        conj_lhs,
2270                        conj_rhs,
2271                        parallelism,
2272                    )
2273                },
2274                join_parallelism,
2275            );
2276            mul(
2277                dst_left,
2278                lhs_right,
2279                rhs_bot_left,
2280                Some(E::faer_one()),
2281                beta,
2282                conj_lhs,
2283                conj_rhs,
2284                parallelism,
2285            );
2286        }
2287    }
2288
2289    unsafe fn lower_x_lower_into_lower_impl_unchecked<E: ComplexField>(
2290        dst: MatMut<'_, E>,
2291        skip_diag: bool,
2292        lhs: MatRef<'_, E>,
2293        lhs_diag: DiagonalKind,
2294        rhs: MatRef<'_, E>,
2295        rhs_diag: DiagonalKind,
2296        alpha: Option<E>,
2297        beta: E,
2298        conj_lhs: Conj,
2299        conj_rhs: Conj,
2300        parallelism: Parallelism,
2301    ) {
2302        let n = dst.nrows();
2303        debug_assert!(n == lhs.nrows());
2304        debug_assert!(n == lhs.ncols());
2305        debug_assert!(n == rhs.nrows());
2306        debug_assert!(n == rhs.ncols());
2307        debug_assert!(n == dst.nrows());
2308        debug_assert!(n == dst.ncols());
2309
2310        if n <= 16 {
2311            let op = {
2312                #[inline(never)]
2313                || {
2314                    stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E);
2315                    stack_mat_16x16_begin!(temp_lhs, n, n, lhs.row_stride(), lhs.col_stride(), E);
2316                    stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E);
2317
2318                    copy_lower(temp_lhs.rb_mut(), lhs, lhs_diag);
2319                    copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
2320
2321                    mul(
2322                        temp_dst.rb_mut(),
2323                        temp_lhs.rb(),
2324                        temp_rhs.rb(),
2325                        None,
2326                        beta,
2327                        conj_lhs,
2328                        conj_rhs,
2329                        parallelism,
2330                    );
2331                    accum_lower(dst, temp_dst.rb(), skip_diag, alpha);
2332                }
2333            };
2334            op();
2335        } else {
2336            let bs = n / 2;
2337
2338            let (dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs);
2339            let (lhs_top_left, _, lhs_bot_left, lhs_bot_right) = lhs.split_at(bs, bs);
2340            let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
2341
2342            // lhs_top_left  × rhs_top_left  => dst_top_left  | low × low => low |   X
2343            // lhs_bot_left  × rhs_top_left  => dst_bot_left  | mat × low => mat | 1/2
2344            // lhs_bot_right × rhs_bot_left  => dst_bot_left  | low × mat => mat | 1/2
2345            // lhs_bot_right × rhs_bot_right => dst_bot_right | low × low => low |   X
2346
2347            lower_x_lower_into_lower_impl_unchecked(
2348                dst_top_left,
2349                skip_diag,
2350                lhs_top_left,
2351                lhs_diag,
2352                rhs_top_left,
2353                rhs_diag,
2354                alpha,
2355                beta,
2356                conj_lhs,
2357                conj_rhs,
2358                parallelism,
2359            );
2360            mat_x_lower_impl_unchecked(
2361                dst_bot_left.rb_mut(),
2362                lhs_bot_left,
2363                rhs_top_left,
2364                rhs_diag,
2365                alpha,
2366                beta,
2367                conj_lhs,
2368                conj_rhs,
2369                parallelism,
2370            );
2371            mat_x_lower_impl_unchecked(
2372                dst_bot_left.reverse_rows_and_cols_mut().transpose_mut(),
2373                rhs_bot_left.reverse_rows_and_cols().transpose(),
2374                lhs_bot_right.reverse_rows_and_cols().transpose(),
2375                lhs_diag,
2376                Some(E::faer_one()),
2377                beta,
2378                conj_rhs,
2379                conj_lhs,
2380                parallelism,
2381            );
2382            lower_x_lower_into_lower_impl_unchecked(
2383                dst_bot_right,
2384                skip_diag,
2385                lhs_bot_right,
2386                lhs_diag,
2387                rhs_bot_right,
2388                rhs_diag,
2389                alpha,
2390                beta,
2391                conj_lhs,
2392                conj_rhs,
2393                parallelism,
2394            )
2395        }
2396    }
2397
2398    unsafe fn upper_x_lower_impl_unchecked<E: ComplexField>(
2399        dst: MatMut<'_, E>,
2400        lhs: MatRef<'_, E>,
2401        lhs_diag: DiagonalKind,
2402        rhs: MatRef<'_, E>,
2403        rhs_diag: DiagonalKind,
2404        alpha: Option<E>,
2405        beta: E,
2406        conj_lhs: Conj,
2407        conj_rhs: Conj,
2408        parallelism: Parallelism,
2409    ) {
2410        let n = dst.nrows();
2411        debug_assert!(n == lhs.nrows());
2412        debug_assert!(n == lhs.ncols());
2413        debug_assert!(n == rhs.nrows());
2414        debug_assert!(n == rhs.ncols());
2415        debug_assert!(n == dst.nrows());
2416        debug_assert!(n == dst.ncols());
2417
2418        if n <= 16 {
2419            let op = {
2420                #[inline(never)]
2421                || {
2422                    stack_mat_16x16_begin!(temp_lhs, n, n, lhs.row_stride(), lhs.col_stride(), E);
2423                    stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E);
2424
2425                    copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
2426                    copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
2427
2428                    mul(
2429                        dst,
2430                        temp_lhs.rb(),
2431                        temp_rhs.rb(),
2432                        alpha,
2433                        beta,
2434                        conj_lhs,
2435                        conj_rhs,
2436                        parallelism,
2437                    );
2438                }
2439            };
2440            op();
2441        } else {
2442            let bs = n / 2;
2443
2444            let (mut dst_top_left, dst_top_right, dst_bot_left, dst_bot_right) =
2445                dst.split_at_mut(bs, bs);
2446            let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_at(bs, bs);
2447            let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
2448
2449            // lhs_top_right × rhs_bot_left  => dst_top_left  | mat × mat => mat |   1
2450            // lhs_top_left  × rhs_top_left  => dst_top_left  | upp × low => mat |   X
2451            //
2452            // lhs_top_right × rhs_bot_right => dst_top_right | mat × low => mat | 1/2
2453            // lhs_bot_right × rhs_bot_left  => dst_bot_left  | upp × mat => mat | 1/2
2454            // lhs_bot_right × rhs_bot_right => dst_bot_right | upp × low => mat |   X
2455
2456            join_raw(
2457                |_| {
2458                    mul(
2459                        dst_top_left.rb_mut(),
2460                        lhs_top_right,
2461                        rhs_bot_left,
2462                        alpha,
2463                        beta,
2464                        conj_lhs,
2465                        conj_rhs,
2466                        parallelism,
2467                    );
2468                    upper_x_lower_impl_unchecked(
2469                        dst_top_left,
2470                        lhs_top_left,
2471                        lhs_diag,
2472                        rhs_top_left,
2473                        rhs_diag,
2474                        Some(E::faer_one()),
2475                        beta,
2476                        conj_lhs,
2477                        conj_rhs,
2478                        parallelism,
2479                    )
2480                },
2481                |_| {
2482                    join_raw(
2483                        |_| {
2484                            mat_x_lower_impl_unchecked(
2485                                dst_top_right,
2486                                lhs_top_right,
2487                                rhs_bot_right,
2488                                rhs_diag,
2489                                alpha,
2490                                beta,
2491                                conj_lhs,
2492                                conj_rhs,
2493                                parallelism,
2494                            )
2495                        },
2496                        |_| {
2497                            mat_x_lower_impl_unchecked(
2498                                dst_bot_left.transpose_mut(),
2499                                rhs_bot_left.transpose(),
2500                                lhs_bot_right.transpose(),
2501                                lhs_diag,
2502                                alpha,
2503                                beta,
2504                                conj_rhs,
2505                                conj_lhs,
2506                                parallelism,
2507                            )
2508                        },
2509                        parallelism,
2510                    );
2511
2512                    upper_x_lower_impl_unchecked(
2513                        dst_bot_right,
2514                        lhs_bot_right,
2515                        lhs_diag,
2516                        rhs_bot_right,
2517                        rhs_diag,
2518                        alpha,
2519                        beta,
2520                        conj_lhs,
2521                        conj_rhs,
2522                        parallelism,
2523                    )
2524                },
2525                parallelism,
2526            );
2527        }
2528    }
2529
2530    unsafe fn upper_x_lower_into_lower_impl_unchecked<E: ComplexField>(
2531        dst: MatMut<'_, E>,
2532        skip_diag: bool,
2533        lhs: MatRef<'_, E>,
2534        lhs_diag: DiagonalKind,
2535        rhs: MatRef<'_, E>,
2536        rhs_diag: DiagonalKind,
2537        alpha: Option<E>,
2538        beta: E,
2539        conj_lhs: Conj,
2540        conj_rhs: Conj,
2541        parallelism: Parallelism,
2542    ) {
2543        let n = dst.nrows();
2544        debug_assert!(n == lhs.nrows());
2545        debug_assert!(n == lhs.ncols());
2546        debug_assert!(n == rhs.nrows());
2547        debug_assert!(n == rhs.ncols());
2548        debug_assert!(n == dst.nrows());
2549        debug_assert!(n == dst.ncols());
2550
2551        if n <= 16 {
2552            let op = {
2553                #[inline(never)]
2554                || {
2555                    stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E);
2556                    stack_mat_16x16_begin!(temp_lhs, n, n, lhs.row_stride(), lhs.col_stride(), E);
2557                    stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E);
2558
2559                    copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
2560                    copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
2561
2562                    mul(
2563                        temp_dst.rb_mut(),
2564                        temp_lhs.rb(),
2565                        temp_rhs.rb(),
2566                        None,
2567                        beta,
2568                        conj_lhs,
2569                        conj_rhs,
2570                        parallelism,
2571                    );
2572
2573                    accum_lower(dst, temp_dst.rb(), skip_diag, alpha);
2574                }
2575            };
2576            op();
2577        } else {
2578            let bs = n / 2;
2579
2580            let (mut dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs);
2581            let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_at(bs, bs);
2582            let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
2583
2584            // lhs_top_left  × rhs_top_left  => dst_top_left  | upp × low => low |   X
2585            // lhs_top_right × rhs_bot_left  => dst_top_left  | mat × mat => low | 1/2
2586            //
2587            // lhs_bot_right × rhs_bot_left  => dst_bot_left  | upp × mat => mat | 1/2
2588            // lhs_bot_right × rhs_bot_right => dst_bot_right | upp × low => low |   X
2589
2590            join_raw(
2591                |_| {
2592                    mat_x_mat_into_lower_impl_unchecked(
2593                        dst_top_left.rb_mut(),
2594                        skip_diag,
2595                        lhs_top_right,
2596                        rhs_bot_left,
2597                        alpha,
2598                        beta,
2599                        conj_lhs,
2600                        conj_rhs,
2601                        parallelism,
2602                    );
2603                    upper_x_lower_into_lower_impl_unchecked(
2604                        dst_top_left,
2605                        skip_diag,
2606                        lhs_top_left,
2607                        lhs_diag,
2608                        rhs_top_left,
2609                        rhs_diag,
2610                        Some(E::faer_one()),
2611                        beta,
2612                        conj_lhs,
2613                        conj_rhs,
2614                        parallelism,
2615                    )
2616                },
2617                |_| {
2618                    mat_x_lower_impl_unchecked(
2619                        dst_bot_left.transpose_mut(),
2620                        rhs_bot_left.transpose(),
2621                        lhs_bot_right.transpose(),
2622                        lhs_diag,
2623                        alpha,
2624                        beta,
2625                        conj_rhs,
2626                        conj_lhs,
2627                        parallelism,
2628                    );
2629                    upper_x_lower_into_lower_impl_unchecked(
2630                        dst_bot_right,
2631                        skip_diag,
2632                        lhs_bot_right,
2633                        lhs_diag,
2634                        rhs_bot_right,
2635                        rhs_diag,
2636                        alpha,
2637                        beta,
2638                        conj_lhs,
2639                        conj_rhs,
2640                        parallelism,
2641                    )
2642                },
2643                parallelism,
2644            );
2645        }
2646    }
2647
2648    unsafe fn mat_x_mat_into_lower_impl_unchecked<E: ComplexField>(
2649        dst: MatMut<'_, E>,
2650        skip_diag: bool,
2651        lhs: MatRef<'_, E>,
2652        rhs: MatRef<'_, E>,
2653        alpha: Option<E>,
2654        beta: E,
2655        conj_lhs: Conj,
2656        conj_rhs: Conj,
2657        parallelism: Parallelism,
2658    ) {
2659        debug_assert!(dst.nrows() == dst.ncols());
2660        debug_assert!(dst.nrows() == lhs.nrows());
2661        debug_assert!(dst.ncols() == rhs.ncols());
2662        debug_assert!(lhs.ncols() == rhs.nrows());
2663
2664        let n = dst.nrows();
2665        let k = lhs.ncols();
2666
2667        let join_parallelism = if n * n * k < 128 * 128 * 128 {
2668            Parallelism::None
2669        } else {
2670            parallelism
2671        };
2672
2673        if n <= 16 {
2674            let op = {
2675                #[inline(never)]
2676                || {
2677                    stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E);
2678
2679                    mul(
2680                        temp_dst.rb_mut(),
2681                        lhs,
2682                        rhs,
2683                        None,
2684                        beta,
2685                        conj_lhs,
2686                        conj_rhs,
2687                        parallelism,
2688                    );
2689                    accum_lower(dst, temp_dst.rb(), skip_diag, alpha);
2690                }
2691            };
2692            op();
2693        } else {
2694            let bs = n / 2;
2695            let (dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs);
2696            let (lhs_top, lhs_bot) = lhs.split_at_row(bs);
2697            let (rhs_left, rhs_right) = rhs.split_at_col(bs);
2698
2699            join_raw(
2700                |_| {
2701                    mul(
2702                        dst_bot_left,
2703                        lhs_bot,
2704                        rhs_left,
2705                        alpha,
2706                        beta,
2707                        conj_lhs,
2708                        conj_rhs,
2709                        parallelism,
2710                    )
2711                },
2712                |_| {
2713                    join_raw(
2714                        |_| {
2715                            mat_x_mat_into_lower_impl_unchecked(
2716                                dst_top_left,
2717                                skip_diag,
2718                                lhs_top,
2719                                rhs_left,
2720                                alpha,
2721                                beta,
2722                                conj_lhs,
2723                                conj_rhs,
2724                                parallelism,
2725                            )
2726                        },
2727                        |_| {
2728                            mat_x_mat_into_lower_impl_unchecked(
2729                                dst_bot_right,
2730                                skip_diag,
2731                                lhs_bot,
2732                                rhs_right,
2733                                alpha,
2734                                beta,
2735                                conj_lhs,
2736                                conj_rhs,
2737                                parallelism,
2738                            )
2739                        },
2740                        join_parallelism,
2741                    )
2742                },
2743                join_parallelism,
2744            );
2745        }
2746    }
2747
2748    /// Describes the parts of the matrix that must be accessed.
2749    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
2750    pub enum BlockStructure {
2751        /// The full matrix is accessed.
2752        Rectangular,
2753        /// The lower triangular half (including the diagonal) is accessed.
2754        TriangularLower,
2755        /// The lower triangular half (excluding the diagonal) is accessed.
2756        StrictTriangularLower,
2757        /// The lower triangular half (excluding the diagonal, which is assumed to be equal to
2758        /// `1.0`) is accessed.
2759        UnitTriangularLower,
2760        /// The upper triangular half (including the diagonal) is accessed.
2761        TriangularUpper,
2762        /// The upper triangular half (excluding the diagonal) is accessed.
2763        StrictTriangularUpper,
2764        /// The upper triangular half (excluding the diagonal, which is assumed to be equal to
2765        /// `1.0`) is accessed.
2766        UnitTriangularUpper,
2767    }
2768
2769    impl BlockStructure {
2770        /// Checks if `self` is full.
2771        #[inline]
2772        pub fn is_dense(self) -> bool {
2773            matches!(self, BlockStructure::Rectangular)
2774        }
2775
2776        /// Checks if `self` is triangular lower (either inclusive or exclusive).
2777        #[inline]
2778        pub fn is_lower(self) -> bool {
2779            use BlockStructure::*;
2780            matches!(
2781                self,
2782                TriangularLower | StrictTriangularLower | UnitTriangularLower
2783            )
2784        }
2785
2786        /// Checks if `self` is triangular upper (either inclusive or exclusive).
2787        #[inline]
2788        pub fn is_upper(self) -> bool {
2789            use BlockStructure::*;
2790            matches!(
2791                self,
2792                TriangularUpper | StrictTriangularUpper | UnitTriangularUpper
2793            )
2794        }
2795
2796        /// Returns the block structure corresponding to the transposed matrix.
2797        #[inline]
2798        pub fn transpose(self) -> Self {
2799            use BlockStructure::*;
2800            match self {
2801                Rectangular => Rectangular,
2802                TriangularLower => TriangularUpper,
2803                StrictTriangularLower => StrictTriangularUpper,
2804                UnitTriangularLower => UnitTriangularUpper,
2805                TriangularUpper => TriangularLower,
2806                StrictTriangularUpper => StrictTriangularLower,
2807                UnitTriangularUpper => UnitTriangularLower,
2808            }
2809        }
2810
2811        #[inline]
2812        pub(crate) fn diag_kind(self) -> DiagonalKind {
2813            use BlockStructure::*;
2814            match self {
2815                Rectangular | TriangularLower | TriangularUpper => DiagonalKind::Generic,
2816                StrictTriangularLower | StrictTriangularUpper => DiagonalKind::Zero,
2817                UnitTriangularLower | UnitTriangularUpper => DiagonalKind::Unit,
2818            }
2819        }
2820    }
2821
2822    /// Computes the matrix product `[alpha * acc] + beta * lhs * rhs` (while optionally conjugating
2823    /// either or both of the input matrices) and stores the result in `acc`.
2824    ///
2825    /// Performs the operation:
2826    /// - `acc = beta * Op_lhs(lhs) * Op_rhs(rhs)` if `alpha` is `None` (in this case, the
2827    ///   preexisting values in `acc` are not read, so it is allowed to be a view over uninitialized
2828    ///   values if `E: Copy`),
2829    /// - `acc = alpha * acc + beta * Op_lhs(lhs) * Op_rhs(rhs)` if `alpha` is `Some(_)`,
2830    ///
2831    /// The left hand side and right hand side may be interpreted as triangular depending on the
2832    /// given corresponding matrix structure.  
2833    ///
2834    /// For the destination matrix, the result is:
2835    /// - fully computed if the structure is rectangular,
2836    /// - only the triangular half (including the diagonal) is computed if the structure is
2837    /// triangular,
2838    /// - only the strict triangular half (excluding the diagonal) is computed if the structure is
2839    /// strictly triangular or unit triangular.
2840    ///
2841    /// `Op_lhs` is the identity if `conj_lhs` is `Conj::No`, and the conjugation operation if it is
2842    /// `Conj::Yes`.  
2843    /// `Op_rhs` is the identity if `conj_rhs` is `Conj::No`, and the conjugation operation if it is
2844    /// `Conj::Yes`.  
2845    ///
2846    /// # Panics
2847    ///
2848    /// Panics if the matrix dimensions are not compatible for matrix multiplication.  
2849    /// i.e.  
2850    ///  - `acc.nrows() == lhs.nrows()`
2851    ///  - `acc.ncols() == rhs.ncols()`
2852    ///  - `lhs.ncols() == rhs.nrows()`
2853    ///
2854    ///  Additionally, matrices that are marked as triangular must be square, i.e., they must have
2855    ///  the same number of rows and columns.
2856    ///
2857    /// # Example
2858    ///
2859    /// ```
2860    /// use faer_core::{
2861    ///     mat,
2862    ///     mul::triangular::{matmul_with_conj, BlockStructure},
2863    ///     unzipped, zipped, Conj, Mat, Parallelism,
2864    /// };
2865    ///
2866    /// let lhs = mat![[0.0, 2.0], [1.0, 3.0]];
2867    /// let rhs = mat![[4.0, 6.0], [5.0, 7.0]];
2868    ///
2869    /// let mut acc = Mat::<f64>::zeros(2, 2);
2870    /// let target = mat![
2871    ///     [
2872    ///         2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)),
2873    ///         0.0,
2874    ///     ],
2875    ///     [
2876    ///         2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)),
2877    ///         2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)),
2878    ///     ],
2879    /// ];
2880    ///
2881    /// matmul_with_conj(
2882    ///     acc.as_mut(),
2883    ///     BlockStructure::TriangularLower,
2884    ///     lhs.as_ref(),
2885    ///     BlockStructure::Rectangular,
2886    ///     Conj::No,
2887    ///     rhs.as_ref(),
2888    ///     BlockStructure::Rectangular,
2889    ///     Conj::No,
2890    ///     None,
2891    ///     2.5,
2892    ///     Parallelism::None,
2893    /// );
2894    ///
2895    /// zipped!(acc.as_ref(), target.as_ref())
2896    ///     .for_each(|unzipped!(acc, target)| assert!((acc.read() - target.read()).abs() < 1e-10));
2897    /// ```
2898    #[track_caller]
2899    #[inline]
2900    pub fn matmul_with_conj<E: ComplexField>(
2901        acc: MatMut<'_, E>,
2902        acc_structure: BlockStructure,
2903        lhs: MatRef<'_, E>,
2904        lhs_structure: BlockStructure,
2905        conj_lhs: Conj,
2906        rhs: MatRef<'_, E>,
2907        rhs_structure: BlockStructure,
2908        conj_rhs: Conj,
2909        alpha: Option<E>,
2910        beta: E,
2911        parallelism: Parallelism,
2912    ) {
2913        assert!(all(
2914            acc.nrows() == lhs.nrows(),
2915            acc.ncols() == rhs.ncols(),
2916            lhs.ncols() == rhs.nrows(),
2917        ));
2918
2919        if !acc_structure.is_dense() {
2920            assert!(acc.nrows() == acc.ncols());
2921        }
2922        if !lhs_structure.is_dense() {
2923            assert!(lhs.nrows() == lhs.ncols());
2924        }
2925        if !rhs_structure.is_dense() {
2926            assert!(rhs.nrows() == rhs.ncols());
2927        }
2928
2929        unsafe {
2930            matmul_unchecked(
2931                acc,
2932                acc_structure,
2933                lhs,
2934                lhs_structure,
2935                conj_lhs,
2936                rhs,
2937                rhs_structure,
2938                conj_rhs,
2939                alpha,
2940                beta,
2941                parallelism,
2942            )
2943        }
2944    }
2945
2946    /// Computes the matrix product `[alpha * acc] + beta * lhs * rhs` and stores the result in
2947    /// `acc`.
2948    ///
2949    /// Performs the operation:
2950    /// - `acc = beta * lhs * rhs` if `alpha` is `None` (in this case, the preexisting values in
2951    ///   `acc` are not read, so it is allowed to be a view over uninitialized values if `E: Copy`),
2952    /// - `acc = alpha * acc + beta * lhs * rhs` if `alpha` is `Some(_)`,
2953    ///
2954    /// The left hand side and right hand side may be interpreted as triangular depending on the
2955    /// given corresponding matrix structure.  
2956    ///
2957    /// For the destination matrix, the result is:
2958    /// - fully computed if the structure is rectangular,
2959    /// - only the triangular half (including the diagonal) is computed if the structure is
2960    /// triangular,
2961    /// - only the strict triangular half (excluding the diagonal) is computed if the structure is
2962    /// strictly triangular or unit triangular.
2963    ///
2964    /// # Panics
2965    ///
2966    /// Panics if the matrix dimensions are not compatible for matrix multiplication.  
2967    /// i.e.  
2968    ///  - `acc.nrows() == lhs.nrows()`
2969    ///  - `acc.ncols() == rhs.ncols()`
2970    ///  - `lhs.ncols() == rhs.nrows()`
2971    ///
2972    ///  Additionally, matrices that are marked as triangular must be square, i.e., they must have
2973    ///  the same number of rows and columns.
2974    ///
2975    /// # Example
2976    ///
2977    /// ```
2978    /// use faer_core::{
2979    ///     mat,
2980    ///     mul::triangular::{matmul, BlockStructure},
2981    ///     unzipped, zipped, Conj, Mat, Parallelism,
2982    /// };
2983    ///
2984    /// let lhs = mat![[0.0, 2.0], [1.0, 3.0]];
2985    /// let rhs = mat![[4.0, 6.0], [5.0, 7.0]];
2986    ///
2987    /// let mut acc = Mat::<f64>::zeros(2, 2);
2988    /// let target = mat![
2989    ///     [
2990    ///         2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)),
2991    ///         0.0,
2992    ///     ],
2993    ///     [
2994    ///         2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)),
2995    ///         2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)),
2996    ///     ],
2997    /// ];
2998    ///
2999    /// matmul(
3000    ///     acc.as_mut(),
3001    ///     BlockStructure::TriangularLower,
3002    ///     lhs.as_ref(),
3003    ///     BlockStructure::Rectangular,
3004    ///     rhs.as_ref(),
3005    ///     BlockStructure::Rectangular,
3006    ///     None,
3007    ///     2.5,
3008    ///     Parallelism::None,
3009    /// );
3010    ///
3011    /// zipped!(acc.as_ref(), target.as_ref())
3012    ///     .for_each(|unzipped!(acc, target)| assert!((acc.read() - target.read()).abs() < 1e-10));
3013    /// ```
3014    #[track_caller]
3015    #[inline]
3016    pub fn matmul<
3017        E: ComplexField,
3018        LhsE: Conjugate<Canonical = E>,
3019        RhsE: Conjugate<Canonical = E>,
3020    >(
3021        acc: MatMut<'_, E>,
3022        acc_structure: BlockStructure,
3023        lhs: MatRef<'_, LhsE>,
3024        lhs_structure: BlockStructure,
3025        rhs: MatRef<'_, RhsE>,
3026        rhs_structure: BlockStructure,
3027        alpha: Option<E>,
3028        beta: E,
3029        parallelism: Parallelism,
3030    ) {
3031        let (lhs, conj_lhs) = lhs.canonicalize();
3032        let (rhs, conj_rhs) = rhs.canonicalize();
3033        matmul_with_conj(
3034            acc,
3035            acc_structure,
3036            lhs,
3037            lhs_structure,
3038            conj_lhs,
3039            rhs,
3040            rhs_structure,
3041            conj_rhs,
3042            alpha,
3043            beta,
3044            parallelism,
3045        );
3046    }
3047
3048    unsafe fn matmul_unchecked<E: ComplexField>(
3049        acc: MatMut<'_, E>,
3050        acc_structure: BlockStructure,
3051        lhs: MatRef<'_, E>,
3052        lhs_structure: BlockStructure,
3053        conj_lhs: Conj,
3054        rhs: MatRef<'_, E>,
3055        rhs_structure: BlockStructure,
3056        conj_rhs: Conj,
3057        alpha: Option<E>,
3058        beta: E,
3059        parallelism: Parallelism,
3060    ) {
3061        debug_assert!(acc.nrows() == lhs.nrows());
3062        debug_assert!(acc.ncols() == rhs.ncols());
3063        debug_assert!(lhs.ncols() == rhs.nrows());
3064
3065        if !acc_structure.is_dense() {
3066            debug_assert!(acc.nrows() == acc.ncols());
3067        }
3068        if !lhs_structure.is_dense() {
3069            debug_assert!(lhs.nrows() == lhs.ncols());
3070        }
3071        if !rhs_structure.is_dense() {
3072            debug_assert!(rhs.nrows() == rhs.ncols());
3073        }
3074
3075        let mut acc = acc;
3076        let mut lhs = lhs;
3077        let mut rhs = rhs;
3078
3079        let mut acc_structure = acc_structure;
3080        let mut lhs_structure = lhs_structure;
3081        let mut rhs_structure = rhs_structure;
3082
3083        let mut conj_lhs = conj_lhs;
3084        let mut conj_rhs = conj_rhs;
3085
3086        // if either the lhs or the rhs is triangular
3087        if rhs_structure.is_lower() {
3088            // do nothing
3089            false
3090        } else if rhs_structure.is_upper() {
3091            // invert acc, lhs and rhs
3092            acc = acc.reverse_rows_and_cols_mut();
3093            lhs = lhs.reverse_rows_and_cols();
3094            rhs = rhs.reverse_rows_and_cols();
3095            acc_structure = acc_structure.transpose();
3096            lhs_structure = lhs_structure.transpose();
3097            rhs_structure = rhs_structure.transpose();
3098            false
3099        } else if lhs_structure.is_lower() {
3100            // invert and transpose
3101            acc = acc.reverse_rows_and_cols_mut().transpose_mut();
3102            (lhs, rhs) = (
3103                rhs.reverse_rows_and_cols().transpose(),
3104                lhs.reverse_rows_and_cols().transpose(),
3105            );
3106            (conj_lhs, conj_rhs) = (conj_rhs, conj_lhs);
3107            (lhs_structure, rhs_structure) = (rhs_structure, lhs_structure);
3108            true
3109        } else if lhs_structure.is_upper() {
3110            // transpose
3111            acc_structure = acc_structure.transpose();
3112            acc = acc.transpose_mut();
3113            (lhs, rhs) = (rhs.transpose(), lhs.transpose());
3114            (conj_lhs, conj_rhs) = (conj_rhs, conj_lhs);
3115            (lhs_structure, rhs_structure) = (rhs_structure.transpose(), lhs_structure.transpose());
3116            true
3117        } else {
3118            // do nothing
3119            false
3120        };
3121
3122        let clear_upper = |acc: MatMut<'_, E>, skip_diag: bool| match &alpha {
3123            &Some(alpha) => zipped!(acc).for_each_triangular_upper(
3124                if skip_diag { Diag::Skip } else { Diag::Include },
3125                |unzipped!(mut acc)| acc.write(alpha.faer_mul(acc.read())),
3126            ),
3127
3128            None => zipped!(acc).for_each_triangular_upper(
3129                if skip_diag { Diag::Skip } else { Diag::Include },
3130                |unzipped!(mut acc)| acc.write(E::faer_zero()),
3131            ),
3132        };
3133
3134        let skip_diag = matches!(
3135            acc_structure,
3136            BlockStructure::StrictTriangularLower
3137                | BlockStructure::StrictTriangularUpper
3138                | BlockStructure::UnitTriangularLower
3139                | BlockStructure::UnitTriangularUpper
3140        );
3141        let lhs_diag = lhs_structure.diag_kind();
3142        let rhs_diag = rhs_structure.diag_kind();
3143
3144        if acc_structure.is_dense() {
3145            if lhs_structure.is_dense() && rhs_structure.is_dense() {
3146                mul(acc, lhs, rhs, alpha, beta, conj_lhs, conj_rhs, parallelism);
3147            } else {
3148                debug_assert!(rhs_structure.is_lower());
3149
3150                if lhs_structure.is_dense() {
3151                    mat_x_lower_impl_unchecked(
3152                        acc,
3153                        lhs,
3154                        rhs,
3155                        rhs_diag,
3156                        alpha,
3157                        beta,
3158                        conj_lhs,
3159                        conj_rhs,
3160                        parallelism,
3161                    )
3162                } else if lhs_structure.is_lower() {
3163                    clear_upper(acc.rb_mut(), true);
3164                    lower_x_lower_into_lower_impl_unchecked(
3165                        acc,
3166                        false,
3167                        lhs,
3168                        lhs_diag,
3169                        rhs,
3170                        rhs_diag,
3171                        alpha,
3172                        beta,
3173                        conj_lhs,
3174                        conj_rhs,
3175                        parallelism,
3176                    );
3177                } else {
3178                    debug_assert!(lhs_structure.is_upper());
3179                    upper_x_lower_impl_unchecked(
3180                        acc,
3181                        lhs,
3182                        lhs_diag,
3183                        rhs,
3184                        rhs_diag,
3185                        alpha,
3186                        beta,
3187                        conj_lhs,
3188                        conj_rhs,
3189                        parallelism,
3190                    )
3191                }
3192            }
3193        } else if acc_structure.is_lower() {
3194            if lhs_structure.is_dense() && rhs_structure.is_dense() {
3195                mat_x_mat_into_lower_impl_unchecked(
3196                    acc,
3197                    skip_diag,
3198                    lhs,
3199                    rhs,
3200                    alpha,
3201                    beta,
3202                    conj_lhs,
3203                    conj_rhs,
3204                    parallelism,
3205                )
3206            } else {
3207                debug_assert!(rhs_structure.is_lower());
3208                if lhs_structure.is_dense() {
3209                    mat_x_lower_into_lower_impl_unchecked(
3210                        acc,
3211                        skip_diag,
3212                        lhs,
3213                        rhs,
3214                        rhs_diag,
3215                        alpha,
3216                        beta,
3217                        conj_lhs,
3218                        conj_rhs,
3219                        parallelism,
3220                    );
3221                } else if lhs_structure.is_lower() {
3222                    lower_x_lower_into_lower_impl_unchecked(
3223                        acc,
3224                        skip_diag,
3225                        lhs,
3226                        lhs_diag,
3227                        rhs,
3228                        rhs_diag,
3229                        alpha,
3230                        beta,
3231                        conj_lhs,
3232                        conj_rhs,
3233                        parallelism,
3234                    )
3235                } else {
3236                    upper_x_lower_into_lower_impl_unchecked(
3237                        acc,
3238                        skip_diag,
3239                        lhs,
3240                        lhs_diag,
3241                        rhs,
3242                        rhs_diag,
3243                        alpha,
3244                        beta,
3245                        conj_lhs,
3246                        conj_rhs,
3247                        parallelism,
3248                    )
3249                }
3250            }
3251        } else if lhs_structure.is_dense() && rhs_structure.is_dense() {
3252            mat_x_mat_into_lower_impl_unchecked(
3253                acc.transpose_mut(),
3254                skip_diag,
3255                rhs.transpose(),
3256                lhs.transpose(),
3257                alpha,
3258                beta,
3259                conj_rhs,
3260                conj_lhs,
3261                parallelism,
3262            )
3263        } else {
3264            debug_assert!(rhs_structure.is_lower());
3265            if lhs_structure.is_dense() {
3266                // lower part of lhs does not contribute to result
3267                upper_x_lower_into_lower_impl_unchecked(
3268                    acc.transpose_mut(),
3269                    skip_diag,
3270                    rhs.transpose(),
3271                    rhs_diag,
3272                    lhs.transpose(),
3273                    lhs_diag,
3274                    alpha,
3275                    beta,
3276                    conj_rhs,
3277                    conj_lhs,
3278                    parallelism,
3279                )
3280            } else if lhs_structure.is_lower() {
3281                if !skip_diag {
3282                    match &alpha {
3283                        &Some(alpha) => {
3284                            zipped!(
3285                                acc.rb_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
3286                                lhs.diagonal().column_vector().as_2d(),
3287                                rhs.diagonal().column_vector().as_2d(),
3288                            )
3289                            .for_each(
3290                                |unzipped!(mut acc, lhs, rhs)| {
3291                                    acc.write(
3292                                        (alpha.faer_mul(acc.read())).faer_add(
3293                                            beta.faer_mul(lhs.read().faer_mul(rhs.read())),
3294                                        ),
3295                                    )
3296                                },
3297                            );
3298                        }
3299                        None => {
3300                            zipped!(
3301                                acc.rb_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
3302                                lhs.diagonal().column_vector().as_2d(),
3303                                rhs.diagonal().column_vector().as_2d(),
3304                            )
3305                            .for_each(
3306                                |unzipped!(mut acc, lhs, rhs)| {
3307                                    acc.write(beta.faer_mul(lhs.read().faer_mul(rhs.read())))
3308                                },
3309                            );
3310                        }
3311                    }
3312                }
3313                clear_upper(acc.rb_mut(), true);
3314            } else {
3315                debug_assert!(lhs_structure.is_upper());
3316                upper_x_lower_into_lower_impl_unchecked(
3317                    acc.transpose_mut(),
3318                    skip_diag,
3319                    rhs.transpose(),
3320                    rhs_diag,
3321                    lhs.transpose(),
3322                    lhs_diag,
3323                    alpha,
3324                    beta,
3325                    conj_rhs,
3326                    conj_lhs,
3327                    parallelism,
3328                )
3329            }
3330        }
3331    }
3332}
3333
3334#[cfg(test)]
3335mod tests {
3336    use super::{
3337        triangular::{BlockStructure, DiagonalKind},
3338        *,
3339    };
3340    use crate::{assert, Mat};
3341    use assert_approx_eq::assert_approx_eq;
3342    use num_complex::Complex32;
3343
3344    #[test]
3345    fn test_stack_mat() {
3346        stack_mat_16x16_begin!(m, 3, 3, 1, 3, f64);
3347        {
3348            let _ = &mut m;
3349            dbg!(&m);
3350        }
3351    }
3352
3353    #[test]
3354    #[ignore = "takes too long in CI"]
3355    fn test_matmul() {
3356        let random = |_, _| c32 {
3357            re: rand::random(),
3358            im: rand::random(),
3359        };
3360
3361        let alphas = [
3362            None,
3363            Some(c32::faer_one()),
3364            Some(c32::faer_zero()),
3365            Some(random(0, 0)),
3366        ];
3367
3368        #[cfg(not(miri))]
3369        let bools = [false, true];
3370        #[cfg(not(miri))]
3371        let betas = [c32::faer_one(), c32::faer_zero(), random(0, 0)];
3372        #[cfg(not(miri))]
3373        let par = [Parallelism::None, Parallelism::Rayon(0)];
3374        #[cfg(not(miri))]
3375        let conjs = [Conj::Yes, Conj::No];
3376
3377        #[cfg(miri)]
3378        let bools = [true];
3379        #[cfg(miri)]
3380        let betas = [random(0, 0)];
3381        #[cfg(miri)]
3382        let par = [Parallelism::None];
3383        #[cfg(miri)]
3384        let conjs = [Conj::Yes];
3385
3386        let big0 = 127;
3387        let big1 = 128;
3388        let big2 = 129;
3389
3390        let mid0 = 15;
3391        let mid1 = 16;
3392        let mid2 = 17;
3393        for (m, n, k) in [
3394            (mid0, mid0, KC + 1),
3395            (big0, big1, 5),
3396            (big1, big0, 5),
3397            (big0, big2, 5),
3398            (big2, big0, 5),
3399            (mid0, mid0, 5),
3400            (mid1, mid1, 5),
3401            (mid2, mid2, 5),
3402            (mid0, mid1, 5),
3403            (mid1, mid0, 5),
3404            (mid0, mid2, 5),
3405            (mid2, mid0, 5),
3406            (mid0, 1, 1),
3407            (1, mid0, 1),
3408            (1, 1, mid0),
3409            (1, mid0, mid0),
3410            (mid0, 1, mid0),
3411            (mid0, mid0, 1),
3412            (1, 1, 1),
3413        ] {
3414            let a = Mat::from_fn(m, k, random);
3415            let b = Mat::from_fn(k, n, random);
3416            let acc_init = Mat::from_fn(m, n, random);
3417
3418            for reverse_acc_cols in bools {
3419                for reverse_acc_rows in bools {
3420                    for reverse_b_cols in bools {
3421                        for reverse_b_rows in bools {
3422                            for reverse_a_cols in bools {
3423                                for reverse_a_rows in bools {
3424                                    for a_colmajor in bools {
3425                                        for b_colmajor in bools {
3426                                            for acc_colmajor in bools {
3427                                                let a = if a_colmajor {
3428                                                    a.to_owned()
3429                                                } else {
3430                                                    a.transpose().to_owned()
3431                                                };
3432                                                let mut a = if a_colmajor {
3433                                                    a.as_ref()
3434                                                } else {
3435                                                    a.as_ref().transpose()
3436                                                };
3437
3438                                                let b = if b_colmajor {
3439                                                    b.to_owned()
3440                                                } else {
3441                                                    b.transpose().to_owned()
3442                                                };
3443                                                let mut b = if b_colmajor {
3444                                                    b.as_ref()
3445                                                } else {
3446                                                    b.as_ref().transpose()
3447                                                };
3448
3449                                                if reverse_a_rows {
3450                                                    a = a.reverse_rows();
3451                                                }
3452                                                if reverse_a_cols {
3453                                                    a = a.reverse_cols();
3454                                                }
3455                                                if reverse_b_rows {
3456                                                    b = b.reverse_rows();
3457                                                }
3458                                                if reverse_b_cols {
3459                                                    b = b.reverse_cols();
3460                                                }
3461                                                for conj_a in conjs {
3462                                                    for conj_b in conjs {
3463                                                        for parallelism in par {
3464                                                            for alpha in alphas {
3465                                                                for beta in betas {
3466                                                                    for use_gemm in [true, false] {
3467                                                                        test_matmul_impl(
3468                                                                            reverse_acc_cols,
3469                                                                            reverse_acc_rows,
3470                                                                            acc_colmajor,
3471                                                                            m,
3472                                                                            n,
3473                                                                            conj_a,
3474                                                                            conj_b,
3475                                                                            parallelism,
3476                                                                            alpha,
3477                                                                            beta,
3478                                                                            use_gemm,
3479                                                                            &acc_init,
3480                                                                            a,
3481                                                                            b,
3482                                                                        );
3483                                                                    }
3484                                                                }
3485                                                            }
3486                                                        }
3487                                                    }
3488                                                }
3489                                            }
3490                                        }
3491                                    }
3492                                }
3493                            }
3494                        }
3495                    }
3496                }
3497            }
3498        }
3499    }
3500
3501    fn matmul_with_conj_fallback<E: ComplexField>(
3502        acc: MatMut<'_, E>,
3503        a: MatRef<'_, E>,
3504        conj_a: Conj,
3505        b: MatRef<'_, E>,
3506        conj_b: Conj,
3507        alpha: Option<E>,
3508        beta: E,
3509        parallelism: Parallelism,
3510    ) {
3511        let m = acc.nrows();
3512        let n = acc.ncols();
3513        let k = a.ncols();
3514
3515        let job = |idx: usize| {
3516            let i = idx % m;
3517            let j = idx / m;
3518            let acc = acc.rb().submatrix(i, j, 1, 1);
3519            let mut acc = unsafe { acc.const_cast() };
3520
3521            let mut local_acc = E::faer_zero();
3522            for depth in 0..k {
3523                let a = a.read(i, depth);
3524                let b = b.read(depth, j);
3525                local_acc = local_acc.faer_add(E::faer_mul(
3526                    match conj_a {
3527                        Conj::Yes => a.faer_conj(),
3528                        Conj::No => a,
3529                    },
3530                    match conj_b {
3531                        Conj::Yes => b.faer_conj(),
3532                        Conj::No => b,
3533                    },
3534                ))
3535            }
3536            match alpha {
3537                Some(alpha) => acc.write(
3538                    0,
3539                    0,
3540                    E::faer_add(acc.read(0, 0).faer_mul(alpha), local_acc.faer_mul(beta)),
3541                ),
3542                None => acc.write(0, 0, local_acc.faer_mul(beta)),
3543            }
3544        };
3545
3546        crate::for_each_raw(m * n, job, parallelism);
3547    }
3548
3549    fn test_matmul_impl(
3550        reverse_acc_cols: bool,
3551        reverse_acc_rows: bool,
3552        acc_colmajor: bool,
3553        m: usize,
3554        n: usize,
3555        conj_a: Conj,
3556        conj_b: Conj,
3557        parallelism: Parallelism,
3558        alpha: Option<c32>,
3559        beta: c32,
3560        use_gemm: bool,
3561        acc_init: &Mat<c32>,
3562        a: MatRef<c32>,
3563        b: MatRef<c32>,
3564    ) {
3565        let mut acc = if acc_colmajor {
3566            acc_init.to_owned()
3567        } else {
3568            acc_init.transpose().to_owned()
3569        };
3570
3571        let mut acc = if acc_colmajor {
3572            acc.as_mut()
3573        } else {
3574            acc.as_mut().transpose_mut()
3575        };
3576        if reverse_acc_rows {
3577            acc = acc.reverse_rows_mut();
3578        }
3579        if reverse_acc_cols {
3580            acc = acc.reverse_cols_mut();
3581        }
3582        let mut target = acc.to_owned();
3583
3584        matmul_with_conj_gemm_dispatch(
3585            acc.rb_mut(),
3586            a,
3587            conj_a,
3588            b,
3589            conj_b,
3590            alpha,
3591            beta,
3592            parallelism,
3593            use_gemm,
3594        );
3595        matmul_with_conj_fallback(
3596            target.as_mut(),
3597            a,
3598            conj_a,
3599            b,
3600            conj_b,
3601            alpha,
3602            beta,
3603            parallelism,
3604        );
3605
3606        for j in 0..n {
3607            for i in 0..m {
3608                let acc: Complex32 = acc.read(i, j).into();
3609                let target: Complex32 = target.read(i, j).into();
3610                assert_approx_eq!(acc.re, target.re, 1e-3);
3611                assert_approx_eq!(acc.im, target.im, 1e-3);
3612            }
3613        }
3614    }
3615
3616    fn generate_structured_matrix(
3617        is_dst: bool,
3618        nrows: usize,
3619        ncols: usize,
3620        structure: BlockStructure,
3621    ) -> Mat<f64> {
3622        let mut mat = Mat::new();
3623        mat.resize_with(nrows, ncols, |_, _| rand::random());
3624
3625        if !is_dst {
3626            let kind = structure.diag_kind();
3627            if structure.is_lower() {
3628                for j in 0..ncols {
3629                    for i in 0..j {
3630                        mat.write(i, j, 0.0);
3631                    }
3632                }
3633            } else if structure.is_upper() {
3634                for j in 0..ncols {
3635                    for i in j + 1..nrows {
3636                        mat.write(i, j, 0.0);
3637                    }
3638                }
3639            }
3640
3641            match kind {
3642                triangular::DiagonalKind::Zero => {
3643                    for i in 0..nrows {
3644                        mat.write(i, i, 0.0);
3645                    }
3646                }
3647                triangular::DiagonalKind::Unit => {
3648                    for i in 0..nrows {
3649                        mat.write(i, i, 1.0);
3650                    }
3651                }
3652                triangular::DiagonalKind::Generic => (),
3653            }
3654        }
3655        mat
3656    }
3657
3658    fn run_test_problem(
3659        m: usize,
3660        n: usize,
3661        k: usize,
3662        dst_structure: BlockStructure,
3663        lhs_structure: BlockStructure,
3664        rhs_structure: BlockStructure,
3665    ) {
3666        let mut dst = generate_structured_matrix(true, m, n, dst_structure);
3667        let mut dst_target = dst.to_owned();
3668        let dst_orig = dst.to_owned();
3669        let lhs = generate_structured_matrix(false, m, k, lhs_structure);
3670        let rhs = generate_structured_matrix(false, k, n, rhs_structure);
3671
3672        for parallelism in [Parallelism::None, Parallelism::Rayon(8)] {
3673            triangular::matmul_with_conj(
3674                dst.as_mut(),
3675                dst_structure,
3676                lhs.as_ref(),
3677                lhs_structure,
3678                Conj::No,
3679                rhs.as_ref(),
3680                rhs_structure,
3681                Conj::No,
3682                None,
3683                2.5,
3684                parallelism,
3685            );
3686
3687            matmul_with_conj(
3688                dst_target.as_mut(),
3689                lhs.as_ref(),
3690                Conj::No,
3691                rhs.as_ref(),
3692                Conj::No,
3693                None,
3694                2.5,
3695                parallelism,
3696            );
3697
3698            if dst_structure.is_dense() {
3699                for j in 0..n {
3700                    for i in 0..m {
3701                        assert_approx_eq!(dst.read(i, j), dst_target.read(i, j));
3702                    }
3703                }
3704            } else if dst_structure.is_lower() {
3705                for j in 0..n {
3706                    if matches!(dst_structure.diag_kind(), DiagonalKind::Generic) {
3707                        for i in 0..j {
3708                            assert_eq!(dst.read(i, j), dst_orig.read(i, j));
3709                        }
3710                        for i in j..n {
3711                            assert_approx_eq!(dst.read(i, j), dst_target.read(i, j));
3712                        }
3713                    } else {
3714                        for i in 0..=j {
3715                            assert_eq!(dst.read(i, j), dst_orig.read(i, j));
3716                        }
3717                        for i in j + 1..n {
3718                            assert_approx_eq!(dst.read(i, j), dst_target.read(i, j));
3719                        }
3720                    }
3721                }
3722            } else {
3723                for j in 0..n {
3724                    if matches!(dst_structure.diag_kind(), DiagonalKind::Generic) {
3725                        for i in 0..=j {
3726                            assert_approx_eq!(dst.read(i, j), dst_target.read(i, j));
3727                        }
3728                        for i in j + 1..n {
3729                            assert_eq!(dst.read(i, j), dst_orig.read(i, j));
3730                        }
3731                    } else {
3732                        for i in 0..j {
3733                            assert_approx_eq!(dst.read(i, j), dst_target.read(i, j));
3734                        }
3735                        for i in j..n {
3736                            assert_eq!(dst.read(i, j), dst_orig.read(i, j));
3737                        }
3738                    }
3739                }
3740            }
3741        }
3742    }
3743
3744    #[test]
3745    fn test_triangular() {
3746        use BlockStructure::*;
3747        let structures = [
3748            Rectangular,
3749            TriangularLower,
3750            TriangularUpper,
3751            StrictTriangularLower,
3752            StrictTriangularUpper,
3753            UnitTriangularLower,
3754            UnitTriangularUpper,
3755        ];
3756
3757        for dst in structures {
3758            for lhs in structures {
3759                for rhs in structures {
3760                    #[cfg(not(miri))]
3761                    let big = 100;
3762
3763                    #[cfg(miri)]
3764                    let big = 31;
3765                    for _ in 0..3 {
3766                        let m = rand::random::<usize>() % big;
3767                        let mut n = rand::random::<usize>() % big;
3768                        let mut k = rand::random::<usize>() % big;
3769
3770                        // for keeping track of miri progress
3771                        #[cfg(miri)]
3772                        dbg!(m, n, k);
3773
3774                        match (!dst.is_dense(), !lhs.is_dense(), !rhs.is_dense()) {
3775                            (true, true, _) | (true, _, true) | (_, true, true) => {
3776                                n = m;
3777                                k = m;
3778                            }
3779                            _ => (),
3780                        }
3781
3782                        if !dst.is_dense() {
3783                            n = m;
3784                        }
3785
3786                        if !lhs.is_dense() {
3787                            k = m;
3788                        }
3789
3790                        if !rhs.is_dense() {
3791                            k = n;
3792                        }
3793
3794                        run_test_problem(m, n, k, dst, lhs, rhs);
3795                    }
3796                }
3797            }
3798        }
3799    }
3800}