1use crate::{
4 assert, c32, c64, group_helpers::*, transmute_unchecked, unzipped, zipped, ComplexField, Conj,
5 Conjugate, DivCeil, MatMut, MatRef, Parallelism, SimdGroupFor,
6};
7use core::{iter::zip, marker::PhantomData, mem::MaybeUninit};
8use faer_entity::{SimdCtx, *};
9use pulp::Simd;
10use reborrow::*;
11
12#[doc(hidden)]
13pub mod inner_prod {
14 use super::*;
15 use crate::assert;
16
17 #[inline(always)]
18 fn a_x_b_accumulate1<C: ConjTy, E: ComplexField, S: Simd>(
19 simd: SimdFor<E, S>,
20 conj: C,
21 a: SliceGroup<E>,
22 b: SliceGroup<E>,
23 offset: pulp::Offset<E::SimdMask<S>>,
24 ) -> SimdGroupFor<E, S> {
25 let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
26 let (b_head, b_body, b_tail) = simd.as_aligned_simd(b, offset);
27 let zero = simd.splat(E::faer_zero());
28 let mut acc0 = simd.conditional_conj_mul(conj, a_head.read_or(zero), b_head.read_or(zero));
29
30 let a_body1 = a_body;
31 let b_body1 = b_body;
32 for (a, b) in zip(a_body1.into_ref_iter(), b_body1.into_ref_iter()) {
33 acc0 = simd.conditional_conj_mul_add_e(conj, a.read_or(zero), b.read_or(zero), acc0);
34 }
35 simd.conditional_conj_mul_add_e(conj, a_tail.read_or(zero), b_tail.read_or(zero), acc0)
36 }
37
38 #[inline(always)]
39 fn a_x_b_accumulate2<C: ConjTy, E: ComplexField, S: Simd>(
40 simd: SimdFor<E, S>,
41 conj: C,
42 a: SliceGroup<E>,
43 b: SliceGroup<E>,
44 offset: pulp::Offset<E::SimdMask<S>>,
45 ) -> SimdGroupFor<E, S> {
46 let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
47 let (b_head, b_body, b_tail) = simd.as_aligned_simd(b, offset);
48 let zero = simd.splat(E::faer_zero());
49 let mut acc0 = simd.conditional_conj_mul(conj, a_head.read_or(zero), b_head.read_or(zero));
50 let mut acc1 = zero;
51
52 let (a_body2, a_body1) = a_body.as_arrays::<2>();
53 let (b_body2, b_body1) = b_body.as_arrays::<2>();
54 for ([a0, a1], [b0, b1]) in zip(
55 a_body2.into_ref_iter().map(RefGroup::unzip),
56 b_body2.into_ref_iter().map(RefGroup::unzip),
57 ) {
58 acc0 = simd.conditional_conj_mul_add_e(conj, a0.read_or(zero), b0.read_or(zero), acc0);
59 acc1 = simd.conditional_conj_mul_add_e(conj, a1.read_or(zero), b1.read_or(zero), acc1);
60 }
61 for (a, b) in zip(a_body1.into_ref_iter(), b_body1.into_ref_iter()) {
62 acc0 = simd.conditional_conj_mul_add_e(conj, a.read_or(zero), b.read_or(zero), acc0);
63 }
64 acc0 =
65 simd.conditional_conj_mul_add_e(conj, a_tail.read_or(zero), b_tail.read_or(zero), acc0);
66 simd.add(acc0, acc1)
67 }
68
69 #[inline(always)]
70 fn a_x_b_accumulate4<C: ConjTy, E: ComplexField, S: Simd>(
71 simd: SimdFor<E, S>,
72 conj: C,
73 a: SliceGroup<E>,
74 b: SliceGroup<E>,
75 offset: pulp::Offset<E::SimdMask<S>>,
76 ) -> SimdGroupFor<E, S> {
77 let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
78 let (b_head, b_body, b_tail) = simd.as_aligned_simd(b, offset);
79 let zero = simd.splat(E::faer_zero());
80 let mut acc0 = simd.conditional_conj_mul(conj, a_head.read_or(zero), b_head.read_or(zero));
81 let mut acc1 = zero;
82 let mut acc2 = zero;
83 let mut acc3 = zero;
84
85 let (a_body4, a_body1) = a_body.as_arrays::<4>();
86 let (b_body4, b_body1) = b_body.as_arrays::<4>();
87 for ([a0, a1, a2, a3], [b0, b1, b2, b3]) in zip(
88 a_body4.into_ref_iter().map(RefGroup::unzip),
89 b_body4.into_ref_iter().map(RefGroup::unzip),
90 ) {
91 acc0 = simd.conditional_conj_mul_add_e(conj, a0.read_or(zero), b0.read_or(zero), acc0);
92 acc1 = simd.conditional_conj_mul_add_e(conj, a1.read_or(zero), b1.read_or(zero), acc1);
93 acc2 = simd.conditional_conj_mul_add_e(conj, a2.read_or(zero), b2.read_or(zero), acc2);
94 acc3 = simd.conditional_conj_mul_add_e(conj, a3.read_or(zero), b3.read_or(zero), acc3);
95 }
96 for (a, b) in zip(a_body1.into_ref_iter(), b_body1.into_ref_iter()) {
97 acc0 = simd.conditional_conj_mul_add_e(conj, a.read_or(zero), b.read_or(zero), acc0);
98 }
99 acc0 =
100 simd.conditional_conj_mul_add_e(conj, a_tail.read_or(zero), b_tail.read_or(zero), acc0);
101 simd.add(simd.add(acc0, acc1), simd.add(acc2, acc3))
102 }
103
104 #[inline(always)]
105 fn a_x_b_accumulate8<C: ConjTy, E: ComplexField, S: Simd>(
106 simd: SimdFor<E, S>,
107 conj: C,
108 a: SliceGroup<E>,
109 b: SliceGroup<E>,
110 offset: pulp::Offset<E::SimdMask<S>>,
111 ) -> SimdGroupFor<E, S> {
112 let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
113 let (b_head, b_body, b_tail) = simd.as_aligned_simd(b, offset);
114 let zero = simd.splat(E::faer_zero());
115 let mut acc0 = simd.conditional_conj_mul(conj, a_head.read_or(zero), b_head.read_or(zero));
116 let mut acc1 = zero;
117 let mut acc2 = zero;
118 let mut acc3 = zero;
119 let mut acc4 = zero;
120 let mut acc5 = zero;
121 let mut acc6 = zero;
122 let mut acc7 = zero;
123
124 let (a_body8, a_body1) = a_body.as_arrays::<8>();
125 let (b_body8, b_body1) = b_body.as_arrays::<8>();
126 for ([a0, a1, a2, a3, a4, a5, a6, a7], [b0, b1, b2, b3, b4, b5, b6, b7]) in zip(
127 a_body8.into_ref_iter().map(RefGroup::unzip),
128 b_body8.into_ref_iter().map(RefGroup::unzip),
129 ) {
130 acc0 = simd.conditional_conj_mul_add_e(conj, a0.read_or(zero), b0.read_or(zero), acc0);
131 acc1 = simd.conditional_conj_mul_add_e(conj, a1.read_or(zero), b1.read_or(zero), acc1);
132 acc2 = simd.conditional_conj_mul_add_e(conj, a2.read_or(zero), b2.read_or(zero), acc2);
133 acc3 = simd.conditional_conj_mul_add_e(conj, a3.read_or(zero), b3.read_or(zero), acc3);
134 acc4 = simd.conditional_conj_mul_add_e(conj, a4.read_or(zero), b4.read_or(zero), acc4);
135 acc5 = simd.conditional_conj_mul_add_e(conj, a5.read_or(zero), b5.read_or(zero), acc5);
136 acc6 = simd.conditional_conj_mul_add_e(conj, a6.read_or(zero), b6.read_or(zero), acc6);
137 acc7 = simd.conditional_conj_mul_add_e(conj, a7.read_or(zero), b7.read_or(zero), acc7);
138 }
139 for (a, b) in zip(a_body1.into_ref_iter(), b_body1.into_ref_iter()) {
140 acc0 = simd.conditional_conj_mul_add_e(conj, a.read_or(zero), b.read_or(zero), acc0);
141 }
142 acc0 =
143 simd.conditional_conj_mul_add_e(conj, a_tail.read_or(zero), b_tail.read_or(zero), acc0);
144 simd.add(
145 simd.add(simd.add(acc0, acc1), simd.add(acc2, acc3)),
146 simd.add(simd.add(acc4, acc5), simd.add(acc6, acc7)),
147 )
148 }
149
150 #[inline(always)]
151 pub fn with_simd_and_offset<C: ConjTy, E: ComplexField, S: Simd>(
152 simd: SimdFor<E, S>,
153 conj: C,
154 a: SliceGroup<E>,
155 b: SliceGroup<E>,
156 offset: pulp::Offset<E::SimdMask<S>>,
157 ) -> E {
158 {
159 let prologue = if E::N_COMPONENTS == 1 {
160 a_x_b_accumulate8(simd, conj, a, b, offset)
161 } else if E::N_COMPONENTS == 2 {
162 a_x_b_accumulate4(simd, conj, a, b, offset)
163 } else if E::N_COMPONENTS == 4 {
164 a_x_b_accumulate2(simd, conj, a, b, offset)
165 } else {
166 a_x_b_accumulate1(simd, conj, a, b, offset)
167 };
168
169 simd.reduce_add(simd.rotate_left(prologue, offset.rotate_left_amount()))
170 }
171 }
172
173 pub struct Impl<'a, C: ConjTy, E: ComplexField> {
174 pub a: SliceGroup<'a, E>,
175 pub b: SliceGroup<'a, E>,
176 pub conj: C,
177 }
178
179 impl<C: ConjTy, E: ComplexField> pulp::WithSimd for Impl<'_, C, E> {
180 type Output = E;
181
182 #[inline(always)]
183 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
184 let simd = SimdFor::new(simd);
185 with_simd_and_offset(simd, self.conj, self.a, self.b, simd.align_offset(self.a))
186 }
187 }
188
189 #[inline(always)]
190 #[track_caller]
191 pub fn inner_prod_with_conj_arch<E: ComplexField>(
192 arch: E::Simd,
193 lhs: MatRef<'_, E>,
194 conj_lhs: Conj,
195 rhs: MatRef<'_, E>,
196 conj_rhs: Conj,
197 ) -> E {
198 assert!(all(
199 lhs.nrows() == rhs.nrows(),
200 lhs.ncols() == 1,
201 rhs.ncols() == 1,
202 ));
203 let nrows = lhs.nrows();
204 let mut a = lhs;
205 let mut b = rhs;
206 if a.row_stride() < 0 {
207 a = a.reverse_rows();
208 b = b.reverse_rows();
209 }
210
211 let res = if a.row_stride() == 1 && b.row_stride() == 1 {
212 let a = SliceGroup::<'_, E>::new(a.try_get_contiguous_col(0));
213 let b = SliceGroup::<'_, E>::new(b.try_get_contiguous_col(0));
214 if conj_lhs == conj_rhs {
215 arch.dispatch(Impl { a, b, conj: NoConj })
216 } else {
217 arch.dispatch(Impl {
218 a,
219 b,
220 conj: YesConj,
221 })
222 }
223 } else {
224 crate::constrained::Size::with2(
225 nrows,
226 1,
227 #[inline(always)]
228 |nrows, ncols| {
229 let zero_idx = ncols.check(0);
230
231 let a = crate::constrained::MatRef::new(a, nrows, ncols);
232 let b = crate::constrained::MatRef::new(b, nrows, ncols);
233 let mut acc = E::faer_zero();
234 if conj_lhs == conj_rhs {
235 for i in nrows.indices() {
236 acc =
237 acc.faer_add(E::faer_mul(a.read(i, zero_idx), b.read(i, zero_idx)));
238 }
239 } else {
240 for i in nrows.indices() {
241 acc = acc.faer_add(E::faer_mul(
242 a.read(i, zero_idx).faer_conj(),
243 b.read(i, zero_idx),
244 ));
245 }
246 }
247 acc
248 },
249 )
250 };
251
252 match conj_rhs {
253 Conj::Yes => res.faer_conj(),
254 Conj::No => res,
255 }
256 }
257
258 #[inline]
259 #[track_caller]
260 pub fn inner_prod_with_conj<E: ComplexField>(
261 lhs: MatRef<'_, E>,
262 conj_lhs: Conj,
263 rhs: MatRef<'_, E>,
264 conj_rhs: Conj,
265 ) -> E {
266 inner_prod_with_conj_arch(E::Simd::default(), lhs, conj_lhs, rhs, conj_rhs)
267 }
268}
269
270#[doc(hidden)]
271pub mod matvec_rowmajor {
272 use super::*;
273 use crate::assert;
274
275 fn matvec_with_conj_impl<E: ComplexField>(
276 acc: MatMut<'_, E>,
277 a: MatRef<'_, E>,
278 conj_a: Conj,
279 b: MatRef<'_, E>,
280 conj_b: Conj,
281 alpha: Option<E>,
282 beta: E,
283 ) {
284 let m = a.nrows();
285 let n = a.ncols();
286
287 assert!(all(
288 b.nrows() == n,
289 b.ncols() == 1,
290 acc.nrows() == m,
291 acc.ncols() == 1,
292 a.col_stride() == 1,
293 b.row_stride() == 1,
294 ));
295
296 let mut acc = acc;
297
298 for i in 0..m {
299 let a = a.submatrix(i, 0, 1, n);
300 let res = inner_prod::inner_prod_with_conj(a.transpose(), conj_a, b, conj_b);
301 match alpha {
302 Some(alpha) => acc.write(
303 i,
304 0,
305 E::faer_add(alpha.faer_mul(acc.read(i, 0)), beta.faer_mul(res)),
306 ),
307 None => acc.write(i, 0, beta.faer_mul(res)),
308 }
309 }
310 }
311
312 pub fn matvec_with_conj<E: ComplexField>(
313 acc: MatMut<'_, E>,
314 lhs: MatRef<'_, E>,
315 conj_lhs: Conj,
316 rhs: MatRef<'_, E>,
317 conj_rhs: Conj,
318 alpha: Option<E>,
319 beta: E,
320 ) {
321 if rhs.row_stride() == 1 {
322 matvec_with_conj_impl(acc, lhs, conj_lhs, rhs, conj_rhs, alpha, beta);
323 } else {
324 matvec_with_conj_impl(
325 acc,
326 lhs,
327 conj_lhs,
328 rhs.to_owned().as_ref(),
329 conj_rhs,
330 alpha,
331 beta,
332 );
333 }
334 }
335}
336
337#[doc(hidden)]
338pub mod matvec_colmajor {
339 use super::*;
340 use crate::assert;
341
342 pub struct Impl<'a, C: ConjTy, E: ComplexField> {
343 pub conj: C,
344 pub acc: SliceGroupMut<'a, E>,
345 pub a: SliceGroup<'a, E>,
346 pub b: E,
347 }
348
349 #[inline(always)]
350 pub fn with_simd_and_offset<C: ConjTy, E: ComplexField, S: Simd>(
351 simd: SimdFor<E, S>,
352 conj: C,
353 acc: SliceGroupMut<'_, E>,
354 a: SliceGroup<'_, E>,
355 b: E,
356 offset: pulp::Offset<SimdMaskFor<E, S>>,
357 ) {
358 let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
359 let (acc_head, acc_body, acc_tail) = simd.as_aligned_simd_mut(acc, offset);
360 let b = simd.splat(b);
361
362 #[inline(always)]
363 pub fn process<C: ConjTy, E: ComplexField, S: Simd>(
364 simd: SimdFor<E, S>,
365 conj: C,
366 mut acc: impl Write<Output = SimdGroupFor<E, S>>,
367 a: impl Read<Output = SimdGroupFor<E, S>>,
368 b: SimdGroupFor<E, S>,
369 ) {
370 acc.write(simd.conditional_conj_mul_add_e(
371 conj,
372 a.read_or(simd.splat(E::faer_zero())),
373 b,
374 acc.read_or(simd.splat(E::faer_zero())),
375 ))
376 }
377
378 process(simd, conj, acc_head, a_head, b);
379 for (acc, a) in acc_body.into_mut_iter().zip(a_body.into_ref_iter()) {
380 process(simd, conj, acc, a, b);
381 }
382 process(simd, conj, acc_tail, a_tail, b);
383 }
384
385 impl<C: ConjTy, E: ComplexField> pulp::WithSimd for Impl<'_, C, E> {
386 type Output = ();
387
388 #[inline(always)]
389 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
390 let simd = SimdFor::new(simd);
391 with_simd_and_offset(
392 simd,
393 self.conj,
394 self.acc,
395 self.a,
396 self.b,
397 simd.align_offset(self.a),
398 )
399 }
400 }
401
402 fn matvec_with_conj_impl<E: ComplexField>(
403 acc: MatMut<'_, E>,
404 a: MatRef<'_, E>,
405 conj_a: Conj,
406 b: MatRef<'_, E>,
407 conj_b: Conj,
408 beta: E,
409 ) {
410 let m = a.nrows();
411 let n = a.ncols();
412
413 assert!(all(
414 b.nrows() == n,
415 b.ncols() == 1,
416 acc.nrows() == m,
417 acc.ncols() == 1,
418 a.row_stride() == 1,
419 acc.row_stride() == 1,
420 ));
421
422 let mut acc = SliceGroupMut::<'_, E>::new(acc.try_get_contiguous_col_mut(0));
423
424 let arch = E::Simd::default();
425 for j in 0..n {
426 let acc = acc.rb_mut();
427 let a = SliceGroup::<'_, E>::new(a.try_get_contiguous_col(j));
428 let b = b.read(j, 0);
429 let b = match conj_b {
430 Conj::Yes => b.faer_conj(),
431 Conj::No => b,
432 };
433 let b = b.faer_mul(beta);
434
435 match conj_a {
436 Conj::Yes => arch.dispatch(Impl {
437 conj: YesConj,
438 acc,
439 a,
440 b,
441 }),
442 Conj::No => arch.dispatch(Impl {
443 conj: NoConj,
444 acc,
445 a,
446 b,
447 }),
448 }
449 }
450 }
451
452 pub fn matvec_with_conj<E: ComplexField>(
453 acc: MatMut<'_, E>,
454 lhs: MatRef<'_, E>,
455 conj_lhs: Conj,
456 rhs: MatRef<'_, E>,
457 conj_rhs: Conj,
458 alpha: Option<E>,
459 beta: E,
460 ) {
461 let m = acc.nrows();
462 let mut acc = acc;
463 if acc.row_stride() == 1 {
464 match alpha {
465 Some(alpha) if alpha == E::faer_one() => {}
466 Some(alpha) => {
467 for i in 0..m {
468 acc.write(i, 0, acc.read(i, 0).faer_mul(alpha));
469 }
470 }
471 None => {
472 for i in 0..m {
473 acc.write(i, 0, E::faer_zero());
474 }
475 }
476 }
477
478 matvec_with_conj_impl(acc, lhs, conj_lhs, rhs, conj_rhs, beta);
479 } else {
480 let mut tmp = crate::Mat::<E>::zeros(m, 1);
481 matvec_with_conj_impl(tmp.as_mut(), lhs, conj_lhs, rhs, conj_rhs, beta);
482 match alpha {
483 Some(alpha) => {
484 for i in 0..m {
485 acc.write(
486 i,
487 0,
488 (acc.read(i, 0).faer_mul(alpha)).faer_add(tmp.read(i, 0)),
489 )
490 }
491 }
492 None => {
493 for i in 0..m {
494 acc.write(i, 0, tmp.read(i, 0))
495 }
496 }
497 }
498 }
499 }
500}
501
502#[doc(hidden)]
503pub mod matvec {
504 use super::*;
505
506 pub fn matvec_with_conj<E: ComplexField>(
507 acc: MatMut<'_, E>,
508 lhs: MatRef<'_, E>,
509 conj_lhs: Conj,
510 rhs: MatRef<'_, E>,
511 conj_rhs: Conj,
512 alpha: Option<E>,
513 beta: E,
514 ) {
515 let mut acc = acc;
516 let mut a = lhs;
517 let mut b = rhs;
518
519 if a.row_stride() < 0 {
520 a = a.reverse_rows();
521 acc = acc.reverse_rows_mut();
522 }
523 if a.col_stride() < 0 {
524 a = a.reverse_cols();
525 b = b.reverse_rows();
526 }
527
528 if a.row_stride() == 1 {
529 return matvec_colmajor::matvec_with_conj(acc, a, conj_lhs, b, conj_rhs, alpha, beta);
530 }
531 if a.col_stride() == 1 {
532 return matvec_rowmajor::matvec_with_conj(acc, a, conj_lhs, b, conj_rhs, alpha, beta);
533 }
534
535 let m = a.nrows();
536 let n = a.ncols();
537
538 match alpha {
539 Some(alpha) => {
540 for i in 0..m {
541 acc.write(i, 0, acc.read(i, 0).faer_mul(alpha));
542 }
543 }
544 None => {
545 for i in 0..m {
546 acc.write(i, 0, E::faer_zero());
547 }
548 }
549 }
550
551 for j in 0..n {
552 let b = b.read(j, 0);
553 let b = match conj_rhs {
554 Conj::Yes => b.faer_conj(),
555 Conj::No => b,
556 };
557 let b = b.faer_mul(beta);
558 for i in 0..m {
559 let mul = a.read(i, j).faer_mul(b);
560 acc.write(i, 0, acc.read(i, 0).faer_add(mul));
561 }
562 }
563 }
564}
565
566#[doc(hidden)]
567pub mod outer_prod {
568 use super::*;
569 use crate::assert;
570
571 pub struct Impl<'a, C: ConjTy, E: ComplexField> {
572 pub conj: C,
573 pub acc: SliceGroupMut<'a, E>,
574 pub a: SliceGroup<'a, E>,
575 pub b: E,
576 pub alpha: Option<E>,
577 }
578
579 #[inline(always)]
580 pub fn with_simd_and_offset<C: ConjTy, E: ComplexField, S: Simd>(
581 simd: SimdFor<E, S>,
582 conj: C,
583 acc: SliceGroupMut<'_, E>,
584 a: SliceGroup<'_, E>,
585 b: E,
586 alpha: Option<E>,
587 offset: pulp::Offset<SimdMaskFor<E, S>>,
588 ) {
589 match alpha {
590 Some(alpha) => {
591 if alpha == E::faer_one() {
592 return matvec_colmajor::with_simd_and_offset(simd, conj, acc, a, b, offset);
593 }
594
595 let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
596 let (acc_head, acc_body, acc_tail) = simd.as_aligned_simd_mut(acc, offset);
597 let b = simd.splat(b);
598 let alpha = simd.splat(alpha);
599
600 #[inline(always)]
601 pub fn process<C: ConjTy, E: ComplexField, S: Simd>(
602 simd: SimdFor<E, S>,
603 conj: C,
604 mut acc: impl Write<Output = SimdGroupFor<E, S>>,
605 a: impl Read<Output = SimdGroupFor<E, S>>,
606 b: SimdGroupFor<E, S>,
607 alpha: SimdGroupFor<E, S>,
608 ) {
609 acc.write(simd.conditional_conj_mul_add_e(
610 conj,
611 a.read_or(simd.splat(E::faer_zero())),
612 b,
613 simd.mul(alpha, acc.read_or(simd.splat(E::faer_zero()))),
614 ))
615 }
616
617 process(simd, conj, acc_head, a_head, b, alpha);
618 for (acc, a) in acc_body.into_mut_iter().zip(a_body.into_ref_iter()) {
619 process(simd, conj, acc, a, b, alpha);
620 }
621 process(simd, conj, acc_tail, a_tail, b, alpha);
622 }
623 None => {
624 let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset);
625 let (acc_head, acc_body, acc_tail) = simd.as_aligned_simd_mut(acc, offset);
626 let b = simd.splat(b);
627
628 #[inline(always)]
629 pub fn process<C: ConjTy, E: ComplexField, S: Simd>(
630 simd: SimdFor<E, S>,
631 conj: C,
632 mut acc: impl Write<Output = SimdGroupFor<E, S>>,
633 a: impl Read<Output = SimdGroupFor<E, S>>,
634 b: SimdGroupFor<E, S>,
635 ) {
636 acc.write(simd.conditional_conj_mul(
637 conj,
638 a.read_or(simd.splat(E::faer_zero())),
639 b,
640 ))
641 }
642
643 process(simd, conj, acc_head, a_head, b);
644 for (acc, a) in acc_body.into_mut_iter().zip(a_body.into_ref_iter()) {
645 process(simd, conj, acc, a, b);
646 }
647 process(simd, conj, acc_tail, a_tail, b);
648 }
649 }
650 }
651
652 impl<C: ConjTy, E: ComplexField> pulp::WithSimd for Impl<'_, C, E> {
653 type Output = ();
654
655 #[inline(always)]
656 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
657 let simd = SimdFor::new(simd);
658 with_simd_and_offset(
659 simd,
660 self.conj,
661 self.acc,
662 self.a,
663 self.b,
664 self.alpha,
665 simd.align_offset(self.a),
666 )
667 }
668 }
669
670 fn outer_prod_with_conj_impl<E: ComplexField>(
671 acc: MatMut<'_, E>,
672 a: MatRef<'_, E>,
673 conj_a: Conj,
674 b: MatRef<'_, E>,
675 conj_b: Conj,
676 alpha: Option<E>,
677 beta: E,
678 ) {
679 let m = acc.nrows();
680 let n = acc.ncols();
681
682 assert!(all(
683 a.nrows() == m,
684 a.ncols() == 1,
685 b.nrows() == n,
686 b.ncols() == 1,
687 acc.row_stride() == 1,
688 a.row_stride() == 1,
689 ));
690
691 let mut acc = acc;
692
693 let arch = E::Simd::default();
694
695 let a = SliceGroup::new(a.try_get_contiguous_col(0));
696
697 for j in 0..n {
698 let acc = SliceGroupMut::new(acc.rb_mut().try_get_contiguous_col_mut(j));
699 let b = b.read(j, 0);
700 let b = match conj_b {
701 Conj::Yes => b.faer_conj(),
702 Conj::No => b,
703 };
704 let b = b.faer_mul(beta);
705 match conj_a {
706 Conj::Yes => arch.dispatch(Impl {
707 conj: YesConj,
708 acc,
709 a,
710 b,
711 alpha,
712 }),
713 Conj::No => arch.dispatch(Impl {
714 conj: NoConj,
715 acc,
716 a,
717 b,
718 alpha,
719 }),
720 }
721 }
722 }
723
724 pub fn outer_prod_with_conj<E: ComplexField>(
725 acc: MatMut<'_, E>,
726 lhs: MatRef<'_, E>,
727 conj_lhs: Conj,
728 rhs: MatRef<'_, E>,
729 conj_rhs: Conj,
730 alpha: Option<E>,
731 beta: E,
732 ) {
733 let mut acc = acc;
734 let mut a = lhs;
735 let mut b = rhs;
736 let mut conj_a = conj_lhs;
737 let mut conj_b = conj_rhs;
738
739 if acc.row_stride() < 0 {
740 acc = acc.reverse_rows_mut();
741 a = a.reverse_rows();
742 }
743 if acc.col_stride() < 0 {
744 acc = acc.reverse_cols_mut();
745 b = b.reverse_rows();
746 }
747
748 if acc.row_stride() > a.col_stride() {
749 acc = acc.transpose_mut();
750 core::mem::swap(&mut a, &mut b);
751 core::mem::swap(&mut conj_a, &mut conj_b);
752 }
753
754 if acc.row_stride() == 1 {
755 if a.row_stride() == 1 {
756 outer_prod_with_conj_impl(acc, a, conj_a, b, conj_b, alpha, beta);
757 } else {
758 outer_prod_with_conj_impl(
759 acc,
760 a.to_owned().as_ref(),
761 conj_a,
762 b,
763 conj_b,
764 alpha,
765 beta,
766 );
767 }
768 } else {
769 let m = acc.nrows();
770 let n = acc.ncols();
771 match alpha {
772 Some(alpha) => {
773 for j in 0..n {
774 let b = b.read(j, 0);
775 let b = match conj_b {
776 Conj::Yes => b.faer_conj(),
777 Conj::No => b,
778 };
779 let b = b.faer_mul(beta);
780 match conj_a {
781 Conj::Yes => {
782 for i in 0..m {
783 let ab = a.read(i, 0).faer_conj().faer_mul(b);
784 acc.write(
785 i,
786 j,
787 E::faer_add(acc.read(i, j).faer_mul(alpha), ab),
788 );
789 }
790 }
791 Conj::No => {
792 for i in 0..m {
793 let ab = a.read(i, 0).faer_mul(b);
794 acc.write(
795 i,
796 j,
797 E::faer_add(acc.read(i, j).faer_mul(alpha), ab),
798 );
799 }
800 }
801 }
802 }
803 }
804 None => {
805 for j in 0..n {
806 let b = b.read(j, 0);
807 let b = match conj_b {
808 Conj::Yes => b.faer_conj(),
809 Conj::No => b,
810 };
811 let b = b.faer_mul(beta);
812 match conj_a {
813 Conj::Yes => {
814 for i in 0..m {
815 acc.write(i, j, a.read(i, 0).faer_conj().faer_mul(b));
816 }
817 }
818 Conj::No => {
819 for i in 0..m {
820 acc.write(i, j, a.read(i, 0).faer_mul(b));
821 }
822 }
823 }
824 }
825 }
826 }
827 }
828 }
829}
830
831const NC: usize = 2048;
832const MC: usize = 48;
833const KC: usize = 64;
834
835struct SimdLaneCount<E: ComplexField> {
836 __marker: PhantomData<fn() -> E>,
837}
838impl<E: ComplexField> pulp::WithSimd for SimdLaneCount<E> {
839 type Output = usize;
840
841 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
842 let _ = simd;
843 core::mem::size_of::<SimdUnitFor<E, S>>() / core::mem::size_of::<UnitFor<E>>()
844 }
845}
846
847struct Ukr<'a, const MR_DIV_N: usize, const NR: usize, CB: ConjTy, E: ComplexField> {
848 conj_b: CB,
849 acc: MatMut<'a, E>,
850 a: MatRef<'a, E>,
851 b: MatRef<'a, E>,
852}
853
854impl<const MR_DIV_N: usize, const NR: usize, CB: ConjTy, E: ComplexField> pulp::WithSimd
855 for Ukr<'_, MR_DIV_N, NR, CB, E>
856{
857 type Output = ();
858
859 #[inline(always)]
860 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
861 let Self {
862 mut acc,
863 a,
864 b,
865 conj_b,
866 } = self;
867 let lane_count =
868 core::mem::size_of::<SimdUnitFor<E, S>>() / core::mem::size_of::<UnitFor<E>>();
869
870 let mr = MR_DIV_N * lane_count;
871 let nr = NR;
872
873 assert!(all(
874 acc.nrows() == mr,
875 acc.ncols() == nr,
876 a.nrows() == mr,
877 b.ncols() == nr,
878 a.ncols() == b.nrows(),
879 a.row_stride() == 1,
880 b.row_stride() == 1,
881 acc.row_stride() == 1
882 ));
883
884 let k = a.ncols();
885 let mut local_acc = [[E::faer_simd_splat(simd, E::faer_zero()); MR_DIV_N]; NR];
886 let simd = SimdFor::<E, S>::new(simd);
887
888 unsafe {
889 let mut one_iter = {
890 #[inline(always)]
891 |depth| {
892 let a = a.ptr_inbounds_at(0, depth);
893
894 let mut a_uninit = [MaybeUninit::<SimdGroupFor<E, S>>::uninit(); MR_DIV_N];
895
896 let mut i = 0usize;
897 loop {
898 if i == MR_DIV_N {
899 break;
900 }
901 a_uninit[i] = MaybeUninit::new(into_copy::<E, _>(E::faer_map(
902 E::faer_copy(&a),
903 #[inline(always)]
904 |ptr| *(ptr.add(i * lane_count) as *const SimdUnitFor<E, S>),
905 )));
906 i += 1;
907 }
908 let a: [SimdGroupFor<E, S>; MR_DIV_N] = transmute_unchecked::<
909 [MaybeUninit<SimdGroupFor<E, S>>; MR_DIV_N],
910 [SimdGroupFor<E, S>; MR_DIV_N],
911 >(a_uninit);
912
913 let mut j = 0usize;
914 loop {
915 if j == NR {
916 break;
917 }
918 let b = simd.splat(E::faer_from_units(E::faer_map(
919 b.ptr_at(depth, j),
920 #[inline(always)]
921 |ptr| *ptr,
922 )));
923 let mut i = 0;
924 loop {
925 if i == MR_DIV_N {
926 break;
927 }
928 let local_acc = &mut local_acc[j][i];
929 *local_acc =
930 simd.conditional_conj_mul_add_e(conj_b, b, a[i], *local_acc);
931 i += 1;
932 }
933 j += 1;
934 }
935 }
936 };
937
938 let mut depth = 0;
939 while depth < k / 4 * 4 {
940 one_iter(depth);
941 one_iter(depth + 1);
942 one_iter(depth + 2);
943 one_iter(depth + 3);
944 depth += 4;
945 }
946 while depth < k {
947 one_iter(depth);
948 depth += 1;
949 }
950
951 let mut j = 0usize;
952 loop {
953 if j == NR {
954 break;
955 }
956 let mut i = 0usize;
957 loop {
958 if i == MR_DIV_N {
959 break;
960 }
961 let acc = acc.rb_mut().ptr_inbounds_at_mut(i * lane_count, j);
962 let mut acc_value = into_copy::<E, _>(E::faer_map(E::faer_copy(&acc), |acc| {
963 *(acc as *const SimdUnitFor<E, S>)
964 }));
965 acc_value = simd.add(acc_value, local_acc[j][i]);
966 E::faer_map(
967 E::faer_zip(acc, from_copy::<E, _>(acc_value)),
968 #[inline(always)]
969 |(acc, new_acc)| *(acc as *mut SimdUnitFor<E, S>) = new_acc,
970 );
971 i += 1;
972 }
973 j += 1;
974 }
975 }
976 }
977}
978
979#[inline]
980fn min(a: usize, b: usize) -> usize {
981 a.min(b)
982}
983
984struct MicroKernelShape<E: ComplexField> {
985 __marker: PhantomData<fn() -> E>,
986}
987
988impl<E: ComplexField> MicroKernelShape<E> {
989 const SHAPE: (usize, usize) = {
990 if E::N_COMPONENTS <= 2 {
991 (2, 2)
992 } else if E::N_COMPONENTS == 4 {
993 (2, 1)
994 } else {
995 (1, 1)
996 }
997 };
998
999 const MAX_MR_DIV_N: usize = Self::SHAPE.0;
1000 const MAX_NR: usize = Self::SHAPE.1;
1001
1002 const IS_2X2: bool = Self::MAX_MR_DIV_N == 2 && Self::MAX_NR == 2;
1003 const IS_2X1: bool = Self::MAX_MR_DIV_N == 2 && Self::MAX_NR == 1;
1004 const IS_1X1: bool = Self::MAX_MR_DIV_N == 2 && Self::MAX_NR == 1;
1005}
1006
1007fn matmul_with_conj_impl<E: ComplexField>(
1012 acc: MatMut<'_, E>,
1013 a: MatRef<'_, E>,
1014 b: MatRef<'_, E>,
1015 conj_b: Conj,
1016 parallelism: Parallelism,
1017) {
1018 use coe::Coerce;
1019 use num_complex::Complex;
1020 if coe::is_same::<E, Complex<E::Real>>() {
1021 let acc: MatMut<'_, Complex<E::Real>> = acc.coerce();
1022 let a: MatRef<'_, Complex<E::Real>> = a.coerce();
1023 let b: MatRef<'_, Complex<E::Real>> = b.coerce();
1024
1025 let Complex {
1026 re: mut acc_re,
1027 im: mut acc_im,
1028 } = acc.real_imag_mut();
1029 let Complex { re: a_re, im: a_im } = a.real_imag();
1030 let Complex { re: b_re, im: b_im } = b.real_imag();
1031
1032 let real_matmul = |acc: MatMut<'_, E::Real>,
1033 a: MatRef<'_, E::Real>,
1034 b: MatRef<'_, E::Real>,
1035 beta: E::Real| {
1036 matmul_with_conj(
1037 acc,
1038 a,
1039 Conj::No,
1040 b,
1041 Conj::No,
1042 Some(E::Real::faer_one()),
1043 beta,
1044 parallelism,
1045 )
1046 };
1047
1048 match conj_b {
1049 Conj::Yes => {
1050 real_matmul(acc_re.rb_mut(), a_re, b_re, E::Real::faer_one());
1051 real_matmul(acc_re.rb_mut(), a_im, b_im, E::Real::faer_one());
1052 real_matmul(acc_im.rb_mut(), a_re, b_im, E::Real::faer_one().faer_neg());
1053 real_matmul(acc_im.rb_mut(), a_im, b_re, E::Real::faer_one());
1054 }
1055 Conj::No => {
1056 real_matmul(acc_re.rb_mut(), a_re, b_re, E::Real::faer_one());
1057 real_matmul(acc_re.rb_mut(), a_im, b_im, E::Real::faer_one().faer_neg());
1058 real_matmul(acc_im.rb_mut(), a_re, b_im, E::Real::faer_one());
1059 real_matmul(acc_im.rb_mut(), a_im, b_re, E::Real::faer_one());
1060 }
1061 }
1062
1063 return;
1064 }
1065
1066 let m = acc.nrows();
1067 let n = acc.ncols();
1068 let k = a.ncols();
1069
1070 let arch = E::Simd::default();
1071 let lane_count = arch.dispatch(SimdLaneCount::<E> {
1072 __marker: PhantomData,
1073 });
1074
1075 let nr = MicroKernelShape::<E>::MAX_NR;
1076 let mr_div_n = MicroKernelShape::<E>::MAX_MR_DIV_N;
1077 let mr = mr_div_n * lane_count;
1078
1079 assert!(all(
1080 acc.row_stride() == 1,
1081 a.row_stride() == 1,
1082 b.row_stride() == 1,
1083 m % lane_count == 0,
1084 ));
1085
1086 let mut acc = acc;
1087
1088 let mut col_outer = 0usize;
1089 while col_outer < n {
1090 let n_chunk = min(NC, n - col_outer);
1091
1092 let b_panel = b.submatrix(0, col_outer, k, n_chunk);
1093 let acc = acc.rb_mut().submatrix_mut(0, col_outer, m, n_chunk);
1094
1095 let mut depth_outer = 0usize;
1096 while depth_outer < k {
1097 let k_chunk = min(KC, k - depth_outer);
1098
1099 let a_panel = a.submatrix(0, depth_outer, m, k_chunk);
1100 let b_block = b_panel.submatrix(depth_outer, 0, k_chunk, n_chunk);
1101
1102 let n_job_count = n_chunk.msrv_div_ceil(nr);
1103 let chunk_count = m.msrv_div_ceil(MC);
1104
1105 let job_count = n_job_count * chunk_count;
1106
1107 let job = |idx: usize| {
1108 assert!(all(
1109 acc.row_stride() == 1,
1110 a.row_stride() == 1,
1111 b.row_stride() == 1,
1112 ));
1113
1114 let col_inner = (idx % n_job_count) * nr;
1115 let row_outer = (idx / n_job_count) * MC;
1116 let m_chunk = min(MC, m - row_outer);
1117
1118 let mut row_inner = 0;
1119 let ncols = min(nr, n_chunk - col_inner);
1120 let ukr_j = ncols;
1121
1122 while row_inner < m_chunk {
1123 let nrows = min(mr, m_chunk - row_inner);
1124
1125 let ukr_i = nrows / lane_count;
1126
1127 let a = a_panel.submatrix(row_outer + row_inner, 0, nrows, k_chunk);
1128 let b = b_block.submatrix(0, col_inner, k_chunk, ncols);
1129 let acc = acc
1130 .rb()
1131 .submatrix(row_outer + row_inner, col_inner, nrows, ncols);
1132 let acc = unsafe { acc.const_cast() };
1133
1134 match conj_b {
1135 Conj::Yes => {
1136 let conj_b = YesConj;
1137 if MicroKernelShape::<E>::IS_2X2 {
1138 match (ukr_i, ukr_j) {
1139 (2, 2) => {
1140 arch.dispatch(Ukr::<2, 2, _, E> { conj_b, acc, a, b })
1141 }
1142 (2, 1) => {
1143 arch.dispatch(Ukr::<2, 1, _, E> { conj_b, acc, a, b })
1144 }
1145 (1, 2) => {
1146 arch.dispatch(Ukr::<1, 2, _, E> { conj_b, acc, a, b })
1147 }
1148 (1, 1) => {
1149 arch.dispatch(Ukr::<1, 1, _, E> { conj_b, acc, a, b })
1150 }
1151 _ => unreachable!(),
1152 }
1153 } else if MicroKernelShape::<E>::IS_2X1 {
1154 match (ukr_i, ukr_j) {
1155 (2, 1) => {
1156 arch.dispatch(Ukr::<2, 1, _, E> { conj_b, acc, a, b })
1157 }
1158 (1, 1) => {
1159 arch.dispatch(Ukr::<1, 1, _, E> { conj_b, acc, a, b })
1160 }
1161 _ => unreachable!(),
1162 }
1163 } else if MicroKernelShape::<E>::IS_1X1 {
1164 match (ukr_i, ukr_j) {
1165 (1, 1) => {
1166 arch.dispatch(Ukr::<1, 1, _, E> { conj_b, acc, a, b })
1167 }
1168 _ => unreachable!(),
1169 }
1170 } else {
1171 unreachable!()
1172 }
1173 }
1174 Conj::No => {
1175 let conj_b = NoConj;
1176 if MicroKernelShape::<E>::IS_2X2 {
1177 match (ukr_i, ukr_j) {
1178 (2, 2) => {
1179 arch.dispatch(Ukr::<2, 2, _, E> { conj_b, acc, a, b })
1180 }
1181 (2, 1) => {
1182 arch.dispatch(Ukr::<2, 1, _, E> { conj_b, acc, a, b })
1183 }
1184 (1, 2) => {
1185 arch.dispatch(Ukr::<1, 2, _, E> { conj_b, acc, a, b })
1186 }
1187 (1, 1) => {
1188 arch.dispatch(Ukr::<1, 1, _, E> { conj_b, acc, a, b })
1189 }
1190 _ => unreachable!(),
1191 }
1192 } else if MicroKernelShape::<E>::IS_2X1 {
1193 match (ukr_i, ukr_j) {
1194 (2, 1) => {
1195 arch.dispatch(Ukr::<2, 1, _, E> { conj_b, acc, a, b })
1196 }
1197 (1, 1) => {
1198 arch.dispatch(Ukr::<1, 1, _, E> { conj_b, acc, a, b })
1199 }
1200 _ => unreachable!(),
1201 }
1202 } else if MicroKernelShape::<E>::IS_1X1 {
1203 match (ukr_i, ukr_j) {
1204 (1, 1) => {
1205 arch.dispatch(Ukr::<1, 1, _, E> { conj_b, acc, a, b })
1206 }
1207 _ => unreachable!(),
1208 }
1209 } else {
1210 unreachable!()
1211 }
1212 }
1213 }
1214 row_inner += nrows;
1215 }
1216 };
1217
1218 crate::for_each_raw(job_count, job, parallelism);
1219
1220 depth_outer += k_chunk;
1221 }
1222
1223 col_outer += n_chunk;
1224 }
1225}
1226
1227#[doc(hidden)]
1228pub fn matmul_with_conj_gemm_dispatch<E: ComplexField>(
1229 mut acc: MatMut<'_, E>,
1230 lhs: MatRef<'_, E>,
1231 conj_lhs: Conj,
1232 rhs: MatRef<'_, E>,
1233 conj_rhs: Conj,
1234 alpha: Option<E>,
1235 beta: E,
1236 parallelism: Parallelism,
1237 _use_gemm: bool,
1238) {
1239 assert!(all(
1240 acc.nrows() == lhs.nrows(),
1241 acc.ncols() == rhs.ncols(),
1242 lhs.ncols() == rhs.nrows(),
1243 ));
1244
1245 let m = acc.nrows();
1246 let n = acc.ncols();
1247 let k = lhs.ncols();
1248
1249 if m == 0 || n == 0 {
1250 return;
1251 }
1252
1253 if m == 1 && n == 1 {
1254 let mut acc = acc;
1255 let ab = inner_prod::inner_prod_with_conj(lhs.transpose(), conj_lhs, rhs, conj_rhs);
1256 match alpha {
1257 Some(alpha) => {
1258 acc.write(
1259 0,
1260 0,
1261 E::faer_add(acc.read(0, 0).faer_mul(alpha), ab.faer_mul(beta)),
1262 );
1263 }
1264 None => {
1265 acc.write(0, 0, ab.faer_mul(beta));
1266 }
1267 }
1268 return;
1269 }
1270
1271 if k == 1 {
1272 outer_prod::outer_prod_with_conj(
1273 acc,
1274 lhs,
1275 conj_lhs,
1276 rhs.transpose(),
1277 conj_rhs,
1278 alpha,
1279 beta,
1280 );
1281 return;
1282 }
1283 if n == 1 {
1284 matvec::matvec_with_conj(acc, lhs, conj_lhs, rhs, conj_rhs, alpha, beta);
1285 return;
1286 }
1287 if m == 1 {
1288 matvec::matvec_with_conj(
1289 acc.transpose_mut(),
1290 rhs.transpose(),
1291 conj_rhs,
1292 lhs.transpose(),
1293 conj_lhs,
1294 alpha,
1295 beta,
1296 );
1297 return;
1298 }
1299
1300 unsafe {
1301 if m + n < 32 && k <= 6 {
1302 macro_rules! small_gemm {
1303 ($term: expr) => {
1304 let term = $term;
1305 match k {
1306 0 => match alpha {
1307 Some(alpha) => {
1308 for i in 0..m {
1309 for j in 0..n {
1310 acc.write_unchecked(
1311 i,
1312 j,
1313 acc.read_unchecked(i, j).faer_mul(alpha),
1314 )
1315 }
1316 }
1317 }
1318 None => {
1319 for i in 0..m {
1320 for j in 0..n {
1321 acc.write_unchecked(i, j, E::faer_zero())
1322 }
1323 }
1324 }
1325 },
1326 1 => match alpha {
1327 Some(alpha) => {
1328 for i in 0..m {
1329 for j in 0..n {
1330 let dot = term(i, j, 0);
1331 acc.write_unchecked(
1332 i,
1333 j,
1334 E::faer_add(
1335 acc.read_unchecked(i, j).faer_mul(alpha),
1336 dot.faer_mul(beta),
1337 ),
1338 )
1339 }
1340 }
1341 }
1342 None => {
1343 for i in 0..m {
1344 for j in 0..n {
1345 let dot = term(i, j, 0);
1346 acc.write_unchecked(i, j, dot.faer_mul(beta))
1347 }
1348 }
1349 }
1350 },
1351 2 => match alpha {
1352 Some(alpha) => {
1353 for i in 0..m {
1354 for j in 0..n {
1355 let dot = term(i, j, 0).faer_add(term(i, j, 1));
1356 acc.write_unchecked(
1357 i,
1358 j,
1359 E::faer_add(
1360 acc.read_unchecked(i, j).faer_mul(alpha),
1361 dot.faer_mul(beta),
1362 ),
1363 )
1364 }
1365 }
1366 }
1367 None => {
1368 for i in 0..m {
1369 for j in 0..n {
1370 let dot = term(i, j, 0).faer_add(term(i, j, 1));
1371 acc.write_unchecked(i, j, dot.faer_mul(beta))
1372 }
1373 }
1374 }
1375 },
1376 3 => match alpha {
1377 Some(alpha) => {
1378 for i in 0..m {
1379 for j in 0..n {
1380 let dot = term(i, j, 0)
1381 .faer_add(term(i, j, 1))
1382 .faer_add(term(i, j, 2));
1383 acc.write_unchecked(
1384 i,
1385 j,
1386 E::faer_add(
1387 acc.read_unchecked(i, j).faer_mul(alpha),
1388 dot.faer_mul(beta),
1389 ),
1390 )
1391 }
1392 }
1393 }
1394 None => {
1395 for i in 0..m {
1396 for j in 0..n {
1397 let dot = term(i, j, 0)
1398 .faer_add(term(i, j, 1))
1399 .faer_add(term(i, j, 2));
1400 acc.write_unchecked(i, j, dot.faer_mul(beta))
1401 }
1402 }
1403 }
1404 },
1405 4 => match alpha {
1406 Some(alpha) => {
1407 for i in 0..m {
1408 for j in 0..n {
1409 let dot = E::faer_add(
1410 E::faer_add(term(i, j, 0), term(i, j, 1)),
1411 E::faer_add(term(i, j, 2), term(i, j, 3)),
1412 );
1413
1414 acc.write_unchecked(
1415 i,
1416 j,
1417 E::faer_add(
1418 acc.read_unchecked(i, j).faer_mul(alpha),
1419 dot.faer_mul(beta),
1420 ),
1421 )
1422 }
1423 }
1424 }
1425 None => {
1426 for i in 0..m {
1427 for j in 0..n {
1428 let dot = E::faer_add(
1429 E::faer_add(term(i, j, 0), term(i, j, 1)),
1430 E::faer_add(term(i, j, 2), term(i, j, 3)),
1431 );
1432 acc.write_unchecked(i, j, dot.faer_mul(beta))
1433 }
1434 }
1435 }
1436 },
1437 5 => match alpha {
1438 Some(alpha) => {
1439 for i in 0..m {
1440 for j in 0..n {
1441 let dot = E::faer_add(
1442 E::faer_add(term(i, j, 0), term(i, j, 1))
1443 .faer_add(term(i, j, 2)),
1444 E::faer_add(term(i, j, 3), term(i, j, 4)),
1445 );
1446
1447 acc.write_unchecked(
1448 i,
1449 j,
1450 E::faer_add(
1451 acc.read_unchecked(i, j).faer_mul(alpha),
1452 dot.faer_mul(beta),
1453 ),
1454 )
1455 }
1456 }
1457 }
1458 None => {
1459 for i in 0..m {
1460 for j in 0..n {
1461 let dot = E::faer_add(
1462 E::faer_add(term(i, j, 0), term(i, j, 1))
1463 .faer_add(term(i, j, 2)),
1464 E::faer_add(term(i, j, 3), term(i, j, 4)),
1465 );
1466 acc.write_unchecked(i, j, dot.faer_mul(beta))
1467 }
1468 }
1469 }
1470 },
1471 6 => match alpha {
1472 Some(alpha) => {
1473 for i in 0..m {
1474 for j in 0..n {
1475 let dot = E::faer_add(
1476 E::faer_add(term(i, j, 0), term(i, j, 1))
1477 .faer_add(term(i, j, 2)),
1478 E::faer_add(term(i, j, 3), term(i, j, 4))
1479 .faer_add(term(i, j, 5)),
1480 );
1481
1482 acc.write_unchecked(
1483 i,
1484 j,
1485 E::faer_add(
1486 acc.read_unchecked(i, j).faer_mul(alpha),
1487 dot.faer_mul(beta),
1488 ),
1489 )
1490 }
1491 }
1492 }
1493 None => {
1494 for i in 0..m {
1495 for j in 0..n {
1496 let dot = E::faer_add(
1497 E::faer_add(term(i, j, 0), term(i, j, 1))
1498 .faer_add(term(i, j, 2)),
1499 E::faer_add(term(i, j, 3), term(i, j, 4))
1500 .faer_add(term(i, j, 5)),
1501 );
1502 acc.write_unchecked(i, j, dot.faer_mul(beta))
1503 }
1504 }
1505 }
1506 },
1507 _ => unreachable!(),
1508 }
1509 };
1510 }
1511
1512 match (conj_lhs, conj_rhs) {
1513 (Conj::Yes, Conj::Yes) => {
1514 let term = {
1515 #[inline(always)]
1516 |i, j, depth| {
1517 (lhs.read_unchecked(i, depth)
1518 .faer_mul(rhs.read_unchecked(depth, j)))
1519 .faer_conj()
1520 }
1521 };
1522 small_gemm!(term);
1523 }
1524 (Conj::Yes, Conj::No) => {
1525 let term = {
1526 #[inline(always)]
1527 |i, j, depth| {
1528 lhs.read_unchecked(i, depth)
1529 .faer_conj()
1530 .faer_mul(rhs.read_unchecked(depth, j))
1531 }
1532 };
1533 small_gemm!(term);
1534 }
1535 (Conj::No, Conj::Yes) => {
1536 let term = {
1537 #[inline(always)]
1538 |i, j, depth| {
1539 lhs.read_unchecked(i, depth)
1540 .faer_mul(rhs.read_unchecked(depth, j).faer_conj())
1541 }
1542 };
1543 small_gemm!(term);
1544 }
1545 (Conj::No, Conj::No) => {
1546 let term = {
1547 #[inline(always)]
1548 |i, j, depth| {
1549 lhs.read_unchecked(i, depth)
1550 .faer_mul(rhs.read_unchecked(depth, j))
1551 }
1552 };
1553 small_gemm!(term);
1554 }
1555 }
1556 return;
1557 }
1558 }
1559
1560 #[cfg(not(test))]
1561 let _use_gemm = true;
1562
1563 if _use_gemm {
1564 let gemm_parallelism = match parallelism {
1565 Parallelism::None => gemm::Parallelism::None,
1566 #[cfg(feature = "rayon")]
1567 Parallelism::Rayon(0) => gemm::Parallelism::Rayon(rayon::current_num_threads()),
1568 #[cfg(feature = "rayon")]
1569 Parallelism::Rayon(n_threads) => gemm::Parallelism::Rayon(n_threads),
1570 };
1571 if coe::is_same::<f32, E>() {
1572 let mut acc: MatMut<'_, f32> = coe::coerce(acc);
1573 let a: MatRef<'_, f32> = coe::coerce(lhs);
1574 let b: MatRef<'_, f32> = coe::coerce(rhs);
1575 let alpha: Option<f32> = coe::coerce_static(alpha);
1576 let beta: f32 = coe::coerce_static(beta);
1577 unsafe {
1578 gemm::gemm(
1579 m,
1580 n,
1581 k,
1582 acc.rb_mut().as_ptr_mut(),
1583 acc.col_stride(),
1584 acc.row_stride(),
1585 alpha.is_some(),
1586 a.as_ptr(),
1587 a.col_stride(),
1588 a.row_stride(),
1589 b.as_ptr(),
1590 b.col_stride(),
1591 b.row_stride(),
1592 alpha.unwrap_or(0.0),
1593 beta,
1594 false,
1595 conj_lhs == Conj::Yes,
1596 conj_rhs == Conj::Yes,
1597 gemm_parallelism,
1598 )
1599 };
1600 return;
1601 }
1602 if coe::is_same::<f64, E>() {
1603 let mut acc: MatMut<'_, f64> = coe::coerce(acc);
1604 let a: MatRef<'_, f64> = coe::coerce(lhs);
1605 let b: MatRef<'_, f64> = coe::coerce(rhs);
1606 let alpha: Option<f64> = coe::coerce_static(alpha);
1607 let beta: f64 = coe::coerce_static(beta);
1608 unsafe {
1609 gemm::gemm(
1610 m,
1611 n,
1612 k,
1613 acc.rb_mut().as_ptr_mut(),
1614 acc.col_stride(),
1615 acc.row_stride(),
1616 alpha.is_some(),
1617 a.as_ptr(),
1618 a.col_stride(),
1619 a.row_stride(),
1620 b.as_ptr(),
1621 b.col_stride(),
1622 b.row_stride(),
1623 alpha.unwrap_or(0.0),
1624 beta,
1625 false,
1626 conj_lhs == Conj::Yes,
1627 conj_rhs == Conj::Yes,
1628 gemm_parallelism,
1629 )
1630 };
1631 return;
1632 }
1633 if coe::is_same::<c32, E>() {
1634 let mut acc: MatMut<'_, c32> = coe::coerce(acc);
1635 let a: MatRef<'_, c32> = coe::coerce(lhs);
1636 let b: MatRef<'_, c32> = coe::coerce(rhs);
1637 let alpha: Option<c32> = coe::coerce_static(alpha);
1638 let beta: c32 = coe::coerce_static(beta);
1639 unsafe {
1640 gemm::gemm(
1641 m,
1642 n,
1643 k,
1644 acc.rb_mut().as_ptr_mut() as *mut gemm::c32,
1645 acc.col_stride(),
1646 acc.row_stride(),
1647 alpha.is_some(),
1648 a.as_ptr() as *const gemm::c32,
1649 a.col_stride(),
1650 a.row_stride(),
1651 b.as_ptr() as *const gemm::c32,
1652 b.col_stride(),
1653 b.row_stride(),
1654 alpha.unwrap_or(c32 { re: 0.0, im: 0.0 }).into(),
1655 beta.into(),
1656 false,
1657 conj_lhs == Conj::Yes,
1658 conj_rhs == Conj::Yes,
1659 gemm_parallelism,
1660 )
1661 };
1662 return;
1663 }
1664 if coe::is_same::<c64, E>() {
1665 let mut acc: MatMut<'_, c64> = coe::coerce(acc);
1666 let a: MatRef<'_, c64> = coe::coerce(lhs);
1667 let b: MatRef<'_, c64> = coe::coerce(rhs);
1668 let alpha: Option<c64> = coe::coerce_static(alpha);
1669 let beta: c64 = coe::coerce_static(beta);
1670 unsafe {
1671 gemm::gemm(
1672 m,
1673 n,
1674 k,
1675 acc.rb_mut().as_ptr_mut() as *mut gemm::c64,
1676 acc.col_stride(),
1677 acc.row_stride(),
1678 alpha.is_some(),
1679 a.as_ptr() as *const gemm::c64,
1680 a.col_stride(),
1681 a.row_stride(),
1682 b.as_ptr() as *const gemm::c64,
1683 b.col_stride(),
1684 b.row_stride(),
1685 alpha.unwrap_or(c64 { re: 0.0, im: 0.0 }).into(),
1686 beta.into(),
1687 false,
1688 conj_lhs == Conj::Yes,
1689 conj_rhs == Conj::Yes,
1690 gemm_parallelism,
1691 )
1692 };
1693 return;
1694 }
1695 }
1696
1697 let arch = E::Simd::default();
1698 let lane_count = arch.dispatch(SimdLaneCount::<E> {
1699 __marker: PhantomData,
1700 });
1701
1702 let mut a = lhs;
1703 let mut b = rhs;
1704 let mut conj_a = conj_lhs;
1705 let mut conj_b = conj_rhs;
1706
1707 if n < m {
1708 (a, b) = (b.transpose(), a.transpose());
1709 core::mem::swap(&mut conj_a, &mut conj_b);
1710 acc = acc.transpose_mut();
1711 }
1712
1713 if b.row_stride() < 0 {
1714 a = a.reverse_cols();
1715 b = b.reverse_rows();
1716 }
1717
1718 let m = acc.nrows();
1719 let n = acc.ncols();
1720
1721 let padded_m = m.msrv_checked_next_multiple_of(lane_count).unwrap();
1722
1723 let mut a_copy = a.to_owned();
1724 a_copy.resize_with(padded_m, k, |_, _| E::faer_zero());
1725 let a_copy = a_copy.as_ref();
1726 let mut tmp = crate::Mat::<E>::zeros(padded_m, n);
1727 let tmp_conj_b = match (conj_a, conj_b) {
1728 (Conj::Yes, Conj::Yes) | (Conj::No, Conj::No) => Conj::No,
1729 (Conj::Yes, Conj::No) | (Conj::No, Conj::Yes) => Conj::Yes,
1730 };
1731 if b.row_stride() == 1 {
1732 matmul_with_conj_impl(tmp.as_mut(), a_copy, b, tmp_conj_b, parallelism);
1733 } else {
1734 let b = b.to_owned();
1735 matmul_with_conj_impl(tmp.as_mut(), a_copy, b.as_ref(), tmp_conj_b, parallelism);
1736 }
1737
1738 let tmp = tmp.as_ref().subrows(0, m);
1739
1740 match alpha {
1741 Some(alpha) => match conj_a {
1742 Conj::Yes => zipped!(acc, tmp).for_each(|unzipped!(mut acc, tmp)| {
1743 acc.write(E::faer_add(
1744 acc.read().faer_mul(alpha),
1745 tmp.read().faer_conj().faer_mul(beta),
1746 ))
1747 }),
1748 Conj::No => zipped!(acc, tmp).for_each(|unzipped!(mut acc, tmp)| {
1749 acc.write(E::faer_add(
1750 acc.read().faer_mul(alpha),
1751 tmp.read().faer_mul(beta),
1752 ))
1753 }),
1754 },
1755 None => match conj_a {
1756 Conj::Yes => {
1757 zipped!(acc, tmp).for_each(|unzipped!(mut acc, tmp)| {
1758 acc.write(tmp.read().faer_conj().faer_mul(beta))
1759 });
1760 }
1761 Conj::No => {
1762 zipped!(acc, tmp)
1763 .for_each(|unzipped!(mut acc, tmp)| acc.write(tmp.read().faer_mul(beta)));
1764 }
1765 },
1766 }
1767}
1768
1769#[inline]
1826#[track_caller]
1827pub fn matmul_with_conj<E: ComplexField>(
1828 acc: MatMut<'_, E>,
1829 lhs: MatRef<'_, E>,
1830 conj_lhs: Conj,
1831 rhs: MatRef<'_, E>,
1832 conj_rhs: Conj,
1833 alpha: Option<E>,
1834 beta: E,
1835 parallelism: Parallelism,
1836) {
1837 assert!(all(
1838 acc.nrows() == lhs.nrows(),
1839 acc.ncols() == rhs.ncols(),
1840 lhs.ncols() == rhs.nrows(),
1841 ));
1842 matmul_with_conj_gemm_dispatch(
1843 acc,
1844 lhs,
1845 conj_lhs,
1846 rhs,
1847 conj_rhs,
1848 alpha,
1849 beta,
1850 parallelism,
1851 true,
1852 );
1853}
1854
1855#[track_caller]
1904pub fn matmul<E: ComplexField, LhsE: Conjugate<Canonical = E>, RhsE: Conjugate<Canonical = E>>(
1905 acc: MatMut<'_, E>,
1906 lhs: MatRef<'_, LhsE>,
1907 rhs: MatRef<'_, RhsE>,
1908 alpha: Option<E>,
1909 beta: E,
1910 parallelism: Parallelism,
1911) {
1912 let (lhs, conj_lhs) = lhs.canonicalize();
1913 let (rhs, conj_rhs) = rhs.canonicalize();
1914 matmul_with_conj::<E>(acc, lhs, conj_lhs, rhs, conj_rhs, alpha, beta, parallelism);
1915}
1916
1917macro_rules! stack_mat_16x16_begin {
1918 ($name: ident, $nrows: expr, $ncols: expr, $rs: expr, $cs: expr, $ty: ty) => {
1919 let __nrows: usize = $nrows;
1920 let __ncols: usize = $ncols;
1921 let __rs: isize = $rs;
1922 let __cs: isize = $cs;
1923 let mut __data = <$ty as $crate::Entity>::faer_map(
1924 <$ty as $crate::Entity>::UNIT,
1925 #[inline(always)]
1926 |()| unsafe {
1927 $crate::transmute_unchecked::<
1928 ::core::mem::MaybeUninit<[<$ty as $crate::Entity>::Unit; 16 * 16]>,
1929 [::core::mem::MaybeUninit<<$ty as $crate::Entity>::Unit>; 16 * 16],
1930 >(::core::mem::MaybeUninit::<
1931 [<$ty as $crate::Entity>::Unit; 16 * 16],
1932 >::uninit())
1933 },
1934 );
1935
1936 <$ty as $crate::Entity>::faer_map(
1937 <$ty as $crate::Entity>::faer_zip(
1938 <$ty as $crate::Entity>::faer_as_mut(&mut __data),
1939 <$ty as $crate::Entity>::faer_into_units(<$ty as $crate::ComplexField>::faer_zero()),
1940 ),
1941 #[inline(always)]
1942 |(__data, zero)| {
1943 let __data: &mut _ = __data;
1944 for __data in __data {
1945 let __data : &mut _ = __data;
1946 *__data = ::core::mem::MaybeUninit::new(::core::clone::Clone::clone(&zero));
1947 }
1948 },
1949 );
1950 let mut __data =
1951 <$ty as $crate::Entity>::faer_map(<$ty as $crate::Entity>::faer_as_mut(&mut __data), |__data: &mut _| {
1952 (__data as *mut [::core::mem::MaybeUninit<<$ty as $crate::Entity>::Unit>; 16 * 16]
1953 as *mut <$ty as $crate::Entity>::Unit)
1954 });
1955
1956 let mut $name = unsafe {
1957 $crate::mat::from_raw_parts_mut::<'_, $ty>(__data, __nrows, __ncols, 1isize, 16isize)
1958 };
1959
1960 if __cs.unsigned_abs() < __rs.unsigned_abs() {
1961 $name = $name.transpose_mut();
1962 }
1963 if __rs == -1 {
1964 $name = $name.reverse_rows_mut();
1965 }
1966 if __cs == -1 {
1967 $name = $name.reverse_cols_mut();
1968 }
1969 };
1970}
1971
1972pub mod triangular {
1975 use super::*;
1976 use crate::{assert, debug_assert, join_raw, zip::Diag};
1977
1978 #[repr(u8)]
1979 #[derive(Copy, Clone, Debug)]
1980 pub(crate) enum DiagonalKind {
1981 Zero,
1982 Unit,
1983 Generic,
1984 }
1985
1986 unsafe fn copy_lower<E: ComplexField>(
1987 mut dst: MatMut<'_, E>,
1988 src: MatRef<'_, E>,
1989 src_diag: DiagonalKind,
1990 ) {
1991 let n = dst.nrows();
1992 debug_assert!(n == dst.nrows());
1993 debug_assert!(n == dst.ncols());
1994 debug_assert!(n == src.nrows());
1995 debug_assert!(n == src.ncols());
1996
1997 let strict = match src_diag {
1998 DiagonalKind::Zero => {
1999 for j in 0..n {
2000 dst.write_unchecked(j, j, E::faer_zero());
2001 }
2002 true
2003 }
2004 DiagonalKind::Unit => {
2005 for j in 0..n {
2006 dst.write_unchecked(j, j, E::faer_one());
2007 }
2008 true
2009 }
2010 DiagonalKind::Generic => false,
2011 };
2012
2013 zipped!(dst.rb_mut())
2014 .for_each_triangular_upper(Diag::Skip, |unzipped!(mut dst)| dst.write(E::faer_zero()));
2015 zipped!(dst, src).for_each_triangular_lower(
2016 if strict { Diag::Skip } else { Diag::Include },
2017 |unzipped!(mut dst, src)| dst.write(src.read()),
2018 );
2019 }
2020
2021 unsafe fn accum_lower<E: ComplexField>(
2022 dst: MatMut<'_, E>,
2023 src: MatRef<'_, E>,
2024 skip_diag: bool,
2025 alpha: Option<E>,
2026 ) {
2027 let n = dst.nrows();
2028 debug_assert!(n == dst.nrows());
2029 debug_assert!(n == dst.ncols());
2030 debug_assert!(n == src.nrows());
2031 debug_assert!(n == src.ncols());
2032
2033 match alpha {
2034 Some(alpha) => {
2035 zipped!(dst, src).for_each_triangular_lower(
2036 if skip_diag { Diag::Skip } else { Diag::Include },
2037 |unzipped!(mut dst, src)| {
2038 dst.write(alpha.faer_mul(dst.read().faer_add(src.read())))
2039 },
2040 );
2041 }
2042 None => {
2043 zipped!(dst, src).for_each_triangular_lower(
2044 if skip_diag { Diag::Skip } else { Diag::Include },
2045 |unzipped!(mut dst, src)| dst.write(src.read()),
2046 );
2047 }
2048 }
2049 }
2050
2051 #[inline]
2052 unsafe fn copy_upper<E: ComplexField>(
2053 dst: MatMut<'_, E>,
2054 src: MatRef<'_, E>,
2055 src_diag: DiagonalKind,
2056 ) {
2057 copy_lower(dst.transpose_mut(), src.transpose(), src_diag)
2058 }
2059
2060 #[inline]
2061 unsafe fn mul<E: ComplexField>(
2062 dst: MatMut<'_, E>,
2063 lhs: MatRef<'_, E>,
2064 rhs: MatRef<'_, E>,
2065 alpha: Option<E>,
2066 beta: E,
2067 conj_lhs: Conj,
2068 conj_rhs: Conj,
2069 parallelism: Parallelism,
2070 ) {
2071 super::matmul_with_conj(dst, lhs, conj_lhs, rhs, conj_rhs, alpha, beta, parallelism);
2072 }
2073
2074 unsafe fn mat_x_lower_into_lower_impl_unchecked<E: ComplexField>(
2075 dst: MatMut<'_, E>,
2076 skip_diag: bool,
2077 lhs: MatRef<'_, E>,
2078 rhs: MatRef<'_, E>,
2079 rhs_diag: DiagonalKind,
2080 alpha: Option<E>,
2081 beta: E,
2082 conj_lhs: Conj,
2083 conj_rhs: Conj,
2084 parallelism: Parallelism,
2085 ) {
2086 let n = dst.nrows();
2087 debug_assert!(n == dst.nrows());
2088 debug_assert!(n == dst.ncols());
2089 debug_assert!(n == lhs.nrows());
2090 debug_assert!(n == lhs.ncols());
2091 debug_assert!(n == rhs.nrows());
2092 debug_assert!(n == rhs.ncols());
2093
2094 if n <= 16 {
2095 let op = {
2096 #[inline(never)]
2097 || {
2098 stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E);
2099 stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E);
2100
2101 copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
2102 mul(
2103 temp_dst.rb_mut(),
2104 lhs,
2105 temp_rhs.rb(),
2106 None,
2107 beta,
2108 conj_lhs,
2109 conj_rhs,
2110 parallelism,
2111 );
2112 accum_lower(dst, temp_dst.rb(), skip_diag, alpha);
2113 }
2114 };
2115 op();
2116 } else {
2117 let bs = n / 2;
2118
2119 let (mut dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs);
2120 let (lhs_top_left, lhs_top_right, lhs_bot_left, lhs_bot_right) = lhs.split_at(bs, bs);
2121 let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
2122
2123 mul(
2131 dst_bot_left.rb_mut(),
2132 lhs_bot_right,
2133 rhs_bot_left,
2134 alpha,
2135 beta,
2136 conj_lhs,
2137 conj_rhs,
2138 parallelism,
2139 );
2140 mat_x_lower_into_lower_impl_unchecked(
2141 dst_bot_right,
2142 skip_diag,
2143 lhs_bot_right,
2144 rhs_bot_right,
2145 rhs_diag,
2146 alpha,
2147 beta,
2148 conj_lhs,
2149 conj_rhs,
2150 parallelism,
2151 );
2152
2153 mat_x_lower_into_lower_impl_unchecked(
2154 dst_top_left.rb_mut(),
2155 skip_diag,
2156 lhs_top_left,
2157 rhs_top_left,
2158 rhs_diag,
2159 alpha,
2160 beta,
2161 conj_lhs,
2162 conj_rhs,
2163 parallelism,
2164 );
2165 mat_x_mat_into_lower_impl_unchecked(
2166 dst_top_left,
2167 skip_diag,
2168 lhs_top_right,
2169 rhs_bot_left,
2170 Some(E::faer_one()),
2171 beta,
2172 conj_lhs,
2173 conj_rhs,
2174 parallelism,
2175 );
2176 mat_x_lower_impl_unchecked(
2177 dst_bot_left,
2178 lhs_bot_left,
2179 rhs_top_left,
2180 rhs_diag,
2181 Some(E::faer_one()),
2182 beta,
2183 conj_lhs,
2184 conj_rhs,
2185 parallelism,
2186 );
2187 }
2188 }
2189
2190 unsafe fn mat_x_lower_impl_unchecked<E: ComplexField>(
2191 dst: MatMut<'_, E>,
2192 lhs: MatRef<'_, E>,
2193 rhs: MatRef<'_, E>,
2194 rhs_diag: DiagonalKind,
2195 alpha: Option<E>,
2196 beta: E,
2197 conj_lhs: Conj,
2198 conj_rhs: Conj,
2199 parallelism: Parallelism,
2200 ) {
2201 let n = rhs.nrows();
2202 let m = lhs.nrows();
2203 debug_assert!(m == lhs.nrows());
2204 debug_assert!(n == lhs.ncols());
2205 debug_assert!(n == rhs.nrows());
2206 debug_assert!(n == rhs.ncols());
2207 debug_assert!(m == dst.nrows());
2208 debug_assert!(n == dst.ncols());
2209
2210 let join_parallelism = if n * n * m < 128 * 128 * 64 {
2211 Parallelism::None
2212 } else {
2213 parallelism
2214 };
2215
2216 if n <= 16 {
2217 let op = {
2218 #[inline(never)]
2219 || {
2220 stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E);
2221
2222 copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
2223
2224 mul(
2225 dst,
2226 lhs,
2227 temp_rhs.rb(),
2228 alpha,
2229 beta,
2230 conj_lhs,
2231 conj_rhs,
2232 parallelism,
2233 );
2234 }
2235 };
2236 op();
2237 } else {
2238 let bs = n / 2;
2242
2243 let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
2244 let (lhs_left, lhs_right) = lhs.split_at_col(bs);
2245 let (mut dst_left, mut dst_right) = dst.split_at_col_mut(bs);
2246
2247 join_raw(
2248 |parallelism| {
2249 mat_x_lower_impl_unchecked(
2250 dst_left.rb_mut(),
2251 lhs_left,
2252 rhs_top_left,
2253 rhs_diag,
2254 alpha,
2255 beta,
2256 conj_lhs,
2257 conj_rhs,
2258 parallelism,
2259 )
2260 },
2261 |parallelism| {
2262 mat_x_lower_impl_unchecked(
2263 dst_right.rb_mut(),
2264 lhs_right,
2265 rhs_bot_right,
2266 rhs_diag,
2267 alpha,
2268 beta,
2269 conj_lhs,
2270 conj_rhs,
2271 parallelism,
2272 )
2273 },
2274 join_parallelism,
2275 );
2276 mul(
2277 dst_left,
2278 lhs_right,
2279 rhs_bot_left,
2280 Some(E::faer_one()),
2281 beta,
2282 conj_lhs,
2283 conj_rhs,
2284 parallelism,
2285 );
2286 }
2287 }
2288
2289 unsafe fn lower_x_lower_into_lower_impl_unchecked<E: ComplexField>(
2290 dst: MatMut<'_, E>,
2291 skip_diag: bool,
2292 lhs: MatRef<'_, E>,
2293 lhs_diag: DiagonalKind,
2294 rhs: MatRef<'_, E>,
2295 rhs_diag: DiagonalKind,
2296 alpha: Option<E>,
2297 beta: E,
2298 conj_lhs: Conj,
2299 conj_rhs: Conj,
2300 parallelism: Parallelism,
2301 ) {
2302 let n = dst.nrows();
2303 debug_assert!(n == lhs.nrows());
2304 debug_assert!(n == lhs.ncols());
2305 debug_assert!(n == rhs.nrows());
2306 debug_assert!(n == rhs.ncols());
2307 debug_assert!(n == dst.nrows());
2308 debug_assert!(n == dst.ncols());
2309
2310 if n <= 16 {
2311 let op = {
2312 #[inline(never)]
2313 || {
2314 stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E);
2315 stack_mat_16x16_begin!(temp_lhs, n, n, lhs.row_stride(), lhs.col_stride(), E);
2316 stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E);
2317
2318 copy_lower(temp_lhs.rb_mut(), lhs, lhs_diag);
2319 copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
2320
2321 mul(
2322 temp_dst.rb_mut(),
2323 temp_lhs.rb(),
2324 temp_rhs.rb(),
2325 None,
2326 beta,
2327 conj_lhs,
2328 conj_rhs,
2329 parallelism,
2330 );
2331 accum_lower(dst, temp_dst.rb(), skip_diag, alpha);
2332 }
2333 };
2334 op();
2335 } else {
2336 let bs = n / 2;
2337
2338 let (dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs);
2339 let (lhs_top_left, _, lhs_bot_left, lhs_bot_right) = lhs.split_at(bs, bs);
2340 let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
2341
2342 lower_x_lower_into_lower_impl_unchecked(
2348 dst_top_left,
2349 skip_diag,
2350 lhs_top_left,
2351 lhs_diag,
2352 rhs_top_left,
2353 rhs_diag,
2354 alpha,
2355 beta,
2356 conj_lhs,
2357 conj_rhs,
2358 parallelism,
2359 );
2360 mat_x_lower_impl_unchecked(
2361 dst_bot_left.rb_mut(),
2362 lhs_bot_left,
2363 rhs_top_left,
2364 rhs_diag,
2365 alpha,
2366 beta,
2367 conj_lhs,
2368 conj_rhs,
2369 parallelism,
2370 );
2371 mat_x_lower_impl_unchecked(
2372 dst_bot_left.reverse_rows_and_cols_mut().transpose_mut(),
2373 rhs_bot_left.reverse_rows_and_cols().transpose(),
2374 lhs_bot_right.reverse_rows_and_cols().transpose(),
2375 lhs_diag,
2376 Some(E::faer_one()),
2377 beta,
2378 conj_rhs,
2379 conj_lhs,
2380 parallelism,
2381 );
2382 lower_x_lower_into_lower_impl_unchecked(
2383 dst_bot_right,
2384 skip_diag,
2385 lhs_bot_right,
2386 lhs_diag,
2387 rhs_bot_right,
2388 rhs_diag,
2389 alpha,
2390 beta,
2391 conj_lhs,
2392 conj_rhs,
2393 parallelism,
2394 )
2395 }
2396 }
2397
2398 unsafe fn upper_x_lower_impl_unchecked<E: ComplexField>(
2399 dst: MatMut<'_, E>,
2400 lhs: MatRef<'_, E>,
2401 lhs_diag: DiagonalKind,
2402 rhs: MatRef<'_, E>,
2403 rhs_diag: DiagonalKind,
2404 alpha: Option<E>,
2405 beta: E,
2406 conj_lhs: Conj,
2407 conj_rhs: Conj,
2408 parallelism: Parallelism,
2409 ) {
2410 let n = dst.nrows();
2411 debug_assert!(n == lhs.nrows());
2412 debug_assert!(n == lhs.ncols());
2413 debug_assert!(n == rhs.nrows());
2414 debug_assert!(n == rhs.ncols());
2415 debug_assert!(n == dst.nrows());
2416 debug_assert!(n == dst.ncols());
2417
2418 if n <= 16 {
2419 let op = {
2420 #[inline(never)]
2421 || {
2422 stack_mat_16x16_begin!(temp_lhs, n, n, lhs.row_stride(), lhs.col_stride(), E);
2423 stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E);
2424
2425 copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
2426 copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
2427
2428 mul(
2429 dst,
2430 temp_lhs.rb(),
2431 temp_rhs.rb(),
2432 alpha,
2433 beta,
2434 conj_lhs,
2435 conj_rhs,
2436 parallelism,
2437 );
2438 }
2439 };
2440 op();
2441 } else {
2442 let bs = n / 2;
2443
2444 let (mut dst_top_left, dst_top_right, dst_bot_left, dst_bot_right) =
2445 dst.split_at_mut(bs, bs);
2446 let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_at(bs, bs);
2447 let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
2448
2449 join_raw(
2457 |_| {
2458 mul(
2459 dst_top_left.rb_mut(),
2460 lhs_top_right,
2461 rhs_bot_left,
2462 alpha,
2463 beta,
2464 conj_lhs,
2465 conj_rhs,
2466 parallelism,
2467 );
2468 upper_x_lower_impl_unchecked(
2469 dst_top_left,
2470 lhs_top_left,
2471 lhs_diag,
2472 rhs_top_left,
2473 rhs_diag,
2474 Some(E::faer_one()),
2475 beta,
2476 conj_lhs,
2477 conj_rhs,
2478 parallelism,
2479 )
2480 },
2481 |_| {
2482 join_raw(
2483 |_| {
2484 mat_x_lower_impl_unchecked(
2485 dst_top_right,
2486 lhs_top_right,
2487 rhs_bot_right,
2488 rhs_diag,
2489 alpha,
2490 beta,
2491 conj_lhs,
2492 conj_rhs,
2493 parallelism,
2494 )
2495 },
2496 |_| {
2497 mat_x_lower_impl_unchecked(
2498 dst_bot_left.transpose_mut(),
2499 rhs_bot_left.transpose(),
2500 lhs_bot_right.transpose(),
2501 lhs_diag,
2502 alpha,
2503 beta,
2504 conj_rhs,
2505 conj_lhs,
2506 parallelism,
2507 )
2508 },
2509 parallelism,
2510 );
2511
2512 upper_x_lower_impl_unchecked(
2513 dst_bot_right,
2514 lhs_bot_right,
2515 lhs_diag,
2516 rhs_bot_right,
2517 rhs_diag,
2518 alpha,
2519 beta,
2520 conj_lhs,
2521 conj_rhs,
2522 parallelism,
2523 )
2524 },
2525 parallelism,
2526 );
2527 }
2528 }
2529
2530 unsafe fn upper_x_lower_into_lower_impl_unchecked<E: ComplexField>(
2531 dst: MatMut<'_, E>,
2532 skip_diag: bool,
2533 lhs: MatRef<'_, E>,
2534 lhs_diag: DiagonalKind,
2535 rhs: MatRef<'_, E>,
2536 rhs_diag: DiagonalKind,
2537 alpha: Option<E>,
2538 beta: E,
2539 conj_lhs: Conj,
2540 conj_rhs: Conj,
2541 parallelism: Parallelism,
2542 ) {
2543 let n = dst.nrows();
2544 debug_assert!(n == lhs.nrows());
2545 debug_assert!(n == lhs.ncols());
2546 debug_assert!(n == rhs.nrows());
2547 debug_assert!(n == rhs.ncols());
2548 debug_assert!(n == dst.nrows());
2549 debug_assert!(n == dst.ncols());
2550
2551 if n <= 16 {
2552 let op = {
2553 #[inline(never)]
2554 || {
2555 stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E);
2556 stack_mat_16x16_begin!(temp_lhs, n, n, lhs.row_stride(), lhs.col_stride(), E);
2557 stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E);
2558
2559 copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
2560 copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
2561
2562 mul(
2563 temp_dst.rb_mut(),
2564 temp_lhs.rb(),
2565 temp_rhs.rb(),
2566 None,
2567 beta,
2568 conj_lhs,
2569 conj_rhs,
2570 parallelism,
2571 );
2572
2573 accum_lower(dst, temp_dst.rb(), skip_diag, alpha);
2574 }
2575 };
2576 op();
2577 } else {
2578 let bs = n / 2;
2579
2580 let (mut dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs);
2581 let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_at(bs, bs);
2582 let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
2583
2584 join_raw(
2591 |_| {
2592 mat_x_mat_into_lower_impl_unchecked(
2593 dst_top_left.rb_mut(),
2594 skip_diag,
2595 lhs_top_right,
2596 rhs_bot_left,
2597 alpha,
2598 beta,
2599 conj_lhs,
2600 conj_rhs,
2601 parallelism,
2602 );
2603 upper_x_lower_into_lower_impl_unchecked(
2604 dst_top_left,
2605 skip_diag,
2606 lhs_top_left,
2607 lhs_diag,
2608 rhs_top_left,
2609 rhs_diag,
2610 Some(E::faer_one()),
2611 beta,
2612 conj_lhs,
2613 conj_rhs,
2614 parallelism,
2615 )
2616 },
2617 |_| {
2618 mat_x_lower_impl_unchecked(
2619 dst_bot_left.transpose_mut(),
2620 rhs_bot_left.transpose(),
2621 lhs_bot_right.transpose(),
2622 lhs_diag,
2623 alpha,
2624 beta,
2625 conj_rhs,
2626 conj_lhs,
2627 parallelism,
2628 );
2629 upper_x_lower_into_lower_impl_unchecked(
2630 dst_bot_right,
2631 skip_diag,
2632 lhs_bot_right,
2633 lhs_diag,
2634 rhs_bot_right,
2635 rhs_diag,
2636 alpha,
2637 beta,
2638 conj_lhs,
2639 conj_rhs,
2640 parallelism,
2641 )
2642 },
2643 parallelism,
2644 );
2645 }
2646 }
2647
2648 unsafe fn mat_x_mat_into_lower_impl_unchecked<E: ComplexField>(
2649 dst: MatMut<'_, E>,
2650 skip_diag: bool,
2651 lhs: MatRef<'_, E>,
2652 rhs: MatRef<'_, E>,
2653 alpha: Option<E>,
2654 beta: E,
2655 conj_lhs: Conj,
2656 conj_rhs: Conj,
2657 parallelism: Parallelism,
2658 ) {
2659 debug_assert!(dst.nrows() == dst.ncols());
2660 debug_assert!(dst.nrows() == lhs.nrows());
2661 debug_assert!(dst.ncols() == rhs.ncols());
2662 debug_assert!(lhs.ncols() == rhs.nrows());
2663
2664 let n = dst.nrows();
2665 let k = lhs.ncols();
2666
2667 let join_parallelism = if n * n * k < 128 * 128 * 128 {
2668 Parallelism::None
2669 } else {
2670 parallelism
2671 };
2672
2673 if n <= 16 {
2674 let op = {
2675 #[inline(never)]
2676 || {
2677 stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E);
2678
2679 mul(
2680 temp_dst.rb_mut(),
2681 lhs,
2682 rhs,
2683 None,
2684 beta,
2685 conj_lhs,
2686 conj_rhs,
2687 parallelism,
2688 );
2689 accum_lower(dst, temp_dst.rb(), skip_diag, alpha);
2690 }
2691 };
2692 op();
2693 } else {
2694 let bs = n / 2;
2695 let (dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs);
2696 let (lhs_top, lhs_bot) = lhs.split_at_row(bs);
2697 let (rhs_left, rhs_right) = rhs.split_at_col(bs);
2698
2699 join_raw(
2700 |_| {
2701 mul(
2702 dst_bot_left,
2703 lhs_bot,
2704 rhs_left,
2705 alpha,
2706 beta,
2707 conj_lhs,
2708 conj_rhs,
2709 parallelism,
2710 )
2711 },
2712 |_| {
2713 join_raw(
2714 |_| {
2715 mat_x_mat_into_lower_impl_unchecked(
2716 dst_top_left,
2717 skip_diag,
2718 lhs_top,
2719 rhs_left,
2720 alpha,
2721 beta,
2722 conj_lhs,
2723 conj_rhs,
2724 parallelism,
2725 )
2726 },
2727 |_| {
2728 mat_x_mat_into_lower_impl_unchecked(
2729 dst_bot_right,
2730 skip_diag,
2731 lhs_bot,
2732 rhs_right,
2733 alpha,
2734 beta,
2735 conj_lhs,
2736 conj_rhs,
2737 parallelism,
2738 )
2739 },
2740 join_parallelism,
2741 )
2742 },
2743 join_parallelism,
2744 );
2745 }
2746 }
2747
2748 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
2750 pub enum BlockStructure {
2751 Rectangular,
2753 TriangularLower,
2755 StrictTriangularLower,
2757 UnitTriangularLower,
2760 TriangularUpper,
2762 StrictTriangularUpper,
2764 UnitTriangularUpper,
2767 }
2768
2769 impl BlockStructure {
2770 #[inline]
2772 pub fn is_dense(self) -> bool {
2773 matches!(self, BlockStructure::Rectangular)
2774 }
2775
2776 #[inline]
2778 pub fn is_lower(self) -> bool {
2779 use BlockStructure::*;
2780 matches!(
2781 self,
2782 TriangularLower | StrictTriangularLower | UnitTriangularLower
2783 )
2784 }
2785
2786 #[inline]
2788 pub fn is_upper(self) -> bool {
2789 use BlockStructure::*;
2790 matches!(
2791 self,
2792 TriangularUpper | StrictTriangularUpper | UnitTriangularUpper
2793 )
2794 }
2795
2796 #[inline]
2798 pub fn transpose(self) -> Self {
2799 use BlockStructure::*;
2800 match self {
2801 Rectangular => Rectangular,
2802 TriangularLower => TriangularUpper,
2803 StrictTriangularLower => StrictTriangularUpper,
2804 UnitTriangularLower => UnitTriangularUpper,
2805 TriangularUpper => TriangularLower,
2806 StrictTriangularUpper => StrictTriangularLower,
2807 UnitTriangularUpper => UnitTriangularLower,
2808 }
2809 }
2810
2811 #[inline]
2812 pub(crate) fn diag_kind(self) -> DiagonalKind {
2813 use BlockStructure::*;
2814 match self {
2815 Rectangular | TriangularLower | TriangularUpper => DiagonalKind::Generic,
2816 StrictTriangularLower | StrictTriangularUpper => DiagonalKind::Zero,
2817 UnitTriangularLower | UnitTriangularUpper => DiagonalKind::Unit,
2818 }
2819 }
2820 }
2821
2822 #[track_caller]
2899 #[inline]
2900 pub fn matmul_with_conj<E: ComplexField>(
2901 acc: MatMut<'_, E>,
2902 acc_structure: BlockStructure,
2903 lhs: MatRef<'_, E>,
2904 lhs_structure: BlockStructure,
2905 conj_lhs: Conj,
2906 rhs: MatRef<'_, E>,
2907 rhs_structure: BlockStructure,
2908 conj_rhs: Conj,
2909 alpha: Option<E>,
2910 beta: E,
2911 parallelism: Parallelism,
2912 ) {
2913 assert!(all(
2914 acc.nrows() == lhs.nrows(),
2915 acc.ncols() == rhs.ncols(),
2916 lhs.ncols() == rhs.nrows(),
2917 ));
2918
2919 if !acc_structure.is_dense() {
2920 assert!(acc.nrows() == acc.ncols());
2921 }
2922 if !lhs_structure.is_dense() {
2923 assert!(lhs.nrows() == lhs.ncols());
2924 }
2925 if !rhs_structure.is_dense() {
2926 assert!(rhs.nrows() == rhs.ncols());
2927 }
2928
2929 unsafe {
2930 matmul_unchecked(
2931 acc,
2932 acc_structure,
2933 lhs,
2934 lhs_structure,
2935 conj_lhs,
2936 rhs,
2937 rhs_structure,
2938 conj_rhs,
2939 alpha,
2940 beta,
2941 parallelism,
2942 )
2943 }
2944 }
2945
2946 #[track_caller]
3015 #[inline]
3016 pub fn matmul<
3017 E: ComplexField,
3018 LhsE: Conjugate<Canonical = E>,
3019 RhsE: Conjugate<Canonical = E>,
3020 >(
3021 acc: MatMut<'_, E>,
3022 acc_structure: BlockStructure,
3023 lhs: MatRef<'_, LhsE>,
3024 lhs_structure: BlockStructure,
3025 rhs: MatRef<'_, RhsE>,
3026 rhs_structure: BlockStructure,
3027 alpha: Option<E>,
3028 beta: E,
3029 parallelism: Parallelism,
3030 ) {
3031 let (lhs, conj_lhs) = lhs.canonicalize();
3032 let (rhs, conj_rhs) = rhs.canonicalize();
3033 matmul_with_conj(
3034 acc,
3035 acc_structure,
3036 lhs,
3037 lhs_structure,
3038 conj_lhs,
3039 rhs,
3040 rhs_structure,
3041 conj_rhs,
3042 alpha,
3043 beta,
3044 parallelism,
3045 );
3046 }
3047
3048 unsafe fn matmul_unchecked<E: ComplexField>(
3049 acc: MatMut<'_, E>,
3050 acc_structure: BlockStructure,
3051 lhs: MatRef<'_, E>,
3052 lhs_structure: BlockStructure,
3053 conj_lhs: Conj,
3054 rhs: MatRef<'_, E>,
3055 rhs_structure: BlockStructure,
3056 conj_rhs: Conj,
3057 alpha: Option<E>,
3058 beta: E,
3059 parallelism: Parallelism,
3060 ) {
3061 debug_assert!(acc.nrows() == lhs.nrows());
3062 debug_assert!(acc.ncols() == rhs.ncols());
3063 debug_assert!(lhs.ncols() == rhs.nrows());
3064
3065 if !acc_structure.is_dense() {
3066 debug_assert!(acc.nrows() == acc.ncols());
3067 }
3068 if !lhs_structure.is_dense() {
3069 debug_assert!(lhs.nrows() == lhs.ncols());
3070 }
3071 if !rhs_structure.is_dense() {
3072 debug_assert!(rhs.nrows() == rhs.ncols());
3073 }
3074
3075 let mut acc = acc;
3076 let mut lhs = lhs;
3077 let mut rhs = rhs;
3078
3079 let mut acc_structure = acc_structure;
3080 let mut lhs_structure = lhs_structure;
3081 let mut rhs_structure = rhs_structure;
3082
3083 let mut conj_lhs = conj_lhs;
3084 let mut conj_rhs = conj_rhs;
3085
3086 if rhs_structure.is_lower() {
3088 false
3090 } else if rhs_structure.is_upper() {
3091 acc = acc.reverse_rows_and_cols_mut();
3093 lhs = lhs.reverse_rows_and_cols();
3094 rhs = rhs.reverse_rows_and_cols();
3095 acc_structure = acc_structure.transpose();
3096 lhs_structure = lhs_structure.transpose();
3097 rhs_structure = rhs_structure.transpose();
3098 false
3099 } else if lhs_structure.is_lower() {
3100 acc = acc.reverse_rows_and_cols_mut().transpose_mut();
3102 (lhs, rhs) = (
3103 rhs.reverse_rows_and_cols().transpose(),
3104 lhs.reverse_rows_and_cols().transpose(),
3105 );
3106 (conj_lhs, conj_rhs) = (conj_rhs, conj_lhs);
3107 (lhs_structure, rhs_structure) = (rhs_structure, lhs_structure);
3108 true
3109 } else if lhs_structure.is_upper() {
3110 acc_structure = acc_structure.transpose();
3112 acc = acc.transpose_mut();
3113 (lhs, rhs) = (rhs.transpose(), lhs.transpose());
3114 (conj_lhs, conj_rhs) = (conj_rhs, conj_lhs);
3115 (lhs_structure, rhs_structure) = (rhs_structure.transpose(), lhs_structure.transpose());
3116 true
3117 } else {
3118 false
3120 };
3121
3122 let clear_upper = |acc: MatMut<'_, E>, skip_diag: bool| match &alpha {
3123 &Some(alpha) => zipped!(acc).for_each_triangular_upper(
3124 if skip_diag { Diag::Skip } else { Diag::Include },
3125 |unzipped!(mut acc)| acc.write(alpha.faer_mul(acc.read())),
3126 ),
3127
3128 None => zipped!(acc).for_each_triangular_upper(
3129 if skip_diag { Diag::Skip } else { Diag::Include },
3130 |unzipped!(mut acc)| acc.write(E::faer_zero()),
3131 ),
3132 };
3133
3134 let skip_diag = matches!(
3135 acc_structure,
3136 BlockStructure::StrictTriangularLower
3137 | BlockStructure::StrictTriangularUpper
3138 | BlockStructure::UnitTriangularLower
3139 | BlockStructure::UnitTriangularUpper
3140 );
3141 let lhs_diag = lhs_structure.diag_kind();
3142 let rhs_diag = rhs_structure.diag_kind();
3143
3144 if acc_structure.is_dense() {
3145 if lhs_structure.is_dense() && rhs_structure.is_dense() {
3146 mul(acc, lhs, rhs, alpha, beta, conj_lhs, conj_rhs, parallelism);
3147 } else {
3148 debug_assert!(rhs_structure.is_lower());
3149
3150 if lhs_structure.is_dense() {
3151 mat_x_lower_impl_unchecked(
3152 acc,
3153 lhs,
3154 rhs,
3155 rhs_diag,
3156 alpha,
3157 beta,
3158 conj_lhs,
3159 conj_rhs,
3160 parallelism,
3161 )
3162 } else if lhs_structure.is_lower() {
3163 clear_upper(acc.rb_mut(), true);
3164 lower_x_lower_into_lower_impl_unchecked(
3165 acc,
3166 false,
3167 lhs,
3168 lhs_diag,
3169 rhs,
3170 rhs_diag,
3171 alpha,
3172 beta,
3173 conj_lhs,
3174 conj_rhs,
3175 parallelism,
3176 );
3177 } else {
3178 debug_assert!(lhs_structure.is_upper());
3179 upper_x_lower_impl_unchecked(
3180 acc,
3181 lhs,
3182 lhs_diag,
3183 rhs,
3184 rhs_diag,
3185 alpha,
3186 beta,
3187 conj_lhs,
3188 conj_rhs,
3189 parallelism,
3190 )
3191 }
3192 }
3193 } else if acc_structure.is_lower() {
3194 if lhs_structure.is_dense() && rhs_structure.is_dense() {
3195 mat_x_mat_into_lower_impl_unchecked(
3196 acc,
3197 skip_diag,
3198 lhs,
3199 rhs,
3200 alpha,
3201 beta,
3202 conj_lhs,
3203 conj_rhs,
3204 parallelism,
3205 )
3206 } else {
3207 debug_assert!(rhs_structure.is_lower());
3208 if lhs_structure.is_dense() {
3209 mat_x_lower_into_lower_impl_unchecked(
3210 acc,
3211 skip_diag,
3212 lhs,
3213 rhs,
3214 rhs_diag,
3215 alpha,
3216 beta,
3217 conj_lhs,
3218 conj_rhs,
3219 parallelism,
3220 );
3221 } else if lhs_structure.is_lower() {
3222 lower_x_lower_into_lower_impl_unchecked(
3223 acc,
3224 skip_diag,
3225 lhs,
3226 lhs_diag,
3227 rhs,
3228 rhs_diag,
3229 alpha,
3230 beta,
3231 conj_lhs,
3232 conj_rhs,
3233 parallelism,
3234 )
3235 } else {
3236 upper_x_lower_into_lower_impl_unchecked(
3237 acc,
3238 skip_diag,
3239 lhs,
3240 lhs_diag,
3241 rhs,
3242 rhs_diag,
3243 alpha,
3244 beta,
3245 conj_lhs,
3246 conj_rhs,
3247 parallelism,
3248 )
3249 }
3250 }
3251 } else if lhs_structure.is_dense() && rhs_structure.is_dense() {
3252 mat_x_mat_into_lower_impl_unchecked(
3253 acc.transpose_mut(),
3254 skip_diag,
3255 rhs.transpose(),
3256 lhs.transpose(),
3257 alpha,
3258 beta,
3259 conj_rhs,
3260 conj_lhs,
3261 parallelism,
3262 )
3263 } else {
3264 debug_assert!(rhs_structure.is_lower());
3265 if lhs_structure.is_dense() {
3266 upper_x_lower_into_lower_impl_unchecked(
3268 acc.transpose_mut(),
3269 skip_diag,
3270 rhs.transpose(),
3271 rhs_diag,
3272 lhs.transpose(),
3273 lhs_diag,
3274 alpha,
3275 beta,
3276 conj_rhs,
3277 conj_lhs,
3278 parallelism,
3279 )
3280 } else if lhs_structure.is_lower() {
3281 if !skip_diag {
3282 match &alpha {
3283 &Some(alpha) => {
3284 zipped!(
3285 acc.rb_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
3286 lhs.diagonal().column_vector().as_2d(),
3287 rhs.diagonal().column_vector().as_2d(),
3288 )
3289 .for_each(
3290 |unzipped!(mut acc, lhs, rhs)| {
3291 acc.write(
3292 (alpha.faer_mul(acc.read())).faer_add(
3293 beta.faer_mul(lhs.read().faer_mul(rhs.read())),
3294 ),
3295 )
3296 },
3297 );
3298 }
3299 None => {
3300 zipped!(
3301 acc.rb_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
3302 lhs.diagonal().column_vector().as_2d(),
3303 rhs.diagonal().column_vector().as_2d(),
3304 )
3305 .for_each(
3306 |unzipped!(mut acc, lhs, rhs)| {
3307 acc.write(beta.faer_mul(lhs.read().faer_mul(rhs.read())))
3308 },
3309 );
3310 }
3311 }
3312 }
3313 clear_upper(acc.rb_mut(), true);
3314 } else {
3315 debug_assert!(lhs_structure.is_upper());
3316 upper_x_lower_into_lower_impl_unchecked(
3317 acc.transpose_mut(),
3318 skip_diag,
3319 rhs.transpose(),
3320 rhs_diag,
3321 lhs.transpose(),
3322 lhs_diag,
3323 alpha,
3324 beta,
3325 conj_rhs,
3326 conj_lhs,
3327 parallelism,
3328 )
3329 }
3330 }
3331 }
3332}
3333
3334#[cfg(test)]
3335mod tests {
3336 use super::{
3337 triangular::{BlockStructure, DiagonalKind},
3338 *,
3339 };
3340 use crate::{assert, Mat};
3341 use assert_approx_eq::assert_approx_eq;
3342 use num_complex::Complex32;
3343
3344 #[test]
3345 fn test_stack_mat() {
3346 stack_mat_16x16_begin!(m, 3, 3, 1, 3, f64);
3347 {
3348 let _ = &mut m;
3349 dbg!(&m);
3350 }
3351 }
3352
3353 #[test]
3354 #[ignore = "takes too long in CI"]
3355 fn test_matmul() {
3356 let random = |_, _| c32 {
3357 re: rand::random(),
3358 im: rand::random(),
3359 };
3360
3361 let alphas = [
3362 None,
3363 Some(c32::faer_one()),
3364 Some(c32::faer_zero()),
3365 Some(random(0, 0)),
3366 ];
3367
3368 #[cfg(not(miri))]
3369 let bools = [false, true];
3370 #[cfg(not(miri))]
3371 let betas = [c32::faer_one(), c32::faer_zero(), random(0, 0)];
3372 #[cfg(not(miri))]
3373 let par = [Parallelism::None, Parallelism::Rayon(0)];
3374 #[cfg(not(miri))]
3375 let conjs = [Conj::Yes, Conj::No];
3376
3377 #[cfg(miri)]
3378 let bools = [true];
3379 #[cfg(miri)]
3380 let betas = [random(0, 0)];
3381 #[cfg(miri)]
3382 let par = [Parallelism::None];
3383 #[cfg(miri)]
3384 let conjs = [Conj::Yes];
3385
3386 let big0 = 127;
3387 let big1 = 128;
3388 let big2 = 129;
3389
3390 let mid0 = 15;
3391 let mid1 = 16;
3392 let mid2 = 17;
3393 for (m, n, k) in [
3394 (mid0, mid0, KC + 1),
3395 (big0, big1, 5),
3396 (big1, big0, 5),
3397 (big0, big2, 5),
3398 (big2, big0, 5),
3399 (mid0, mid0, 5),
3400 (mid1, mid1, 5),
3401 (mid2, mid2, 5),
3402 (mid0, mid1, 5),
3403 (mid1, mid0, 5),
3404 (mid0, mid2, 5),
3405 (mid2, mid0, 5),
3406 (mid0, 1, 1),
3407 (1, mid0, 1),
3408 (1, 1, mid0),
3409 (1, mid0, mid0),
3410 (mid0, 1, mid0),
3411 (mid0, mid0, 1),
3412 (1, 1, 1),
3413 ] {
3414 let a = Mat::from_fn(m, k, random);
3415 let b = Mat::from_fn(k, n, random);
3416 let acc_init = Mat::from_fn(m, n, random);
3417
3418 for reverse_acc_cols in bools {
3419 for reverse_acc_rows in bools {
3420 for reverse_b_cols in bools {
3421 for reverse_b_rows in bools {
3422 for reverse_a_cols in bools {
3423 for reverse_a_rows in bools {
3424 for a_colmajor in bools {
3425 for b_colmajor in bools {
3426 for acc_colmajor in bools {
3427 let a = if a_colmajor {
3428 a.to_owned()
3429 } else {
3430 a.transpose().to_owned()
3431 };
3432 let mut a = if a_colmajor {
3433 a.as_ref()
3434 } else {
3435 a.as_ref().transpose()
3436 };
3437
3438 let b = if b_colmajor {
3439 b.to_owned()
3440 } else {
3441 b.transpose().to_owned()
3442 };
3443 let mut b = if b_colmajor {
3444 b.as_ref()
3445 } else {
3446 b.as_ref().transpose()
3447 };
3448
3449 if reverse_a_rows {
3450 a = a.reverse_rows();
3451 }
3452 if reverse_a_cols {
3453 a = a.reverse_cols();
3454 }
3455 if reverse_b_rows {
3456 b = b.reverse_rows();
3457 }
3458 if reverse_b_cols {
3459 b = b.reverse_cols();
3460 }
3461 for conj_a in conjs {
3462 for conj_b in conjs {
3463 for parallelism in par {
3464 for alpha in alphas {
3465 for beta in betas {
3466 for use_gemm in [true, false] {
3467 test_matmul_impl(
3468 reverse_acc_cols,
3469 reverse_acc_rows,
3470 acc_colmajor,
3471 m,
3472 n,
3473 conj_a,
3474 conj_b,
3475 parallelism,
3476 alpha,
3477 beta,
3478 use_gemm,
3479 &acc_init,
3480 a,
3481 b,
3482 );
3483 }
3484 }
3485 }
3486 }
3487 }
3488 }
3489 }
3490 }
3491 }
3492 }
3493 }
3494 }
3495 }
3496 }
3497 }
3498 }
3499 }
3500
3501 fn matmul_with_conj_fallback<E: ComplexField>(
3502 acc: MatMut<'_, E>,
3503 a: MatRef<'_, E>,
3504 conj_a: Conj,
3505 b: MatRef<'_, E>,
3506 conj_b: Conj,
3507 alpha: Option<E>,
3508 beta: E,
3509 parallelism: Parallelism,
3510 ) {
3511 let m = acc.nrows();
3512 let n = acc.ncols();
3513 let k = a.ncols();
3514
3515 let job = |idx: usize| {
3516 let i = idx % m;
3517 let j = idx / m;
3518 let acc = acc.rb().submatrix(i, j, 1, 1);
3519 let mut acc = unsafe { acc.const_cast() };
3520
3521 let mut local_acc = E::faer_zero();
3522 for depth in 0..k {
3523 let a = a.read(i, depth);
3524 let b = b.read(depth, j);
3525 local_acc = local_acc.faer_add(E::faer_mul(
3526 match conj_a {
3527 Conj::Yes => a.faer_conj(),
3528 Conj::No => a,
3529 },
3530 match conj_b {
3531 Conj::Yes => b.faer_conj(),
3532 Conj::No => b,
3533 },
3534 ))
3535 }
3536 match alpha {
3537 Some(alpha) => acc.write(
3538 0,
3539 0,
3540 E::faer_add(acc.read(0, 0).faer_mul(alpha), local_acc.faer_mul(beta)),
3541 ),
3542 None => acc.write(0, 0, local_acc.faer_mul(beta)),
3543 }
3544 };
3545
3546 crate::for_each_raw(m * n, job, parallelism);
3547 }
3548
3549 fn test_matmul_impl(
3550 reverse_acc_cols: bool,
3551 reverse_acc_rows: bool,
3552 acc_colmajor: bool,
3553 m: usize,
3554 n: usize,
3555 conj_a: Conj,
3556 conj_b: Conj,
3557 parallelism: Parallelism,
3558 alpha: Option<c32>,
3559 beta: c32,
3560 use_gemm: bool,
3561 acc_init: &Mat<c32>,
3562 a: MatRef<c32>,
3563 b: MatRef<c32>,
3564 ) {
3565 let mut acc = if acc_colmajor {
3566 acc_init.to_owned()
3567 } else {
3568 acc_init.transpose().to_owned()
3569 };
3570
3571 let mut acc = if acc_colmajor {
3572 acc.as_mut()
3573 } else {
3574 acc.as_mut().transpose_mut()
3575 };
3576 if reverse_acc_rows {
3577 acc = acc.reverse_rows_mut();
3578 }
3579 if reverse_acc_cols {
3580 acc = acc.reverse_cols_mut();
3581 }
3582 let mut target = acc.to_owned();
3583
3584 matmul_with_conj_gemm_dispatch(
3585 acc.rb_mut(),
3586 a,
3587 conj_a,
3588 b,
3589 conj_b,
3590 alpha,
3591 beta,
3592 parallelism,
3593 use_gemm,
3594 );
3595 matmul_with_conj_fallback(
3596 target.as_mut(),
3597 a,
3598 conj_a,
3599 b,
3600 conj_b,
3601 alpha,
3602 beta,
3603 parallelism,
3604 );
3605
3606 for j in 0..n {
3607 for i in 0..m {
3608 let acc: Complex32 = acc.read(i, j).into();
3609 let target: Complex32 = target.read(i, j).into();
3610 assert_approx_eq!(acc.re, target.re, 1e-3);
3611 assert_approx_eq!(acc.im, target.im, 1e-3);
3612 }
3613 }
3614 }
3615
3616 fn generate_structured_matrix(
3617 is_dst: bool,
3618 nrows: usize,
3619 ncols: usize,
3620 structure: BlockStructure,
3621 ) -> Mat<f64> {
3622 let mut mat = Mat::new();
3623 mat.resize_with(nrows, ncols, |_, _| rand::random());
3624
3625 if !is_dst {
3626 let kind = structure.diag_kind();
3627 if structure.is_lower() {
3628 for j in 0..ncols {
3629 for i in 0..j {
3630 mat.write(i, j, 0.0);
3631 }
3632 }
3633 } else if structure.is_upper() {
3634 for j in 0..ncols {
3635 for i in j + 1..nrows {
3636 mat.write(i, j, 0.0);
3637 }
3638 }
3639 }
3640
3641 match kind {
3642 triangular::DiagonalKind::Zero => {
3643 for i in 0..nrows {
3644 mat.write(i, i, 0.0);
3645 }
3646 }
3647 triangular::DiagonalKind::Unit => {
3648 for i in 0..nrows {
3649 mat.write(i, i, 1.0);
3650 }
3651 }
3652 triangular::DiagonalKind::Generic => (),
3653 }
3654 }
3655 mat
3656 }
3657
3658 fn run_test_problem(
3659 m: usize,
3660 n: usize,
3661 k: usize,
3662 dst_structure: BlockStructure,
3663 lhs_structure: BlockStructure,
3664 rhs_structure: BlockStructure,
3665 ) {
3666 let mut dst = generate_structured_matrix(true, m, n, dst_structure);
3667 let mut dst_target = dst.to_owned();
3668 let dst_orig = dst.to_owned();
3669 let lhs = generate_structured_matrix(false, m, k, lhs_structure);
3670 let rhs = generate_structured_matrix(false, k, n, rhs_structure);
3671
3672 for parallelism in [Parallelism::None, Parallelism::Rayon(8)] {
3673 triangular::matmul_with_conj(
3674 dst.as_mut(),
3675 dst_structure,
3676 lhs.as_ref(),
3677 lhs_structure,
3678 Conj::No,
3679 rhs.as_ref(),
3680 rhs_structure,
3681 Conj::No,
3682 None,
3683 2.5,
3684 parallelism,
3685 );
3686
3687 matmul_with_conj(
3688 dst_target.as_mut(),
3689 lhs.as_ref(),
3690 Conj::No,
3691 rhs.as_ref(),
3692 Conj::No,
3693 None,
3694 2.5,
3695 parallelism,
3696 );
3697
3698 if dst_structure.is_dense() {
3699 for j in 0..n {
3700 for i in 0..m {
3701 assert_approx_eq!(dst.read(i, j), dst_target.read(i, j));
3702 }
3703 }
3704 } else if dst_structure.is_lower() {
3705 for j in 0..n {
3706 if matches!(dst_structure.diag_kind(), DiagonalKind::Generic) {
3707 for i in 0..j {
3708 assert_eq!(dst.read(i, j), dst_orig.read(i, j));
3709 }
3710 for i in j..n {
3711 assert_approx_eq!(dst.read(i, j), dst_target.read(i, j));
3712 }
3713 } else {
3714 for i in 0..=j {
3715 assert_eq!(dst.read(i, j), dst_orig.read(i, j));
3716 }
3717 for i in j + 1..n {
3718 assert_approx_eq!(dst.read(i, j), dst_target.read(i, j));
3719 }
3720 }
3721 }
3722 } else {
3723 for j in 0..n {
3724 if matches!(dst_structure.diag_kind(), DiagonalKind::Generic) {
3725 for i in 0..=j {
3726 assert_approx_eq!(dst.read(i, j), dst_target.read(i, j));
3727 }
3728 for i in j + 1..n {
3729 assert_eq!(dst.read(i, j), dst_orig.read(i, j));
3730 }
3731 } else {
3732 for i in 0..j {
3733 assert_approx_eq!(dst.read(i, j), dst_target.read(i, j));
3734 }
3735 for i in j..n {
3736 assert_eq!(dst.read(i, j), dst_orig.read(i, j));
3737 }
3738 }
3739 }
3740 }
3741 }
3742 }
3743
3744 #[test]
3745 fn test_triangular() {
3746 use BlockStructure::*;
3747 let structures = [
3748 Rectangular,
3749 TriangularLower,
3750 TriangularUpper,
3751 StrictTriangularLower,
3752 StrictTriangularUpper,
3753 UnitTriangularLower,
3754 UnitTriangularUpper,
3755 ];
3756
3757 for dst in structures {
3758 for lhs in structures {
3759 for rhs in structures {
3760 #[cfg(not(miri))]
3761 let big = 100;
3762
3763 #[cfg(miri)]
3764 let big = 31;
3765 for _ in 0..3 {
3766 let m = rand::random::<usize>() % big;
3767 let mut n = rand::random::<usize>() % big;
3768 let mut k = rand::random::<usize>() % big;
3769
3770 #[cfg(miri)]
3772 dbg!(m, n, k);
3773
3774 match (!dst.is_dense(), !lhs.is_dense(), !rhs.is_dense()) {
3775 (true, true, _) | (true, _, true) | (_, true, true) => {
3776 n = m;
3777 k = m;
3778 }
3779 _ => (),
3780 }
3781
3782 if !dst.is_dense() {
3783 n = m;
3784 }
3785
3786 if !lhs.is_dense() {
3787 k = m;
3788 }
3789
3790 if !rhs.is_dense() {
3791 k = n;
3792 }
3793
3794 run_test_problem(m, n, k, dst, lhs, rhs);
3795 }
3796 }
3797 }
3798 }
3799 }
3800}